Switch DB API to user IDs

This commit changes the Network schema to use user IDs instead of
usernames. While at it, a new UNIQUE(user, name) constraint ensures
there is no conflict with custom network names.

Closes: https://todo.sr.ht/~emersion/soju/86
References: https://todo.sr.ht/~emersion/soju/29
This commit is contained in:
Simon Ser 2020-10-24 15:14:23 +02:00
parent b3e136e3b7
commit fa16337d97
No known key found for this signature in database
GPG key ID: 0FDE7BE0E88F5E48
5 changed files with 54 additions and 22 deletions

View file

@ -114,8 +114,9 @@ func main() {
if err := db.StoreUser(u); err != nil { if err := db.StoreUser(u); err != nil {
log.Fatalf("failed to store user %q: %v", username, err) log.Fatalf("failed to store user %q: %v", username, err)
} }
userID := u.ID
l, err := db.ListNetworks(username) l, err := db.ListNetworks(userID)
if err != nil { if err != nil {
log.Fatalf("failed to list networks for user %q: %v", username, err) log.Fatalf("failed to list networks for user %q: %v", username, err)
} }
@ -181,7 +182,7 @@ func main() {
n.Realname = netRealname n.Realname = netRealname
n.Pass = pass n.Pass = pass
if err := db.StoreNetwork(username, n); err != nil { if err := db.StoreNetwork(userID, n); err != nil {
logger.Fatalf("failed to store network: %v", err) logger.Fatalf("failed to store network: %v", err)
} }

53
db.go
View file

@ -70,7 +70,7 @@ CREATE TABLE User (
CREATE TABLE Network ( CREATE TABLE Network (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
name VARCHAR(255), name VARCHAR(255),
user VARCHAR(255) NOT NULL, user INTEGER NOT NULL,
addr VARCHAR(255) NOT NULL, addr VARCHAR(255) NOT NULL,
nick VARCHAR(255) NOT NULL, nick VARCHAR(255) NOT NULL,
username VARCHAR(255), username VARCHAR(255),
@ -82,8 +82,9 @@ CREATE TABLE Network (
sasl_plain_password VARCHAR(255), sasl_plain_password VARCHAR(255),
sasl_external_cert BLOB DEFAULT NULL, sasl_external_cert BLOB DEFAULT NULL,
sasl_external_key BLOB DEFAULT NULL, sasl_external_key BLOB DEFAULT NULL,
FOREIGN KEY(user) REFERENCES User(username), FOREIGN KEY(user) REFERENCES User(id),
UNIQUE(user, addr, nick) UNIQUE(user, addr, nick),
UNIQUE(user, name)
); );
CREATE TABLE Channel ( CREATE TABLE Channel (
@ -115,6 +116,36 @@ var migrations = []string{
DROP TABLE User; DROP TABLE User;
ALTER TABLE UserNew RENAME TO User; ALTER TABLE UserNew RENAME TO User;
`, `,
`
CREATE TABLE NetworkNew (
id INTEGER PRIMARY KEY,
name VARCHAR(255),
user INTEGER NOT NULL,
addr VARCHAR(255) NOT NULL,
nick VARCHAR(255) NOT NULL,
username VARCHAR(255),
realname VARCHAR(255),
pass VARCHAR(255),
connect_commands VARCHAR(1023),
sasl_mechanism VARCHAR(255),
sasl_plain_username VARCHAR(255),
sasl_plain_password VARCHAR(255),
sasl_external_cert BLOB DEFAULT NULL,
sasl_external_key BLOB DEFAULT NULL,
FOREIGN KEY(user) REFERENCES User(id),
UNIQUE(user, addr, nick),
UNIQUE(user, name)
);
INSERT INTO NetworkNew
SELECT Network.id, name, User.id as user, addr, nick,
Network.username, realname, pass, connect_commands,
sasl_mechanism, sasl_plain_username, sasl_plain_password,
sasl_external_cert, sasl_external_key
FROM Network
JOIN User ON Network.user = User.username;
DROP TABLE Network;
ALTER TABLE NetworkNew RENAME TO Network;
`,
} }
type DB struct { type DB struct {
@ -263,7 +294,7 @@ func (db *DB) StoreUser(user *User) error {
return err return err
} }
func (db *DB) DeleteUser(username string) error { func (db *DB) DeleteUser(id int64) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
@ -279,17 +310,17 @@ func (db *DB) DeleteUser(username string) error {
FROM Channel FROM Channel
JOIN Network ON Channel.network = Network.id JOIN Network ON Channel.network = Network.id
WHERE Network.user = ? WHERE Network.user = ?
)`, username) )`, id)
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec("DELETE FROM Network WHERE user = ?", username) _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec("DELETE FROM User WHERE username = ?", username) _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
if err != nil { if err != nil {
return err return err
} }
@ -297,7 +328,7 @@ func (db *DB) DeleteUser(username string) error {
return tx.Commit() return tx.Commit()
} }
func (db *DB) ListNetworks(username string) ([]Network, error) { func (db *DB) ListNetworks(userID int64) ([]Network, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
@ -306,7 +337,7 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
sasl_external_cert, sasl_external_key sasl_external_cert, sasl_external_key
FROM Network FROM Network
WHERE user = ?`, WHERE user = ?`,
username) userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -342,7 +373,7 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
return networks, nil return networks, nil
} }
func (db *DB) StoreNetwork(username string, network *Network) error { func (db *DB) StoreNetwork(userID int64, network *Network) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
@ -385,7 +416,7 @@ func (db *DB) StoreNetwork(username string, network *Network) error {
realname, pass, connect_commands, sasl_mechanism, sasl_plain_username, realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
sasl_plain_password, sasl_external_cert, sasl_external_key) sasl_plain_password, sasl_external_cert, sasl_external_key)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands, userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
network.SASL.External.PrivKeyBlob) network.SASL.External.PrivKeyBlob)
if err != nil { if err != nil {

View file

@ -1016,7 +1016,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
return return
} }
n.Nick = nick n.Nick = nick
err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network) err = dc.srv.db.StoreNetwork(dc.user.ID, &n.Network)
}) })
if err != nil { if err != nil {
return err return err
@ -1697,7 +1697,7 @@ func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
n.SASL.Mechanism = "PLAIN" n.SASL.Mechanism = "PLAIN"
n.SASL.Plain.Username = username n.SASL.Plain.Username = username
n.SASL.Plain.Password = password n.SASL.Plain.Password = password
if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil { if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil {
dc.logger.Printf("failed to save NickServ credentials: %v", err) dc.logger.Printf("failed to save NickServ credentials: %v", err)
} }
} }

