Use environment DialContext and Listen everywhere
This commit is contained in:
parent
f19619ef9f
commit
b0d3c686ba
11 changed files with 64 additions and 30 deletions
|
|
@ -1,9 +1,14 @@
|
|||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
|
||||
"github.com/FAU-CDI/wisski-distillery/internal/bookkeeping"
|
||||
"github.com/FAU-CDI/wisski-distillery/pkg/logging"
|
||||
|
|
@ -15,13 +20,29 @@ import (
|
|||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var proxyNameCounter uint64
|
||||
|
||||
// registerDialingProxy registers a new custom network protocol with the underlying sql driver.
|
||||
// The new protocol will call dialer with the provided network argument.
|
||||
// The name of the new protocol is returned.
|
||||
func registerDialingProxy(network string, dialer func(context context.Context, network, address string) (net.Conn, error)) (name string) {
|
||||
name = fmt.Sprintf("sql-proxy-%d", atomic.AddUint64(&proxyNameCounter, 1))
|
||||
mysqldriver.RegisterDialContext(name, func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return dialer(ctx, network, addr)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// sqlOpen opens a new sql connection to the provided database using the administrative credentials
|
||||
func (sql SQL) openDatabase(database string, config *gorm.Config) (*gorm.DB, error) {
|
||||
func (sql *SQL) openDatabase(database string, config *gorm.Config) (*gorm.DB, error) {
|
||||
network := sql.sqlNetwork.Get(func() string {
|
||||
return registerDialingProxy("tcp", sql.Core.Environment.DialContext)
|
||||
})
|
||||
cfg := mysql.Config{
|
||||
DSN: fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=True&loc=Local", sql.Config.MysqlAdminUser, sql.Config.MysqlAdminPassword, sql.ServerURL, database),
|
||||
DriverName: "mysql",
|
||||
DSN: fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8&parseTime=True&loc=Local", sql.Config.MysqlAdminUser, sql.Config.MysqlAdminPassword, network, sql.ServerURL, database),
|
||||
DefaultStringSize: 256,
|
||||
}
|
||||
// TODO: Use sql.Core.Environment.Dial
|
||||
|
||||
db, err := gorm.Open(mysql.New(cfg), config)
|
||||
if err != nil {
|
||||
|
|
@ -38,7 +59,7 @@ func (sql SQL) openDatabase(database string, config *gorm.Config) (*gorm.DB, err
|
|||
}
|
||||
|
||||
// OpenBookkeeping opens a connection to the bookkeeping database
|
||||
func (sql SQL) OpenBookkeeping(silent bool) (*gorm.DB, error) {
|
||||
func (sql *SQL) OpenBookkeeping(silent bool) (*gorm.DB, error) {
|
||||
|
||||
config := &gorm.Config{}
|
||||
if silent {
|
||||
|
|
@ -75,12 +96,12 @@ func (sql SQL) Snapshot(io stream.IOStream, dest io.Writer, database string) err
|
|||
}
|
||||
|
||||
// OpenShell executes a mysql shell command
|
||||
func (sql SQL) OpenShell(io stream.IOStream, argv ...string) (int, error) {
|
||||
func (sql *SQL) OpenShell(io stream.IOStream, argv ...string) (int, error) {
|
||||
return sql.Stack(sql.Environment).Exec(io, "sql", "mysql", argv...)
|
||||
}
|
||||
|
||||
// WaitShell waits for the sql database to be reachable via a docker-compose shell
|
||||
func (sql SQL) WaitShell() error {
|
||||
func (sql *SQL) WaitShell() error {
|
||||
n := stream.FromNil()
|
||||
return wait.Wait(func() bool {
|
||||
code, err := sql.OpenShell(n, "-e", "show databases;")
|
||||
|
|
@ -89,7 +110,7 @@ func (sql SQL) WaitShell() error {
|
|||
}
|
||||
|
||||
// Wait waits for a connection to the bookkeeping table to suceed
|
||||
func (sql SQL) Wait() error {
|
||||
func (sql *SQL) Wait() error {
|
||||
return wait.Wait(func() bool {
|
||||
_, err := sql.OpenBookkeeping(true)
|
||||
return err == nil
|
||||
|
|
@ -98,14 +119,14 @@ func (sql SQL) Wait() error {
|
|||
|
||||
var errInvalidDatabaseName = errors.New("SQLProvision: Invalid database name")
|
||||
|
||||
func (sql SQL) Query(query string, args ...interface{}) bool {
|
||||
func (sql *SQL) Query(query string, args ...interface{}) bool {
|
||||
raw := sqle.Format(query, args...)
|
||||
code, err := sql.OpenShell(stream.FromNil(), "-e", raw)
|
||||
return err == nil && code == 0
|
||||
}
|
||||
|
||||
// SQLProvision provisions a new sql database and user
|
||||
func (sql SQL) Provision(name, user, password string) error {
|
||||
func (sql *SQL) Provision(name, user, password string) error {
|
||||
// wait for the database
|
||||
if err := sql.WaitShell(); err != nil {
|
||||
return err
|
||||
|
|
@ -128,7 +149,7 @@ func (sql SQL) Provision(name, user, password string) error {
|
|||
var errSQLPurgeUser = errors.New("unable to delete user")
|
||||
|
||||
// SQLPurgeUser deletes the specified user from the database
|
||||
func (sql SQL) PurgeUser(user string) error {
|
||||
func (sql *SQL) PurgeUser(user string) error {
|
||||
if !sql.Query("DROP USER IF EXISTS ?@`%`; FLUSH PRIVILEGES; ", user) {
|
||||
return errSQLPurgeUser
|
||||
}
|
||||
|
|
@ -139,7 +160,7 @@ func (sql SQL) PurgeUser(user string) error {
|
|||
var errSQLPurgeDB = errors.New("unable to drop database")
|
||||
|
||||
// SQLPurgeDatabase deletes the specified db from the database
|
||||
func (sql SQL) PurgeDatabase(db string) error {
|
||||
func (sql *SQL) PurgeDatabase(db string) error {
|
||||
if !sqle.IsSafeDatabaseName(db) {
|
||||
return errSQLPurgeDB
|
||||
}
|
||||
|
|
@ -154,7 +175,7 @@ var errSQLUnsafeDatabaseName = errors.New("bookkeeping database has an unsafe na
|
|||
var errSQLUnableToCreate = errors.New("unable to create bookkeeping database")
|
||||
|
||||
// Bootstrap bootstraps the SQL database, and makes sure that the bookkeeping table is up-to-date
|
||||
func (sql SQL) Bootstrap(io stream.IOStream) error {
|
||||
func (sql *SQL) Bootstrap(io stream.IOStream) error {
|
||||
if err := sql.WaitShell(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue