diff --git a/go.mod b/go.mod index 3417efa..b89da52 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/pkg/errors v0.9.1 github.com/tkw1536/goprogram v0.1.1 golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741 - golang.org/x/net v0.0.0-20221012135044-0b7e1fb9d458 golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0 gorm.io/driver/mysql v1.3.6 gorm.io/gorm v1.23.10 diff --git a/go.sum b/go.sum index 031cb79..97801d3 100644 --- a/go.sum +++ b/go.sum @@ -27,8 +27,6 @@ github.com/tkw1536/goprogram v0.1.1 h1:gamK9OuRqoX2yQlA/nkgfVHHZWd/u2uUj6vJMYrYa github.com/tkw1536/goprogram v0.1.1/go.mod h1:Jqs0sTMzhrAGCX3JQrlEwQ0WRWQACCvuQQkaBDp65pE= golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741 h1:fGZugkZk2UgYBxtpKmvub51Yno1LJDeEsRp2xGD+0gY= golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= -golang.org/x/net v0.0.0-20221012135044-0b7e1fb9d458 h1:MgJ6t2zo8v0tbmLCueaCbF1RM+TtB0rs3Lv8DGtOIpY= -golang.org/x/net v0.0.0-20221012135044-0b7e1fb9d458/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0 h1:cu5kTvlzcw1Q5S9f5ip1/cpiB4nXvw1XYzFPGgzLUOY= golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/component/home/public.go b/internal/component/home/public.go index fce9157..50af427 100644 --- a/internal/component/home/public.go +++ b/internal/component/home/public.go @@ -15,12 +15,14 @@ import ( ) func (home *Home) updateInstances(ctx context.Context, io stream.IOStream) { - timex.SetInterval(ctx, home.RefreshInterval, func(t time.Time) { - io.Printf("[%s]: reloading instance list\n", t.Format(time.Stamp)) + go func() { + for t := range timex.TickContext(ctx, home.RefreshInterval) { + io.Printf("[%s]: reloading instance list\n", t.Format(time.Stamp)) - names, _ := home.instanceMap() - home.instanceNames.Set(names) - }) + names, _ := home.instanceMap() + home.instanceNames.Set(names) + } + }() } func (home *Home) instanceMap() (map[string]struct{}, error) { @@ -37,12 +39,14 @@ func (home *Home) instanceMap() (map[string]struct{}, error) { } func (home *Home) updateRender(ctx context.Context, io stream.IOStream) { - timex.SetInterval(ctx, home.RefreshInterval, func(t time.Time) { - io.Printf("[%s]: reloading home render\n", t.Format(time.Stamp)) + go func() { + for t := range timex.TickContext(ctx, home.RefreshInterval) { + io.Printf("[%s]: reloading home render\n", t.Format(time.Stamp)) - bytes, _ := home.homeRender() - home.homeBytes.Set(bytes) - }) + bytes, _ := home.homeRender() + home.homeBytes.Set(bytes) + } + }() } //go:embed "home.html" diff --git a/internal/component/home/redirect.go b/internal/component/home/redirect.go index 36882d6..38d1f0f 100644 --- a/internal/component/home/redirect.go +++ b/internal/component/home/redirect.go @@ -12,11 +12,14 @@ import ( ) func (home *Home) updateRedirect(ctx context.Context, io stream.IOStream) { - timex.SetInterval(ctx, home.RefreshInterval, func(t time.Time) { - io.Printf("[%s]: reloading overrides\n", t.Format(time.Stamp)) - redirect, _ := home.loadRedirect() - home.redirect.Set(&redirect) - }) + go func() { + for t := range timex.TickContext(ctx, home.RefreshInterval) { + io.Printf("[%s]: reloading overrides\n", t.Format(time.Stamp)) + + redirect, _ := home.loadRedirect() + home.redirect.Set(&redirect) + } + }() } func (home *Home) loadRedirect() (redirect Redirect, err error) { diff --git a/internal/component/resolver/prefixes.go b/internal/component/resolver/prefixes.go index 89b2e44..99b8422 100644 --- a/internal/component/resolver/prefixes.go +++ b/internal/component/resolver/prefixes.go @@ -10,11 +10,13 @@ import ( // updatePrefixes starts updating prefixes func (resolver *Resolver) updatePrefixes(io stream.IOStream, ctx context.Context) { - timex.SetInterval(ctx, resolver.RefreshInterval, func(t time.Time) { - io.Printf("[%s]: reloading prefixes\n", t.Format(time.Stamp)) - prefixes, _ := resolver.AllPrefixes() - resolver.prefixes.Set(prefixes) - }) + go func() { + for t := range timex.TickContext(ctx, resolver.RefreshInterval) { + io.Printf("[%s]: reloading prefixes\n", t.Format(time.Stamp)) + prefixes, _ := resolver.AllPrefixes() + resolver.prefixes.Set(prefixes) + } + }() } // AllPrefixes returns a list of all prefixes from the server. diff --git a/internal/component/sql/connect.go b/internal/component/sql/connect.go index aafe19a..7f99ae1 100644 --- a/internal/component/sql/connect.go +++ b/internal/component/sql/connect.go @@ -6,16 +6,16 @@ import ( "fmt" "net" "sync/atomic" + "time" mysqldriver "github.com/go-sql-driver/mysql" - "github.com/tkw1536/goprogram/stream" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" "github.com/FAU-CDI/wisski-distillery/internal/models" - "github.com/FAU-CDI/wisski-distillery/pkg/wait" + "github.com/FAU-CDI/wisski-distillery/pkg/timex" ) // @@ -42,10 +42,10 @@ func (sql *SQL) Exec(query string, args ...interface{}) error { // WaitExec waits for the query interface to be able to connect to the database func (sql *SQL) WaitExec() error { - return wait.Wait(func() bool { + return timex.TickUntilFunc(func(time.Time) bool { err := sql.Exec("select 1;") return err == nil - }, sql.PollInterval, sql.PollContext) + }, sql.PollContext, sql.PollInterval) } // @@ -91,12 +91,10 @@ func (sql *SQL) QueryTable(silent bool, table string) (*gorm.DB, error) { // WaitQueryTable waits for a connection to succeed via QueryTable func (sql *SQL) WaitQueryTable() error { // TODO: Establish a convention on when to wait for this! - n := stream.FromNil() - return wait.Wait(func() bool { + return timex.TickUntilFunc(func(time.Time) bool { _, err := sql.QueryTable(true, models.InstanceTable) - n.EPrintf("[SQL.WaitQueryTable]: %s\n", err) return err == nil - }, sql.PollInterval, sql.PollContext) + }, sql.PollContext, sql.PollInterval) } // diff --git a/internal/component/sql/update.go b/internal/component/sql/update.go index 2258d65..63fe7c9 100644 --- a/internal/component/sql/update.go +++ b/internal/component/sql/update.go @@ -3,11 +3,12 @@ package sql import ( "errors" "fmt" + "time" "github.com/FAU-CDI/wisski-distillery/internal/models" "github.com/FAU-CDI/wisski-distillery/pkg/logging" "github.com/FAU-CDI/wisski-distillery/pkg/sqle" - "github.com/FAU-CDI/wisski-distillery/pkg/wait" + "github.com/FAU-CDI/wisski-distillery/pkg/timex" "github.com/tkw1536/goprogram/exit" "github.com/tkw1536/goprogram/stream" ) @@ -22,11 +23,10 @@ func (sql *SQL) Shell(io stream.IOStream, argv ...string) (int, error) { // unsafeWaitShell waits for a connection via the database shell to succeed func (sql *SQL) unsafeWaitShell() error { n := stream.FromNil() - return wait.Wait(func() bool { + return timex.TickUntilFunc(func(time.Time) bool { code, err := sql.Shell(n, "-e", "select 1;") - n.EPrintf("[SQL.unsafeWaitShell]: %d %s\n", code, err) return err == nil && code == 0 - }, sql.PollInterval, sql.PollContext) + }, sql.PollContext, sql.PollInterval) } // unsafeQuery shell executes a raw database query. diff --git a/internal/component/triplestore/database.go b/internal/component/triplestore/database.go index 4ee1590..9423ba1 100644 --- a/internal/component/triplestore/database.go +++ b/internal/component/triplestore/database.go @@ -6,8 +6,9 @@ import ( "io" "mime/multipart" "net/http" + "time" - "github.com/FAU-CDI/wisski-distillery/pkg/wait" + "github.com/FAU-CDI/wisski-distillery/pkg/timex" "github.com/pkg/errors" "github.com/tkw1536/goprogram/stream" ) @@ -87,7 +88,7 @@ func (ts Triplestore) OpenRaw(method, url string, body interface{}, bodyName str // This is achieved using a polling strategy. func (ts Triplestore) Wait() error { n := stream.FromNil() - return wait.Wait(func() bool { + return timex.TickUntilFunc(func(time.Time) bool { res, err := ts.OpenRaw("GET", "/rest/repositories", nil, "", "") n.EPrintf("[Triplestore.Wait]: %s\n", err) if err != nil { @@ -95,7 +96,7 @@ func (ts Triplestore) Wait() error { } defer res.Body.Close() return true - }, ts.PollInterval, ts.PollContext) + }, ts.PollContext, ts.PollInterval) } // TriplestorePurgeUser deletes the specified user from the triplestore diff --git a/pkg/timex/timex.go b/pkg/timex/timex.go index c8562c8..88d3ea1 100644 --- a/pkg/timex/timex.go +++ b/pkg/timex/timex.go @@ -1,3 +1,4 @@ +// Package timex provides Interval and Wait package timex import ( @@ -5,21 +6,49 @@ import ( "time" ) -// SetInterval invokes f with the current time and then spawns a new goroutine that runs f every d, until context is closed. -func SetInterval(ctx context.Context, d time.Duration, f func(t time.Time)) { - f(time.Now()) +// TickContext is like [time.Tick], but closes the returned channel once the context closes. +// As such it can be recovered by the garbage collector; see [time.TickContext]. +// +// Unlike [time.Tick], immediatly send the current time on the given channel. +func TickContext(c context.Context, d time.Duration) <-chan time.Time { + if d < 0 { + return nil + } + timer := make(chan time.Time, 1) + timer <- time.Now() go func() { t := time.NewTicker(d) defer t.Stop() + defer close(timer) for { select { case tick := <-t.C: - f(tick) - case <-ctx.Done(): + timer <- tick + case <-c.Done(): return } } + }() + return timer +} + +// TickUntilFunc invokes f every d until either context is closed, or f returns true. +// f is invoked once immediatly when the timer starts. +// +// TickUntilFunc blocks until f is no longer invoked. +// +// Returns the error of the context (if any). +func TickUntilFunc(f func(t time.Time) bool, c context.Context, d time.Duration) error { + context, cancel := context.WithCancel(c) + defer cancel() + + for t := range TickContext(context, d) { + if f(t) { + break + } + } + return c.Err() } diff --git a/pkg/wait/wait.go b/pkg/wait/wait.go deleted file mode 100644 index 5763703..0000000 --- a/pkg/wait/wait.go +++ /dev/null @@ -1,31 +0,0 @@ -package wait - -import ( - "context" - "time" -) - -// Wait repeatedly invokes f, until it returns true or the context is closed. -// The invocation interval is determined by interval. -func Wait(f func() bool, interval time.Duration, context context.Context) error { - // create a new timer - timer := time.NewTimer(interval) - if !timer.Stop() { - <-timer.C - } - defer timer.Stop() - - for { - if f() { - return nil - } - - // reset the timer, and wait for it again! - timer.Reset(interval) - select { - case <-timer.C: - case <-context.Done(): - return context.Err() - } - } -}