Refactor CSRF protection
This commit is contained in:
parent
59b565ae19
commit
eb17dbe33f
8 changed files with 20 additions and 45 deletions
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"github.com/FAU-CDI/wisski-distillery/internal/dis/component"
|
"github.com/FAU-CDI/wisski-distillery/internal/dis/component"
|
||||||
"github.com/FAU-CDI/wisski-distillery/internal/dis/component/sql"
|
"github.com/FAU-CDI/wisski-distillery/internal/dis/component/sql"
|
||||||
"github.com/FAU-CDI/wisski-distillery/pkg/lazy"
|
"github.com/FAU-CDI/wisski-distillery/pkg/lazy"
|
||||||
"github.com/gorilla/csrf"
|
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
)
|
)
|
||||||
|
|
@ -41,15 +40,3 @@ func (auth *Auth) HandleRoute(ctx context.Context, route string) (http.Handler,
|
||||||
|
|
||||||
return router, nil
|
return router, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) CSRF() func(http.Handler) http.Handler {
|
|
||||||
// setup the csrf handler (if needed)
|
|
||||||
// TOOD: This should move to the server handler
|
|
||||||
return auth.csrf.Get(func() func(http.Handler) http.Handler {
|
|
||||||
var opts []csrf.Option
|
|
||||||
if !auth.Config.HTTPSEnabled() {
|
|
||||||
opts = append(opts, csrf.Secure(false))
|
|
||||||
}
|
|
||||||
return csrf.Protect(auth.Config.CSRFSecret(), opts...)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,7 @@ func (panel *UserPanel) routePassword(ctx context.Context) http.Handler {
|
||||||
{Name: "new2", Type: httpx.PasswordField, EmptyOnError: true, Label: "New Password (again)"},
|
{Name: "new2", Type: httpx.PasswordField, EmptyOnError: true, Label: "New Password (again)"},
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
|
CSRF: true,
|
||||||
CSRF: panel.Dependencies.Auth.CSRF(),
|
|
||||||
|
|
||||||
RenderTemplate: passwordTemplate,
|
RenderTemplate: passwordTemplate,
|
||||||
RenderTemplateContext: panel.UserFormContext,
|
RenderTemplateContext: panel.UserFormContext,
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,7 @@ func (panel *UserPanel) routeTOTPEnable(ctx context.Context) http.Handler {
|
||||||
{Name: "password", Type: httpx.PasswordField, EmptyOnError: true, Label: "Current Password"},
|
{Name: "password", Type: httpx.PasswordField, EmptyOnError: true, Label: "Current Password"},
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
|
CSRF: true,
|
||||||
CSRF: panel.Dependencies.Auth.CSRF(),
|
|
||||||
|
|
||||||
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
||||||
user, err := panel.Dependencies.Auth.UserOf(r)
|
user, err := panel.Dependencies.Auth.UserOf(r)
|
||||||
|
|
@ -81,8 +80,7 @@ func (panel *UserPanel) routeTOTPEnroll(ctx context.Context) http.Handler {
|
||||||
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode"},
|
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode"},
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
|
CSRF: true,
|
||||||
CSRF: panel.Dependencies.Auth.CSRF(),
|
|
||||||
|
|
||||||
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
||||||
user, err := panel.Dependencies.Auth.UserOf(r)
|
user, err := panel.Dependencies.Auth.UserOf(r)
|
||||||
|
|
@ -152,8 +150,7 @@ func (panel *UserPanel) routeTOTPDisable(ctx context.Context) http.Handler {
|
||||||
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Current Passcode"},
|
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Current Passcode"},
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
|
CSRF: true,
|
||||||
CSRF: panel.Dependencies.Auth.CSRF(),
|
|
||||||
|
|
||||||
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
SkipForm: func(r *http.Request) (data struct{}, skip bool) {
|
||||||
user, err := panel.Dependencies.Auth.UserOf(r)
|
user, err := panel.Dependencies.Auth.UserOf(r)
|
||||||
|
|
|
||||||
|
|
@ -121,8 +121,7 @@ func (auth *Auth) authLogin(ctx context.Context) http.Handler {
|
||||||
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode (optional)"},
|
{Name: "otp", Type: httpx.TextField, EmptyOnError: true, Label: "Passcode (optional)"},
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
|
CSRF: true,
|
||||||
CSRF: auth.CSRF(),
|
|
||||||
|
|
||||||
RenderForm: func(context httpx.FormContext, w http.ResponseWriter, r *http.Request) {
|
RenderForm: func(context httpx.FormContext, w http.ResponseWriter, r *http.Request) {
|
||||||
if context.Err != nil {
|
if context.Err != nil {
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,6 @@ func (admin *Admin) HandleRoute(ctx context.Context, route string) (handler http
|
||||||
Handler: admin.serveSocket,
|
Handler: admin.serveSocket,
|
||||||
}
|
}
|
||||||
handler = admin.Dependencies.Auth.Protect(socket, auth.Admin)
|
handler = admin.Dependencies.Auth.Protect(socket, auth.Admin)
|
||||||
handler = admin.Dependencies.Auth.CSRF()(handler)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle everything
|
// handle everything
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ func (admin *Admin) createUser(ctx context.Context) http.Handler {
|
||||||
},
|
},
|
||||||
FieldTemplate: httpx.PureCSSFieldTemplate,
|
FieldTemplate: httpx.PureCSSFieldTemplate,
|
||||||
|
|
||||||
CSRF: admin.Dependencies.Auth.CSRF(),
|
CSRF: true,
|
||||||
|
|
||||||
RenderTemplate: userCreateTemplate,
|
RenderTemplate: userCreateTemplate,
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/FAU-CDI/wisski-distillery/pkg/cancel"
|
"github.com/FAU-CDI/wisski-distillery/pkg/cancel"
|
||||||
|
"github.com/gorilla/csrf"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -29,7 +30,15 @@ func (control *Control) Server(ctx context.Context, progress io.Writer) (http.Ha
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return func(handler http.HandlerFunc) http.Handler {
|
||||||
|
// setup a csrf protector for everything with POST
|
||||||
|
var opts []csrf.Option
|
||||||
|
if !control.Config.HTTPSEnabled() {
|
||||||
|
opts = append(opts, csrf.Secure(false))
|
||||||
|
}
|
||||||
|
opts = append(opts, csrf.SameSite(csrf.SameSiteStrictMode))
|
||||||
|
return csrf.Protect(control.Config.CSRFSecret(), opts...)(handler)
|
||||||
|
}(func(w http.ResponseWriter, r *http.Request) {
|
||||||
r = r.WithContext(cancel.ValuesOf(r.Context(), ctx))
|
r = r.WithContext(cancel.ValuesOf(r.Context(), ctx))
|
||||||
mux.ServeHTTP(w, r)
|
mux.ServeHTTP(w, r)
|
||||||
}), nil
|
}), nil
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/FAU-CDI/wisski-distillery/pkg/lazy"
|
|
||||||
"github.com/gorilla/csrf"
|
"github.com/gorilla/csrf"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -25,10 +24,8 @@ type Form[D any] struct {
|
||||||
// FieldTemplate may be nil; in which case [DefaultFieldTemplate] is used.
|
// FieldTemplate may be nil; in which case [DefaultFieldTemplate] is used.
|
||||||
FieldTemplate *template.Template
|
FieldTemplate *template.Template
|
||||||
|
|
||||||
// CSRF holds an optional reference to a CSRF.Protect call.
|
// CSRF indicates if a CSRF field should be added automatically
|
||||||
// It must be set before any other functions on this Form are called, and may not be changed.
|
CSRF bool
|
||||||
CSRF func(http.Handler) http.Handler
|
|
||||||
csrf lazy.Lazy[http.Handler]
|
|
||||||
|
|
||||||
// SkipForm, if non-nil, is called on every get request to determine if form parsing should be skipped entirely.
|
// SkipForm, if non-nil, is called on every get request to determine if form parsing should be skipped entirely.
|
||||||
// If skip is true, RenderSuccess is directly called with the given values map.
|
// If skip is true, RenderSuccess is directly called with the given values map.
|
||||||
|
|
@ -100,20 +97,8 @@ func (form *Form[D]) Values(r *http.Request) (v map[string]string, d D, err erro
|
||||||
return values, d, nil
|
return values, d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP implements [http.Handler].
|
// ServeHTTP implements [http.Handler] and serves the form
|
||||||
// If the form contains a csrf reference, then this is invoked also.
|
|
||||||
func (form *Form[D]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (form *Form[D]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
handler := form.csrf.Get(func() (handler http.Handler) {
|
|
||||||
handler = http.HandlerFunc(form.serveHTTP)
|
|
||||||
if form.CSRF != nil {
|
|
||||||
handler = form.CSRF(handler)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
})
|
|
||||||
handler.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (form *Form[D]) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch {
|
switch {
|
||||||
default:
|
default:
|
||||||
TextInterceptor.Intercept(w, r, ErrMethodNotAllowed)
|
TextInterceptor.Intercept(w, r, ErrMethodNotAllowed)
|
||||||
|
|
@ -139,7 +124,7 @@ func (form *Form[D]) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// renderForm renders the form into a request
|
// renderForm renders the form into a request
|
||||||
func (form *Form[D]) renderForm(err error, values map[string]string, w http.ResponseWriter, r *http.Request) {
|
func (form *Form[D]) renderForm(err error, values map[string]string, w http.ResponseWriter, r *http.Request) {
|
||||||
template := form.Template(values, err != nil)
|
template := form.Template(values, err != nil)
|
||||||
if form.CSRF != nil {
|
if form.CSRF {
|
||||||
template += csrf.TemplateField(r)
|
template += csrf.TemplateField(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue