control: Generalize cookie and csrf handling
This commit is contained in:
parent
eb17dbe33f
commit
34bdb3cf24
15 changed files with 122 additions and 44 deletions
|
|
@ -19,14 +19,18 @@ type Auth struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
store lazy.Lazy[sessions.Store]
|
store lazy.Lazy[sessions.Store]
|
||||||
csrf lazy.Lazy[func(http.Handler) http.Handler]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
_ component.Routeable = (*Auth)(nil)
|
_ component.Routeable = (*Auth)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (auth *Auth) Routes() []string { return []string{"/auth/"} }
|
func (auth *Auth) Routes() component.Routes {
|
||||||
|
return component.Routes{
|
||||||
|
Paths: []string{"/auth/"},
|
||||||
|
CSRF: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (auth *Auth) HandleRoute(ctx context.Context, route string) (http.Handler, error) {
|
func (auth *Auth) HandleRoute(ctx context.Context, route string) (http.Handler, error) {
|
||||||
router := httprouter.New()
|
router := httprouter.New()
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,13 @@ var (
|
||||||
_ component.Routeable = (*UserPanel)(nil)
|
_ component.Routeable = (*UserPanel)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (panel *UserPanel) Routes() []string { return []string{"/user/"} }
|
func (panel *UserPanel) Routes() component.Routes {
|
||||||
|
return component.Routes{
|
||||||
|
Paths: []string{"/user/"},
|
||||||
|
CSRF: true,
|
||||||
|
Decorator: panel.Dependencies.Auth.Require(nil),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (panel *UserPanel) HandleRoute(ctx context.Context, route string) (http.Handler, error) {
|
func (panel *UserPanel) HandleRoute(ctx context.Context, route string) (http.Handler, error) {
|
||||||
router := httprouter.New()
|
router := httprouter.New()
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,6 @@ func (panel *UserPanel) routePassword(ctx context.Context) http.Handler {
|
||||||
{Name: "new2", Type: httpx.PasswordField, EmptyOnError: true, Label: "New Password (again)"},
|
{Name: "new2", Type: httpx.PasswordField, EmptyOnError: true, Label: "New Password (again)"},
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
CSRF: true,
|
|
||||||
|
|
||||||
RenderTemplate: passwordTemplate,
|
RenderTemplate: passwordTemplate,
|
||||||
RenderTemplateContext: panel.UserFormContext,
|
RenderTemplateContext: panel.UserFormContext,
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ func (panel *UserPanel) routeTOTPEnable(ctx context.Context) http.Handler {
|
||||||
{Name: "password", Type: httpx.PasswordField, EmptyOnError: true, Label: "Current Password"},
|
{Name: "password", Type: httpx.PasswordField, EmptyOnError: true, Label: "Current Password"},
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
CSRF: true,
|
|
||||||
|
|
||||||
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
||||||
user, err := panel.Dependencies.Auth.UserOf(r)
|
user, err := panel.Dependencies.Auth.UserOf(r)
|
||||||
|
|
@ -80,7 +79,6 @@ func (panel *UserPanel) routeTOTPEnroll(ctx context.Context) http.Handler {
|
||||||
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode"},
|
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode"},
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
CSRF: true,
|
|
||||||
|
|
||||||
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
||||||
user, err := panel.Dependencies.Auth.UserOf(r)
|
user, err := panel.Dependencies.Auth.UserOf(r)
|
||||||
|
|
@ -150,7 +148,6 @@ func (panel *UserPanel) routeTOTPDisable(ctx context.Context) http.Handler {
|
||||||
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Current Passcode"},
|
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Current Passcode"},
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
CSRF: true,
|
|
||||||
|
|
||||||
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
||||||
user, err := panel.Dependencies.Auth.UserOf(r)
|
user, err := panel.Dependencies.Auth.UserOf(r)
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,14 @@ func (auth *Auth) Protect(handler http.Handler, perm Permission) http.Handler {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Require returns a slice containing one decorator that acts like Protect(perm) on every request.
|
||||||
|
// It returns
|
||||||
|
func (auth *Auth) Require(perm Permission) func(http.Handler) http.Handler {
|
||||||
|
return func(h http.Handler) http.Handler {
|
||||||
|
return auth.Protect(h, perm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Admin represents a permission that checks if a user is an administrator and has totp enabled.
|
// Admin represents a permission that checks if a user is an administrator and has totp enabled.
|
||||||
var Admin Permission = func(user *AuthUser, r *http.Request) (ok Grant, err error) {
|
var Admin Permission = func(user *AuthUser, r *http.Request) (ok Grant, err error) {
|
||||||
return Bool2Grant(user != nil && user.IsAdmin() && user.IsTOTPEnabled(), "user needs to have admin permissions and passcode enabled"), nil
|
return Bool2Grant(user != nil && user.IsAdmin() && user.IsTOTPEnabled(), "user needs to have admin permissions and passcode enabled"), nil
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/FAU-CDI/wisski-distillery/internal/dis/component/control"
|
||||||
"github.com/FAU-CDI/wisski-distillery/internal/dis/component/control/static"
|
"github.com/FAU-CDI/wisski-distillery/internal/dis/component/control/static"
|
||||||
"github.com/FAU-CDI/wisski-distillery/pkg/httpx"
|
"github.com/FAU-CDI/wisski-distillery/pkg/httpx"
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
|
|
@ -30,7 +31,7 @@ func (auth *Auth) UserOf(r *http.Request) (user *AuthUser, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// try to read the name from the session
|
// try to read the name from the session
|
||||||
name, ok := sess.Values[sessionUserKey]
|
name, ok := sess.Values[control.SessionUserKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
@ -57,18 +58,14 @@ func (auth *Auth) UserOf(r *http.Request) (user *AuthUser, err error) {
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const sessionCookieName = "distillery-session"
|
|
||||||
|
|
||||||
// session returns the session that belongs to a given request.
|
// session returns the session that belongs to a given request.
|
||||||
// If the session is not set, creates a new session.
|
// If the session is not set, creates a new session.
|
||||||
func (auth *Auth) session(r *http.Request) (*sessions.Session, error) {
|
func (auth *Auth) session(r *http.Request) (*sessions.Session, error) {
|
||||||
return auth.store.Get(func() sessions.Store {
|
return auth.store.Get(func() sessions.Store {
|
||||||
return sessions.NewCookieStore([]byte(auth.Config.SessionSecret))
|
return sessions.NewCookieStore([]byte(auth.Config.SessionSecret))
|
||||||
}).Get(r, sessionCookieName)
|
}).Get(r, control.SessionCookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
const sessionUserKey = "user"
|
|
||||||
|
|
||||||
type contextUserKey struct{}
|
type contextUserKey struct{}
|
||||||
|
|
||||||
var ctxUserKey = contextUserKey{}
|
var ctxUserKey = contextUserKey{}
|
||||||
|
|
@ -84,7 +81,7 @@ func (auth *Auth) Login(w http.ResponseWriter, r *http.Request, user *AuthUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sess.Values[sessionUserKey] = user.User.User
|
sess.Values[control.SessionUserKey] = user.User.User
|
||||||
return sess.Save(r, w)
|
return sess.Save(r, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -121,7 +118,6 @@ func (auth *Auth) authLogin(ctx context.Context) http.Handler {
|
||||||
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode (optional)"},
|
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode (optional)"},
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
CSRF: true,
|
|
||||||
|
|
||||||
RenderForm: func(context httpx.FormContext, w http.ResponseWriter, r *http.Request) {
|
RenderForm: func(context httpx.FormContext, w http.ResponseWriter, r *http.Request) {
|
||||||
if context.Err != nil {
|
if context.Err != nil {
|
||||||
|
|
|
||||||
|
|
@ -36,19 +36,24 @@ var (
|
||||||
_ component.Routeable = (*Admin)(nil)
|
_ component.Routeable = (*Admin)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Admin) Routes() []string { return []string{"/admin/"} }
|
func (admin *Admin) Routes() component.Routes {
|
||||||
|
return component.Routes{
|
||||||
|
Paths: []string{"/admin/"},
|
||||||
|
CSRF: true,
|
||||||
|
Decorator: admin.Dependencies.Auth.Require(auth.Admin),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (admin *Admin) HandleRoute(ctx context.Context, route string) (handler http.Handler, err error) {
|
func (admin *Admin) HandleRoute(ctx context.Context, route string) (handler http.Handler, err error) {
|
||||||
|
|
||||||
router := httprouter.New()
|
router := httprouter.New()
|
||||||
|
|
||||||
{
|
{
|
||||||
socket := &httpx.WebSocket{
|
handler = &httpx.WebSocket{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
Fallback: router,
|
Fallback: router,
|
||||||
Handler: admin.serveSocket,
|
Handler: admin.serveSocket,
|
||||||
}
|
}
|
||||||
handler = admin.Dependencies.Auth.Protect(socket, auth.Admin)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle everything
|
// handle everything
|
||||||
|
|
|
||||||
|
|
@ -64,8 +64,6 @@ func (admin *Admin) createUser(ctx context.Context) http.Handler {
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
|
|
||||||
CSRF: true,
|
|
||||||
|
|
||||||
RenderTemplate: userCreateTemplate,
|
RenderTemplate: userCreateTemplate,
|
||||||
|
|
||||||
Validate: func(r *http.Request, values map[string]string) (cu createUserResult, err error) {
|
Validate: func(r *http.Request, values map[string]string) (cu createUserResult, err error) {
|
||||||
|
|
|
||||||
14
internal/dis/component/control/cookies.go
Normal file
14
internal/dis/component/control/cookies.go
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
package control
|
||||||
|
|
||||||
|
// CSRFCookie, CSRFCookieField, SessionCookie and SessionUserKey
|
||||||
|
// hold the names of the cookies and fields used for specific cookies.
|
||||||
|
//
|
||||||
|
// These are intentionally kept short to conserve bandwidth.
|
||||||
|
const (
|
||||||
|
CSRFCookie = "F" // CSRF cookie sent on a lot of requests
|
||||||
|
CSRFCookieField = "@" // form field name __should not be used by anything else__
|
||||||
|
// to pay respect
|
||||||
|
|
||||||
|
SessionCookie = "x" // name of the cookie to use ; to doubt
|
||||||
|
SessionUserKey = "@" // key within the session data to hold the username
|
||||||
|
)
|
||||||
|
|
@ -25,7 +25,12 @@ var (
|
||||||
_ component.Routeable = (*Home)(nil)
|
_ component.Routeable = (*Home)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Home) Routes() []string { return []string{"/"} }
|
func (*Home) Routes() component.Routes {
|
||||||
|
return component.Routes{
|
||||||
|
Paths: []string{"/"},
|
||||||
|
CSRF: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (home *Home) HandleRoute(ctx context.Context, route string) (http.Handler, error) {
|
func (home *Home) HandleRoute(ctx context.Context, route string) (http.Handler, error) {
|
||||||
return home, nil
|
return home, nil
|
||||||
|
|
|
||||||
|
|
@ -18,28 +18,38 @@ func (control *Control) Server(ctx context.Context, progress io.Writer) (http.Ha
|
||||||
// create a new mux
|
// create a new mux
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
// add all the servable routes!
|
// create a csrf protector
|
||||||
|
csrfProtector := control.CSRF()
|
||||||
|
|
||||||
|
// iterate over all the handler
|
||||||
for _, s := range control.Dependencies.Routeables {
|
for _, s := range control.Dependencies.Routeables {
|
||||||
for _, route := range s.Routes() {
|
routes := s.Routes()
|
||||||
zerolog.Ctx(ctx).Info().Str("component", s.Name()).Str("route", route).Msg("mounting route")
|
zerolog.Ctx(ctx).Info().Str("component", s.Name()).Strs("paths", routes.Paths).Bool("csrf", routes.CSRF).Bool("decorator", routes.Decorator != nil).Msg("mounting route")
|
||||||
handler, err := s.HandleRoute(ctx, route)
|
|
||||||
|
for _, path := range routes.Paths {
|
||||||
|
handler, err := s.HandleRoute(ctx, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mux.Handle(route, handler)
|
mux.Handle(path, routes.Decorate(handler, csrfProtector))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(handler http.HandlerFunc) http.Handler {
|
// apply the given context function
|
||||||
// setup a csrf protector for everything with POST
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r = r.WithContext(cancel.ValuesOf(r.Context(), ctx))
|
||||||
|
mux.ServeHTTP(w, r)
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRF returns a CSRF handler for the given function
|
||||||
|
func (control *Control) CSRF() func(http.Handler) http.Handler {
|
||||||
var opts []csrf.Option
|
var opts []csrf.Option
|
||||||
if !control.Config.HTTPSEnabled() {
|
if !control.Config.HTTPSEnabled() {
|
||||||
opts = append(opts, csrf.Secure(false))
|
opts = append(opts, csrf.Secure(false))
|
||||||
}
|
}
|
||||||
opts = append(opts, csrf.SameSite(csrf.SameSiteStrictMode))
|
opts = append(opts, csrf.SameSite(csrf.SameSiteStrictMode))
|
||||||
return csrf.Protect(control.Config.CSRFSecret(), opts...)(handler)
|
opts = append(opts, csrf.CookieName(CSRFCookie))
|
||||||
}(func(w http.ResponseWriter, r *http.Request) {
|
opts = append(opts, csrf.FieldName(CSRFCookieField))
|
||||||
r = r.WithContext(cancel.ValuesOf(r.Context(), ctx))
|
return csrf.Protect(control.Config.CSRFSecret(), opts...)
|
||||||
mux.ServeHTTP(w, r)
|
|
||||||
}), nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,12 @@ var (
|
||||||
_ component.Routeable = (*Static)(nil)
|
_ component.Routeable = (*Static)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Static) Routes() []string { return []string{"/static/"} }
|
func (*Static) Routes() component.Routes {
|
||||||
|
return component.Routes{
|
||||||
|
Paths: []string{"/static/"},
|
||||||
|
CSRF: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//go:embed dist
|
//go:embed dist
|
||||||
var staticFS embed.FS
|
var staticFS embed.FS
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,12 @@ var (
|
||||||
_ component.Cronable = (*Resolver)(nil)
|
_ component.Cronable = (*Resolver)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (resolver *Resolver) Routes() []string { return []string{"/go/", "/wisski/get/"} }
|
func (resolver *Resolver) Routes() component.Routes {
|
||||||
|
return component.Routes{
|
||||||
|
Paths: []string{"/go/", "/wisski/get/"},
|
||||||
|
CSRF: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (resolver *Resolver) HandleRoute(ctx context.Context, route string) (http.Handler, error) {
|
func (resolver *Resolver) HandleRoute(ctx context.Context, route string) (http.Handler, error) {
|
||||||
logger := zerolog.Ctx(ctx)
|
logger := zerolog.Ctx(ctx)
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,35 @@ import (
|
||||||
type Routeable interface {
|
type Routeable interface {
|
||||||
Component
|
Component
|
||||||
|
|
||||||
// Routes returns the routes served by this servable
|
// Routes returns information about the routes to be handled by this Routeable
|
||||||
Routes() []string
|
Routes() Routes
|
||||||
|
|
||||||
// HandleRoute returns the handler for the requested route
|
// HandleRoute returns the handler for the requested path
|
||||||
HandleRoute(ctx context.Context, route string) (http.Handler, error)
|
HandleRoute(ctx context.Context, path string) (http.Handler, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Routes represents information about a single Routeable
|
||||||
|
type Routes struct {
|
||||||
|
// Paths are the paths handled by this routeable.
|
||||||
|
// Each path is passed to HandleRoute() individually.
|
||||||
|
Paths []string
|
||||||
|
|
||||||
|
// CSRF indicates if this route should be protected by CSRF.
|
||||||
|
// CSRF protection is applied prior to any custom decorator being called.
|
||||||
|
CSRF bool
|
||||||
|
|
||||||
|
// Decorators is a function applied to the handler returned by HandleRoute.
|
||||||
|
// When nil, it is not applied.
|
||||||
|
Decorator func(http.Handler) http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decorate decorates the provided handler with the options specified in this handler.
|
||||||
|
func (routes Routes) Decorate(handler http.Handler, csrf func(http.Handler) http.Handler) http.Handler {
|
||||||
|
if routes.CSRF && csrf != nil {
|
||||||
|
handler = csrf(handler)
|
||||||
|
}
|
||||||
|
if routes.Decorator != nil {
|
||||||
|
handler = routes.Decorator(handler)
|
||||||
|
}
|
||||||
|
return handler
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ type Form[D any] struct {
|
||||||
// FieldTemplate may be nil; in which case [DefaultFieldTemplate] is used.
|
// FieldTemplate may be nil; in which case [DefaultFieldTemplate] is used.
|
||||||
FieldTemplate *template.Template
|
FieldTemplate *template.Template
|
||||||
|
|
||||||
// CSRF indicates if a CSRF field should be added automatically
|
// SkipCSRF if CSRF should be explicitly omitted
|
||||||
CSRF bool
|
SkipCSRF bool
|
||||||
|
|
||||||
// SkipForm, if non-nil, is called on every get request to determine if form parsing should be skipped entirely.
|
// SkipForm, if non-nil, is called on every get request to determine if form parsing should be skipped entirely.
|
||||||
// If skip is true, RenderSuccess is directly called with the given values map.
|
// If skip is true, RenderSuccess is directly called with the given values map.
|
||||||
|
|
@ -124,7 +124,7 @@ func (form *Form[D]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// renderForm renders the form into a request
|
// renderForm renders the form into a request
|
||||||
func (form *Form[D]) renderForm(err error, values map[string]string, w http.ResponseWriter, r *http.Request) {
|
func (form *Form[D]) renderForm(err error, values map[string]string, w http.ResponseWriter, r *http.Request) {
|
||||||
template := form.Template(values, err != nil)
|
template := form.Template(values, err != nil)
|
||||||
if form.CSRF {
|
if !form.SkipCSRF {
|
||||||
template += csrf.TemplateField(r)
|
template += csrf.TemplateField(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue