Add context args to Database interface

This is a mecanical change, which just lifts up the context.TODO()
calls from inside the DB implementations to the callers.

Future work involves properly wiring up the contexts when it makes
sense.
This commit is contained in:
Simon Ser 2021-10-18 19:15:15 +02:00
parent 4be6c4b19c
commit 9ec1f1a5b0
11 changed files with 110 additions and 101 deletions

View file

@ -2,6 +2,7 @@ package main
import ( import (
"bufio" "bufio"
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -75,7 +76,7 @@ func main() {
Password: string(hashed), Password: string(hashed),
Admin: *admin, Admin: *admin,
} }
if err := db.StoreUser(&user); err != nil { if err := db.StoreUser(context.TODO(), &user); err != nil {
log.Fatalf("failed to create user: %v", err) log.Fatalf("failed to create user: %v", err)
} }
case "change-password": case "change-password":
@ -85,7 +86,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
user, err := db.GetUser(username) user, err := db.GetUser(context.TODO(), username)
if err != nil { if err != nil {
log.Fatalf("failed to get user: %v", err) log.Fatalf("failed to get user: %v", err)
} }
@ -101,7 +102,7 @@ func main() {
} }
user.Password = string(hashed) user.Password = string(hashed)
if err := db.StoreUser(user); err != nil { if err := db.StoreUser(context.TODO(), user); err != nil {
log.Fatalf("failed to update password: %v", err) log.Fatalf("failed to update password: %v", err)
} }
default: default:

View file

@ -2,6 +2,7 @@ package main
import ( import (
"bufio" "bufio"
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -79,7 +80,7 @@ func main() {
log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err) log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
} }
l, err := db.ListUsers() l, err := db.ListUsers(context.TODO())
if err != nil { if err != nil {
log.Fatalf("failed to list users in DB: %v", err) log.Fatalf("failed to list users in DB: %v", err)
} }
@ -111,12 +112,12 @@ func main() {
u.Admin = section.Values.Get("Admin") == "true" u.Admin = section.Values.Get("Admin") == "true"
if err := db.StoreUser(u); err != nil { if err := db.StoreUser(context.TODO(), u); err != nil {
log.Fatalf("failed to store user %q: %v", username, err) log.Fatalf("failed to store user %q: %v", username, err)
} }
userID := u.ID userID := u.ID
l, err := db.ListNetworks(userID) l, err := db.ListNetworks(context.TODO(), userID)
if err != nil { if err != nil {
log.Fatalf("failed to list networks for user %q: %v", username, err) log.Fatalf("failed to list networks for user %q: %v", username, err)
} }
@ -183,11 +184,11 @@ func main() {
n.Pass = pass n.Pass = pass
n.Enabled = section.Values.Get("IRCConnectEnabled") != "false" n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
if err := db.StoreNetwork(userID, n); err != nil { if err := db.StoreNetwork(context.TODO(), userID, n); err != nil {
logger.Fatalf("failed to store network: %v", err) logger.Fatalf("failed to store network: %v", err)
} }
l, err := db.ListChannels(n.ID) l, err := db.ListChannels(context.TODO(), n.ID)
if err != nil { if err != nil {
logger.Fatalf("failed to list channels: %v", err) logger.Fatalf("failed to list channels: %v", err)
} }
@ -217,7 +218,7 @@ func main() {
ch.Key = section.Values.Get("Key") ch.Key = section.Values.Get("Key")
ch.Detached = section.Values.Get("Detached") == "true" ch.Detached = section.Values.Get("Detached") == "true"
if err := db.StoreChannel(n.ID, ch); err != nil { if err := db.StoreChannel(context.TODO(), n.ID, ch); err != nil {
logger.Printf("channel %q: failed to store channel: %v", chName, err) logger.Printf("channel %q: failed to store channel: %v", chName, err)
} }
}) })

27
db.go
View file

