diff --git a/internal/dis/component/auth/auth.go b/internal/dis/component/auth/auth.go index 236bc64..776a076 100644 --- a/internal/dis/component/auth/auth.go +++ b/internal/dis/component/auth/auth.go @@ -19,14 +19,18 @@ type Auth struct { } store lazy.Lazy[sessions.Store] - csrf lazy.Lazy[func(http.Handler) http.Handler] } var ( _ component.Routeable = (*Auth)(nil) ) -func (auth *Auth) Routes() []string { return []string{"/auth/"} } +func (auth *Auth) Routes() component.Routes { + return component.Routes{ + Paths: []string{"/auth/"}, + CSRF: true, + } +} func (auth *Auth) HandleRoute(ctx context.Context, route string) (http.Handler, error) { router := httprouter.New() diff --git a/internal/dis/component/auth/panel/panel.go b/internal/dis/component/auth/panel/panel.go index 06191ba..6405e27 100644 --- a/internal/dis/component/auth/panel/panel.go +++ b/internal/dis/component/auth/panel/panel.go @@ -22,7 +22,13 @@ var ( _ component.Routeable = (*UserPanel)(nil) ) -func (panel *UserPanel) Routes() []string { return []string{"/user/"} } +func (panel *UserPanel) Routes() component.Routes { + return component.Routes{ + Paths: []string{"/user/"}, + CSRF: true, + Decorator: panel.Dependencies.Auth.Require(nil), + } +} func (panel *UserPanel) HandleRoute(ctx context.Context, route string) (http.Handler, error) { router := httprouter.New() diff --git a/internal/dis/component/auth/panel/password.go b/internal/dis/component/auth/panel/password.go index 37fc37a..8671239 100644 --- a/internal/dis/component/auth/panel/password.go +++ b/internal/dis/component/auth/panel/password.go @@ -33,7 +33,6 @@ func (panel *UserPanel) routePassword(ctx context.Context) http.Handler { {Name: "new2", Type: httpx.PasswordField, EmptyOnError: true, Label: "New Password (again)"}, }, FieldTemplate: httpx.PureCSSFieldTemplate, - CSRF: true, RenderTemplate: passwordTemplate, RenderTemplateContext: panel.UserFormContext, diff --git a/internal/dis/component/auth/panel/totp.go b/internal/dis/component/auth/panel/totp.go index e114bab..5e1ad19 100644 --- a/internal/dis/component/auth/panel/totp.go +++ b/internal/dis/component/auth/panel/totp.go @@ -22,7 +22,6 @@ func (panel *UserPanel) routeTOTPEnable(ctx context.Context) http.Handler { {Name: "password", Type: httpx.PasswordField, EmptyOnError: true, Label: "Current Password"}, }, FieldTemplate: httpx.PureCSSFieldTemplate, - CSRF: true, SkipForm: func(r *http.Request) (data struct{}, skip bool) { user, err := panel.Dependencies.Auth.UserOf(r) @@ -80,7 +79,6 @@ func (panel *UserPanel) routeTOTPEnroll(ctx context.Context) http.Handler { {Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode"}, }, FieldTemplate: httpx.PureCSSFieldTemplate, - CSRF: true, SkipForm: func(r *http.Request) (data struct{}, skip bool) { user, err := panel.Dependencies.Auth.UserOf(r) @@ -150,7 +148,6 @@ func (panel *UserPanel) routeTOTPDisable(ctx context.Context) http.Handler { {Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Current Passcode"}, }, FieldTemplate: httpx.PureCSSFieldTemplate, - CSRF: true, SkipForm: func(r *http.Request) (data struct{}, skip bool) { user, err := panel.Dependencies.Auth.UserOf(r) diff --git a/internal/dis/component/auth/protect.go b/internal/dis/component/auth/protect.go index ed399af..85387c1 100644 --- a/internal/dis/component/auth/protect.go +++ b/internal/dis/component/auth/protect.go @@ -72,6 +72,14 @@ func (auth *Auth) Protect(handler http.Handler, perm Permission) http.Handler { }) } +// Require returns a slice containing one decorator that acts like Protect(perm) on every request. +// It returns +func (auth *Auth) Require(perm Permission) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return auth.Protect(h, perm) + } +} + // Admin represents a permission that checks if a user is an administrator and has totp enabled. var Admin Permission = func(user *AuthUser, r *http.Request) (ok Grant, err error) { return Bool2Grant(user != nil && user.IsAdmin() && user.IsTOTPEnabled(), "user needs to have admin permissions and passcode enabled"), nil diff --git a/internal/dis/component/auth/session.go b/internal/dis/component/auth/session.go index ba5bf94..4ba10c3 100644 --- a/internal/dis/component/auth/session.go +++ b/internal/dis/component/auth/session.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" + "github.com/FAU-CDI/wisski-distillery/internal/dis/component/control" "github.com/FAU-CDI/wisski-distillery/internal/dis/component/control/static" "github.com/FAU-CDI/wisski-distillery/pkg/httpx" "github.com/gorilla/sessions" @@ -30,7 +31,7 @@ func (auth *Auth) UserOf(r *http.Request) (user *AuthUser, err error) { } // try to read the name from the session - name, ok := sess.Values[sessionUserKey] + name, ok := sess.Values[control.SessionUserKey] if !ok { return nil, nil } @@ -57,18 +58,14 @@ func (auth *Auth) UserOf(r *http.Request) (user *AuthUser, err error) { return user, nil } -const sessionCookieName = "distillery-session" - // session returns the session that belongs to a given request. // If the session is not set, creates a new session. func (auth *Auth) session(r *http.Request) (*sessions.Session, error) { return auth.store.Get(func() sessions.Store { return sessions.NewCookieStore([]byte(auth.Config.SessionSecret)) - }).Get(r, sessionCookieName) + }).Get(r, control.SessionCookie) } -const sessionUserKey = "user" - type contextUserKey struct{} var ctxUserKey = contextUserKey{} @@ -84,7 +81,7 @@ func (auth *Auth) Login(w http.ResponseWriter, r *http.Request, user *AuthUser) if err != nil { return err } - sess.Values[sessionUserKey] = user.User.User + sess.Values[control.SessionUserKey] = user.User.User return sess.Save(r, w) } @@ -121,7 +118,6 @@ func (auth *Auth) authLogin(ctx context.Context) http.Handler { {Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode (optional)"}, }, FieldTemplate: httpx.PureCSSFieldTemplate, - CSRF: true, RenderForm: func(context httpx.FormContext, w http.ResponseWriter, r *http.Request) { if context.Err != nil { diff --git a/internal/dis/component/control/admin/admin.go b/internal/dis/component/control/admin/admin.go index 0ed0d36..7e90695 100644 --- a/internal/dis/component/control/admin/admin.go +++ b/internal/dis/component/control/admin/admin.go @@ -36,19 +36,24 @@ var ( _ component.Routeable = (*Admin)(nil) ) -func (*Admin) Routes() []string { return []string{"/admin/"} } +func (admin *Admin) Routes() component.Routes { + return component.Routes{ + Paths: []string{"/admin/"}, + CSRF: true, + Decorator: admin.Dependencies.Auth.Require(auth.Admin), + } +} func (admin *Admin) HandleRoute(ctx context.Context, route string) (handler http.Handler, err error) { router := httprouter.New() { - socket := &httpx.WebSocket{ + handler = &httpx.WebSocket{ Context: ctx, Fallback: router, Handler: admin.serveSocket, } - handler = admin.Dependencies.Auth.Protect(socket, auth.Admin) } // handle everything diff --git a/internal/dis/component/control/admin/users.go b/internal/dis/component/control/admin/users.go index bfd15a8..a493355 100644 --- a/internal/dis/component/control/admin/users.go +++ b/internal/dis/component/control/admin/users.go @@ -64,8 +64,6 @@ func (admin *Admin) createUser(ctx context.Context) http.Handler { }, FieldTemplate: httpx.PureCSSFieldTemplate, - CSRF: true, - RenderTemplate: userCreateTemplate, Validate: func(r *http.Request, values map[string]string) (cu createUserResult, err error) { diff --git a/internal/dis/component/control/cookies.go b/internal/dis/component/control/cookies.go new file mode 100644 index 0000000..70d68a1 --- /dev/null +++ b/internal/dis/component/control/cookies.go @@ -0,0 +1,14 @@ +package control + +// CSRFCookie, CSRFCookieField, SessionCookie and SessionUserKey +// hold the names of the cookies and fields used for specific cookies. +// +// These are intentionally kept short to conserve bandwidth. +const ( + CSRFCookie = "F" // CSRF cookie sent on a lot of requests + CSRFCookieField = "@" // form field name __should not be used by anything else__ + // to pay respect + + SessionCookie = "x" // name of the cookie to use ; to doubt + SessionUserKey = "@" // key within the session data to hold the username +) diff --git a/internal/dis/component/control/home/home.go b/internal/dis/component/control/home/home.go index 81be12f..007b494 100644 --- a/internal/dis/component/control/home/home.go +++ b/internal/dis/component/control/home/home.go @@ -25,7 +25,12 @@ var ( _ component.Routeable = (*Home)(nil) ) -func (*Home) Routes() []string { return []string{"/"} } +func (*Home) Routes() component.Routes { + return component.Routes{ + Paths: []string{"/"}, + CSRF: false, + } +} func (home *Home) HandleRoute(ctx context.Context, route string) (http.Handler, error) { return home, nil diff --git a/internal/dis/component/control/server.go b/internal/dis/component/control/server.go index d03872e..05eeb56 100644 --- a/internal/dis/component/control/server.go +++ b/internal/dis/component/control/server.go @@ -18,28 +18,38 @@ func (control *Control) Server(ctx context.Context, progress io.Writer) (http.Ha // create a new mux mux := http.NewServeMux() - // add all the servable routes! + // create a csrf protector + csrfProtector := control.CSRF() + + // iterate over all the handler for _, s := range control.Dependencies.Routeables { - for _, route := range s.Routes() { - zerolog.Ctx(ctx).Info().Str("component", s.Name()).Str("route", route).Msg("mounting route") - handler, err := s.HandleRoute(ctx, route) + routes := s.Routes() + zerolog.Ctx(ctx).Info().Str("component", s.Name()).Strs("paths", routes.Paths).Bool("csrf", routes.CSRF).Bool("decorator", routes.Decorator != nil).Msg("mounting route") + + for _, path := range routes.Paths { + handler, err := s.HandleRoute(ctx, path) if err != nil { return nil, err } - mux.Handle(route, handler) + mux.Handle(path, routes.Decorate(handler, csrfProtector)) } } - return func(handler http.HandlerFunc) http.Handler { - // setup a csrf protector for everything with POST - var opts []csrf.Option - if !control.Config.HTTPSEnabled() { - opts = append(opts, csrf.Secure(false)) - } - opts = append(opts, csrf.SameSite(csrf.SameSiteStrictMode)) - return csrf.Protect(control.Config.CSRFSecret(), opts...)(handler) - }(func(w http.ResponseWriter, r *http.Request) { + // apply the given context function + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r = r.WithContext(cancel.ValuesOf(r.Context(), ctx)) mux.ServeHTTP(w, r) }), nil } + +// CSRF returns a CSRF handler for the given function +func (control *Control) CSRF() func(http.Handler) http.Handler { + var opts []csrf.Option + if !control.Config.HTTPSEnabled() { + opts = append(opts, csrf.Secure(false)) + } + opts = append(opts, csrf.SameSite(csrf.SameSiteStrictMode)) + opts = append(opts, csrf.CookieName(CSRFCookie)) + opts = append(opts, csrf.FieldName(CSRFCookieField)) + return csrf.Protect(control.Config.CSRFSecret(), opts...) +} diff --git a/internal/dis/component/control/static/static.go b/internal/dis/component/control/static/static.go index ab05c5a..d6479ed 100644 --- a/internal/dis/component/control/static/static.go +++ b/internal/dis/component/control/static/static.go @@ -18,7 +18,12 @@ var ( _ component.Routeable = (*Static)(nil) ) -func (*Static) Routes() []string { return []string{"/static/"} } +func (*Static) Routes() component.Routes { + return component.Routes{ + Paths: []string{"/static/"}, + CSRF: false, + } +} //go:embed dist var staticFS embed.FS diff --git a/internal/dis/component/resolver/resolver.go b/internal/dis/component/resolver/resolver.go index 259b3bb..879e423 100644 --- a/internal/dis/component/resolver/resolver.go +++ b/internal/dis/component/resolver/resolver.go @@ -32,7 +32,12 @@ var ( _ component.Cronable = (*Resolver)(nil) ) -func (resolver *Resolver) Routes() []string { return []string{"/go/", "/wisski/get/"} } +func (resolver *Resolver) Routes() component.Routes { + return component.Routes{ + Paths: []string{"/go/", "/wisski/get/"}, + CSRF: false, + } +} func (resolver *Resolver) HandleRoute(ctx context.Context, route string) (http.Handler, error) { logger := zerolog.Ctx(ctx) diff --git a/internal/dis/component/server.go b/internal/dis/component/server.go index 0f09e80..7726afb 100644 --- a/internal/dis/component/server.go +++ b/internal/dis/component/server.go @@ -9,9 +9,35 @@ import ( type Routeable interface { Component - // Routes returns the routes served by this servable - Routes() []string + // Routes returns information about the routes to be handled by this Routeable + Routes() Routes - // HandleRoute returns the handler for the requested route - HandleRoute(ctx context.Context, route string) (http.Handler, error) + // HandleRoute returns the handler for the requested path + HandleRoute(ctx context.Context, path string) (http.Handler, error) +} + +// Routes represents information about a single Routeable +type Routes struct { + // Paths are the paths handled by this routeable. + // Each path is passed to HandleRoute() individually. + Paths []string + + // CSRF indicates if this route should be protected by CSRF. + // CSRF protection is applied prior to any custom decorator being called. + CSRF bool + + // Decorators is a function applied to the handler returned by HandleRoute. + // When nil, it is not applied. + Decorator func(http.Handler) http.Handler +} + +// Decorate decorates the provided handler with the options specified in this handler. +func (routes Routes) Decorate(handler http.Handler, csrf func(http.Handler) http.Handler) http.Handler { + if routes.CSRF && csrf != nil { + handler = csrf(handler) + } + if routes.Decorator != nil { + handler = routes.Decorator(handler) + } + return handler } diff --git a/pkg/httpx/form.go b/pkg/httpx/form.go index c5d058d..56a6993 100644 --- a/pkg/httpx/form.go +++ b/pkg/httpx/form.go @@ -24,8 +24,8 @@ type Form[D any] struct { // FieldTemplate may be nil; in which case [DefaultFieldTemplate] is used. FieldTemplate *template.Template - // CSRF indicates if a CSRF field should be added automatically - CSRF bool + // SkipCSRF if CSRF should be explicitly omitted + SkipCSRF bool // SkipForm, if non-nil, is called on every get request to determine if form parsing should be skipped entirely. // If skip is true, RenderSuccess is directly called with the given values map. @@ -124,7 +124,7 @@ func (form *Form[D]) ServeHTTP(w http.ResponseWriter, r *http.Request) { // renderForm renders the form into a request func (form *Form[D]) renderForm(err error, values map[string]string, w http.ResponseWriter, r *http.Request) { template := form.Template(values, err != nil) - if form.CSRF { + if !form.SkipCSRF { template += csrf.TemplateField(r) }