soju/db.go
Simon Ser 998546cdc3
Introduce User.Created
For Network and Channel, the database only needed to define one Store
operation to create/update a record. However since User is missing an ID
we couldn't have a single StoreUser function like other types. We had
CreateUser and UpdatePassword. As new User fields get added (e.g. the
upcoming Admin flag) this isn't sustainable.

We could have CreateUser and UpdateUser, but this wouldn't be consistent
with other types. Instead, introduce User.Created which indicates
whether the record is already stored in the DB. This can be used in a
new StoreUser function to decide whether we need to UPDATE or INSERT
without relying on SQL constraints and INSERT OR UPDATE.

The ListUsers and GetUser functions set User.Created to true.
2020-06-08 11:59:03 +02:00

433 lines
10 KiB
Go

package soju
import (
"database/sql"
"fmt"
"strings"
"sync"
_ "github.com/mattn/go-sqlite3"
)
type User struct {
Created bool
Username string
Password string // hashed
}
type SASL struct {
Mechanism string
Plain struct {
Username string
Password string
}
// TLS client certificate authentication.
External struct {
// X.509 certificate in DER form.
CertBlob []byte
// PKCS#8 private key in DER form.
PrivKeyBlob []byte
}
}
type Network struct {
ID int64
Name string
Addr string
Nick string
Username string
Realname string
Pass string
ConnectCommands []string
SASL SASL
}
func (net *Network) GetName() string {
if net.Name != "" {
return net.Name
}
return net.Addr
}
type Channel struct {
ID int64
Name string
Key string
Detached bool
}
const schema = `
CREATE TABLE User (
username VARCHAR(255) PRIMARY KEY,
password VARCHAR(255) NOT NULL
);
CREATE TABLE Network (
id INTEGER PRIMARY KEY,
name VARCHAR(255),
user VARCHAR(255) NOT NULL,
addr VARCHAR(255) NOT NULL,
nick VARCHAR(255) NOT NULL,
username VARCHAR(255),
realname VARCHAR(255),
pass VARCHAR(255),
connect_commands VARCHAR(1023),
sasl_mechanism VARCHAR(255),
sasl_plain_username VARCHAR(255),
sasl_plain_password VARCHAR(255),
sasl_external_cert BLOB DEFAULT NULL,
sasl_external_key BLOB DEFAULT NULL,
FOREIGN KEY(user) REFERENCES User(username),
UNIQUE(user, addr, nick)
);
CREATE TABLE Channel (
id INTEGER PRIMARY KEY,
network INTEGER NOT NULL,
name VARCHAR(255) NOT NULL,
key VARCHAR(255),
detached INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, name)
);
`
var migrations = []string{
"", // migration #0 is reserved for schema initialization
"ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
"ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
"ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
"ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
}
type DB struct {
lock sync.RWMutex
db *sql.DB
}
func OpenSQLDB(driver, source string) (*DB, error) {
sqlDB, err := sql.Open(driver, source)
if err != nil {
return nil, err
}
db := &DB{db: sqlDB}
if err := db.upgrade(); err != nil {
return nil, err
}
return db, nil
}
func (db *DB) Close() error {
db.lock.Lock()
defer db.lock.Unlock()
return db.Close()
}
func (db *DB) upgrade() error {
db.lock.Lock()
defer db.lock.Unlock()
var version int
if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
return fmt.Errorf("failed to query schema version: %v", err)
}
if version == len(migrations) {
return nil
} else if version > len(migrations) {
return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
}
tx, err := db.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if version == 0 {
if _, err := tx.Exec(schema); err != nil {
return fmt.Errorf("failed to initialize schema: %v", err)
}
} else {
for i := version; i < len(migrations); i++ {
if _, err := tx.Exec(migrations[i]); err != nil {
return fmt.Errorf("failed to execute migration #%v: %v", i, err)
}
}
}
// For some reason prepared statements don't work here
_, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
if err != nil {
return fmt.Errorf("failed to bump schema version: %v", err)
}
return tx.Commit()
}
func fromStringPtr(ptr *string) string {
if ptr == nil {
return ""
}
return *ptr
}
func toStringPtr(s string) *string {
if s == "" {
return nil
}
return &s
}
func (db *DB) ListUsers() ([]User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
rows, err := db.db.Query("SELECT username, password FROM User")
if err != nil {
return nil, err
}
defer rows.Close()
var users []User
for rows.Next() {
var user User
var password *string
if err := rows.Scan(&user.Username, &password); err != nil {
return nil, err
}
user.Created = true
user.Password = fromStringPtr(password)
users = append(users, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return users, nil
}
func (db *DB) GetUser(username string) (*User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
user := &User{Created: true, Username: username}
var password *string
row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
if err := row.Scan(&password); err != nil {
return nil, err
}
user.Password = fromStringPtr(password)
return user, nil
}
func (db *DB) StoreUser(user *User) error {
db.lock.Lock()
defer db.lock.Unlock()
password := toStringPtr(user.Password)
var err error
if user.Created {
_, err = db.db.Exec("UPDATE User SET password = ? WHERE username = ?",
password, user.Username)
} else {
_, err = db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)",
user.Username, password)
if err == nil {
user.Created = true
}
}
return err
}
func (db *DB) ListNetworks(username string) ([]Network, error) {
db.lock.RLock()
defer db.lock.RUnlock()
rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
sasl_external_cert, sasl_external_key
FROM Network
WHERE user = ?`,
username)
if err != nil {
return nil, err
}
defer rows.Close()
var networks []Network
for rows.Next() {
var net Network
var name, username, realname, pass, connectCommands *string
var saslMechanism, saslPlainUsername, saslPlainPassword *string
err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
&net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
if err != nil {
return nil, err
}
net.Name = fromStringPtr(name)
net.Username = fromStringPtr(username)
net.Realname = fromStringPtr(realname)
net.Pass = fromStringPtr(pass)
if connectCommands != nil {
net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
}
net.SASL.Mechanism = fromStringPtr(saslMechanism)
net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
networks = append(networks, net)
}
if err := rows.Err(); err != nil {
return nil, err
}
return networks, nil
}
func (db *DB) StoreNetwork(username string, network *Network) error {
db.lock.Lock()
defer db.lock.Unlock()
netName := toStringPtr(network.Name)
netUsername := toStringPtr(network.Username)
realname := toStringPtr(network.Realname)
pass := toStringPtr(network.Pass)
connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
var saslMechanism, saslPlainUsername, saslPlainPassword *string
if network.SASL.Mechanism != "" {
saslMechanism = &network.SASL.Mechanism
switch network.SASL.Mechanism {
case "PLAIN":
saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
network.SASL.External.CertBlob = nil
network.SASL.External.PrivKeyBlob = nil
case "EXTERNAL":
// keep saslPlain* nil
default:
return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
}
}
var err error
if network.ID != 0 {
_, err = db.db.Exec(`UPDATE Network
SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
sasl_external_cert = ?, sasl_external_key = ?
WHERE id = ?`,
netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword,
network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
network.ID)
} else {
var res sql.Result
res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
sasl_plain_password, sasl_external_cert, sasl_external_key)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
network.SASL.External.PrivKeyBlob)
if err != nil {
return err
}
network.ID, err = res.LastInsertId()
}
return err
}
func (db *DB) DeleteNetwork(id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
tx, err := db.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
if err != nil {
return err
}
return tx.Commit()
}
func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
db.lock.RLock()
defer db.lock.RUnlock()
rows, err := db.db.Query(`SELECT id, name, key, detached
FROM Channel
WHERE network = ?`, networkID)
if err != nil {
return nil, err
}
defer rows.Close()
var channels []Channel
for rows.Next() {
var ch Channel
var key *string
if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
return nil, err
}
ch.Key = fromStringPtr(key)
channels = append(channels, ch)
}
if err := rows.Err(); err != nil {
return nil, err
}
return channels, nil
}
func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
db.lock.Lock()
defer db.lock.Unlock()
key := toStringPtr(ch.Key)
var err error
if ch.ID != 0 {
_, err = db.db.Exec(`UPDATE Channel
SET network = ?, name = ?, key = ?, detached = ?
WHERE id = ?`,
networkID, ch.Name, key, ch.Detached, ch.ID)
} else {
var res sql.Result
res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
VALUES (?, ?, ?, ?)`,
networkID, ch.Name, key, ch.Detached)
if err != nil {
return err
}
ch.ID, err = res.LastInsertId()
}
return err
}
func (db *DB) DeleteChannel(networkID int64, name string) error {
db.lock.Lock()
defer db.lock.Unlock()
_, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
return err
}