@ -1,6 +1,7 @@
package soju package soju
import ( import (
"context"
"fmt" "fmt"
"net/url" "net/url"
"strings" "strings"
@ -9,22 +10,22 @@ import (
type Database interface { type Database interface {
Close() error Close() error
Stats() (*DatabaseStats, error) Stats(ctx context.Context) (*DatabaseStats, error)
ListUsers() ([]User, error) ListUsers(ctx context.Context) ([]User, error)
GetUser(username string) (*User, error) GetUser(ctx context.Context, username string) (*User, error)
StoreUser(user *User) error StoreUser(ctx context.Context, user *User) error
DeleteUser(id int64) error DeleteUser(ctx context.Context, id int64) error
ListNetworks(userID int64) ([]Network, error) ListNetworks(ctx context.Context, userID int64) ([]Network, error)
StoreNetwork(userID int64, network *Network) error StoreNetwork(ctx context.Context, userID int64, network *Network) error
DeleteNetwork(id int64) error DeleteNetwork(ctx context.Context, id int64) error
ListChannels(networkID int64) ([]Channel, error) ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
StoreChannel(networKID int64, ch *Channel) error StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
DeleteChannel(id int64) error DeleteChannel(ctx context.Context, id int64) error
ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
} }
func OpenDB(driver, source string) (Database, error) { func OpenDB(driver, source string) (Database, error) {

View file

@ -147,8 +147,8 @@ func (db *PostgresDB) Close() error {
return db.db.Close() return db.db.Close()
} }
func (db *PostgresDB) Stats() (*DatabaseStats, error) { func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
var stats DatabaseStats var stats DatabaseStats
@ -163,8 +163,8 @@ func (db *PostgresDB) Stats() (*DatabaseStats, error) {
return &stats, nil return &stats, nil
} }
func (db *PostgresDB) ListUsers() ([]User, error) { func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, rows, err := db.db.QueryContext(ctx,
@ -192,8 +192,8 @@ func (db *PostgresDB) ListUsers() ([]User, error) {
return users, nil return users, nil
} }
func (db *PostgresDB) GetUser(username string) (*User, error) { func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
user := &User{Username: username} user := &User{Username: username}
@ -210,8 +210,8 @@ func (db *PostgresDB) GetUser(username string) (*User, error) {
return user, nil return user, nil
} }
func (db *PostgresDB) StoreUser(user *User) error { func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
password := toNullString(user.Password) password := toNullString(user.Password)
@ -234,16 +234,16 @@ func (db *PostgresDB) StoreUser(user *User) error {
return err return err
} }
func (db *PostgresDB) DeleteUser(id int64) error { func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id) _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
return err return err
} }
func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) { func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, ` rows, err := db.db.QueryContext(ctx, `
@ -286,8 +286,8 @@ func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
return networks, nil return networks, nil
} }
func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error { func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
netName := toNullString(network.Name) netName := toNullString(network.Name)
@ -338,16 +338,16 @@ func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
return err return err
} }
func (db *PostgresDB) DeleteNetwork(id int64) error { func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id) _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
return err return err
} }
func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) { func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, ` rows, err := db.db.QueryContext(ctx, `
@ -380,8 +380,8 @@ func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
return channels, nil return channels, nil
} }
func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error { func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
key := toNullString(ch.Key) key := toNullString(ch.Key)
@ -408,16 +408,16 @@ func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
return err return err
} }
func (db *PostgresDB) DeleteChannel(id int64) error { func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
_, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id) _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
return err return err
} }
func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) { func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, ` rows, err := db.db.QueryContext(ctx, `
@ -444,8 +444,8 @@ func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt,
return receipts, nil return receipts, nil
} }
func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error { func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
ctx, cancel := context.WithTimeout(context.TODO(), postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()
tx, err := db.db.Begin() tx, err := db.db.Begin()

View file

@ -208,11 +208,11 @@ func (db *SqliteDB) upgrade() error {
return tx.Commit() return tx.Commit()
} }
func (db *SqliteDB) Stats() (*DatabaseStats, error) { func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
var stats DatabaseStats var stats DatabaseStats
@ -234,11 +234,11 @@ func toNullString(s string) sql.NullString {
} }
} }
func (db *SqliteDB) ListUsers() ([]User, error) { func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, rows, err := db.db.QueryContext(ctx,
@ -266,11 +266,11 @@ func (db *SqliteDB) ListUsers() ([]User, error) {
return users, nil return users, nil
} }
func (db *SqliteDB) GetUser(username string) (*User, error) { func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
user := &User{Username: username} user := &User{Username: username}
@ -287,11 +287,11 @@ func (db *SqliteDB) GetUser(username string) (*User, error) {
return user, nil return user, nil
} }
func (db *SqliteDB) StoreUser(user *User) error { func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
args := []interface{}{ args := []interface{}{
@ -323,11 +323,11 @@ func (db *SqliteDB) StoreUser(user *User) error {
return err return err
} }
func (db *SqliteDB) DeleteUser(id int64) error { func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
tx, err := db.db.Begin() tx, err := db.db.Begin()
@ -371,11 +371,11 @@ func (db *SqliteDB) DeleteUser(id int64) error {
return tx.Commit() return tx.Commit()
} }
func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) { func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, ` rows, err := db.db.QueryContext(ctx, `
@ -420,11 +420,11 @@ func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
return networks, nil return networks, nil
} }
func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error { func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
@ -490,11 +490,11 @@ func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
return err return err
} }
func (db *SqliteDB) DeleteNetwork(id int64) error { func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
tx, err := db.db.Begin() tx, err := db.db.Begin()
@ -521,11 +521,11 @@ func (db *SqliteDB) DeleteNetwork(id int64) error {
return tx.Commit() return tx.Commit()
} }
func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) { func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, `SELECT rows, err := db.db.QueryContext(ctx, `SELECT
@ -558,11 +558,11 @@ func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
return channels, nil return channels, nil
} }
func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error { func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
args := []interface{}{ args := []interface{}{
@ -598,22 +598,22 @@ func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
return err return err
} }
func (db *SqliteDB) DeleteChannel(id int64) error { func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
_, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id) _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id)
return err return err
} }
func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) { func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, ` rows, err := db.db.QueryContext(ctx, `
@ -642,11 +642,11 @@ func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, er
return receipts, nil return receipts, nil
} }
func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error { func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
ctx, cancel := context.WithTimeout(context.TODO(), sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()
tx, err := db.db.Begin() tx, err := db.db.Begin()

View file

@ -1,6 +1,7 @@
package soju package soju
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@ -976,7 +977,7 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
func (dc *downstreamConn) authenticate(username, password string) error { func (dc *downstreamConn) authenticate(username, password string) error {
username, clientName, networkName := unmarshalUsername(username) username, clientName, networkName := unmarshalUsername(username)
u, err := dc.srv.db.GetUser(username) u, err := dc.srv.db.GetUser(context.TODO(), username)
if err != nil { if err != nil {
dc.logger.Printf("failed authentication for %q: user not found: %v", username, err) dc.logger.Printf("failed authentication for %q: user not found: %v", username, err)
return errAuthFailed return errAuthFailed
@ -1377,7 +1378,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
return return
} }
n.Nick = nick n.Nick = nick
err = dc.srv.db.StoreNetwork(dc.user.ID, &n.Network) err = dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network)
}) })
if err != nil { if err != nil {
return err return err
@ -1427,7 +1428,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}) })
n.Realname = storeRealname n.Realname = storeRealname
if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil { if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil {
dc.logger.Printf("failed to store network realname: %v", err) dc.logger.Printf("failed to store network realname: %v", err)
storeErr = err storeErr = err
} }
@ -1516,7 +1517,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
} }
uc.network.channels.SetValue(upstreamName, ch) uc.network.channels.SetValue(upstreamName, ch)
} }
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
} }
} }
@ -1548,7 +1549,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
} }
uc.network.channels.SetValue(upstreamName, ch) uc.network.channels.SetValue(upstreamName, ch)
} }
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
} }
} else { } else {
@ -2445,7 +2446,7 @@ func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
n.SASL.Mechanism = "PLAIN" n.SASL.Mechanism = "PLAIN"
n.SASL.Plain.Username = username n.SASL.Plain.Username = username
n.SASL.Plain.Password = password n.SASL.Plain.Password = password
if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil { if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &n.Network); err != nil {
dc.logger.Printf("failed to save NickServ credentials: %v", err) dc.logger.Printf("failed to save NickServ credentials: %v", err)
} }
} }

View file

@ -1,6 +1,7 @@
package soju package soju
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"mime" "mime"
@ -85,7 +86,7 @@ func (s *Server) prefix() *irc.Prefix {
} }
func (s *Server) Start() error { func (s *Server) Start() error {
users, err := s.db.ListUsers() users, err := s.db.ListUsers(context.TODO())
if err != nil { if err != nil {
return err return err
} }
@ -126,7 +127,7 @@ func (s *Server) createUser(user *User) (*user, error) {
return nil, fmt.Errorf("user %q already exists", user.Username) return nil, fmt.Errorf("user %q already exists", user.Username)
} }
err := s.db.StoreUser(user) err := s.db.StoreUser(context.TODO(), user)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create user in db: %v", err) return nil, fmt.Errorf("could not create user in db: %v", err)
} }

View file

@ -1,6 +1,7 @@
package soju package soju
import ( import (
"context"
"net" "net"
"testing" "testing"
@ -43,7 +44,7 @@ func createTestUser(t *testing.T, db Database) *User {
} }
record := &User{Username: testUsername, Password: string(hashed)} record := &User{Username: testUsername, Password: string(hashed)}
if err := db.StoreUser(record); err != nil { if err := db.StoreUser(context.TODO(), record); err != nil {
t.Fatalf("failed to store test user: %v", err) t.Fatalf("failed to store test user: %v", err)
} }
@ -68,7 +69,7 @@ func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Li
Nick: user.Username, Nick: user.Username,
Enabled: true, Enabled: true,
} }
if err := db.StoreNetwork(user.ID, network); err != nil { if err := db.StoreNetwork(context.TODO(), user.ID, network); err != nil {
t.Fatalf("failed to store test network: %v", err) t.Fatalf("failed to store test network: %v", err)
} }

View file

@ -1,6 +1,7 @@
package soju package soju
import ( import (
"context"
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
@ -657,7 +658,7 @@ func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error {
net.SASL.External.PrivKeyBlob = privKey net.SASL.External.PrivKeyBlob = privKey
net.SASL.Mechanism = "EXTERNAL" net.SASL.Mechanism = "EXTERNAL"
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil { if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
return err return err
} }
@ -698,7 +699,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error {
net.SASL.Plain.Password = params[2] net.SASL.Plain.Password = params[2]
net.SASL.Mechanism = "PLAIN" net.SASL.Mechanism = "PLAIN"
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil { if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
return err return err
} }
@ -722,7 +723,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error {
net.SASL.External.PrivKeyBlob = nil net.SASL.External.PrivKeyBlob = nil
net.SASL.Mechanism = "" net.SASL.Mechanism = ""
if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil { if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil {
return err return err
} }
@ -860,7 +861,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error {
u.stop() u.stop()
if err := dc.srv.db.DeleteUser(u.ID); err != nil { if err := dc.srv.db.DeleteUser(context.TODO(), u.ID); err != nil {
return fmt.Errorf("failed to delete user: %v", err) return fmt.Errorf("failed to delete user: %v", err)
} }
@ -1015,7 +1016,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
uc.updateChannelAutoDetach(upstreamName) uc.updateChannelAutoDetach(upstreamName)
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
return fmt.Errorf("failed to update channel: %v", err) return fmt.Errorf("failed to update channel: %v", err)
} }
@ -1024,7 +1025,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
} }
func handleServiceServerStatus(dc *downstreamConn, params []string) error { func handleServiceServerStatus(dc *downstreamConn, params []string) error {
dbStats, err := dc.user.srv.db.Stats() dbStats, err := dc.user.srv.db.Stats(context.TODO())
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package soju package soju
import ( import (
"context"
"crypto" "crypto"
"crypto/sha256" "crypto/sha256"
"crypto/tls" "crypto/tls"
@ -1516,7 +1517,7 @@ func (uc *upstreamConn) handleDetachedMessage(ch *Channel, msg *irc.Message) {
} }
if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) { if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
uc.network.attach(ch) uc.network.attach(ch)
if err := uc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil {
uc.logger.Printf("failed to update channel %q: %v", ch.Name, err) uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
} }
} }

21
user.go
View file

@ -1,6 +1,7 @@
package soju package soju
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
@ -330,7 +331,7 @@ func (net *network) deleteChannel(name string) error {
} }
} }
if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil { if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil {
return err return err
} }
net.channels.Delete(name) net.channels.Delete(name)
@ -367,7 +368,7 @@ func (net *network) storeClientDeliveryReceipts(clientName string) {
}) })
}) })
if err := net.user.srv.db.StoreClientDeliveryReceipts(net.ID, clientName, receipts); err != nil { if err := net.user.srv.db.StoreClientDeliveryReceipts(context.TODO(), net.ID, clientName, receipts); err != nil {
net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err) net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
} }
} }
@ -487,7 +488,7 @@ func (u *user) run() {
close(u.done) close(u.done)
}() }()
networks, err := u.srv.db.ListNetworks(u.ID) networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
if err != nil { if err != nil {
u.logger.Printf("failed to list networks for user %q: %v", u.Username, err) u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
return return
@ -495,7 +496,7 @@ func (u *user) run() {
for _, record := range networks { for _, record := range networks {
record := record record := record
channels, err := u.srv.db.ListChannels(record.ID) channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
if err != nil { if err != nil {
u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err) u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
continue continue
@ -505,7 +506,7 @@ func (u *user) run() {
u.networks = append(u.networks, network) u.networks = append(u.networks, network)
if u.hasPersistentMsgStore() { if u.hasPersistentMsgStore() {
receipts, err := u.srv.db.ListDeliveryReceipts(record.ID) receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
if err != nil { if err != nil {
u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err) u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
return return
@ -590,7 +591,7 @@ func (u *user) run() {
continue continue
} }
uc.network.detach(c) uc.network.detach(c)
if err := uc.srv.db.StoreChannel(uc.network.ID, c); err != nil { if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err) u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
} }
case eventDownstreamConnected: case eventDownstreamConnected:
@ -779,7 +780,7 @@ func (u *user) createNetwork(record *Network) (*network, error) {
} }
network := newNetwork(u, record, nil) network := newNetwork(u, record, nil)
err := u.srv.db.StoreNetwork(u.ID, &network.Network) err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -821,7 +822,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
panic("tried updating a non-existing network") panic("tried updating a non-existing network")
} }
if err := u.srv.db.StoreNetwork(u.ID, record); err != nil { if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil {
return nil, err return nil, err
} }
@ -888,7 +889,7 @@ func (u *user) deleteNetwork(id int64) error {
panic("tried deleting a non-existing network") panic("tried deleting a non-existing network")
} }
if err := u.srv.db.DeleteNetwork(network.ID); err != nil { if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil {
return err return err
} }
@ -914,7 +915,7 @@ func (u *user) updateUser(record *User) error {
} }
realnameUpdated := u.Realname != record.Realname realnameUpdated := u.Realname != record.Realname
if err := u.srv.db.StoreUser(record); err != nil { if err := u.srv.db.StoreUser(context.TODO(), record); err != nil {
return fmt.Errorf("failed to update user %q: %v", u.Username, err) return fmt.Errorf("failed to update user %q: %v", u.Username, err)
} }
u.User = *record u.User = *record