diff --git a/go.mod b/go.mod
index b57dcdf..90b3a99 100644
--- a/go.mod
+++ b/go.mod
@@ -9,6 +9,7 @@ require (
github.com/feiin/sqlstring v0.3.0
github.com/gliderlabs/ssh v0.3.5
github.com/go-sql-driver/mysql v1.6.0
+ github.com/gorilla/csrf v1.7.1
github.com/gorilla/mux v1.8.0
github.com/gorilla/sessions v1.2.1
github.com/gorilla/websocket v1.5.0
diff --git a/go.sum b/go.sum
index d23e904..777db6e 100644
--- a/go.sum
+++ b/go.sum
@@ -14,6 +14,8 @@ github.com/gliderlabs/ssh v0.3.5/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4x
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
+github.com/gorilla/csrf v1.7.1 h1:Ir3o2c1/Uzj6FBxMlAUB6SivgVMy1ONXwYgXn+/aHPE=
+github.com/gorilla/csrf v1.7.1/go.mod h1:+a/4tCmqhG6/w4oafeAZ9pEa3/NZOWYVbD9fV0FwIQA=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
diff --git a/internal/config/config.go b/internal/config/config.go
index e01a023..595d548 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -3,6 +3,8 @@ package config
import (
"fmt"
+ "hash/fnv"
+ "math/rand"
"net/url"
"reflect"
"strings"
@@ -100,6 +102,21 @@ type Config struct {
ConfigPath string
}
+// CSRFSecret return the csrfSecret derived from the session secret
+func (config *Config) CSRFSecret() []byte {
+ // take the hash of the secret
+ h := fnv.New32a()
+ h.Write([]byte(config.SessionSecret))
+
+ // seed a random number generator
+ rand := rand.New(rand.NewSource(int64(h.Sum32())))
+
+ // take a bunch of bytes from it
+ secret := make([]byte, 32)
+ rand.Read(secret)
+ return secret
+}
+
// String serializes this configuration into a string
func (config Config) String() string {
values := &strings.Builder{}
diff --git a/internal/dis/component/auth/auth.go b/internal/dis/component/auth/auth.go
index af7e0cb..669c2e2 100644
--- a/internal/dis/component/auth/auth.go
+++ b/internal/dis/component/auth/auth.go
@@ -1,10 +1,11 @@
package auth
import (
- "sync"
+ "net/http"
"github.com/FAU-CDI/wisski-distillery/internal/dis/component"
"github.com/FAU-CDI/wisski-distillery/internal/dis/component/sql"
+ "github.com/FAU-CDI/wisski-distillery/pkg/lazy"
"github.com/gorilla/sessions"
)
@@ -14,8 +15,8 @@ type Auth struct {
SQL *sql.SQL
}
- storeOnce sync.Once
- store sessions.Store
+ store lazy.Lazy[sessions.Store]
+ csrf lazy.Lazy[func(http.Handler) http.Handler]
}
var (
diff --git a/internal/dis/component/auth/templates/login.html b/internal/dis/component/auth/templates/login.html
index a93260f..83945d0 100644
--- a/internal/dis/component/auth/templates/login.html
+++ b/internal/dis/component/auth/templates/login.html
@@ -23,6 +23,7 @@
+ {{ .CSRF }}
diff --git a/internal/dis/component/auth/web.go b/internal/dis/component/auth/web.go
index 63dcf4e..fbb961b 100644
--- a/internal/dis/component/auth/web.go
+++ b/internal/dis/component/auth/web.go
@@ -2,11 +2,13 @@ package auth
import (
"context"
+ "html/template"
"net/http"
"net/url"
"github.com/FAU-CDI/wisski-distillery/internal/dis/component/control/static"
"github.com/FAU-CDI/wisski-distillery/pkg/httpx"
+ "github.com/gorilla/csrf"
"github.com/gorilla/sessions"
"github.com/julienschmidt/httprouter"
@@ -28,10 +30,9 @@ const (
// session returns the session belonging to a request
func (auth *Auth) session(r *http.Request) (*sessions.Session, error) {
- auth.storeOnce.Do(func() {
- auth.store = sessions.NewCookieStore([]byte(auth.Config.SessionSecret))
- })
- return auth.store.Get(r, sessionCookieName)
+ return auth.store.Get(func() sessions.Store {
+ return sessions.NewCookieStore([]byte(auth.Config.SessionSecret))
+ }).Get(r, sessionCookieName)
}
// UserOf returns the user logged into the given request.
@@ -112,6 +113,15 @@ var loginResponse = httpx.Response{
func (auth *Auth) HandleRoute(ctx context.Context, route string) (http.Handler, error) {
router := httprouter.New()
+ csrf := auth.csrf.Get(func() func(http.Handler) http.Handler {
+ var opts []csrf.Option
+ if !auth.Config.HTTPSEnabled() {
+ opts = append(opts, csrf.Secure(false))
+ }
+ opts = append(opts, csrf.Path(route))
+ return csrf.Protect(auth.Config.CSRFSecret(), opts...)
+ })
+
router.Handler(http.MethodGet, route, auth.Protect(loginResponse, nil))
router.HandlerFunc(http.MethodGet, route+"login", auth.loginRoute)
@@ -119,11 +129,12 @@ func (auth *Auth) HandleRoute(ctx context.Context, route string) (http.Handler,
router.HandlerFunc(http.MethodGet, route+"logout", auth.logoutRoute)
- return router, nil
+ return csrf(router), nil
}
type loginContext struct {
Message string
+ CSRF template.HTML
}
// Protect returns a new handler which requires a user to be logged in and pass the perm function.
@@ -231,6 +242,7 @@ func (auth *Auth) loginRoute(w http.ResponseWriter, r *http.Request) {
form:
httpx.WriteHTML(loginContext{
Message: message,
+ CSRF: csrf.TemplateField(r),
}, nil, loginTemplate, "", w, r)
return
success: