From 6e066636159e508dd15c5dcbba10ad9754b25a59 Mon Sep 17 00:00:00 2001 From: Hubert Hirtz Date: Fri, 8 Oct 2021 19:15:56 +0200 Subject: [PATCH] PostgreSQL support --- cmd/soju/main.go | 2 +- cmd/sojuctl/main.go | 2 +- contrib/znc-import.go | 2 +- db.go | 11 ++ db_postgres.go | 402 ++++++++++++++++++++++++++++++++++++++++++ db_sqlite.go | 4 +- doc/soju.1.scd | 14 +- go.mod | 1 + go.sum | 2 + server_test.go | 2 +- 10 files changed, 434 insertions(+), 8 deletions(-) create mode 100644 db_postgres.go diff --git a/cmd/soju/main.go b/cmd/soju/main.go index 67dc444..952ea78 100644 --- a/cmd/soju/main.go +++ b/cmd/soju/main.go @@ -61,7 +61,7 @@ func main() { cfg.Listen = []string{":6697"} } - db, err := soju.OpenSqliteDB(cfg.SQLDriver, cfg.SQLSource) + db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource) if err != nil { log.Fatalf("failed to open database: %v", err) } diff --git a/cmd/sojuctl/main.go b/cmd/sojuctl/main.go index 48720b2..d19ccfb 100644 --- a/cmd/sojuctl/main.go +++ b/cmd/sojuctl/main.go @@ -43,7 +43,7 @@ func main() { cfg = config.Defaults() } - db, err := soju.OpenSqliteDB(cfg.SQLDriver, cfg.SQLSource) + db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource) if err != nil { log.Fatalf("failed to open database: %v", err) } diff --git a/contrib/znc-import.go b/contrib/znc-import.go index fb8b16a..8dd02ed 100644 --- a/contrib/znc-import.go +++ b/contrib/znc-import.go @@ -61,7 +61,7 @@ func main() { cfg = config.Defaults() } - db, err := soju.OpenSqliteDB(cfg.SQLDriver, cfg.SQLSource) + db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource) if err != nil { log.Fatalf("failed to open database: %v", err) } diff --git a/db.go b/db.go index 160645e..703d993 100644 --- a/db.go +++ b/db.go @@ -27,6 +27,17 @@ type Database interface { StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error } +func OpenDB(driver, source string) (Database, error) { + switch driver { + case "sqlite3": + return OpenSqliteDB(source) + case "postgres": + return OpenPostgresDB(source) + default: + return nil, fmt.Errorf("unsupported database driver: %q", driver) + } +} + type DatabaseStats struct { Users int64 Networks int64 diff --git a/db_postgres.go b/db_postgres.go new file mode 100644 index 0000000..aafbd95 --- /dev/null +++ b/db_postgres.go @@ -0,0 +1,402 @@ +package soju + +import ( + "database/sql" + "errors" + "fmt" + "math" + "strings" + "time" + + _ "github.com/lib/pq" +) + +const postgresConfigSchema = ` +CREATE TABLE IF NOT EXISTS "Config" ( + id SMALLINT PRIMARY KEY, + version INTEGER NOT NULL, + CHECK(id = 1) +); +` + +const postgresSchema = ` +CREATE TABLE "User" ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin BOOLEAN NOT NULL DEFAULT FALSE, + realname VARCHAR(255) +); + +CREATE TABLE "Network" ( + id SERIAL PRIMARY KEY, + name VARCHAR(255), + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BYTEA DEFAULT NULL, + sasl_external_key BYTEA DEFAULT NULL, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + UNIQUE("user", addr, nick), + UNIQUE("user", name) +); + +CREATE TABLE "Channel" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + detached BOOLEAN NOT NULL DEFAULT FALSE, + detached_internal_msgid VARCHAR(255), + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + UNIQUE(network, name) +); + +CREATE TABLE "DeliveryReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + client VARCHAR(255) NOT NULL DEFAULT '', + internal_msgid VARCHAR(255) NOT NULL, + UNIQUE(network, target, client) +); +` + +var postgresMigrations = []string{ + "", // migration #0 is reserved for schema initialization +} + +type PostgresDB struct { + db *sql.DB +} + +func OpenPostgresDB(source string) (Database, error) { + sqlPostgresDB, err := sql.Open("postgres", source) + if err != nil { + return nil, err + } + + db := &PostgresDB{db: sqlPostgresDB} + if err := db.upgrade(); err != nil { + sqlPostgresDB.Close() + return nil, err + } + + return db, nil +} + +func (db *PostgresDB) upgrade() error { + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.Exec(postgresConfigSchema); err != nil { + return fmt.Errorf("failed to create Config table: %s", err) + } + + var version int + err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to query schema version: %s", err) + } + + if version == len(postgresMigrations) { + return nil + } + if version > len(postgresMigrations) { + return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version) + } + + if version == 0 { + if _, err := tx.Exec(postgresSchema); err != nil { + return fmt.Errorf("failed to initialize schema: %s", err) + } + } else { + for i := version; i < len(postgresMigrations); i++ { + if _, err := tx.Exec(postgresMigrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1) + ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations)) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} + +func (db *PostgresDB) Close() error { + return db.db.Close() +} + +func (db *PostgresDB) Stats() (*DatabaseStats, error) { + var stats DatabaseStats + row := db.db.QueryRow(`SELECT + (SELECT COUNT(*) FROM "User") AS users, + (SELECT COUNT(*) FROM "Network") AS networks, + (SELECT COUNT(*) FROM "Channel") AS channels`) + if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil { + return nil, err + } + + return &stats, nil +} + +func (db *PostgresDB) ListUsers() ([]User, error) { + rows, err := db.db.Query(`SELECT id, username, password, admin, realname FROM "User"`) + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + var password, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (db *PostgresDB) GetUser(username string) (*User, error) { + user := &User{Username: username} + + var password, realname sql.NullString + row := db.db.QueryRow( + `SELECT id, password, admin, realname FROM "User" WHERE username = $1`, + username) + if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + return user, nil +} + +func (db *PostgresDB) StoreUser(user *User) error { + password := toNullString(user.Password) + realname := toNullString(user.Realname) + err := db.db.QueryRow(` + INSERT INTO "User" (username, password, admin, realname) + VALUES ($1, $2, $3, $4) + ON CONFLICT (username) + DO UPDATE SET password = $2, admin = $3, realname = $4 + RETURNING id`, + user.Username, password, user.Admin, realname).Scan(&user.ID) + return err +} + +func (db *PostgresDB) DeleteUser(id int64) error { + _, err := db.db.Exec(`DELETE FROM "User" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) { + rows, err := db.db.Query(` + SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, + sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled + FROM "Network" + WHERE "user" = $1`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var networks []Network + for rows.Next() { + var net Network + var name, username, realname, pass, connectCommands sql.NullString + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname, + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, + &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled) + if err != nil { + return nil, err + } + net.Name = name.String + net.Username = username.String + net.Realname = realname.String + net.Pass = pass.String + if connectCommands.Valid { + net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") + } + net.SASL.Mechanism = saslMechanism.String + net.SASL.Plain.Username = saslPlainUsername.String + net.SASL.Plain.Password = saslPlainPassword.String + networks = append(networks, net) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return networks, nil +} + +func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error { + netName := toNullString(network.Name) + netUsername := toNullString(network.Username) + realname := toNullString(network.Realname) + pass := toNullString(network.Pass) + connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n")) + + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + if network.SASL.Mechanism != "" { + saslMechanism = toNullString(network.SASL.Mechanism) + switch network.SASL.Mechanism { + case "PLAIN": + saslPlainUsername = toNullString(network.SASL.Plain.Username) + saslPlainPassword = toNullString(network.SASL.Plain.Password) + network.SASL.External.CertBlob = nil + network.SASL.External.PrivKeyBlob = nil + case "EXTERNAL": + // keep saslPlain* nil + default: + return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism) + } + } + + err := db.db.QueryRow(` + INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands, + sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, + sasl_external_key, enabled) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + ON CONFLICT ("user", name) + DO UPDATE SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7, + connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10, + sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13, + enabled = $14 + RETURNING id`, + userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands, + saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, + network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID) + return err +} + +func (db *PostgresDB) DeleteNetwork(id int64) error { + _, err := db.db.Exec(`DELETE FROM "Network" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) { + rows, err := db.db.Query(` + SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, + detach_on + FROM "Channel" + WHERE network = $1`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []Channel + for rows.Next() { + var ch Channel + var key, detachedInternalMsgID sql.NullString + var detachAfter int64 + if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil { + return nil, err + } + ch.Key = key.String + ch.DetachedInternalMsgID = detachedInternalMsgID.String + ch.DetachAfter = time.Duration(detachAfter) * time.Second + channels = append(channels, ch) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} + +func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error { + key := toNullString(ch.Key) + detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds())) + err := db.db.QueryRow(` + INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, + detach_after, detach_on) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (network, name) + DO UPDATE SET network = $1, name = $2, key = $3, detached = $4, detached_internal_msgid = $5, + relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9 + RETURNING id`, + networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), + ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID) + return err +} + +func (db *PostgresDB) DeleteChannel(id int64) error { + _, err := db.db.Exec(`DELETE FROM "Channel" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) { + rows, err := db.db.Query(` + SELECT id, target, client, internal_msgid + FROM "DeliveryReceipt" + WHERE network = $1`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var receipts []DeliveryReceipt + for rows.Next() { + var rcpt DeliveryReceipt + if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil { + return nil, err + } + receipts = append(receipts, rcpt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return receipts, nil +} + +func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error { + stmt, err := db.db.Prepare(` + INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid) + VALUES ($1, $2, $3, $4) + ON CONFLICT (network, target, client) + DO UPDATE SET internal_msgid = $4 + RETURNING id`) + if err != nil { + return err + } + defer stmt.Close() + + // No need for a transaction since all changes are atomic and don't break data coherence. + for i := range receipts { + rcpt := &receipts[i] + err := stmt.QueryRow(networkID, rcpt.Target, client, rcpt.InternalMsgID).Scan(&rcpt.ID) + if err != nil { + return err + } + } + return nil +} diff --git a/db_sqlite.go b/db_sqlite.go index 8b71102..a6dc8eb 100644 --- a/db_sqlite.go +++ b/db_sqlite.go @@ -142,8 +142,8 @@ type SqliteDB struct { db *sql.DB } -func OpenSqliteDB(driver, source string) (Database, error) { - sqlSqliteDB, err := sql.Open(driver, source) +func OpenSqliteDB(source string) (Database, error) { + sqlSqliteDB, err := sql.Open("sqlite3", source) if err != nil { return nil, err } diff --git a/doc/soju.1.scd b/doc/soju.1.scd index 2960c69..d244477 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -106,8 +106,18 @@ The following directives are supported: *tls* Enable TLS support. The certificate and the key files must be PEM-encoded. -*db* sqlite3 - Set the SQLite database path (default: "soju.db" in the current directory). +*db* + Set the database location for user, network and channel storage. By default, + a _sqlite3_ database is opened in "./soju.db". + + Supported drivers: + + - _sqlite3_ expects _source_ to be a path to the SQLite file + - _postgres_ expects _source_ to be a space-separated list of _key=value_ + parameters, e.g. _db postgres "host=/run/postgresql dbname=soju"_. Note + that _sslmode_ defaults to _require_. For more information on connection + strings, see: + . *log* fs Path to the bouncer logs root directory, or empty to disable logging. By diff --git a/go.mod b/go.mod index 3c3072e..6634d67 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 github.com/klauspost/compress v1.13.5 // indirect + github.com/lib/pq v1.10.3 github.com/mattn/go-sqlite3 v1.14.8 github.com/pires/go-proxyproto v0.6.1 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 diff --git a/go.sum b/go.sum index dfce806..0497d01 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/klauspost/compress v1.13.5 h1:9O69jUPDcsT9fEm74W92rZL9FQY7rCdaXVneq+y github.com/klauspost/compress v1.13.5/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg= +github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU= diff --git a/server_test.go b/server_test.go index c087d0e..4de7515 100644 --- a/server_test.go +++ b/server_test.go @@ -16,7 +16,7 @@ const ( ) func createTempDB(t *testing.T) Database { - db, err := OpenSqliteDB("sqlite3", ":memory:") + db, err := OpenDB("sqlite3", ":memory:") if err != nil { t.Fatalf("failed to create temporary SQLite database: %v", err) }