diff --git a/internal/dis/component/auth/auth.go b/internal/dis/component/auth/auth.go index 5325bb9..1ef2388 100644 --- a/internal/dis/component/auth/auth.go +++ b/internal/dis/component/auth/auth.go @@ -30,8 +30,8 @@ var ( func (auth *Auth) Routes() component.Routes { return component.Routes{ - Paths: []string{"/auth/"}, - CSRF: true, + Prefix: "/auth/", + CSRF: true, } } diff --git a/internal/dis/component/auth/next/next.go b/internal/dis/component/auth/next/next.go index 515a877..9b00040 100644 --- a/internal/dis/component/auth/next/next.go +++ b/internal/dis/component/auth/next/next.go @@ -29,7 +29,7 @@ var ( func (next *Next) Routes() component.Routes { return component.Routes{ - Paths: []string{"/next/"}, + Prefix: "/next/", Decorator: next.Dependencies.Auth.Require(auth.User), } } diff --git a/internal/dis/component/auth/panel/panel.go b/internal/dis/component/auth/panel/panel.go index 5411795..e98cd1c 100644 --- a/internal/dis/component/auth/panel/panel.go +++ b/internal/dis/component/auth/panel/panel.go @@ -32,7 +32,7 @@ var ( func (panel *UserPanel) Routes() component.Routes { return component.Routes{ - Paths: []string{"/user/"}, + Prefix: "/user/", CSRF: true, Decorator: panel.Dependencies.Auth.Require(nil), } diff --git a/internal/dis/component/control/admin/admin.go b/internal/dis/component/control/admin/admin.go index 2b81977..f10a8f4 100644 --- a/internal/dis/component/control/admin/admin.go +++ b/internal/dis/component/control/admin/admin.go @@ -44,7 +44,7 @@ var ( func (admin *Admin) Routes() component.Routes { return component.Routes{ - Paths: []string{"/admin/"}, + Prefix: "/admin/", CSRF: true, Decorator: admin.Dependencies.Auth.Require(auth.Admin), } diff --git a/internal/dis/component/control/home/home.go b/internal/dis/component/control/home/home.go index 02d7044..6713d89 100644 --- a/internal/dis/component/control/home/home.go +++ b/internal/dis/component/control/home/home.go @@ -29,8 +29,9 @@ var ( func (*Home) Routes() component.Routes { return component.Routes{ - Paths: []string{"/"}, - CSRF: false, + Prefix: "/", + MatchAllDomains: true, + CSRF: false, } } diff --git a/internal/dis/component/control/legal/legal.go b/internal/dis/component/control/legal/legal.go index 4dfb4c2..0235e47 100644 --- a/internal/dis/component/control/legal/legal.go +++ b/internal/dis/component/control/legal/legal.go @@ -32,8 +32,10 @@ var legalTemplate = static.AssetsLegal.MustParseShared("legal.html", legalTempla func (legal *Legal) Routes() component.Routes { return component.Routes{ - Paths: []string{"/legal/"}, - CSRF: false, + Prefix: "/legal/", + Exact: true, + + CSRF: false, } } diff --git a/internal/dis/component/control/news/news.go b/internal/dis/component/control/news/news.go index cf4d23f..755604d 100644 --- a/internal/dis/component/control/news/news.go +++ b/internal/dis/component/control/news/news.go @@ -33,8 +33,9 @@ var ( func (*News) Routes() component.Routes { return component.Routes{ - Paths: []string{"/news/"}, - CSRF: false, + Prefix: "/news/", + Exact: true, + CSRF: false, } } @@ -126,10 +127,6 @@ func (news *News) HandleRoute(ctx context.Context, path string) (http.Handler, e return httpx.HTMLHandler[newsContext]{ Handler: func(r *http.Request) (nc newsContext, err error) { - if strings.TrimSuffix(r.URL.Path, "/") != strings.TrimSuffix(path, "/") { - return nc, httpx.ErrNotFound - } - news.Dependencies.Custom.Update(&nc, r) nc.Items, err = items, itemsErr diff --git a/internal/dis/component/control/server.go b/internal/dis/component/control/server.go index 05eeb56..37e2d93 100644 --- a/internal/dis/component/control/server.go +++ b/internal/dis/component/control/server.go @@ -2,10 +2,14 @@ package control import ( "context" + "fmt" "io" "net/http" + "github.com/FAU-CDI/wisski-distillery/internal/dis/component" "github.com/FAU-CDI/wisski-distillery/pkg/cancel" + "github.com/FAU-CDI/wisski-distillery/pkg/httpx" + "github.com/FAU-CDI/wisski-distillery/pkg/mux" "github.com/gorilla/csrf" "github.com/rs/zerolog" ) @@ -15,8 +19,25 @@ import ( // // Logging messages are directed to progress func (control *Control) Server(ctx context.Context, progress io.Writer) (http.Handler, error) { - // create a new mux - mux := http.NewServeMux() + logger := zerolog.Ctx(ctx) + + var mux mux.Mux[component.RouteContext] + mux.Context = func(r *http.Request) component.RouteContext { + slug, ok := control.Still.Config.SlugFromHost(r.Host) + return component.RouteContext{ + DefaultDomain: slug == "" && ok, + } + } + mux.Panic = func(panic any, w http.ResponseWriter, r *http.Request) { + // log the panic + logger.Error(). + Str("panic", fmt.Sprint(panic)). + Str("path", r.URL.Path). + Msg("panic serving handler") + + // and send an internal server error + httpx.TextInterceptor.Fallback.ServeHTTP(w, r) + } // create a csrf protector csrfProtector := control.CSRF() @@ -24,14 +45,35 @@ func (control *Control) Server(ctx context.Context, progress io.Writer) (http.Ha // iterate over all the handler for _, s := range control.Dependencies.Routeables { routes := s.Routes() - zerolog.Ctx(ctx).Info().Str("component", s.Name()).Strs("paths", routes.Paths).Bool("csrf", routes.CSRF).Bool("decorator", routes.Decorator != nil).Msg("mounting route") + zerolog.Ctx(ctx).Info(). + Str("Name", s.Name()). + Str("Prefix", routes.Prefix). + Strs("Aliases", routes.Aliases). + Bool("Exact", routes.Exact). + Bool("CSRF", routes.CSRF). + Bool("Decorator", routes.Decorator != nil). + Bool("MatchAllDomains", routes.MatchAllDomains). + Msg("mounting route") - for _, path := range routes.Paths { - handler, err := s.HandleRoute(ctx, path) - if err != nil { - return nil, err - } - mux.Handle(path, routes.Decorate(handler, csrfProtector)) + // call the handler for the route + handler, err := s.HandleRoute(ctx, routes.Prefix) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("Component", s.Name()). + Str("Prefix", routes.Prefix). + Msg("error mounting route") + continue + } + + // decorate the handler + handler = routes.Decorate(handler, csrfProtector) + + // determine the predicate + predicate := routes.Predicate(mux.ContextOf) + + // and add all the prefixes + for _, prefix := range append([]string{routes.Prefix}, routes.Aliases...) { + mux.Add(prefix, predicate, routes.Exact, handler) } } diff --git a/internal/dis/component/control/static/static.go b/internal/dis/component/control/static/static.go index d6479ed..10f9014 100644 --- a/internal/dis/component/control/static/static.go +++ b/internal/dis/component/control/static/static.go @@ -20,8 +20,9 @@ var ( func (*Static) Routes() component.Routes { return component.Routes{ - Paths: []string{"/static/"}, - CSRF: false, + Prefix: "/static/", + + CSRF: false, } } diff --git a/internal/dis/component/resolver/resolver.go b/internal/dis/component/resolver/resolver.go index e4cccc7..f07f572 100644 --- a/internal/dis/component/resolver/resolver.go +++ b/internal/dis/component/resolver/resolver.go @@ -29,8 +29,6 @@ type Resolver struct { prefixes lazy.Lazy[map[string]string] // cached prefixes (from the server) RefreshInterval time.Duration - - handler lazy.Lazy[wdresolve.ResolveHandler] // handler } var ( @@ -40,8 +38,9 @@ var ( func (resolver *Resolver) Routes() component.Routes { return component.Routes{ - Paths: []string{"/go/", "/wisski/get/"}, - CSRF: false, + Prefix: "/wisski/get/", + Aliases: []string{"/go/"}, + CSRF: false, } } @@ -59,42 +58,42 @@ func (resolver *Resolver) HandleRoute(ctx context.Context, route string) (http.H logger := zerolog.Ctx(ctx) + var p wdresolve.ResolveHandler var err error - return resolver.handler.Get(func() (p wdresolve.ResolveHandler) { - p.HandleIndex = func(context wdresolve.IndexContext, w http.ResponseWriter, r *http.Request) { - ctx := resolverContext{ - IndexContext: context, - } - resolver.Dependencies.Custom.Update(&ctx, r) - httpx.WriteHTML(ctx, nil, resolverTemplate, "", w, r) + p.HandleIndex = func(context wdresolve.IndexContext, w http.ResponseWriter, r *http.Request) { + ctx := resolverContext{ + IndexContext: context, } - p.TrustXForwardedProto = true + resolver.Dependencies.Custom.Update(&ctx, r) - fallback := &resolvers.Regexp{ - Data: map[string]string{}, - } + httpx.WriteHTML(ctx, nil, resolverTemplate, "", w, r) + } + p.TrustXForwardedProto = true - // handle the default domain name! - domainName := resolver.Config.DefaultDomain - if domainName != "" { - fallback.Data[fmt.Sprintf("^https?://(.*)\\.%s", regexp.QuoteMeta(domainName))] = fmt.Sprintf("https://$1.%s", domainName) - logger.Info().Str("name", domainName).Msg("registering default domain") - } + fallback := &resolvers.Regexp{ + Data: map[string]string{}, + } - // handle the extra domains! - for _, domain := range resolver.Config.SelfExtraDomains { - fallback.Data[fmt.Sprintf("^https?://(.*)\\.%s", regexp.QuoteMeta(domain))] = fmt.Sprintf("https://$1.%s", domainName) - logger.Info().Str("name", domainName).Msg("registering legacy domain") - } + // handle the default domain name! + domainName := resolver.Config.DefaultDomain + if domainName != "" { + fallback.Data[fmt.Sprintf("^https?://(.*)\\.%s", regexp.QuoteMeta(domainName))] = fmt.Sprintf("https://$1.%s", domainName) + logger.Info().Str("name", domainName).Msg("registering default domain") + } - // resolve the prefixes - p.Resolver = resolvers.InOrder{ - resolver, - fallback, - } - return p - }), err + // handle the extra domains! + for _, domain := range resolver.Config.SelfExtraDomains { + fallback.Data[fmt.Sprintf("^https?://(.*)\\.%s", regexp.QuoteMeta(domain))] = fmt.Sprintf("https://$1.%s", domainName) + logger.Info().Str("name", domainName).Msg("registering legacy domain") + } + + // resolve the prefixes + p.Resolver = resolvers.InOrder{ + resolver, + fallback, + } + return p, err } func (resolver *Resolver) Target(uri string) string { diff --git a/internal/dis/component/server.go b/internal/dis/component/server.go index 7726afb..5d0dab7 100644 --- a/internal/dis/component/server.go +++ b/internal/dis/component/server.go @@ -3,6 +3,8 @@ package component import ( "context" "net/http" + + "github.com/FAU-CDI/wisski-distillery/pkg/mux" ) // Routeable is a component that is servable @@ -18,9 +20,18 @@ type Routeable interface { // Routes represents information about a single Routeable type Routes struct { - // Paths are the paths handled by this routeable. - // Each path is passed to HandleRoute() individually. - Paths []string + // Prefix is the prefix this pattern handles + Prefix string + + // MatchAllDomains indicates that all domains, even the non-default domain, should be matched + MatchAllDomains bool + + // Exact indicates that only the exact prefix, as opposed to any sub-paths, are matched. + // Trailing '/'s are automatically trimmed, even with an exact match. + Exact bool + + // Aliases are the additional prefixes this route handles. + Aliases []string // CSRF indicates if this route should be protected by CSRF. // CSRF protection is applied prior to any custom decorator being called. @@ -31,6 +42,22 @@ type Routes struct { Decorator func(http.Handler) http.Handler } +type RouteContext struct { + DefaultDomain bool +} + +// Predicate returns the predicate corresponding to the given route +func (routes Routes) Predicate(context func(*http.Request) RouteContext) mux.Predicate { + if routes.MatchAllDomains { + return nil + } + + // match only the default domain + return func(r *http.Request) bool { + return context(r).DefaultDomain + } +} + // Decorate decorates the provided handler with the options specified in this handler. func (routes Routes) Decorate(handler http.Handler, csrf func(http.Handler) http.Handler) http.Handler { if routes.CSRF && csrf != nil { diff --git a/pkg/mux/mux.go b/pkg/mux/mux.go new file mode 100644 index 0000000..12f9b1b --- /dev/null +++ b/pkg/mux/mux.go @@ -0,0 +1,154 @@ +// Package mux provides mux +package mux + +import ( + "context" + "net/http" +) + +// Mux represents a mux that can handle different requests +type Mux[C any] struct { + prefixes map[string][]handler + exacts map[string][]handler + + Context func(r *http.Request) C // called to set context on the given request + + Panic func(panic any, w http.ResponseWriter, r *http.Request) // called on panic + NotFound http.Handler // optional handler to be called in case of a not found +} + +type contextKey struct{} + +var theContextKey = contextKey{} + +type handler struct { + Predicate Predicate + http.Handler +} + +func (mux *Mux[T]) Prepare(r *http.Request) *http.Request { + if mux == nil || mux.Context == nil { + return r + } + + ctx := context.WithValue(r.Context(), theContextKey, mux.Context(r)) + return r.WithContext(ctx) +} + +func (mux *Mux[T]) ContextOf(r *http.Request) (t T) { + value, ok := r.Context().Value(theContextKey).(T) + if !ok { + return t + } + return value +} + +// Add adds a handler for the given path +func (mux *Mux[T]) Add(path string, predicate Predicate, exact bool, h http.Handler) { + if mux.exacts == nil { + mux.exacts = make(map[string][]handler) + } + if mux.prefixes == nil { + mux.prefixes = make(map[string][]handler) + } + + mPath := normalizePath(path) + mHandler := handler{Predicate: predicate, Handler: h} + if exact { + mux.exacts[mPath] = append(mux.exacts[mPath], mHandler) + } else { + mux.prefixes[mPath] = append(mux.prefixes[mPath], mHandler) + } +} + +// Match returns the handler to be applied for the given request. +func (mux *Mux[T]) Match(r *http.Request, prepare bool) (http.Handler, bool) { + if mux == nil { + return nil, false + } + + if prepare { + r = mux.Prepare(r) + } + + candidate := normalizePath(r.URL.Path) + + // match the exact path first + for _, h := range mux.exacts[candidate] { + if h.Predicate.Call(r) { + return h.Handler, true + } + } + + // iterate over path segment candidates + for { + // check the current candidate + for _, h := range mux.prefixes[candidate] { + if h.Predicate.Call(r) { + return h.Handler, true + } + } + + // if the candidate is the root url, we can bail out now + if len(candidate) == 0 || candidate == "/" { + return nil, false + } + + // move to the parent segment + candidate = parentSegment(candidate) + } + +} + +func (mux *Mux[T]) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // handle panics with the panic handler + defer func() { + caught := recover() + if caught == nil { + return + } + + if mux == nil || mux.Panic == nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // silently ignore any panic()s in the panic handler + defer func() { + recover() + }() + + // call the panic handler + mux.Panic(caught, w, r) + }() + + // prepare the request + r = mux.Prepare(r) + + // find the right handler + // or go into 404 mode + handler, ok := mux.Match(r, false) + if !ok { + if mux == nil || mux.NotFound == nil { + http.NotFound(w, r) + return + } + mux.NotFound.ServeHTTP(w, r) + return + } + + // call the actual handling + handler.ServeHTTP(w, r) +} + +// Predicate represents a matching predicate for a given request. +// The nil predicate always matches +type Predicate func(r *http.Request) bool + +// Call checks if this predicate matches the given request. +func (p Predicate) Call(r *http.Request) bool { + if p == nil { + return true + } + return p(r) +} diff --git a/pkg/mux/path.go b/pkg/mux/path.go new file mode 100644 index 0000000..942e8e0 --- /dev/null +++ b/pkg/mux/path.go @@ -0,0 +1,28 @@ +package mux + +import ( + "path" +) + +// normalizePath normalizes the provided path. +// It ensures that there is both a leading and trailing slash. +func normalizePath(value string) string { + value = path.Clean(value) + if value != "/" { + value = value + "/" + } + return value +} + +// parentSegment returns the parent segment of the provided path +// it assumes that normalizePath has been called on value. +func parentSegment(value string) string { + if value == "" || value == "/" { + return "/" + } + parent := path.Dir(value[:len(value)-1]) + if parent != "/" { + parent = parent + "/" + } + return parent +}