Refactor CSRF protection

This commit is contained in:
Tom Wiesing 2023-01-05 14:07:36 +01:00
parent 59b565ae19
commit eb17dbe33f
No known key found for this signature in database
8 changed files with 20 additions and 45 deletions

View file

@ -7,7 +7,6 @@ import (
"github.com/FAU-CDI/wisski-distillery/internal/dis/component"
"github.com/FAU-CDI/wisski-distillery/internal/dis/component/sql"
"github.com/FAU-CDI/wisski-distillery/pkg/lazy"
"github.com/gorilla/csrf"
"github.com/gorilla/sessions"
"github.com/julienschmidt/httprouter"
)
@ -41,15 +40,3 @@ func (auth *Auth) HandleRoute(ctx context.Context, route string) (http.Handler,
return router, nil
}
func (auth *Auth) CSRF() func(http.Handler) http.Handler {
// setup the csrf handler (if needed)
// TOOD: This should move to the server handler
return auth.csrf.Get(func() func(http.Handler) http.Handler {
var opts []csrf.Option
if !auth.Config.HTTPSEnabled() {
opts = append(opts, csrf.Secure(false))
}
return csrf.Protect(auth.Config.CSRFSecret(), opts...)
})
}

View file

@ -33,8 +33,7 @@ func (panel *UserPanel) routePassword(ctx context.Context) http.Handler {
{Name: "new2", Type: httpx.PasswordField, EmptyOnError: true, Label: "New Password (again)"},
},
FieldTemplate: httpx.PureCSSFieldTemplate,
CSRF: panel.Dependencies.Auth.CSRF(),
CSRF: true,
RenderTemplate: passwordTemplate,
RenderTemplateContext: panel.UserFormContext,

View file

@ -22,8 +22,7 @@ func (panel *UserPanel) routeTOTPEnable(ctx context.Context) http.Handler {
{Name: "password", Type: httpx.PasswordField, EmptyOnError: true, Label: "Current Password"},
},
FieldTemplate: httpx.PureCSSFieldTemplate,
CSRF: panel.Dependencies.Auth.CSRF(),
CSRF: true,
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
user, err := panel.Dependencies.Auth.UserOf(r)
@ -81,8 +80,7 @@ func (panel *UserPanel) routeTOTPEnroll(ctx context.Context) http.Handler {
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode"},
},
FieldTemplate: httpx.PureCSSFieldTemplate,
CSRF: panel.Dependencies.Auth.CSRF(),
CSRF: true,
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
user, err := panel.Dependencies.Auth.UserOf(r)
@ -152,8 +150,7 @@ func (panel *UserPanel) routeTOTPDisable(ctx context.Context) http.Handler {
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Current Passcode"},
},
FieldTemplate: httpx.PureCSSFieldTemplate,
CSRF: panel.Dependencies.Auth.CSRF(),
CSRF: true,
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
user, err := panel.Dependencies.Auth.UserOf(r)

View file

@ -121,8 +121,7 @@ func (auth *Auth) authLogin(ctx context.Context) http.Handler {
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode (optional)"},
},
FieldTemplate: httpx.PureCSSFieldTemplate,
CSRF: auth.CSRF(),
CSRF: true,
RenderForm: func(context httpx.FormContext, w http.ResponseWriter, r *http.Request) {
if context.Err != nil {

View file

@ -49,7 +49,6 @@ func (admin *Admin) HandleRoute(ctx context.Context, route string) (handler http
Handler: admin.serveSocket,
}
handler = admin.Dependencies.Auth.Protect(socket, auth.Admin)
handler = admin.Dependencies.Auth.CSRF()(handler)
}
// handle everything

View file

@ -64,7 +64,7 @@ func (admin *Admin) createUser(ctx context.Context) http.Handler {
},
FieldTemplate: httpx.PureCSSFieldTemplate,
CSRF: admin.Dependencies.Auth.CSRF(),
CSRF: true,
RenderTemplate: userCreateTemplate,

View file

@ -6,6 +6,7 @@ import (
"net/http"
"github.com/FAU-CDI/wisski-distillery/pkg/cancel"
"github.com/gorilla/csrf"
"github.com/rs/zerolog"
)
@ -29,7 +30,15 @@ func (control *Control) Server(ctx context.Context, progress io.Writer) (http.Ha
}
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return func(handler http.HandlerFunc) http.Handler {
// setup a csrf protector for everything with POST
var opts []csrf.Option
if !control.Config.HTTPSEnabled() {
opts = append(opts, csrf.Secure(false))
}
opts = append(opts, csrf.SameSite(csrf.SameSiteStrictMode))
return csrf.Protect(control.Config.CSRFSecret(), opts...)(handler)
}(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(cancel.ValuesOf(r.Context(), ctx))
mux.ServeHTTP(w, r)
}), nil