View file

@ -548,7 +548,7 @@ func handleServiceCertfpGenerate(dc *downstreamConn, params []string) error {
net.SASL.External.PrivKeyBlob = privKeyBytes net.SASL.External.PrivKeyBlob = privKeyBytes
net.SASL.Mechanism = "EXTERNAL" net.SASL.Mechanism = "EXTERNAL"
if err := dc.srv.db.StoreNetwork(net.Username, &net.Network); err != nil { if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
return err return err
} }
@ -593,7 +593,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error {
net.SASL.Plain.Password = params[2] net.SASL.Plain.Password = params[2]
net.SASL.Mechanism = "PLAIN" net.SASL.Mechanism = "PLAIN"
if err := dc.srv.db.StoreNetwork(net.Username, &net.Network); err != nil { if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
return err return err
} }
@ -617,7 +617,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error {
net.SASL.External.PrivKeyBlob = nil net.SASL.External.PrivKeyBlob = nil
net.SASL.Mechanism = "" net.SASL.Mechanism = ""
if err := dc.srv.db.StoreNetwork(dc.user.Username, &net.Network); err != nil { if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil {
return err return err
} }
@ -689,7 +689,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error {
u.stop() u.stop()
if err := dc.srv.db.DeleteUser(username); err != nil { if err := dc.srv.db.DeleteUser(dc.user.ID); err != nil {
return fmt.Errorf("failed to delete user: %v", err) return fmt.Errorf("failed to delete user: %v", err)
} }

View file

@ -314,7 +314,7 @@ func (u *user) getNetworkByID(id int64) *network {
func (u *user) run() { func (u *user) run() {
defer close(u.done) defer close(u.done)
networks, err := u.srv.db.ListNetworks(u.Username) networks, err := u.srv.db.ListNetworks(u.ID)
if err != nil { if err != nil {
u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err) u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
return return
@ -508,7 +508,7 @@ func (u *user) createNetwork(record *Network) (*network, error) {
} }
network := newNetwork(u, record, nil) network := newNetwork(u, record, nil)
err := u.srv.db.StoreNetwork(u.Username, &network.Network) err := u.srv.db.StoreNetwork(u.ID, &network.Network)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -528,7 +528,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
panic("tried updating a non-existing network") panic("tried updating a non-existing network")
} }
if err := u.srv.db.StoreNetwork(u.Username, record); err != nil { if err := u.srv.db.StoreNetwork(u.ID, record); err != nil {
return nil, err return nil, err
} }