Cleanup and document hacky sql interaction

This commit is contained in:
Tom Wiesing 2022-09-19 23:16:56 +02:00
parent 881b538dff
commit 07409a01be
No known key found for this signature in database
17 changed files with 284 additions and 204 deletions

View file

@ -0,0 +1,139 @@
package sql
import (
"context"
"database/sql"
"fmt"
"net"
"sync/atomic"
mysqldriver "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/FAU-CDI/wisski-distillery/internal/models"
"github.com/FAU-CDI/wisski-distillery/pkg/wait"
)
//
// ========== low-level connection ==========
//
// Query performs a database query, outside a database contect
func (sql *SQL) Query(query string, args ...interface{}) error {
// connect to the server
conn, err := sql.connect("")
if err != nil {
return err
}
// do the query!
{
_, err := conn.Exec(query, args...)
if err != nil {
return err
}
return nil
}
}
// WaitQuery waits for the query interface to be able to connect to the database
func (sql *SQL) WaitQuery() error {
return wait.Wait(func() bool {
err := sql.Query("select 1;")
// log.Printf("[WaitQuery] %s\n", err) // debug
return err == nil
}, sql.PollInterval, sql.PollContext)
}
//
// ========== connection via gorm ==========
//
// QueryTable returns a gorm.DB to connect to the provided gorm database table
func (sql *SQL) QueryTable(silent bool, name string) (*gorm.DB, error) {
conn, err := sql.connect(sql.Config.DistilleryDatabase)
if err != nil {
return nil, err
}
// gorm configuration
config := &gorm.Config{}
if silent {
config.Logger = logger.Default.LogMode(logger.Silent)
}
// mysql connection
cfg := mysql.Config{
Conn: conn,
DefaultStringSize: 256,
}
// open the gorm connection!
db, err := gorm.Open(mysql.New(cfg), config)
if err != nil {
return nil, err
}
// set the table
db = db.Table(name)
// check that nothing went wrong
if db.Error != nil {
return nil, db.Error
}
return db, nil
}
// WaitQueryTable waits for a connection to succeed via QueryTable
func (sql *SQL) WaitQueryTable() error {
return wait.Wait(func() bool {
_, err := sql.QueryTable(true, models.InstanceTable)
return err == nil
}, sql.PollInterval, sql.PollContext)
}
//
// ========== low-level database connection ==========
//
func (ssql *SQL) connect(database string) (*sql.DB, error) {
conn, err := sql.Open("mysql", ssql.dsn(database))
if err != nil {
return nil, err
}
conn.SetMaxIdleConns(0)
return conn, nil
}
// dsn returns a dsn fof connecting to the database
func (sql *SQL) dsn(database string) string {
user := sql.Config.MysqlAdminUser
pass := sql.Config.MysqlAdminPassword
network := sql.network()
server := sql.ServerURL
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8&parseTime=True&loc=Local", user, pass, network, server, database)
}
var proxyNameCounter uint64
// network returns the network to use to connect to the database
func (sql *SQL) network() string {
return sql.lazyNetwork.Get(func() (name string) {
network := "tcp"
// register a new DialContext function to use the environment.
// this seems like a bit of a hack, but it works for now.
name = fmt.Sprintf("sql-network-%d", atomic.AddUint64(&proxyNameCounter, 1))
mysqldriver.RegisterDialContext(name, func(ctx context.Context, addr string) (net.Conn, error) {
return sql.Core.Environment.DialContext(ctx, network, addr)
})
return
})
}

View file

@ -0,0 +1 @@
package sql

View file

@ -1,111 +0,0 @@
package sql
import (
"context"
"errors"
"fmt"
"net"
"sync/atomic"
mysqldriver "github.com/go-sql-driver/mysql"
"github.com/FAU-CDI/wisski-distillery/pkg/sqle"
"github.com/FAU-CDI/wisski-distillery/pkg/wait"
"github.com/tkw1536/goprogram/stream"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var proxyNameCounter uint64
// network returns the network to use to connect to the database
func (sql *SQL) network() string {
return sql.lazyNetwork.Get(func() (name string) {
network := "tcp"
// register a new DialContext function to use the environment.
// this seems like a bit of a hack, but it works for now.
name = fmt.Sprintf("sql-network-%d", atomic.AddUint64(&proxyNameCounter, 1))
mysqldriver.RegisterDialContext(name, func(ctx context.Context, addr string) (net.Conn, error) {
return sql.Core.Environment.DialContext(ctx, network, addr)
})
return
})
}
// sqlOpen opens a new sql connection to the provided database using the administrative credentials
func (sql *SQL) openDatabase(database string, config *gorm.Config) (*gorm.DB, error) {
cfg := mysql.Config{
DriverName: "mysql",
DSN: fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8&parseTime=True&loc=Local", sql.Config.MysqlAdminUser, sql.Config.MysqlAdminPassword, sql.network(), sql.ServerURL, database),
DefaultStringSize: 256,
}
db, err := gorm.Open(mysql.New(cfg), config)
if err != nil {
return db, err
}
gdb, err := db.DB()
if err != nil {
return db, err
}
gdb.SetMaxIdleConns(0)
return db, nil
}
// OpenBookkeeping opens a connection to the bookkeeping database
func (sql *SQL) OpenBookkeeping(silent bool) (*gorm.DB, error) {
config := &gorm.Config{}
if silent {
config.Logger = logger.Default.LogMode(logger.Silent)
}
// open the database
db, err := sql.openDatabase(sql.Config.DistilleryBookkeepingDatabase, config)
if err != nil {
return nil, err
}
// load the table
table := db.Table(sql.Config.DistilleryBookkeepingTable)
if table.Error != nil {
return nil, err
}
return table, nil
}
// Shell runs a mysql shell command.
func (sql *SQL) Shell(io stream.IOStream, argv ...string) (int, error) {
return sql.Stack(sql.Environment).Exec(io, "sql", "mysql", argv...)
}
// WaitShell waits for the sql database to be reachable via shell
func (sql *SQL) WaitShell() error {
n := stream.FromNil()
return wait.Wait(func() bool {
code, err := sql.Shell(n, "-e", "show databases;")
return err == nil && code == 0
}, sql.PollInterval, sql.PollContext)
}
// Wait waits for a connection to the bookkeeping table to suceed
func (sql *SQL) Wait() error {
return wait.Wait(func() bool {
_, err := sql.OpenBookkeeping(true)
return err == nil
}, sql.PollInterval, sql.PollContext)
}
var errInvalidDatabaseName = errors.New("SQLProvision: Invalid database name")
// Query performs a raw database query
func (sql *SQL) Query(query string, args ...interface{}) bool {
raw := sqle.Format(query, args...)
code, err := sql.Shell(stream.FromNil(), "-e", raw)
return err == nil && code == 0
}

View file

@ -6,47 +6,81 @@ import (
"github.com/FAU-CDI/wisski-distillery/pkg/sqle"
)
// SQLProvision provisions a new sql database and user
var errProvisionInvalidDatabaseParams = errors.New("Provision: Invalid parameters")
var errProvisionInvalidGrant = errors.New("Provision: Grant failed")
// Provision provisions a new sql database and user
func (sql *SQL) Provision(name, user, password string) error {
// wait for the database
if err := sql.WaitShell(); err != nil {
// NOTE(twiesing): We shouldn't use string concat to build sql queries.
// But the driver doesn't support using query params for this particular query.
// Apparently it's a "feature", see https://github.com/go-sql-driver/mysql/issues/398#issuecomment-169951763.
// quick and dirty check to make sure that all the names won't sql inject.
if !sqle.IsSafeDatabaseLiteral(name) || !sqle.IsSafeDatabaseSingleQuote(user) || !sqle.IsSafeDatabaseSingleQuote(password) {
return errProvisionInvalidDatabaseParams
}
// We use the sql shell here, because not only can we not use query params, but the driver outright rejects queries.
// Queries of the form "CREATE USER 'test'@'%' IDENTIFIED BY 'test'; FLUSH PRIVILEGES;" return error 1064 when using driver, but are fine with the shell.
// This should be fixed eventually, but I have no idea how.
if err := sql.unsafeWaitShell(); err != nil {
return err
}
// it's not a safe database name!
if !sqle.IsSafeDatabaseName(name) {
return errInvalidDatabaseName
query := "CREATE DATABASE `" + name + "`;" +
"CREATE USER '" + user + "'@'%' IDENTIFIED BY '" + password + "';" +
"GRANT ALL PRIVILEGES ON `" + name + "`.* TO `" + user + "`@`%`; FLUSH PRIVILEGES;"
if !sql.unsafeQueryShell(query) {
return errProvisionInvalidGrant
}
// create the database and user!
if !sql.Query("CREATE DATABASE `"+name+"`; CREATE USER ?@`%` IDENTIFIED BY ?; GRANT ALL PRIVILEGES ON `"+name+"`.* TO ?@`%`; FLUSH PRIVILEGES;", user, password, user) {
return errors.New("SQLProvision: Failed to create user")
}
// and done!
return nil
}
var errSQLPurgeUser = errors.New("unable to delete user")
var errCreateSuperuserGrant = errors.New("CreateSuperUser: Grant failed")
func (sql *SQL) CreateSuperuser(user, password string, allowExisting bool) error {
// NOTE(twiesing): This function unsafely uses the shell directly to create a superuser.
// This is for two reasons:
// (1) this is used during bootstraping
// (2) The underlying driver doesn't support "GRANT ALL PRIVILEGES"
// See also [sql.Provision].
if !sqle.IsSafeDatabaseSingleQuote(user) || !sqle.IsSafeDatabaseSingleQuote(password) {
return errProvisionInvalidDatabaseParams
}
if err := sql.unsafeWaitShell(); err != nil {
return err
}
var IfNotExists string
if allowExisting {
IfNotExists = "IF NOT EXISTS"
}
query := "CREATE USER " + IfNotExists + " '" + user + "'@'%' IDENTIFIED BY '" + password + "';" +
"GRANT ALL PRIVILEGES ON *.* TO '" + user + "'@'%' WITH GRANT OPTION; FLUSH PRIVILEGES;"
if !sql.unsafeQueryShell(query) {
return errCreateSuperuserGrant
}
return nil
}
// SQLPurgeUser deletes the specified user from the database
func (sql *SQL) PurgeUser(user string) error {
if !sql.Query("DROP USER IF EXISTS ?@`%`; FLUSH PRIVILEGES; ", user) {
return errSQLPurgeUser
}
return nil
return sql.Query("DROP USER IF EXISTS ?@`%`; FLUSH PRIVILEGES; ", user)
}
var errSQLPurgeDB = errors.New("unable to drop database")
var errSQLPurgeDB = errors.New("unable to drop database: unsafe database name")
// SQLPurgeDatabase deletes the specified db from the database
func (sql *SQL) PurgeDatabase(db string) error {
if !sqle.IsSafeDatabaseName(db) {
if !sqle.IsSafeDatabaseLiteral(db) {
return errSQLPurgeDB
}
if !sql.Query("DROP DATABASE IF EXISTS `" + db + "`") {
return errSQLPurgeDB
}
return nil
return sql.Query("DROP DATABASE IF EXISTS `" + db + "`")
}

View file

@ -4,57 +4,82 @@ import (
"errors"
"fmt"
"github.com/FAU-CDI/wisski-distillery/internal/bookkeeping"
"github.com/FAU-CDI/wisski-distillery/internal/models"
"github.com/FAU-CDI/wisski-distillery/pkg/logging"
"github.com/FAU-CDI/wisski-distillery/pkg/sqle"
"github.com/FAU-CDI/wisski-distillery/pkg/wait"
"github.com/tkw1536/goprogram/stream"
)
// Shell runs a mysql shell with the provided databases.
//
// NOTE(twiesing): This command should not be used to connect to the database or execute queries except in known situations.
func (sql *SQL) Shell(io stream.IOStream, argv ...string) (int, error) {
return sql.Stack(sql.Environment).Exec(io, "sql", "mysql", argv...)
}
// unsafeWaitShell waits for a connection via the database shell to succeed
func (sql *SQL) unsafeWaitShell() error {
n := stream.FromNil()
return wait.Wait(func() bool {
code, err := sql.Shell(n, "-e", "select 1;")
// log.Printf("[unsafeWaitShell] %d %s\n", code, err) // debug
return err == nil && code == 0
}, sql.PollInterval, sql.PollContext)
}
// unsafeQuery shell executes a raw database query.
func (sql *SQL) unsafeQueryShell(query string) bool {
code, err := sql.Shell(stream.FromNil(), "-e", query)
return err == nil && code == 0
}
var errSQLUnableToCreateUser = errors.New("unable to create administrative user")
var errSQLUnsafeDatabaseName = errors.New("bookkeeping database has an unsafe name")
var errSQLUnableToCreate = errors.New("unable to create bookkeeping database")
var errSQLUnsafeDatabaseName = errors.New("distillery database has an unsafe name")
// Update initializes or updates the SQL database.
func (sql *SQL) Update(io stream.IOStream) error {
if err := sql.WaitShell(); err != nil {
return err
}
// create the admin user
logging.LogMessage(io, "Creating administrative user")
// unsafely create the admin user!
{
username := sql.Config.MysqlAdminUser
password := sql.Config.MysqlAdminPassword
if !sql.Query("CREATE USER IF NOT EXISTS ?@'%' IDENTIFIED BY ?; GRANT ALL PRIVILEGES ON *.* TO ?@`%` WITH GRANT OPTION; FLUSH PRIVILEGES;", username, password, username) {
return errSQLUnableToCreateUser
if err := sql.unsafeWaitShell(); err != nil {
return err
}
logging.LogMessage(io, "Creating administrative user")
{
username := sql.Config.MysqlAdminUser
password := sql.Config.MysqlAdminPassword
if err := sql.CreateSuperuser(username, password, true); err != nil {
return errSQLUnableToCreateUser
}
}
}
// create the admin user
logging.LogMessage(io, "Creating sql database")
{
if !sqle.IsSafeDatabaseName(sql.Config.DistilleryBookkeepingDatabase) {
if !sqle.IsSafeDatabaseLiteral(sql.Config.DistilleryDatabase) {
return errSQLUnsafeDatabaseName
}
createDBSQL := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`;", sql.Config.DistilleryBookkeepingDatabase)
if !sql.Query(createDBSQL) {
return errSQLUnableToCreate
createDBSQL := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`;", sql.Config.DistilleryDatabase)
if err := sql.Query(createDBSQL); err != nil {
return err
}
}
// wait for the database to come up
logging.LogMessage(io, "Waiting for database update to be complete")
sql.Wait()
sql.WaitQueryTable()
// open the database
logging.LogMessage(io, "Migrating bookkeeping table")
logging.LogMessage(io, "Migrating instances table")
{
db, err := sql.OpenBookkeeping(false)
db, err := sql.QueryTable(false, models.InstanceTable)
if err != nil {
return fmt.Errorf("unable to access bookkeeping table: %s", err)
}
if err := db.AutoMigrate(&bookkeeping.Instance{}); err != nil {
if err := db.AutoMigrate(&models.Instance{}); err != nil {
return fmt.Errorf("unable to migrate bookkeeping table: %s", err)
}
}