diff --git a/cmd/soju/main.go b/cmd/soju/main.go index 5e784c9..1230cf1 100644 --- a/cmd/soju/main.go +++ b/cmd/soju/main.go @@ -24,6 +24,7 @@ import ( "git.sr.ht/~emersion/soju" "git.sr.ht/~emersion/soju/config" + "git.sr.ht/~emersion/soju/database" ) // TCP keep-alive interval for downstream TCP connections @@ -116,7 +117,7 @@ func main() { log.Printf("failed to bump max number of opened files: %v", err) } - db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource) + db, err := database.Open(cfg.SQLDriver, cfg.SQLSource) if err != nil { log.Fatalf("failed to open database: %v", err) } @@ -308,7 +309,7 @@ func main() { log.Printf("server listening on %q", listen) } - if db, ok := db.(soju.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil { + if db, ok := db.(database.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil { if err := db.RegisterMetrics(srv.MetricsRegistry); err != nil { log.Fatalf("failed to register database metrics: %v", err) } diff --git a/cmd/sojuctl/main.go b/cmd/sojuctl/main.go index 3c4b916..24ed1ec 100644 --- a/cmd/sojuctl/main.go +++ b/cmd/sojuctl/main.go @@ -9,10 +9,11 @@ import ( "log" "os" - "git.sr.ht/~emersion/soju" - "git.sr.ht/~emersion/soju/config" "golang.org/x/crypto/bcrypt" "golang.org/x/crypto/ssh/terminal" + + "git.sr.ht/~emersion/soju/config" + "git.sr.ht/~emersion/soju/database" ) const usage = `usage: sojuctl [-config path] [options...] @@ -44,7 +45,7 @@ func main() { cfg = config.Defaults() } - db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource) + db, err := database.Open(cfg.SQLDriver, cfg.SQLSource) if err != nil { log.Fatalf("failed to open database: %v", err) } @@ -73,7 +74,7 @@ func main() { log.Fatalf("failed to hash password: %v", err) } - user := soju.User{ + user := database.User{ Username: username, Password: string(hashed), Admin: *admin, diff --git a/contrib/znc-import.go b/contrib/znc-import.go index 9aa9dd9..afb6a9e 100644 --- a/contrib/znc-import.go +++ b/contrib/znc-import.go @@ -12,8 +12,8 @@ import ( "strings" "unicode" - "git.sr.ht/~emersion/soju" "git.sr.ht/~emersion/soju/config" + "git.sr.ht/~emersion/soju/database" ) const usage = `usage: znc-import [options...] @@ -64,7 +64,7 @@ func main() { ctx := context.Background() - db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource) + db, err := database.Open(cfg.SQLDriver, cfg.SQLSource) if err != nil { log.Fatalf("failed to open database: %v", err) } @@ -86,7 +86,7 @@ func main() { if err != nil { log.Fatalf("failed to list users in DB: %v", err) } - existingUsers := make(map[string]*soju.User, len(l)) + existingUsers := make(map[string]*database.User, len(l)) for i, u := range l { existingUsers[u.Username] = &l[i] } @@ -107,7 +107,7 @@ func main() { log.Printf("user %q: updating existing user", username) } else { // "!!" is an invalid crypt format, thus disables password auth - u = &soju.User{Username: username, Password: "!!"} + u = &database.User{Username: username, Password: "!!"} usersCreated++ log.Printf("user %q: creating new user", username) } @@ -123,7 +123,7 @@ func main() { if err != nil { log.Fatalf("failed to list networks for user %q: %v", username, err) } - existingNetworks := make(map[string]*soju.Network, len(l)) + existingNetworks := make(map[string]*database.Network, len(l)) for i, n := range l { existingNetworks[n.GetName()] = &l[i] } @@ -175,7 +175,7 @@ func main() { if ok { logger.Printf("updating existing network") } else { - n = &soju.Network{Name: netName} + n = &database.Network{Name: netName} logger.Printf("creating new network") } @@ -194,7 +194,7 @@ func main() { if err != nil { logger.Fatalf("failed to list channels: %v", err) } - existingChannels := make(map[string]*soju.Channel, len(l)) + existingChannels := make(map[string]*database.Channel, len(l)) for i, ch := range l { existingChannels[ch.Name] = &l[i] } @@ -213,7 +213,7 @@ func main() { if ok { logger.Printf("channel %q: updating existing channel", chName) } else { - ch = &soju.Channel{Name: chName} + ch = &database.Channel{Name: chName} logger.Printf("channel %q: creating new channel", chName) } diff --git a/db.go b/database/database.go similarity index 90% rename from db.go rename to database/database.go index 1bb1576..f98a6d1 100644 --- a/db.go +++ b/database/database.go @@ -1,4 +1,4 @@ -package soju +package database import ( "context" @@ -38,7 +38,7 @@ type MetricsCollectorDatabase interface { RegisterMetrics(r prometheus.Registerer) error } -func OpenDB(driver, source string) (Database, error) { +func Open(driver, source string) (Database, error) { switch driver { case "sqlite3": return OpenSqliteDB(source) @@ -149,20 +149,6 @@ const ( FilterMessage ) -func parseFilter(filter string) (MessageFilter, error) { - switch filter { - case "default": - return FilterDefault, nil - case "none": - return FilterNone, nil - case "highlight": - return FilterHighlight, nil - case "message": - return FilterMessage, nil - } - return 0, fmt.Errorf("unknown filter: %q", filter) -} - type Channel struct { ID int64 Name string diff --git a/db_postgres.go b/database/postgres.go similarity index 95% rename from db_postgres.go rename to database/postgres.go index 8833adf..36eb91e 100644 --- a/db_postgres.go +++ b/database/postgres.go @@ -1,4 +1,4 @@ -package soju +package database import ( "context" @@ -127,6 +127,37 @@ func OpenPostgresDB(source string) (Database, error) { return db, nil } +func openTempPostgresDB(source string) (*sql.DB, error) { + db, err := sql.Open("postgres", source) + if err != nil { + return nil, fmt.Errorf("failed to connect to PostgreSQL: %v", err) + } + + // Store all tables in a temporary schema which will be dropped when the + // connection to PostgreSQL is closed. + db.SetMaxOpenConns(1) + if _, err := db.Exec("SET search_path TO pg_temp"); err != nil { + return nil, fmt.Errorf("failed to set PostgreSQL search_path: %v", err) + } + + return db, nil +} + +func OpenTempPostgresDB(source string) (Database, error) { + sqlPostgresDB, err := openTempPostgresDB(source) + if err != nil { + return nil, err + } + + db := &PostgresDB{db: sqlPostgresDB} + if err := db.upgrade(); err != nil { + sqlPostgresDB.Close() + return nil, err + } + + return db, nil +} + func (db *PostgresDB) upgrade() error { tx, err := db.db.Begin() if err != nil { diff --git a/db_postgres_test.go b/database/postgres_test.go similarity index 81% rename from db_postgres_test.go rename to database/postgres_test.go index 577cc7e..4df736b 100644 --- a/db_postgres_test.go +++ b/database/postgres_test.go @@ -1,7 +1,6 @@ -package soju +package database import ( - "database/sql" "os" "testing" ) @@ -68,29 +67,17 @@ CREATE TABLE "DeliveryReceipt" ( ); ` -func openTempPostgresDB(t *testing.T) *sql.DB { +func TestPostgresMigrations(t *testing.T) { source, ok := os.LookupEnv("SOJU_TEST_POSTGRES") if !ok { t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests") } - db, err := sql.Open("postgres", source) + sqlDB, err := openTempPostgresDB(source) if err != nil { - t.Fatalf("failed to connect to PostgreSQL: %v", err) + t.Fatalf("openTempPostgresDB() failed: %v", err) } - // Store all tables in a temporary schema which will be dropped when the - // connection to PostgreSQL is closed. - db.SetMaxOpenConns(1) - if _, err := db.Exec("SET search_path TO pg_temp"); err != nil { - t.Fatalf("failed to set PostgreSQL search_path: %v", err) - } - - return db -} - -func TestPostgresMigrations(t *testing.T) { - sqlDB := openTempPostgresDB(t) if _, err := sqlDB.Exec(postgresV0Schema); err != nil { t.Fatalf("DB.Exec() failed for v0 schema: %v", err) } diff --git a/db_sqlite.go b/database/sqlite.go similarity index 97% rename from db_sqlite.go rename to database/sqlite.go index 0ba2d60..b7a68e1 100644 --- a/db_sqlite.go +++ b/database/sqlite.go @@ -1,4 +1,4 @@ -package soju +package database import ( "context" @@ -15,6 +15,12 @@ import ( const sqliteQueryTimeout = 5 * time.Second +const sqliteTimeLayout = "2006-01-02T15:04:05.000Z" + +func formatSqliteTime(t time.Time) string { + return t.UTC().Format(sqliteTimeLayout) +} + const sqliteSchema = ` CREATE TABLE User ( id INTEGER PRIMARY KEY, @@ -212,6 +218,13 @@ func OpenSqliteDB(source string) (Database, error) { return db, nil } +func OpenTempSqliteDB() (Database, error) { + // :memory: will open a separate database for each new connection. Make + // sure the sql package only uses a single connection via SetMaxOpenConns. + // An alternative solution is to use "file::memory:?cache=shared". + return OpenSqliteDB(":memory:") +} + func (db *SqliteDB) Close() error { return db.db.Close() } @@ -732,7 +745,7 @@ func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name st } return nil, err } - if t, err := time.Parse(serverTimeLayout, timestamp); err != nil { + if t, err := time.Parse(sqliteTimeLayout, timestamp); err != nil { return nil, err } else { receipt.Timestamp = t @@ -746,7 +759,7 @@ func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, recei args := []interface{}{ sql.Named("id", receipt.ID), - sql.Named("timestamp", formatServerTime(receipt.Timestamp)), + sql.Named("timestamp", formatSqliteTime(receipt.Timestamp)), sql.Named("network", networkID), sql.Named("target", receipt.Target), } diff --git a/db_sqlite_test.go b/database/sqlite_test.go similarity index 98% rename from db_sqlite_test.go rename to database/sqlite_test.go index 9524d28..eb2c7ad 100644 --- a/db_sqlite_test.go +++ b/database/sqlite_test.go @@ -1,4 +1,4 @@ -package soju +package database import ( "database/sql" diff --git a/downstream.go b/downstream.go index b1ce893..b818d4b 100644 --- a/downstream.go +++ b/downstream.go @@ -16,6 +16,8 @@ import ( "github.com/emersion/go-sasl" "golang.org/x/crypto/bcrypt" "gopkg.in/irc.v3" + + "git.sr.ht/~emersion/soju/database" ) type ircError struct { @@ -100,7 +102,7 @@ func parseBouncerNetID(subcommand, s string) (int64, error) { return id, nil } -func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) { +func fillNetworkAddrAttrs(attrs irc.Tags, network *database.Network) { u, err := network.URL() if err != nil { return @@ -132,13 +134,13 @@ func getNetworkAttrs(network *network) irc.Tags { attrs := irc.Tags{ "name": irc.TagValue(network.GetName()), "state": irc.TagValue(state), - "nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)), + "nickname": irc.TagValue(database.GetNick(&network.user.User, &network.Network)), } if network.Username != "" { attrs["username"] = irc.TagValue(network.Username) } - if realname := GetRealname(&network.user.User, &network.Network); realname != "" { + if realname := database.GetRealname(&network.user.User, &network.Network); realname != "" { attrs["realname"] = irc.TagValue(realname) } @@ -169,7 +171,7 @@ func networkAddrFromAttrs(attrs irc.Tags) string { return addr } -func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error { +func updateNetworkAttrs(record *database.Network, attrs irc.Tags, subcommand string) error { addrAttrs := irc.Tags{} fillNetworkAddrAttrs(addrAttrs, record) @@ -414,7 +416,7 @@ func isOurNick(net *network, nick string) bool { // know whether this name is our nickname. Best-effort: use the network's // configured nickname and hope it was the one being used when we were // connected. - return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network)) + return net.casemap(nick) == net.casemap(database.GetNick(&net.user.User, &net.Network)) } // marshalEntity converts an upstream entity name (ie. channel or nick) into a @@ -1146,9 +1148,9 @@ func (dc *downstreamConn) updateNick() { if uc := dc.upstream(); uc != nil { nick = uc.nick } else if dc.network != nil { - nick = GetNick(&dc.user.User, &dc.network.Network) + nick = database.GetNick(&dc.user.User, &dc.network.Network) } else { - nick = GetNick(&dc.user.User, nil) + nick = database.GetNick(&dc.user.User, nil) } if nick == dc.nick { @@ -1201,9 +1203,9 @@ func (dc *downstreamConn) updateRealname() { if uc := dc.upstream(); uc != nil { realname = uc.realname } else if dc.network != nil { - realname = GetRealname(&dc.user.User, &dc.network.Network) + realname = database.GetRealname(&dc.user.User, &dc.network.Network) } else { - realname = GetRealname(&dc.user.User, nil) + realname = database.GetRealname(&dc.user.User, nil) } if realname != dc.realname { @@ -1439,7 +1441,7 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error { dc.logger.Printf("auto-saving network %q", dc.registration.networkName) var err error - network, err = dc.user.createNetwork(ctx, &Network{ + network, err = dc.user.createNetwork(ctx, &database.Network{ Addr: dc.registration.networkName, Nick: nick, Enabled: true, @@ -1475,7 +1477,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error { if uc := dc.upstream(); uc != nil { dc.nick = uc.nick } else if dc.network != nil { - dc.nick = GetNick(&dc.user.User, &dc.network.Network) + dc.nick = database.GetNick(&dc.user.User, &dc.network.Network) } else { dc.nick = dc.user.Username } @@ -1931,7 +1933,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } uc.network.attach(ctx, ch) } else { - ch = &Channel{ + ch = &database.Channel{ Name: upstreamName, Key: key, } @@ -1963,7 +1965,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if ch != nil { uc.network.detach(ch) } else { - ch = &Channel{ + ch = &database.Channel{ Name: name, Detached: true, } @@ -2911,7 +2913,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"}, }} } else if r == nil { - r = &ReadReceipt{ + r = &database.ReadReceipt{ Target: entityCM, } } @@ -3082,7 +3084,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } attrs := irc.ParseTags(attrsStr) - record := &Network{Nick: dc.nick, Enabled: true} + record := &database.Network{Nick: dc.nick, Enabled: true} if err := updateNetworkAttrs(record, attrs, subcommand); err != nil { return err } diff --git a/irc.go b/irc.go index 4eaa9ef..ba48de5 100644 --- a/irc.go +++ b/irc.go @@ -9,6 +9,8 @@ import ( "unicode/utf8" "gopkg.in/irc.v3" + + "git.sr.ht/~emersion/soju/database" ) const ( @@ -653,12 +655,12 @@ func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel { type channelCasemapMap struct{ casemapMap } -func (cm *channelCasemapMap) Value(name string) *Channel { +func (cm *channelCasemapMap) Value(name string) *database.Channel { entry, ok := cm.innerMap[cm.casemap(name)] if !ok { return nil } - return entry.value.(*Channel) + return entry.value.(*database.Channel) } type membershipsCasemapMap struct{ casemapMap } diff --git a/msgstore.go b/msgstore.go index dced6d9..d6c6379 100644 --- a/msgstore.go +++ b/msgstore.go @@ -9,6 +9,8 @@ import ( "git.sr.ht/~sircmpwn/go-bare" "gopkg.in/irc.v3" + + "git.sr.ht/~emersion/soju/database" ) // messageStore is a per-user store for IRC messages. @@ -17,11 +19,11 @@ type messageStore interface { // LastMsgID queries the last message ID for the given network, entity and // date. The message ID returned may not refer to a valid message, but can be // used in history queries. - LastMsgID(network *Network, entity string, t time.Time) (string, error) + LastMsgID(network *database.Network, entity string, t time.Time) (string, error) // LoadLatestID queries the latest non-event messages for the given network, // entity and date, up to a count of limit messages, sorted from oldest to newest. - LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) - Append(network *Network, entity string, msg *irc.Message) (id string, err error) + LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) + Append(network *database.Network, entity string, msg *irc.Message) (id string, err error) } type chatHistoryTarget struct { @@ -38,17 +40,17 @@ type chatHistoryMessageStore interface { // It returns up to limit targets, starting from start and ending on end, // both excluded. end may be before or after start. // If events is false, only PRIVMSG/NOTICE messages are considered. - ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) + ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) // LoadBeforeTime loads up to limit messages before start down to end. The // returned messages must be between and excluding the provided bounds. // end is before start. // If events is false, only PRIVMSG/NOTICE messages are considered. - LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) + LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) // LoadBeforeTime loads up to limit messages after start up to end. The // returned messages must be between and excluding the provided bounds. // end is after start. // If events is false, only PRIVMSG/NOTICE messages are considered. - LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) + LoadAfterTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) } type searchOptions struct { @@ -66,7 +68,7 @@ type searchMessageStore interface { messageStore // Search returns messages matching the specified options. - Search(ctx context.Context, network *Network, search searchOptions) ([]*irc.Message, error) + Search(ctx context.Context, network *database.Network, search searchOptions) ([]*irc.Message, error) } type msgIDType uint diff --git a/msgstore_fs.go b/msgstore_fs.go index 14a79cf..078a947 100644 --- a/msgstore_fs.go +++ b/msgstore_fs.go @@ -13,6 +13,8 @@ import ( "git.sr.ht/~sircmpwn/go-bare" "gopkg.in/irc.v3" + + "git.sr.ht/~emersion/soju/database" ) const ( @@ -80,7 +82,7 @@ type fsMessageStoreFile struct { // https://github.com/znc/znc/blob/master/modules/log.cpp type fsMessageStore struct { root string - user *User + user *database.User // Write-only files used by Append files map[string]*fsMessageStoreFile // indexed by entity @@ -90,7 +92,7 @@ var _ messageStore = (*fsMessageStore)(nil) var _ chatHistoryMessageStore = (*fsMessageStore)(nil) var _ searchMessageStore = (*fsMessageStore)(nil) -func newFSMessageStore(root string, user *User) *fsMessageStore { +func newFSMessageStore(root string, user *database.User) *fsMessageStore { return &fsMessageStore{ root: filepath.Join(root, escapeFilename(user.Username)), user: user, @@ -98,14 +100,14 @@ func newFSMessageStore(root string, user *User) *fsMessageStore { } } -func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string { +func (ms *fsMessageStore) logPath(network *database.Network, entity string, t time.Time) string { year, month, day := t.Date() filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename) } // nextMsgID queries the message ID for the next message to be written to f. -func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) { +func nextFSMsgID(network *database.Network, entity string, t time.Time, f *os.File) (string, error) { offset, err := f.Seek(0, io.SeekEnd) if err != nil { return "", fmt.Errorf("failed to query next FS message ID: %v", err) @@ -113,7 +115,7 @@ func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (stri return formatFSMsgID(network.ID, entity, t, offset), nil } -func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { +func (ms *fsMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) { p := ms.logPath(network, entity, t) fi, err := os.Stat(p) if os.IsNotExist(err) { @@ -124,7 +126,7 @@ func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil } -func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { +func (ms *fsMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) { s := formatMessage(msg) if s == "" { return "", nil @@ -253,7 +255,7 @@ func formatMessage(msg *irc.Message) string { } } -func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) { +func (ms *fsMessageStore) parseMessage(line string, network *database.Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) { var hour, minute, second int _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second) if err != nil { @@ -380,7 +382,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *Network, entity str // our nickname in the logs, so grab it from the network settings. // Not very accurate since this may not match our nick at the time // the message was received, but we can't do a lot better. - entity = GetNick(ms.user, network) + entity = database.GetNick(ms.user, network) } params = []string{entity, text} } @@ -399,7 +401,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *Network, entity str return msg, t, nil } -func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) { +func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) { path := ms.logPath(network, entity, ref) f, err := os.Open(path) if err != nil { @@ -458,7 +460,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, r } } -func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) { +func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) { path := ms.logPath(network, entity, ref) f, err := os.Open(path) if err != nil { @@ -493,7 +495,7 @@ func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, re return history, nil } -func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) { +func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) { if start.IsZero() { start = time.Now() } else { @@ -526,11 +528,11 @@ func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *Network, e return messages[remaining:], nil } -func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { +func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { return ms.getBeforeTime(ctx, network, entity, start, end, limit, events, nil) } -func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) { +func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) { start = start.In(time.Local) if end.IsZero() { end = time.Now() @@ -562,11 +564,11 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *Network, en return messages, nil } -func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { +func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { return ms.getAfterTime(ctx, network, entity, start, end, limit, events, nil) } -func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { +func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) { var afterTime time.Time var afterOffset int64 if id != "" { @@ -614,7 +616,7 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, en return history[remaining:], nil } -func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { +func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { start = start.In(time.Local) end = end.In(time.Local) rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) @@ -693,7 +695,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, sta return targets, nil } -func (ms *fsMessageStore) Search(ctx context.Context, network *Network, opts searchOptions) ([]*irc.Message, error) { +func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts searchOptions) ([]*irc.Message, error) { text := strings.ToLower(opts.text) selector := func(m *irc.Message) bool { if opts.from != "" && m.User != opts.from { @@ -711,7 +713,7 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *Network, opts sea } } -func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error { +func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *database.Network) error { oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName())) newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName())) // Avoid loosing data by overwriting an existing directory diff --git a/msgstore_memory.go b/msgstore_memory.go index 4bac476..02158c0 100644 --- a/msgstore_memory.go +++ b/msgstore_memory.go @@ -7,6 +7,8 @@ import ( "git.sr.ht/~sircmpwn/go-bare" "gopkg.in/irc.v3" + + "git.sr.ht/~emersion/soju/database" ) const messageRingBufferCap = 4096 @@ -55,7 +57,7 @@ func (ms *memoryMessageStore) Close() error { return nil } -func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer { +func (ms *memoryMessageStore) get(network *database.Network, entity string) *messageRingBuffer { k := ringBufferKey{networkID: network.ID, entity: entity} if rb, ok := ms.buffers[k]; ok { return rb @@ -65,7 +67,7 @@ func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingB return rb } -func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { +func (ms *memoryMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) { var seq uint64 k := ringBufferKey{networkID: network.ID, entity: entity} if rb, ok := ms.buffers[k]; ok { @@ -74,7 +76,7 @@ func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time. return formatMemoryMsgID(network.ID, entity, seq), nil } -func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { +func (ms *memoryMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) { switch msg.Command { case "PRIVMSG", "NOTICE": // Only append these messages, because LoadLatestID shouldn't return @@ -94,7 +96,7 @@ func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.M return formatMemoryMsgID(network.ID, entity, seq), nil } -func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { +func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) { _, _, seq, err := parseMemoryMsgID(id) if err != nil { return nil, err diff --git a/server.go b/server.go index ace4aa1..b4ecb0b 100644 --- a/server.go +++ b/server.go @@ -20,6 +20,7 @@ import ( "nhooyr.io/websocket" "git.sr.ht/~emersion/soju/config" + "git.sr.ht/~emersion/soju/database" ) // TODO: make configurable @@ -141,7 +142,7 @@ type Server struct { MetricsRegistry prometheus.Registerer // can be nil config atomic.Value // *Config - db Database + db database.Database stopWG sync.WaitGroup lock sync.Mutex @@ -161,7 +162,7 @@ type Server struct { } } -func NewServer(db Database) *Server { +func NewServer(db database.Database) *Server { srv := &Server{ Logger: NewLogger(log.Writer(), true), db: db, @@ -273,7 +274,7 @@ func (s *Server) Shutdown() { } } -func (s *Server) createUser(ctx context.Context, user *User) (*user, error) { +func (s *Server) createUser(ctx context.Context, user *database.User) (*user, error) { s.lock.Lock() defer s.lock.Unlock() @@ -304,7 +305,7 @@ func (s *Server) getUser(name string) *user { return u } -func (s *Server) addUserLocked(user *User) *user { +func (s *Server) addUserLocked(user *database.User) *user { s.Logger.Printf("starting bouncer for user %q", user.Username) u := newUser(s, user) s.users[u.Username] = u diff --git a/server_test.go b/server_test.go index b163e0e..3dbb934 100644 --- a/server_test.go +++ b/server_test.go @@ -3,10 +3,13 @@ package soju import ( "context" "net" + "os" "testing" "golang.org/x/crypto/bcrypt" "gopkg.in/irc.v3" + + "git.sr.ht/~emersion/soju/database" ) var testServerPrefix = &irc.Prefix{Name: "soju-test-server"} @@ -16,34 +19,35 @@ const ( testPassword = testUsername ) -func createTempSqliteDB(t *testing.T) Database { - db, err := OpenDB("sqlite3", ":memory:") +func createTempSqliteDB(t *testing.T) database.Database { + db, err := database.OpenTempSqliteDB() if err != nil { t.Fatalf("failed to create temporary SQLite database: %v", err) } - // :memory: will open a separate database for each new connection. Make - // sure the sql package only uses a single connection. An alternative - // solution is to use "file::memory:?cache=shared". - db.(*SqliteDB).db.SetMaxOpenConns(1) return db } -func createTempPostgresDB(t *testing.T) Database { - db := &PostgresDB{db: openTempPostgresDB(t)} - if err := db.upgrade(); err != nil { - t.Fatalf("failed to upgrade PostgreSQL database: %v", err) +func createTempPostgresDB(t *testing.T) database.Database { + source, ok := os.LookupEnv("SOJU_TEST_POSTGRES") + if !ok { + t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests") + } + + db, err := database.OpenTempPostgresDB(source) + if err != nil { + t.Fatalf("failed to create temporary PostgreSQL database: %v", err) } return db } -func createTestUser(t *testing.T, db Database) *User { +func createTestUser(t *testing.T, db database.Database) *database.User { hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost) if err != nil { t.Fatalf("failed to generate bcrypt hash: %v", err) } - record := &User{Username: testUsername, Password: string(hashed)} + record := &database.User{Username: testUsername, Password: string(hashed)} if err := db.StoreUser(context.Background(), record); err != nil { t.Fatalf("failed to store test user: %v", err) } @@ -57,13 +61,13 @@ func createTestDownstream(t *testing.T, srv *Server) ircConn { return newNetIRCConn(c2) } -func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) { +func createTestUpstream(t *testing.T, db database.Database, user *database.User) (*database.Network, net.Listener) { ln, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to create TCP listener: %v", err) } - network := &Network{ + network := &database.Network{ Name: "testnet", Addr: "irc+insecure://" + ln.Addr().String(), Nick: user.Username, @@ -95,7 +99,7 @@ func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message { return msg } -func registerDownstreamConn(t *testing.T, c ircConn, network *Network) { +func registerDownstreamConn(t *testing.T, c ircConn, network *database.Network) { c.WriteMessage(&irc.Message{ Command: "PASS", Params: []string{testPassword}, @@ -151,7 +155,7 @@ func registerUpstreamConn(t *testing.T, c ircConn) { }) } -func testServer(t *testing.T, db Database) { +func testServer(t *testing.T, db database.Database) { user := createTestUser(t, db) network, upstream := createTestUpstream(t, db, user) defer upstream.Close() diff --git a/service.go b/service.go index 2a94ee9..9eaf3be 100644 --- a/service.go +++ b/service.go @@ -17,6 +17,8 @@ import ( "golang.org/x/crypto/bcrypt" "gopkg.in/irc.v3" + + "git.sr.ht/~emersion/soju/database" ) const serviceNick = "BouncerServ" @@ -447,7 +449,7 @@ func newNetworkFlagSet() *networkFlagSet { return fs } -func (fs *networkFlagSet) update(network *Network) error { +func (fs *networkFlagSet) update(network *database.Network) error { if fs.Addr != nil { if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 { scheme := addrParts[0] @@ -508,7 +510,7 @@ func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params return fmt.Errorf("flag -addr is required") } - record := &Network{ + record := &database.Network{ Addr: *fs.Addr, Enabled: true, } @@ -833,7 +835,7 @@ func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) return fmt.Errorf("failed to hash password: %v", err) } - user := &User{ + user := &database.User{ Username: *username, Password: string(hashed), Realname: *realname, @@ -971,9 +973,9 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params n := 0 sendNetwork := func(net *network) { - var channels []*Channel + var channels []*database.Channel for _, entry := range net.channels.innerMap { - channels = append(channels, entry.value.(*Channel)) + channels = append(channels, entry.value.(*database.Channel)) } sort.Slice(channels, func(i, j int) bool { @@ -1031,6 +1033,20 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params return nil } +func parseFilter(filter string) (database.MessageFilter, error) { + switch filter { + case "default": + return database.FilterDefault, nil + case "none": + return database.FilterNone, nil + case "highlight": + return database.FilterHighlight, nil + case "message": + return database.FilterMessage, nil + } + return 0, fmt.Errorf("unknown filter: %q", filter) +} + type channelFlagSet struct { *flag.FlagSet RelayDetached, ReattachOn, DetachAfter, DetachOn *string @@ -1045,7 +1061,7 @@ func newChannelFlagSet() *channelFlagSet { return fs } -func (fs *channelFlagSet) update(channel *Channel) error { +func (fs *channelFlagSet) update(channel *database.Channel) error { if fs.RelayDetached != nil { filter, err := parseFilter(*fs.RelayDetached) if err != nil { diff --git a/upstream.go b/upstream.go index 69f563c..cc8a7e3 100644 --- a/upstream.go +++ b/upstream.go @@ -17,6 +17,8 @@ import ( "github.com/emersion/go-sasl" "gopkg.in/irc.v3" + + "git.sr.ht/~emersion/soju/database" ) // permanentUpstreamCaps is the static list of upstream capabilities always @@ -510,7 +512,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } highlight := uc.network.isHighlight(msg) - if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) { + if ch.DetachOn == database.FilterMessage || ch.DetachOn == database.FilterDefault || (ch.DetachOn == database.FilterHighlight && highlight) { uc.updateChannelAutoDetach(target) } } @@ -765,7 +767,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if uc.network.channels.Len() > 0 { var channels, keys []string for _, entry := range uc.network.channels.innerMap { - ch := entry.value.(*Channel) + ch := entry.value.(*database.Channel) channels = append(channels, ch.Name) keys = append(keys, ch.Key) } @@ -1553,7 +1555,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } // Check if the nick we want is now free - wantNick := GetNick(&uc.user.User, &uc.network.Network) + wantNick := database.GetNick(&uc.user.User, &uc.network.Network) wantNickCM := uc.network.casemap(wantNick) if !online && uc.nickCM != wantNickCM { found := false @@ -1796,13 +1798,13 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } -func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) { +func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *database.Channel, msg *irc.Message) { if uc.network.detachedMessageNeedsRelay(ch, msg) { uc.forEachDownstream(func(dc *downstreamConn) { dc.relayDetachedMessage(uc.network, msg) }) } - if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) { + if ch.ReattachOn == database.FilterMessage || (ch.ReattachOn == database.FilterHighlight && uc.network.isHighlight(msg)) { uc.network.attach(ctx, ch) if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { uc.logger.Printf("failed to update channel %q: %v", ch.Name, err) @@ -1960,10 +1962,10 @@ func splitSpace(s string) []string { } func (uc *upstreamConn) register(ctx context.Context) { - uc.nick = GetNick(&uc.user.User, &uc.network.Network) + uc.nick = database.GetNick(&uc.user.User, &uc.network.Network) uc.nickCM = uc.network.casemap(uc.nick) - uc.username = GetUsername(&uc.user.User, &uc.network.Network) - uc.realname = GetRealname(&uc.user.User, &uc.network.Network) + uc.username = database.GetUsername(&uc.user.User, &uc.network.Network) + uc.realname = database.GetRealname(&uc.user.User, &uc.network.Network) uc.SendMessage(ctx, &irc.Message{ Command: "CAP", @@ -2193,7 +2195,7 @@ func (uc *upstreamConn) updateMonitor() { } }) - wantNick := GetNick(&uc.user.User, &uc.network.Network) + wantNick := database.GetNick(&uc.user.User, &uc.network.Network) wantNickCM := uc.network.casemap(wantNick) if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) { addList = append(addList, wantNickCM) diff --git a/user.go b/user.go index 9b7879f..e8847e8 100644 --- a/user.go +++ b/user.go @@ -14,6 +14,8 @@ import ( "time" "gopkg.in/irc.v3" + + "git.sr.ht/~emersion/soju/database" ) type event interface{} @@ -123,7 +125,7 @@ func (ds deliveredStore) ForEachClient(f func(clientName string)) { } type network struct { - Network + database.Network user *user logger Logger stopped chan struct{} @@ -135,7 +137,7 @@ type network struct { casemap casemapping } -func newNetwork(user *user, record *Network, channels []Channel) *network { +func newNetwork(user *user, record *database.Network, channels []database.Channel) *network { logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())} m := channelCasemapMap{newCasemapMap(0)} @@ -176,7 +178,7 @@ func (net *network) isStopped() bool { } } -func userIdent(u *User) string { +func userIdent(u *database.User) string { // The ident is a string we will send to upstream servers in clear-text. // For privacy reasons, make sure it doesn't expose any meaningful user // metadata. We just use the base64-encoded hashed ID, so that people don't @@ -278,7 +280,7 @@ func (net *network) stop() { } } -func (net *network) detach(ch *Channel) { +func (net *network) detach(ch *database.Channel) { if ch.Detached { return } @@ -312,7 +314,7 @@ func (net *network) detach(ch *Channel) { }) } -func (net *network) attach(ctx context.Context, ch *Channel) { +func (net *network) attach(ctx context.Context, ch *database.Channel) { if !ch.Detached { return } @@ -388,13 +390,13 @@ func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName return } - var receipts []DeliveryReceipt + var receipts []database.DeliveryReceipt net.delivered.ForEachTarget(func(target string) { msgID := net.delivered.LoadID(target, clientName) if msgID == "" { return } - receipts = append(receipts, DeliveryReceipt{ + receipts = append(receipts, database.DeliveryReceipt{ Target: target, InternalMsgID: msgID, }) @@ -421,9 +423,9 @@ func (net *network) isHighlight(msg *irc.Message) bool { return msg.Prefix.Name != nick && isHighlight(text, nick) } -func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool { +func (net *network) detachedMessageNeedsRelay(ch *database.Channel, msg *irc.Message) bool { highlight := net.isHighlight(msg) - return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight) + return ch.RelayDetached == database.FilterMessage || ((ch.RelayDetached == database.FilterHighlight || ch.RelayDetached == database.FilterDefault) && highlight) } func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) { @@ -443,7 +445,7 @@ func (net *network) autoSaveSASLPlain(ctx context.Context, username, password st } type user struct { - User + database.User srv *Server logger Logger @@ -455,7 +457,7 @@ type user struct { msgStore messageStore } -func newUser(srv *Server, record *User) *user { +func newUser(srv *Server, record *database.User) *user { logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} var msgStore messageStore @@ -817,7 +819,7 @@ func (u *user) removeNetwork(network *network) { panic("tried to remove a non-existing network") } -func (u *user) checkNetwork(record *Network) error { +func (u *user) checkNetwork(record *database.Network) error { url, err := record.URL() if err != nil { return err @@ -867,7 +869,7 @@ func (u *user) checkNetwork(record *Network) error { return nil } -func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) { +func (u *user) createNetwork(ctx context.Context, record *database.Network) (*network, error) { if record.ID != 0 { panic("tried creating an already-existing network") } @@ -894,7 +896,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er return network, nil } -func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) { +func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*network, error) { if record.ID == 0 { panic("tried updating a new network") } @@ -920,9 +922,9 @@ func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, er // Most network changes require us to re-connect to the upstream server - channels := make([]Channel, 0, network.channels.Len()) + channels := make([]database.Channel, 0, network.channels.Len()) for _, entry := range network.channels.innerMap { - ch := entry.value.(*Channel) + ch := entry.value.(*database.Channel) channels = append(channels, *ch) } @@ -992,7 +994,7 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error { return nil } -func (u *user) updateUser(ctx context.Context, record *User) error { +func (u *user) updateUser(ctx context.Context, record *database.User) error { if u.ID != record.ID { panic("ID mismatch when updating user") } @@ -1005,7 +1007,7 @@ func (u *user) updateUser(ctx context.Context, record *User) error { if realnameUpdated { // Re-connect to networks which use the default realname - var needUpdate []Network + var needUpdate []database.Network for _, net := range u.networks { if net.Realname != "" { continue @@ -1016,7 +1018,7 @@ func (u *user) updateUser(ctx context.Context, record *User) error { if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") { uc.SendMessage(ctx, &irc.Message{ Command: "SETNAME", - Params: []string{GetRealname(&u.User, &net.Network)}, + Params: []string{database.GetRealname(&u.User, &net.Network)}, }) continue }