diff --git a/internal/dis/component/auth/auth.go b/internal/dis/component/auth/auth.go index d77264a..236bc64 100644 --- a/internal/dis/component/auth/auth.go +++ b/internal/dis/component/auth/auth.go @@ -7,7 +7,6 @@ import ( "github.com/FAU-CDI/wisski-distillery/internal/dis/component" "github.com/FAU-CDI/wisski-distillery/internal/dis/component/sql" "github.com/FAU-CDI/wisski-distillery/pkg/lazy" - "github.com/gorilla/csrf" "github.com/gorilla/sessions" "github.com/julienschmidt/httprouter" ) @@ -41,15 +40,3 @@ func (auth *Auth) HandleRoute(ctx context.Context, route string) (http.Handler, return router, nil } - -func (auth *Auth) CSRF() func(http.Handler) http.Handler { - // setup the csrf handler (if needed) - // TOOD: This should move to the server handler - return auth.csrf.Get(func() func(http.Handler) http.Handler { - var opts []csrf.Option - if !auth.Config.HTTPSEnabled() { - opts = append(opts, csrf.Secure(false)) - } - return csrf.Protect(auth.Config.CSRFSecret(), opts...) - }) -} diff --git a/internal/dis/component/auth/panel/password.go b/internal/dis/component/auth/panel/password.go index ced8267..37fc37a 100644 --- a/internal/dis/component/auth/panel/password.go +++ b/internal/dis/component/auth/panel/password.go @@ -33,8 +33,7 @@ func (panel *UserPanel) routePassword(ctx context.Context) http.Handler { {Name: "new2", Type: httpx.PasswordField, EmptyOnError: true, Label: "New Password (again)"}, }, FieldTemplate: httpx.PureCSSFieldTemplate, - - CSRF: panel.Dependencies.Auth.CSRF(), + CSRF: true, RenderTemplate: passwordTemplate, RenderTemplateContext: panel.UserFormContext, diff --git a/internal/dis/component/auth/panel/totp.go b/internal/dis/component/auth/panel/totp.go index e0e22cd..e114bab 100644 --- a/internal/dis/component/auth/panel/totp.go +++ b/internal/dis/component/auth/panel/totp.go @@ -22,8 +22,7 @@ func (panel *UserPanel) routeTOTPEnable(ctx context.Context) http.Handler { {Name: "password", Type: httpx.PasswordField, EmptyOnError: true, Label: "Current Password"}, }, FieldTemplate: httpx.PureCSSFieldTemplate, - - CSRF: panel.Dependencies.Auth.CSRF(), + CSRF: true, SkipForm: func(r *http.Request) (data struct{}, skip bool) { user, err := panel.Dependencies.Auth.UserOf(r) @@ -81,8 +80,7 @@ func (panel *UserPanel) routeTOTPEnroll(ctx context.Context) http.Handler { {Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode"}, }, FieldTemplate: httpx.PureCSSFieldTemplate, - - CSRF: panel.Dependencies.Auth.CSRF(), + CSRF: true, SkipForm: func(r *http.Request) (data struct{}, skip bool) { user, err := panel.Dependencies.Auth.UserOf(r) @@ -152,8 +150,7 @@ func (panel *UserPanel) routeTOTPDisable(ctx context.Context) http.Handler { {Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Current Passcode"}, }, FieldTemplate: httpx.PureCSSFieldTemplate, - - CSRF: panel.Dependencies.Auth.CSRF(), + CSRF: true, SkipForm: func(r *http.Request) (data struct{}, skip bool) { user, err := panel.Dependencies.Auth.UserOf(r) diff --git a/internal/dis/component/auth/session.go b/internal/dis/component/auth/session.go index 2339bcf..ba5bf94 100644 --- a/internal/dis/component/auth/session.go +++ b/internal/dis/component/auth/session.go @@ -121,8 +121,7 @@ func (auth *Auth) authLogin(ctx context.Context) http.Handler { {Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode (optional)"}, }, FieldTemplate: httpx.PureCSSFieldTemplate, - - CSRF: auth.CSRF(), + CSRF: true, RenderForm: func(context httpx.FormContext, w http.ResponseWriter, r *http.Request) { if context.Err != nil { diff --git a/internal/dis/component/control/admin/admin.go b/internal/dis/component/control/admin/admin.go index 35c8bb2..0ed0d36 100644 --- a/internal/dis/component/control/admin/admin.go +++ b/internal/dis/component/control/admin/admin.go @@ -49,7 +49,6 @@ func (admin *Admin) HandleRoute(ctx context.Context, route string) (handler http Handler: admin.serveSocket, } handler = admin.Dependencies.Auth.Protect(socket, auth.Admin) - handler = admin.Dependencies.Auth.CSRF()(handler) } // handle everything diff --git a/internal/dis/component/control/admin/users.go b/internal/dis/component/control/admin/users.go index 68ce422..bfd15a8 100644 --- a/internal/dis/component/control/admin/users.go +++ b/internal/dis/component/control/admin/users.go @@ -64,7 +64,7 @@ func (admin *Admin) createUser(ctx context.Context) http.Handler { }, FieldTemplate: httpx.PureCSSFieldTemplate, - CSRF: admin.Dependencies.Auth.CSRF(), + CSRF: true, RenderTemplate: userCreateTemplate, diff --git a/internal/dis/component/control/server.go b/internal/dis/component/control/server.go index bc38038..d03872e 100644 --- a/internal/dis/component/control/server.go +++ b/internal/dis/component/control/server.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/FAU-CDI/wisski-distillery/pkg/cancel" + "github.com/gorilla/csrf" "github.com/rs/zerolog" ) @@ -29,7 +30,15 @@ func (control *Control) Server(ctx context.Context, progress io.Writer) (http.Ha } } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return func(handler http.HandlerFunc) http.Handler { + // setup a csrf protector for everything with POST + var opts []csrf.Option + if !control.Config.HTTPSEnabled() { + opts = append(opts, csrf.Secure(false)) + } + opts = append(opts, csrf.SameSite(csrf.SameSiteStrictMode)) + return csrf.Protect(control.Config.CSRFSecret(), opts...)(handler) + }(func(w http.ResponseWriter, r *http.Request) { r = r.WithContext(cancel.ValuesOf(r.Context(), ctx)) mux.ServeHTTP(w, r) }), nil diff --git a/pkg/httpx/form.go b/pkg/httpx/form.go index a90ff4a..c5d058d 100644 --- a/pkg/httpx/form.go +++ b/pkg/httpx/form.go @@ -6,7 +6,6 @@ import ( "net/http" "strings" - "github.com/FAU-CDI/wisski-distillery/pkg/lazy" "github.com/gorilla/csrf" ) @@ -25,10 +24,8 @@ type Form[D any] struct { // FieldTemplate may be nil; in which case [DefaultFieldTemplate] is used. FieldTemplate *template.Template - // CSRF holds an optional reference to a CSRF.Protect call. - // It must be set before any other functions on this Form are called, and may not be changed. - CSRF func(http.Handler) http.Handler - csrf lazy.Lazy[http.Handler] + // CSRF indicates if a CSRF field should be added automatically + CSRF bool // SkipForm, if non-nil, is called on every get request to determine if form parsing should be skipped entirely. // If skip is true, RenderSuccess is directly called with the given values map. @@ -100,20 +97,8 @@ func (form *Form[D]) Values(r *http.Request) (v map[string]string, d D, err erro return values, d, nil } -// ServeHTTP implements [http.Handler]. -// If the form contains a csrf reference, then this is invoked also. +// ServeHTTP implements [http.Handler] and serves the form func (form *Form[D]) ServeHTTP(w http.ResponseWriter, r *http.Request) { - handler := form.csrf.Get(func() (handler http.Handler) { - handler = http.HandlerFunc(form.serveHTTP) - if form.CSRF != nil { - handler = form.CSRF(handler) - } - return - }) - handler.ServeHTTP(w, r) -} - -func (form *Form[D]) serveHTTP(w http.ResponseWriter, r *http.Request) { switch { default: TextInterceptor.Intercept(w, r, ErrMethodNotAllowed) @@ -139,7 +124,7 @@ func (form *Form[D]) serveHTTP(w http.ResponseWriter, r *http.Request) { // renderForm renders the form into a request func (form *Form[D]) renderForm(err error, values map[string]string, w http.ResponseWriter, r *http.Request) { template := form.Template(values, err != nil) - if form.CSRF != nil { + if form.CSRF { template += csrf.TemplateField(r) }