Use environment DialContext and Listen everywhere

This commit is contained in:
Tom Wiesing 2022-09-19 12:42:33 +02:00
parent f19619ef9f
commit b0d3c686ba
No known key found for this signature in database
11 changed files with 64 additions and 30 deletions

View file

@ -1,9 +1,14 @@
package sql
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync/atomic"
mysqldriver "github.com/go-sql-driver/mysql"
"github.com/FAU-CDI/wisski-distillery/internal/bookkeeping"
"github.com/FAU-CDI/wisski-distillery/pkg/logging"
@ -15,13 +20,29 @@ import (
"gorm.io/gorm/logger"
)
var proxyNameCounter uint64
// registerDialingProxy registers a new custom network protocol with the underlying sql driver.
// The new protocol will call dialer with the provided network argument.
// The name of the new protocol is returned.
func registerDialingProxy(network string, dialer func(context context.Context, network, address string) (net.Conn, error)) (name string) {
name = fmt.Sprintf("sql-proxy-%d", atomic.AddUint64(&proxyNameCounter, 1))
mysqldriver.RegisterDialContext(name, func(ctx context.Context, addr string) (net.Conn, error) {
return dialer(ctx, network, addr)
})
return
}
// sqlOpen opens a new sql connection to the provided database using the administrative credentials
func (sql SQL) openDatabase(database string, config *gorm.Config) (*gorm.DB, error) {
func (sql *SQL) openDatabase(database string, config *gorm.Config) (*gorm.DB, error) {
network := sql.sqlNetwork.Get(func() string {
return registerDialingProxy("tcp", sql.Core.Environment.DialContext)
})
cfg := mysql.Config{
DSN: fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=True&loc=Local", sql.Config.MysqlAdminUser, sql.Config.MysqlAdminPassword, sql.ServerURL, database),
DriverName: "mysql",
DSN: fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8&parseTime=True&loc=Local", sql.Config.MysqlAdminUser, sql.Config.MysqlAdminPassword, network, sql.ServerURL, database),
DefaultStringSize: 256,
}
// TODO: Use sql.Core.Environment.Dial
db, err := gorm.Open(mysql.New(cfg), config)
if err != nil {
@ -38,7 +59,7 @@ func (sql SQL) openDatabase(database string, config *gorm.Config) (*gorm.DB, err
}
// OpenBookkeeping opens a connection to the bookkeeping database
func (sql SQL) OpenBookkeeping(silent bool) (*gorm.DB, error) {
func (sql *SQL) OpenBookkeeping(silent bool) (*gorm.DB, error) {
config := &gorm.Config{}
if silent {
@ -75,12 +96,12 @@ func (sql SQL) Snapshot(io stream.IOStream, dest io.Writer, database string) err
}
// OpenShell executes a mysql shell command
func (sql SQL) OpenShell(io stream.IOStream, argv ...string) (int, error) {
func (sql *SQL) OpenShell(io stream.IOStream, argv ...string) (int, error) {
return sql.Stack(sql.Environment).Exec(io, "sql", "mysql", argv...)
}
// WaitShell waits for the sql database to be reachable via a docker-compose shell
func (sql SQL) WaitShell() error {
func (sql *SQL) WaitShell() error {
n := stream.FromNil()
return wait.Wait(func() bool {
code, err := sql.OpenShell(n, "-e", "show databases;")
@ -89,7 +110,7 @@ func (sql SQL) WaitShell() error {
}
// Wait waits for a connection to the bookkeeping table to suceed
func (sql SQL) Wait() error {
func (sql *SQL) Wait() error {
return wait.Wait(func() bool {
_, err := sql.OpenBookkeeping(true)
return err == nil
@ -98,14 +119,14 @@ func (sql SQL) Wait() error {
var errInvalidDatabaseName = errors.New("SQLProvision: Invalid database name")
func (sql SQL) Query(query string, args ...interface{}) bool {
func (sql *SQL) Query(query string, args ...interface{}) bool {
raw := sqle.Format(query, args...)
code, err := sql.OpenShell(stream.FromNil(), "-e", raw)
return err == nil && code == 0
}
// SQLProvision provisions a new sql database and user
func (sql SQL) Provision(name, user, password string) error {
func (sql *SQL) Provision(name, user, password string) error {
// wait for the database
if err := sql.WaitShell(); err != nil {
return err
@ -128,7 +149,7 @@ func (sql SQL) Provision(name, user, password string) error {
var errSQLPurgeUser = errors.New("unable to delete user")
// SQLPurgeUser deletes the specified user from the database
func (sql SQL) PurgeUser(user string) error {
func (sql *SQL) PurgeUser(user string) error {
if !sql.Query("DROP USER IF EXISTS ?@`%`; FLUSH PRIVILEGES; ", user) {
return errSQLPurgeUser
}
@ -139,7 +160,7 @@ func (sql SQL) PurgeUser(user string) error {
var errSQLPurgeDB = errors.New("unable to drop database")
// SQLPurgeDatabase deletes the specified db from the database
func (sql SQL) PurgeDatabase(db string) error {
func (sql *SQL) PurgeDatabase(db string) error {
if !sqle.IsSafeDatabaseName(db) {
return errSQLPurgeDB
}
@ -154,7 +175,7 @@ var errSQLUnsafeDatabaseName = errors.New("bookkeeping database has an unsafe na
var errSQLUnableToCreate = errors.New("unable to create bookkeeping database")
// Bootstrap bootstraps the SQL database, and makes sure that the bookkeeping table is up-to-date
func (sql SQL) Bootstrap(io stream.IOStream) error {
func (sql *SQL) Bootstrap(io stream.IOStream) error {
if err := sql.WaitShell(); err != nil {
return err
}

View file

@ -7,6 +7,7 @@ import (
"github.com/FAU-CDI/wisski-distillery/internal/component"
"github.com/FAU-CDI/wisski-distillery/pkg/environment"
"github.com/FAU-CDI/wisski-distillery/pkg/lazy"
)
type SQL struct {
@ -16,6 +17,8 @@ type SQL struct {
PollContext context.Context // context to abort polling with
PollInterval time.Duration // duration to wait for during wait
sqlNetwork lazy.Lazy[string]
}
func (SQL) Name() string {
@ -25,7 +28,7 @@ func (SQL) Name() string {
//go:embed all:sql
var resources embed.FS
func (ssh SQL) Stack(env environment.Environment) component.StackWithResources {
func (ssh *SQL) Stack(env environment.Environment) component.StackWithResources {
return ssh.ComponentBase.MakeStack(env, component.StackWithResources{
Resources: resources,
ContextPath: "sql",