diff --git a/go.mod b/go.mod index b57dcdf..90b3a99 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/feiin/sqlstring v0.3.0 github.com/gliderlabs/ssh v0.3.5 github.com/go-sql-driver/mysql v1.6.0 + github.com/gorilla/csrf v1.7.1 github.com/gorilla/mux v1.8.0 github.com/gorilla/sessions v1.2.1 github.com/gorilla/websocket v1.5.0 diff --git a/go.sum b/go.sum index d23e904..777db6e 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/gliderlabs/ssh v0.3.5/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4x github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gorilla/csrf v1.7.1 h1:Ir3o2c1/Uzj6FBxMlAUB6SivgVMy1ONXwYgXn+/aHPE= +github.com/gorilla/csrf v1.7.1/go.mod h1:+a/4tCmqhG6/w4oafeAZ9pEa3/NZOWYVbD9fV0FwIQA= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= diff --git a/internal/config/config.go b/internal/config/config.go index e01a023..595d548 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,8 @@ package config import ( "fmt" + "hash/fnv" + "math/rand" "net/url" "reflect" "strings" @@ -100,6 +102,21 @@ type Config struct { ConfigPath string } +// CSRFSecret return the csrfSecret derived from the session secret +func (config *Config) CSRFSecret() []byte { + // take the hash of the secret + h := fnv.New32a() + h.Write([]byte(config.SessionSecret)) + + // seed a random number generator + rand := rand.New(rand.NewSource(int64(h.Sum32()))) + + // take a bunch of bytes from it + secret := make([]byte, 32) + rand.Read(secret) + return secret +} + // String serializes this configuration into a string func (config Config) String() string { values := &strings.Builder{} diff --git a/internal/dis/component/auth/auth.go b/internal/dis/component/auth/auth.go index af7e0cb..669c2e2 100644 --- a/internal/dis/component/auth/auth.go +++ b/internal/dis/component/auth/auth.go @@ -1,10 +1,11 @@ package auth import ( - "sync" + "net/http" "github.com/FAU-CDI/wisski-distillery/internal/dis/component" "github.com/FAU-CDI/wisski-distillery/internal/dis/component/sql" + "github.com/FAU-CDI/wisski-distillery/pkg/lazy" "github.com/gorilla/sessions" ) @@ -14,8 +15,8 @@ type Auth struct { SQL *sql.SQL } - storeOnce sync.Once - store sessions.Store + store lazy.Lazy[sessions.Store] + csrf lazy.Lazy[func(http.Handler) http.Handler] } var ( diff --git a/internal/dis/component/auth/templates/login.html b/internal/dis/component/auth/templates/login.html index a93260f..83945d0 100644 --- a/internal/dis/component/auth/templates/login.html +++ b/internal/dis/component/auth/templates/login.html @@ -23,6 +23,7 @@ + {{ .CSRF }} diff --git a/internal/dis/component/auth/web.go b/internal/dis/component/auth/web.go index 63dcf4e..fbb961b 100644 --- a/internal/dis/component/auth/web.go +++ b/internal/dis/component/auth/web.go @@ -2,11 +2,13 @@ package auth import ( "context" + "html/template" "net/http" "net/url" "github.com/FAU-CDI/wisski-distillery/internal/dis/component/control/static" "github.com/FAU-CDI/wisski-distillery/pkg/httpx" + "github.com/gorilla/csrf" "github.com/gorilla/sessions" "github.com/julienschmidt/httprouter" @@ -28,10 +30,9 @@ const ( // session returns the session belonging to a request func (auth *Auth) session(r *http.Request) (*sessions.Session, error) { - auth.storeOnce.Do(func() { - auth.store = sessions.NewCookieStore([]byte(auth.Config.SessionSecret)) - }) - return auth.store.Get(r, sessionCookieName) + return auth.store.Get(func() sessions.Store { + return sessions.NewCookieStore([]byte(auth.Config.SessionSecret)) + }).Get(r, sessionCookieName) } // UserOf returns the user logged into the given request. @@ -112,6 +113,15 @@ var loginResponse = httpx.Response{ func (auth *Auth) HandleRoute(ctx context.Context, route string) (http.Handler, error) { router := httprouter.New() + csrf := auth.csrf.Get(func() func(http.Handler) http.Handler { + var opts []csrf.Option + if !auth.Config.HTTPSEnabled() { + opts = append(opts, csrf.Secure(false)) + } + opts = append(opts, csrf.Path(route)) + return csrf.Protect(auth.Config.CSRFSecret(), opts...) + }) + router.Handler(http.MethodGet, route, auth.Protect(loginResponse, nil)) router.HandlerFunc(http.MethodGet, route+"login", auth.loginRoute) @@ -119,11 +129,12 @@ func (auth *Auth) HandleRoute(ctx context.Context, route string) (http.Handler, router.HandlerFunc(http.MethodGet, route+"logout", auth.logoutRoute) - return router, nil + return csrf(router), nil } type loginContext struct { Message string + CSRF template.HTML } // Protect returns a new handler which requires a user to be logged in and pass the perm function. @@ -231,6 +242,7 @@ func (auth *Auth) loginRoute(w http.ResponseWriter, r *http.Request) { form: httpx.WriteHTML(loginContext{ Message: message, + CSRF: csrf.TemplateField(r), }, nil, loginTemplate, "", w, r) return success: