Refactor CSRF protection

This commit is contained in:
Tom Wiesing 2023-01-05 14:07:36 +01:00
parent 59b565ae19
commit eb17dbe33f
No known key found for this signature in database
8 changed files with 20 additions and 45 deletions

View file

@ -6,7 +6,6 @@ import (
"net/http"
"strings"
"github.com/FAU-CDI/wisski-distillery/pkg/lazy"
"github.com/gorilla/csrf"
)
@ -25,10 +24,8 @@ type Form[D any] struct {
// FieldTemplate may be nil; in which case [DefaultFieldTemplate] is used.
FieldTemplate *template.Template
// CSRF holds an optional reference to a CSRF.Protect call.
// It must be set before any other functions on this Form are called, and may not be changed.
CSRF func(http.Handler) http.Handler
csrf lazy.Lazy[http.Handler]
// CSRF indicates if a CSRF field should be added automatically
CSRF bool
// SkipForm, if non-nil, is called on every get request to determine if form parsing should be skipped entirely.
// If skip is true, RenderSuccess is directly called with the given values map.
@ -100,20 +97,8 @@ func (form *Form[D]) Values(r *http.Request) (v map[string]string, d D, err erro
return values, d, nil
}
// ServeHTTP implements [http.Handler].
// If the form contains a csrf reference, then this is invoked also.
// ServeHTTP implements [http.Handler] and serves the form
func (form *Form[D]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler := form.csrf.Get(func() (handler http.Handler) {
handler = http.HandlerFunc(form.serveHTTP)
if form.CSRF != nil {
handler = form.CSRF(handler)
}
return
})
handler.ServeHTTP(w, r)
}
func (form *Form[D]) serveHTTP(w http.ResponseWriter, r *http.Request) {
switch {
default:
TextInterceptor.Intercept(w, r, ErrMethodNotAllowed)
@ -139,7 +124,7 @@ func (form *Form[D]) serveHTTP(w http.ResponseWriter, r *http.Request) {
// renderForm renders the form into a request
func (form *Form[D]) renderForm(err error, values map[string]string, w http.ResponseWriter, r *http.Request) {
template := form.Template(values, err != nil)
if form.CSRF != nil {
if form.CSRF {
template += csrf.TemplateField(r)
}