Introduce a database package

This commit is contained in:
Simon Ser 2022-05-09 12:34:43 +02:00
parent 27f21eab94
commit 3a7dee8128
18 changed files with 206 additions and 152 deletions

View file

@ -24,6 +24,7 @@ import (
"git.sr.ht/~emersion/soju"
"git.sr.ht/~emersion/soju/config"
"git.sr.ht/~emersion/soju/database"
)
// TCP keep-alive interval for downstream TCP connections
@ -116,7 +117,7 @@ func main() {
log.Printf("failed to bump max number of opened files: %v", err)
}
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
@ -308,7 +309,7 @@ func main() {
log.Printf("server listening on %q", listen)
}
if db, ok := db.(soju.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil {
if db, ok := db.(database.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil {
if err := db.RegisterMetrics(srv.MetricsRegistry); err != nil {
log.Fatalf("failed to register database metrics: %v", err)
}

View file

@ -9,10 +9,11 @@ import (
"log"
"os"
"git.sr.ht/~emersion/soju"
"git.sr.ht/~emersion/soju/config"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/ssh/terminal"
"git.sr.ht/~emersion/soju/config"
"git.sr.ht/~emersion/soju/database"
)
const usage = `usage: sojuctl [-config path] <action> [options...]
@ -44,7 +45,7 @@ func main() {
cfg = config.Defaults()
}
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
@ -73,7 +74,7 @@ func main() {
log.Fatalf("failed to hash password: %v", err)
}
user := soju.User{
user := database.User{
Username: username,
Password: string(hashed),
Admin: *admin,

View file

@ -12,8 +12,8 @@ import (
"strings"
"unicode"
"git.sr.ht/~emersion/soju"
"git.sr.ht/~emersion/soju/config"
"git.sr.ht/~emersion/soju/database"
)
const usage = `usage: znc-import [options...] <znc config path>
@ -64,7 +64,7 @@ func main() {
ctx := context.Background()
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
@ -86,7 +86,7 @@ func main() {
if err != nil {
log.Fatalf("failed to list users in DB: %v", err)
}
existingUsers := make(map[string]*soju.User, len(l))
existingUsers := make(map[string]*database.User, len(l))
for i, u := range l {
existingUsers[u.Username] = &l[i]
}
@ -107,7 +107,7 @@ func main() {
log.Printf("user %q: updating existing user", username)
} else {
// "!!" is an invalid crypt format, thus disables password auth
u = &soju.User{Username: username, Password: "!!"}
u = &database.User{Username: username, Password: "!!"}
usersCreated++
log.Printf("user %q: creating new user", username)
}
@ -123,7 +123,7 @@ func main() {
if err != nil {
log.Fatalf("failed to list networks for user %q: %v", username, err)
}
existingNetworks := make(map[string]*soju.Network, len(l))
existingNetworks := make(map[string]*database.Network, len(l))
for i, n := range l {
existingNetworks[n.GetName()] = &l[i]
}
@ -175,7 +175,7 @@ func main() {
if ok {
logger.Printf("updating existing network")
} else {
n = &soju.Network{Name: netName}
n = &database.Network{Name: netName}
logger.Printf("creating new network")
}
@ -194,7 +194,7 @@ func main() {
if err != nil {
logger.Fatalf("failed to list channels: %v", err)
}
existingChannels := make(map[string]*soju.Channel, len(l))
existingChannels := make(map[string]*database.Channel, len(l))
for i, ch := range l {
existingChannels[ch.Name] = &l[i]
}
@ -213,7 +213,7 @@ func main() {
if ok {
logger.Printf("channel %q: updating existing channel", chName)
} else {
ch = &soju.Channel{Name: chName}
ch = &database.Channel{Name: chName}
logger.Printf("channel %q: creating new channel", chName)
}

View file

@ -1,4 +1,4 @@
package soju
package database
import (
"context"
@ -38,7 +38,7 @@ type MetricsCollectorDatabase interface {
RegisterMetrics(r prometheus.Registerer) error
}
func OpenDB(driver, source string) (Database, error) {
func Open(driver, source string) (Database, error) {
switch driver {
case "sqlite3":
return OpenSqliteDB(source)
@ -149,20 +149,6 @@ const (
FilterMessage
)
func parseFilter(filter string) (MessageFilter, error) {
switch filter {
case "default":
return FilterDefault, nil
case "none":
return FilterNone, nil
case "highlight":
return FilterHighlight, nil
case "message":
return FilterMessage, nil
}
return 0, fmt.Errorf("unknown filter: %q", filter)
}
type Channel struct {
ID int64
Name string

View file

@ -1,4 +1,4 @@
package soju
package database
import (
"context"
@ -127,6 +127,37 @@ func OpenPostgresDB(source string) (Database, error) {
return db, nil
}
func openTempPostgresDB(source string) (*sql.DB, error) {
db, err := sql.Open("postgres", source)
if err != nil {
return nil, fmt.Errorf("failed to connect to PostgreSQL: %v", err)
}
// Store all tables in a temporary schema which will be dropped when the
// connection to PostgreSQL is closed.
db.SetMaxOpenConns(1)
if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
return nil, fmt.Errorf("failed to set PostgreSQL search_path: %v", err)
}
return db, nil
}
func OpenTempPostgresDB(source string) (Database, error) {
sqlPostgresDB, err := openTempPostgresDB(source)
if err != nil {
return nil, err
}
db := &PostgresDB{db: sqlPostgresDB}
if err := db.upgrade(); err != nil {
sqlPostgresDB.Close()
return nil, err
}
return db, nil
}
func (db *PostgresDB) upgrade() error {
tx, err := db.db.Begin()
if err != nil {

View file

@ -1,7 +1,6 @@
package soju
package database
import (
"database/sql"
"os"
"testing"
)
@ -68,29 +67,17 @@ CREATE TABLE "DeliveryReceipt" (
);
`
func openTempPostgresDB(t *testing.T) *sql.DB {
func TestPostgresMigrations(t *testing.T) {
source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
if !ok {
t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
}
db, err := sql.Open("postgres", source)
sqlDB, err := openTempPostgresDB(source)
if err != nil {
t.Fatalf("failed to connect to PostgreSQL: %v", err)
t.Fatalf("openTempPostgresDB() failed: %v", err)
}
// Store all tables in a temporary schema which will be dropped when the
// connection to PostgreSQL is closed.
db.SetMaxOpenConns(1)
if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
t.Fatalf("failed to set PostgreSQL search_path: %v", err)
}
return db
}
func TestPostgresMigrations(t *testing.T) {
sqlDB := openTempPostgresDB(t)
if _, err := sqlDB.Exec(postgresV0Schema); err != nil {
t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
}

View file

@ -1,4 +1,4 @@
package soju
package database
import (
"context"
@ -15,6 +15,12 @@ import (
const sqliteQueryTimeout = 5 * time.Second
const sqliteTimeLayout = "2006-01-02T15:04:05.000Z"
func formatSqliteTime(t time.Time) string {
return t.UTC().Format(sqliteTimeLayout)
}
const sqliteSchema = `
CREATE TABLE User (
id INTEGER PRIMARY KEY,
@ -212,6 +218,13 @@ func OpenSqliteDB(source string) (Database, error) {
return db, nil
}
func OpenTempSqliteDB() (Database, error) {
// :memory: will open a separate database for each new connection. Make
// sure the sql package only uses a single connection via SetMaxOpenConns.
// An alternative solution is to use "file::memory:?cache=shared".
return OpenSqliteDB(":memory:")
}
func (db *SqliteDB) Close() error {
return db.db.Close()
}
@ -732,7 +745,7 @@ func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name st
}
return nil, err
}
if t, err := time.Parse(serverTimeLayout, timestamp); err != nil {
if t, err := time.Parse(sqliteTimeLayout, timestamp); err != nil {
return nil, err
} else {
receipt.Timestamp = t
@ -746,7 +759,7 @@ func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, recei
args := []interface{}{
sql.Named("id", receipt.ID),
sql.Named("timestamp", formatServerTime(receipt.Timestamp)),
sql.Named("timestamp", formatSqliteTime(receipt.Timestamp)),
sql.Named("network", networkID),
sql.Named("target", receipt.Target),
}

View file

@ -1,4 +1,4 @@
package soju
package database
import (
"database/sql"

View file

@ -16,6 +16,8 @@ import (
"github.com/emersion/go-sasl"
"golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
)
type ircError struct {
@ -100,7 +102,7 @@ func parseBouncerNetID(subcommand, s string) (int64, error) {
return id, nil
}
func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) {
func fillNetworkAddrAttrs(attrs irc.Tags, network *database.Network) {
u, err := network.URL()
if err != nil {
return
@ -132,13 +134,13 @@ func getNetworkAttrs(network *network) irc.Tags {
attrs := irc.Tags{
"name": irc.TagValue(network.GetName()),
"state": irc.TagValue(state),
"nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)),
"nickname": irc.TagValue(database.GetNick(&network.user.User, &network.Network)),
}
if network.Username != "" {
attrs["username"] = irc.TagValue(network.Username)
}
if realname := GetRealname(&network.user.User, &network.Network); realname != "" {
if realname := database.GetRealname(&network.user.User, &network.Network); realname != "" {
attrs["realname"] = irc.TagValue(realname)
}
@ -169,7 +171,7 @@ func networkAddrFromAttrs(attrs irc.Tags) string {
return addr
}
func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error {
func updateNetworkAttrs(record *database.Network, attrs irc.Tags, subcommand string) error {
addrAttrs := irc.Tags{}
fillNetworkAddrAttrs(addrAttrs, record)
@ -414,7 +416,7 @@ func isOurNick(net *network, nick string) bool {
// know whether this name is our nickname. Best-effort: use the network's
// configured nickname and hope it was the one being used when we were
// connected.
return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network))
return net.casemap(nick) == net.casemap(database.GetNick(&net.user.User, &net.Network))
}
// marshalEntity converts an upstream entity name (ie. channel or nick) into a
@ -1146,9 +1148,9 @@ func (dc *downstreamConn) updateNick() {
if uc := dc.upstream(); uc != nil {
nick = uc.nick
} else if dc.network != nil {
nick = GetNick(&dc.user.User, &dc.network.Network)
nick = database.GetNick(&dc.user.User, &dc.network.Network)
} else {
nick = GetNick(&dc.user.User, nil)
nick = database.GetNick(&dc.user.User, nil)
}
if nick == dc.nick {
@ -1201,9 +1203,9 @@ func (dc *downstreamConn) updateRealname() {
if uc := dc.upstream(); uc != nil {
realname = uc.realname
} else if dc.network != nil {
realname = GetRealname(&dc.user.User, &dc.network.Network)
realname = database.GetRealname(&dc.user.User, &dc.network.Network)
} else {
realname = GetRealname(&dc.user.User, nil)
realname = database.GetRealname(&dc.user.User, nil)
}
if realname != dc.realname {
@ -1439,7 +1441,7 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
dc.logger.Printf("auto-saving network %q", dc.registration.networkName)
var err error
network, err = dc.user.createNetwork(ctx, &Network{
network, err = dc.user.createNetwork(ctx, &database.Network{
Addr: dc.registration.networkName,
Nick: nick,
Enabled: true,
@ -1475,7 +1477,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
if uc := dc.upstream(); uc != nil {
dc.nick = uc.nick
} else if dc.network != nil {
dc.nick = GetNick(&dc.user.User, &dc.network.Network)
dc.nick = database.GetNick(&dc.user.User, &dc.network.Network)
} else {
dc.nick = dc.user.Username
}
@ -1931,7 +1933,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
uc.network.attach(ctx, ch)
} else {
ch = &Channel{
ch = &database.Channel{
Name: upstreamName,
Key: key,
}
@ -1963,7 +1965,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if ch != nil {
uc.network.detach(ch)
} else {
ch = &Channel{
ch = &database.Channel{
Name: name,
Detached: true,
}
@ -2911,7 +2913,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
}}
} else if r == nil {
r = &ReadReceipt{
r = &database.ReadReceipt{
Target: entityCM,
}
}
@ -3082,7 +3084,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
attrs := irc.ParseTags(attrsStr)
record := &Network{Nick: dc.nick, Enabled: true}
record := &database.Network{Nick: dc.nick, Enabled: true}
if err := updateNetworkAttrs(record, attrs, subcommand); err != nil {
return err
}

6
irc.go
View file

@ -9,6 +9,8 @@ import (
"unicode/utf8"
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
)
const (
@ -653,12 +655,12 @@ func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
type channelCasemapMap struct{ casemapMap }
func (cm *channelCasemapMap) Value(name string) *Channel {
func (cm *channelCasemapMap) Value(name string) *database.Channel {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return nil
}
return entry.value.(*Channel)
return entry.value.(*database.Channel)
}
type membershipsCasemapMap struct{ casemapMap }

View file

@ -9,6 +9,8 @@ import (
"git.sr.ht/~sircmpwn/go-bare"
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
)
// messageStore is a per-user store for IRC messages.
@ -17,11 +19,11 @@ type messageStore interface {
// LastMsgID queries the last message ID for the given network, entity and
// date. The message ID returned may not refer to a valid message, but can be
// used in history queries.
LastMsgID(network *Network, entity string, t time.Time) (string, error)
LastMsgID(network *database.Network, entity string, t time.Time) (string, error)
// LoadLatestID queries the latest non-event messages for the given network,
// entity and date, up to a count of limit messages, sorted from oldest to newest.
LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error)
Append(network *Network, entity string, msg *irc.Message) (id string, err error)
LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error)
Append(network *database.Network, entity string, msg *irc.Message) (id string, err error)
}
type chatHistoryTarget struct {
@ -38,17 +40,17 @@ type chatHistoryMessageStore interface {
// It returns up to limit targets, starting from start and ending on end,
// both excluded. end may be before or after start.
// If events is false, only PRIVMSG/NOTICE messages are considered.
ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
// LoadBeforeTime loads up to limit messages before start down to end. The
// returned messages must be between and excluding the provided bounds.
// end is before start.
// If events is false, only PRIVMSG/NOTICE messages are considered.
LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
// LoadBeforeTime loads up to limit messages after start up to end. The
// returned messages must be between and excluding the provided bounds.
// end is after start.
// If events is false, only PRIVMSG/NOTICE messages are considered.
LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
LoadAfterTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
}
type searchOptions struct {
@ -66,7 +68,7 @@ type searchMessageStore interface {
messageStore
// Search returns messages matching the specified options.
Search(ctx context.Context, network *Network, search searchOptions) ([]*irc.Message, error)
Search(ctx context.Context, network *database.Network, search searchOptions) ([]*irc.Message, error)
}
type msgIDType uint

View file

@ -13,6 +13,8 @@ import (
"git.sr.ht/~sircmpwn/go-bare"
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
)
const (
@ -80,7 +82,7 @@ type fsMessageStoreFile struct {
// https://github.com/znc/znc/blob/master/modules/log.cpp
type fsMessageStore struct {
root string
user *User
user *database.User
// Write-only files used by Append
files map[string]*fsMessageStoreFile // indexed by entity
@ -90,7 +92,7 @@ var _ messageStore = (*fsMessageStore)(nil)
var _ chatHistoryMessageStore = (*fsMessageStore)(nil)
var _ searchMessageStore = (*fsMessageStore)(nil)
func newFSMessageStore(root string, user *User) *fsMessageStore {
func newFSMessageStore(root string, user *database.User) *fsMessageStore {
return &fsMessageStore{
root: filepath.Join(root, escapeFilename(user.Username)),
user: user,
@ -98,14 +100,14 @@ func newFSMessageStore(root string, user *User) *fsMessageStore {
}
}
func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string {
func (ms *fsMessageStore) logPath(network *database.Network, entity string, t time.Time) string {
year, month, day := t.Date()
filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename)
}
// nextMsgID queries the message ID for the next message to be written to f.
func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) {
func nextFSMsgID(network *database.Network, entity string, t time.Time, f *os.File) (string, error) {
offset, err := f.Seek(0, io.SeekEnd)
if err != nil {
return "", fmt.Errorf("failed to query next FS message ID: %v", err)
@ -113,7 +115,7 @@ func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (stri
return formatFSMsgID(network.ID, entity, t, offset), nil
}
func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
func (ms *fsMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) {
p := ms.logPath(network, entity, t)
fi, err := os.Stat(p)
if os.IsNotExist(err) {
@ -124,7 +126,7 @@ func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time
return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil
}
func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
func (ms *fsMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
s := formatMessage(msg)
if s == "" {
return "", nil
@ -253,7 +255,7 @@ func formatMessage(msg *irc.Message) string {
}
}
func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
func (ms *fsMessageStore) parseMessage(line string, network *database.Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
var hour, minute, second int
_, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second)
if err != nil {
@ -380,7 +382,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *Network, entity str
// our nickname in the logs, so grab it from the network settings.
// Not very accurate since this may not match our nick at the time
// the message was received, but we can't do a lot better.
entity = GetNick(ms.user, network)
entity = database.GetNick(ms.user, network)
}
params = []string{entity, text}
}
@ -399,7 +401,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *Network, entity str
return msg, t, nil
}
func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref)
f, err := os.Open(path)
if err != nil {
@ -458,7 +460,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, r
}
}
func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref)
f, err := os.Open(path)
if err != nil {
@ -493,7 +495,7 @@ func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, re
return history, nil
}
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
if start.IsZero() {
start = time.Now()
} else {
@ -526,11 +528,11 @@ func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *Network, e
return messages[remaining:], nil
}
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
return ms.getBeforeTime(ctx, network, entity, start, end, limit, events, nil)
}
func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
start = start.In(time.Local)
if end.IsZero() {
end = time.Now()
@ -562,11 +564,11 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *Network, en
return messages, nil
}
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
return ms.getAfterTime(ctx, network, entity, start, end, limit, events, nil)
}
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
var afterTime time.Time
var afterOffset int64
if id != "" {
@ -614,7 +616,7 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, en
return history[remaining:], nil
}
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
start = start.In(time.Local)
end = end.In(time.Local)
rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
@ -693,7 +695,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, sta
return targets, nil
}
func (ms *fsMessageStore) Search(ctx context.Context, network *Network, opts searchOptions) ([]*irc.Message, error) {
func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts searchOptions) ([]*irc.Message, error) {
text := strings.ToLower(opts.text)
selector := func(m *irc.Message) bool {
if opts.from != "" && m.User != opts.from {
@ -711,7 +713,7 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *Network, opts sea
}
}
func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error {
func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *database.Network) error {
oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName()))
newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName()))
// Avoid loosing data by overwriting an existing directory

View file

@ -7,6 +7,8 @@ import (
"git.sr.ht/~sircmpwn/go-bare"
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
)
const messageRingBufferCap = 4096
@ -55,7 +57,7 @@ func (ms *memoryMessageStore) Close() error {
return nil
}
func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer {
func (ms *memoryMessageStore) get(network *database.Network, entity string) *messageRingBuffer {
k := ringBufferKey{networkID: network.ID, entity: entity}
if rb, ok := ms.buffers[k]; ok {
return rb
@ -65,7 +67,7 @@ func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingB
return rb
}
func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
func (ms *memoryMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) {
var seq uint64
k := ringBufferKey{networkID: network.ID, entity: entity}
if rb, ok := ms.buffers[k]; ok {
@ -74,7 +76,7 @@ func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.
return formatMemoryMsgID(network.ID, entity, seq), nil
}
func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
func (ms *memoryMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
switch msg.Command {
case "PRIVMSG", "NOTICE":
// Only append these messages, because LoadLatestID shouldn't return
@ -94,7 +96,7 @@ func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.M
return formatMemoryMsgID(network.ID, entity, seq), nil
}
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
_, _, seq, err := parseMemoryMsgID(id)
if err != nil {
return nil, err

View file

@ -20,6 +20,7 @@ import (
"nhooyr.io/websocket"
"git.sr.ht/~emersion/soju/config"
"git.sr.ht/~emersion/soju/database"
)
// TODO: make configurable
@ -141,7 +142,7 @@ type Server struct {
MetricsRegistry prometheus.Registerer // can be nil
config atomic.Value // *Config
db Database
db database.Database
stopWG sync.WaitGroup
lock sync.Mutex
@ -161,7 +162,7 @@ type Server struct {
}
}
func NewServer(db Database) *Server {
func NewServer(db database.Database) *Server {
srv := &Server{
Logger: NewLogger(log.Writer(), true),
db: db,
@ -273,7 +274,7 @@ func (s *Server) Shutdown() {
}
}
func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
func (s *Server) createUser(ctx context.Context, user *database.User) (*user, error) {
s.lock.Lock()
defer s.lock.Unlock()
@ -304,7 +305,7 @@ func (s *Server) getUser(name string) *user {
return u
}
func (s *Server) addUserLocked(user *User) *user {
func (s *Server) addUserLocked(user *database.User) *user {
s.Logger.Printf("starting bouncer for user %q", user.Username)
u := newUser(s, user)
s.users[u.Username] = u

View file

@ -3,10 +3,13 @@ package soju
import (
"context"
"net"
"os"
"testing"
"golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
)
var testServerPrefix = &irc.Prefix{Name: "soju-test-server"}
@ -16,34 +19,35 @@ const (
testPassword = testUsername
)
func createTempSqliteDB(t *testing.T) Database {
db, err := OpenDB("sqlite3", ":memory:")
func createTempSqliteDB(t *testing.T) database.Database {
db, err := database.OpenTempSqliteDB()
if err != nil {
t.Fatalf("failed to create temporary SQLite database: %v", err)
}
// :memory: will open a separate database for each new connection. Make
// sure the sql package only uses a single connection. An alternative
// solution is to use "file::memory:?cache=shared".
db.(*SqliteDB).db.SetMaxOpenConns(1)
return db
}
func createTempPostgresDB(t *testing.T) Database {
db := &PostgresDB{db: openTempPostgresDB(t)}
if err := db.upgrade(); err != nil {
t.Fatalf("failed to upgrade PostgreSQL database: %v", err)
func createTempPostgresDB(t *testing.T) database.Database {
source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
if !ok {
t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
}
db, err := database.OpenTempPostgresDB(source)
if err != nil {
t.Fatalf("failed to create temporary PostgreSQL database: %v", err)
}
return db
}
func createTestUser(t *testing.T, db Database) *User {
func createTestUser(t *testing.T, db database.Database) *database.User {
hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("failed to generate bcrypt hash: %v", err)
}
record := &User{Username: testUsername, Password: string(hashed)}
record := &database.User{Username: testUsername, Password: string(hashed)}
if err := db.StoreUser(context.Background(), record); err != nil {
t.Fatalf("failed to store test user: %v", err)
}
@ -57,13 +61,13 @@ func createTestDownstream(t *testing.T, srv *Server) ircConn {
return newNetIRCConn(c2)
}
func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) {
func createTestUpstream(t *testing.T, db database.Database, user *database.User) (*database.Network, net.Listener) {
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to create TCP listener: %v", err)
}
network := &Network{
network := &database.Network{
Name: "testnet",
Addr: "irc+insecure://" + ln.Addr().String(),
Nick: user.Username,
@ -95,7 +99,7 @@ func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message {
return msg
}
func registerDownstreamConn(t *testing.T, c ircConn, network *Network) {
func registerDownstreamConn(t *testing.T, c ircConn, network *database.Network) {
c.WriteMessage(&irc.Message{
Command: "PASS",
Params: []string{testPassword},
@ -151,7 +155,7 @@ func registerUpstreamConn(t *testing.T, c ircConn) {
})
}
func testServer(t *testing.T, db Database) {
func testServer(t *testing.T, db database.Database) {
user := createTestUser(t, db)
network, upstream := createTestUpstream(t, db, user)
defer upstream.Close()

View file

@ -17,6 +17,8 @@ import (
"golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
)
const serviceNick = "BouncerServ"
@ -447,7 +449,7 @@ func newNetworkFlagSet() *networkFlagSet {
return fs
}
func (fs *networkFlagSet) update(network *Network) error {
func (fs *networkFlagSet) update(network *database.Network) error {
if fs.Addr != nil {
if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
scheme := addrParts[0]
@ -508,7 +510,7 @@ func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params
return fmt.Errorf("flag -addr is required")
}
record := &Network{
record := &database.Network{
Addr: *fs.Addr,
Enabled: true,
}
@ -833,7 +835,7 @@ func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string)
return fmt.Errorf("failed to hash password: %v", err)
}
user := &User{
user := &database.User{
Username: *username,
Password: string(hashed),
Realname: *realname,
@ -971,9 +973,9 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
n := 0
sendNetwork := func(net *network) {
var channels []*Channel
var channels []*database.Channel
for _, entry := range net.channels.innerMap {
channels = append(channels, entry.value.(*Channel))
channels = append(channels, entry.value.(*database.Channel))
}
sort.Slice(channels, func(i, j int) bool {
@ -1031,6 +1033,20 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
return nil
}
func parseFilter(filter string) (database.MessageFilter, error) {
switch filter {
case "default":
return database.FilterDefault, nil
case "none":
return database.FilterNone, nil
case "highlight":
return database.FilterHighlight, nil
case "message":
return database.FilterMessage, nil
}
return 0, fmt.Errorf("unknown filter: %q", filter)
}
type channelFlagSet struct {
*flag.FlagSet
RelayDetached, ReattachOn, DetachAfter, DetachOn *string
@ -1045,7 +1061,7 @@ func newChannelFlagSet() *channelFlagSet {
return fs
}
func (fs *channelFlagSet) update(channel *Channel) error {
func (fs *channelFlagSet) update(channel *database.Channel) error {
if fs.RelayDetached != nil {
filter, err := parseFilter(*fs.RelayDetached)
if err != nil {

View file

@ -17,6 +17,8 @@ import (
"github.com/emersion/go-sasl"
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
)
// permanentUpstreamCaps is the static list of upstream capabilities always
@ -510,7 +512,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}
highlight := uc.network.isHighlight(msg)
if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) {
if ch.DetachOn == database.FilterMessage || ch.DetachOn == database.FilterDefault || (ch.DetachOn == database.FilterHighlight && highlight) {
uc.updateChannelAutoDetach(target)
}
}
@ -765,7 +767,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if uc.network.channels.Len() > 0 {
var channels, keys []string
for _, entry := range uc.network.channels.innerMap {
ch := entry.value.(*Channel)
ch := entry.value.(*database.Channel)
channels = append(channels, ch.Name)
keys = append(keys, ch.Key)
}
@ -1553,7 +1555,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}
// Check if the nick we want is now free
wantNick := GetNick(&uc.user.User, &uc.network.Network)
wantNick := database.GetNick(&uc.user.User, &uc.network.Network)
wantNickCM := uc.network.casemap(wantNick)
if !online && uc.nickCM != wantNickCM {
found := false
@ -1796,13 +1798,13 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return nil
}
func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) {
func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *database.Channel, msg *irc.Message) {
if uc.network.detachedMessageNeedsRelay(ch, msg) {
uc.forEachDownstream(func(dc *downstreamConn) {
dc.relayDetachedMessage(uc.network, msg)
})
}
if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
if ch.ReattachOn == database.FilterMessage || (ch.ReattachOn == database.FilterHighlight && uc.network.isHighlight(msg)) {
uc.network.attach(ctx, ch)
if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
@ -1960,10 +1962,10 @@ func splitSpace(s string) []string {
}
func (uc *upstreamConn) register(ctx context.Context) {
uc.nick = GetNick(&uc.user.User, &uc.network.Network)
uc.nick = database.GetNick(&uc.user.User, &uc.network.Network)
uc.nickCM = uc.network.casemap(uc.nick)
uc.username = GetUsername(&uc.user.User, &uc.network.Network)
uc.realname = GetRealname(&uc.user.User, &uc.network.Network)
uc.username = database.GetUsername(&uc.user.User, &uc.network.Network)
uc.realname = database.GetRealname(&uc.user.User, &uc.network.Network)
uc.SendMessage(ctx, &irc.Message{
Command: "CAP",
@ -2193,7 +2195,7 @@ func (uc *upstreamConn) updateMonitor() {
}
})
wantNick := GetNick(&uc.user.User, &uc.network.Network)
wantNick := database.GetNick(&uc.user.User, &uc.network.Network)
wantNickCM := uc.network.casemap(wantNick)
if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) {
addList = append(addList, wantNickCM)

40
user.go
View file

@ -14,6 +14,8 @@ import (
"time"
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
)
type event interface{}
@ -123,7 +125,7 @@ func (ds deliveredStore) ForEachClient(f func(clientName string)) {
}
type network struct {
Network
database.Network
user *user
logger Logger
stopped chan struct{}
@ -135,7 +137,7 @@ type network struct {
casemap casemapping
}
func newNetwork(user *user, record *Network, channels []Channel) *network {
func newNetwork(user *user, record *database.Network, channels []database.Channel) *network {
logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
m := channelCasemapMap{newCasemapMap(0)}
@ -176,7 +178,7 @@ func (net *network) isStopped() bool {
}
}
func userIdent(u *User) string {
func userIdent(u *database.User) string {
// The ident is a string we will send to upstream servers in clear-text.
// For privacy reasons, make sure it doesn't expose any meaningful user
// metadata. We just use the base64-encoded hashed ID, so that people don't
@ -278,7 +280,7 @@ func (net *network) stop() {
}
}
func (net *network) detach(ch *Channel) {
func (net *network) detach(ch *database.Channel) {
if ch.Detached {
return
}
@ -312,7 +314,7 @@ func (net *network) detach(ch *Channel) {
})
}
func (net *network) attach(ctx context.Context, ch *Channel) {
func (net *network) attach(ctx context.Context, ch *database.Channel) {
if !ch.Detached {
return
}
@ -388,13 +390,13 @@ func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName
return
}
var receipts []DeliveryReceipt
var receipts []database.DeliveryReceipt
net.delivered.ForEachTarget(func(target string) {
msgID := net.delivered.LoadID(target, clientName)
if msgID == "" {
return
}
receipts = append(receipts, DeliveryReceipt{
receipts = append(receipts, database.DeliveryReceipt{
Target: target,
InternalMsgID: msgID,
})
@ -421,9 +423,9 @@ func (net *network) isHighlight(msg *irc.Message) bool {
return msg.Prefix.Name != nick && isHighlight(text, nick)
}
func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
func (net *network) detachedMessageNeedsRelay(ch *database.Channel, msg *irc.Message) bool {
highlight := net.isHighlight(msg)
return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
return ch.RelayDetached == database.FilterMessage || ((ch.RelayDetached == database.FilterHighlight || ch.RelayDetached == database.FilterDefault) && highlight)
}
func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
@ -443,7 +445,7 @@ func (net *network) autoSaveSASLPlain(ctx context.Context, username, password st
}
type user struct {
User
database.User
srv *Server
logger Logger
@ -455,7 +457,7 @@ type user struct {
msgStore messageStore
}
func newUser(srv *Server, record *User) *user {
func newUser(srv *Server, record *database.User) *user {
logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
var msgStore messageStore
@ -817,7 +819,7 @@ func (u *user) removeNetwork(network *network) {
panic("tried to remove a non-existing network")
}
func (u *user) checkNetwork(record *Network) error {
func (u *user) checkNetwork(record *database.Network) error {
url, err := record.URL()
if err != nil {
return err
@ -867,7 +869,7 @@ func (u *user) checkNetwork(record *Network) error {
return nil
}
func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
func (u *user) createNetwork(ctx context.Context, record *database.Network) (*network, error) {
if record.ID != 0 {
panic("tried creating an already-existing network")
}
@ -894,7 +896,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er
return network, nil
}
func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*network, error) {
if record.ID == 0 {
panic("tried updating a new network")
}
@ -920,9 +922,9 @@ func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, er
// Most network changes require us to re-connect to the upstream server
channels := make([]Channel, 0, network.channels.Len())
channels := make([]database.Channel, 0, network.channels.Len())
for _, entry := range network.channels.innerMap {
ch := entry.value.(*Channel)
ch := entry.value.(*database.Channel)
channels = append(channels, *ch)
}
@ -992,7 +994,7 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
return nil
}
func (u *user) updateUser(ctx context.Context, record *User) error {
func (u *user) updateUser(ctx context.Context, record *database.User) error {
if u.ID != record.ID {
panic("ID mismatch when updating user")
}
@ -1005,7 +1007,7 @@ func (u *user) updateUser(ctx context.Context, record *User) error {
if realnameUpdated {
// Re-connect to networks which use the default realname
var needUpdate []Network
var needUpdate []database.Network
for _, net := range u.networks {
if net.Realname != "" {
continue
@ -1016,7 +1018,7 @@ func (u *user) updateUser(ctx context.Context, record *User) error {
if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") {
uc.SendMessage(ctx, &irc.Message{
Command: "SETNAME",
Params: []string{GetRealname(&u.User, &net.Network)},
Params: []string{database.GetRealname(&u.User, &net.Network)},
})
continue
}