Refactor CSRF protection
This commit is contained in:
parent
59b565ae19
commit
eb17dbe33f
8 changed files with 20 additions and 45 deletions
|
|
@ -6,7 +6,6 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/FAU-CDI/wisski-distillery/pkg/lazy"
|
||||
"github.com/gorilla/csrf"
|
||||
)
|
||||
|
||||
|
|
@ -25,10 +24,8 @@ type Form[D any] struct {
|
|||
// FieldTemplate may be nil; in which case [DefaultFieldTemplate] is used.
|
||||
FieldTemplate *template.Template
|
||||
|
||||
// CSRF holds an optional reference to a CSRF.Protect call.
|
||||
// It must be set before any other functions on this Form are called, and may not be changed.
|
||||
CSRF func(http.Handler) http.Handler
|
||||
csrf lazy.Lazy[http.Handler]
|
||||
// CSRF indicates if a CSRF field should be added automatically
|
||||
CSRF bool
|
||||
|
||||
// SkipForm, if non-nil, is called on every get request to determine if form parsing should be skipped entirely.
|
||||
// If skip is true, RenderSuccess is directly called with the given values map.
|
||||
|
|
@ -100,20 +97,8 @@ func (form *Form[D]) Values(r *http.Request) (v map[string]string, d D, err erro
|
|||
return values, d, nil
|
||||
}
|
||||
|
||||
// ServeHTTP implements [http.Handler].
|
||||
// If the form contains a csrf reference, then this is invoked also.
|
||||
// ServeHTTP implements [http.Handler] and serves the form
|
||||
func (form *Form[D]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
handler := form.csrf.Get(func() (handler http.Handler) {
|
||||
handler = http.HandlerFunc(form.serveHTTP)
|
||||
if form.CSRF != nil {
|
||||
handler = form.CSRF(handler)
|
||||
}
|
||||
return
|
||||
})
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (form *Form[D]) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
default:
|
||||
TextInterceptor.Intercept(w, r, ErrMethodNotAllowed)
|
||||
|
|
@ -139,7 +124,7 @@ func (form *Form[D]) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
// renderForm renders the form into a request
|
||||
func (form *Form[D]) renderForm(err error, values map[string]string, w http.ResponseWriter, r *http.Request) {
|
||||
template := form.Template(values, err != nil)
|
||||
if form.CSRF != nil {
|
||||
if form.CSRF {
|
||||
template += csrf.TemplateField(r)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue