diff --git a/cmd/soju/main.go b/cmd/soju/main.go index ec036d0..4657e78 100644 --- a/cmd/soju/main.go +++ b/cmd/soju/main.go @@ -92,6 +92,7 @@ func loadConfig() (*config.Server, *soju.Config, error) { cfg := &soju.Config{ Hostname: raw.Hostname, Title: raw.Title, + LogDriver: raw.MsgStore.Driver, LogPath: raw.MsgStore.Source, HTTPOrigins: raw.HTTPOrigins, AcceptProxyIPs: raw.AcceptProxyIPs, diff --git a/config/config.go b/config/config.go index e3f9c7e..3579ca3 100644 --- a/config/config.go +++ b/config/config.go @@ -150,8 +150,7 @@ func parse(cfg scfg.Block) (*Server, error) { return nil, err } switch srv.MsgStore.Driver { - case "memory": - srv.MsgStore.Source = "" + case "memory", "db": case "fs": if err := d.ParseParams(nil, &srv.MsgStore.Source); err != nil { return nil, err diff --git a/contrib/migrate-logs/main.go b/contrib/migrate-logs/main.go new file mode 100644 index 0000000..014e119 --- /dev/null +++ b/contrib/migrate-logs/main.go @@ -0,0 +1,148 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "log" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "git.sr.ht/~emersion/soju/database" + "git.sr.ht/~emersion/soju/msgstore" +) + +const usage = `usage: migrate-logs + +Migrates existing Soju logs stored on disk to a Soju database. Database is specified +in the format of "driver:source" where driver is sqlite3 or postgres and source +is the string that would be in the Soju config file. + +Options: + + -help Show this help message +` + +var logRoot string + +func init() { + flag.Usage = func() { + fmt.Fprint(flag.CommandLine.Output(), usage) + } +} + +func migrateNetwork(ctx context.Context, db database.Database, user *database.User, network *database.Network) error { + log.Printf("Migrating logs for network: %s\n", network.Name) + + rootPath := filepath.Join(logRoot, msgstore.EscapeFilename(user.Username), msgstore.EscapeFilename(network.GetName())) + root, err := os.Open(rootPath) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return fmt.Errorf("unable to open network folder: %s", rootPath) + } + + // The returned targets are escaped, and there is no way to un-escape + // TODO: switch to ReadDir (Go 1.16+) + targets, err := root.Readdirnames(0) + root.Close() + if err != nil { + return fmt.Errorf("unable to read network folder: %s", rootPath) + } + + for _, target := range targets { + log.Printf("Migrating logs for target: %s\n", target) + + // target is already escaped here + targetPath := filepath.Join(rootPath, target) + targetDir, err := os.Open(targetPath) + if err != nil { + return fmt.Errorf("unable to open target folder: %s", targetPath) + } + + entryNames, err := targetDir.Readdirnames(0) + targetDir.Close() + if err != nil { + return fmt.Errorf("unable to read target folder: %s", targetPath) + } + sort.Strings(entryNames) + + for _, entryName := range entryNames { + entryPath := filepath.Join(targetPath, entryName) + + var year, month, day int + _, err := fmt.Sscanf(entryName, "%04d-%02d-%02d.log", &year, &month, &day) + if err != nil { + return fmt.Errorf("invalid entry name: %s", entryName) + } + ref := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC) + + entry, err := os.Open(entryPath) + if err != nil { + return fmt.Errorf("unable to open entry: %s", entryPath) + } + sc := bufio.NewScanner(entry) + for sc.Scan() { + msg, _, err := msgstore.FSParseMessage(sc.Text(), user, network, target, ref, true) + if err != nil { + return fmt.Errorf("unable to parse entry: %s: %s", entryPath, sc.Text()) + } else if msg == nil { + continue + } + _, err = db.StoreMessage(ctx, network.ID, target, msg) + if err != nil { + return fmt.Errorf("unable to store message: %s: %s: %v", entryPath, sc.Text(), err) + } + } + if sc.Err() != nil { + return fmt.Errorf("unable to parse entry: %s: %v", entryPath, sc.Err()) + } + entry.Close() + } + } + return nil +} + +func main() { + flag.Parse() + + ctx := context.Background() + + logRoot = flag.Arg(0) + dbParams := strings.Split(flag.Arg(1), ":") + + if len(dbParams) != 2 { + log.Fatalf("database not properly specified: %s", flag.Arg(1)) + } + + db, err := database.Open(dbParams[0], dbParams[1]) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + defer db.Close() + + users, err := db.ListUsers(ctx) + if err != nil { + log.Fatalf("unable to get users: %v", err) + } + + for _, user := range users { + log.Printf("Migrating logs for user: %s\n", user.Username) + + networks, err := db.ListNetworks(ctx, user.ID) + if err != nil { + log.Fatalf("unable to get networks for user: #%d %s", user.ID, user.Username) + } + + for _, network := range networks { + if err := migrateNetwork(ctx, db, &user, &network); err != nil { + log.Fatalf("migrating %v: %v", network.Name, err) + } + } + } +} diff --git a/database/database.go b/database/database.go index 93d80b5..c680a5f 100644 --- a/database/database.go +++ b/database/database.go @@ -10,8 +10,25 @@ import ( "github.com/prometheus/client_golang/prometheus" "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v4" ) +type MessageTarget struct { + Name string + LatestMessage time.Time +} + +type MessageOptions struct { + AfterID int64 + AfterTime time.Time + BeforeTime time.Time + Limit int + Events bool + Sender string + Text string + TakeLast bool +} + type Database interface { Close() error Stats(ctx context.Context) (*DatabaseStats, error) @@ -41,6 +58,11 @@ type Database interface { ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error DeleteWebPushSubscription(ctx context.Context, id int64) error + + GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error) + StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) + ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) + ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error) } type MetricsCollectorDatabase interface { diff --git a/database/postgres.go b/database/postgres.go index b7bc53e..a4597d3 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -9,9 +9,11 @@ import ( "strings" "time" + "git.sr.ht/~emersion/soju/xirc" _ "github.com/lib/pq" "github.com/prometheus/client_golang/prometheus" promcollectors "github.com/prometheus/client_golang/prometheus/collectors" + "gopkg.in/irc.v4" ) const postgresQueryTimeout = 5 * time.Second @@ -112,6 +114,30 @@ CREATE TABLE "WebPushSubscription" ( key_p256dh TEXT, UNIQUE(network, endpoint) ); + +CREATE TABLE "MessageTarget" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target TEXT NOT NULL, + UNIQUE(network, target) +); + +CREATE TEXT SEARCH DICTIONARY "search_simple_dictionary" ( + TEMPLATE = pg_catalog.simple +); +CREATE TEXT SEARCH CONFIGURATION "search_simple" ( COPY = pg_catalog.simple ); +ALTER TEXT SEARCH CONFIGURATION "search_simple" ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH "search_simple_dictionary"; +CREATE TABLE "Message" ( + id SERIAL PRIMARY KEY, + target INTEGER NOT NULL REFERENCES "MessageTarget"(id) ON DELETE CASCADE, + raw TEXT NOT NULL, + time TIMESTAMP WITH TIME ZONE NOT NULL, + sender TEXT NOT NULL, + text TEXT, + text_search tsvector GENERATED ALWAYS AS (to_tsvector('search_simple', text)) STORED +); +CREATE INDEX "MessageIndex" ON "Message" (target, time); +CREATE INDEX "MessageSearchIndex" ON "Message" USING GIN (text_search); ` var postgresMigrations = []string{ @@ -173,6 +199,30 @@ var postgresMigrations = []string{ `ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`, `ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`, `ALTER TABLE "User" ADD COLUMN downstream_interacted_at TIMESTAMP WITH TIME ZONE`, + ` + CREATE TABLE "MessageTarget" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target TEXT NOT NULL, + UNIQUE(network, target) + ); + CREATE TEXT SEARCH DICTIONARY "search_simple_dictionary" ( + TEMPLATE = pg_catalog.simple + ); + CREATE TEXT SEARCH CONFIGURATION "search_simple" ( COPY = pg_catalog.simple ); + ALTER TEXT SEARCH CONFIGURATION "search_simple" ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH "search_simple_dictionary"; + CREATE TABLE "Message" ( + id SERIAL PRIMARY KEY, + target INTEGER NOT NULL REFERENCES "MessageTarget"(id) ON DELETE CASCADE, + raw TEXT NOT NULL, + time TIMESTAMP WITH TIME ZONE NOT NULL, + sender TEXT NOT NULL, + text TEXT, + text_search tsvector GENERATED ALWAYS AS (to_tsvector('search_simple', text)) STORED + ); + CREATE INDEX "MessageIndex" ON "Message" (target, time); + CREATE INDEX "MessageSearchIndex" ON "Message" USING GIN (text_search); + `, } type PostgresDB struct { @@ -847,6 +897,229 @@ func (db *PostgresDB) DeleteWebPushSubscription(ctx context.Context, id int64) e return err } +func (db *PostgresDB) GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var msgID int64 + row := db.db.QueryRowContext(ctx, ` + SELECT m.id FROM "Message" AS m, "MessageTarget" as t + WHERE t.network = $1 AND t.target = $2 AND m.target = t.id + ORDER BY m.time DESC LIMIT 1`, + networkID, + name, + ) + if err := row.Scan(&msgID); err != nil { + if err == sql.ErrNoRows { + return 0, nil + } + return 0, err + } + return msgID, nil +} + +func (db *PostgresDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var t time.Time + if tag, ok := msg.Tags["time"]; ok { + var err error + t, err = time.Parse(xirc.ServerTimeLayout, tag) + if err != nil { + return 0, fmt.Errorf("failed to parse message time tag: %v", err) + } + } else { + t = time.Now() + } + + var text sql.NullString + switch msg.Command { + case "PRIVMSG", "NOTICE": + if len(msg.Params) > 1 { + text.Valid = true + text.String = msg.Params[1] + } + } + + _, err := db.db.ExecContext(ctx, ` + INSERT INTO "MessageTarget" (network, target) + VALUES ($1, $2) + ON CONFLICT DO NOTHING`, + networkID, + name, + ) + if err != nil { + return 0, err + } + + var id int64 + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "Message" (target, raw, time, sender, text) + SELECT id, $1, $2, $3, $4 + FROM "MessageTarget" as t + WHERE network = $5 AND target = $6 + RETURNING id`, + msg.String(), + t, + msg.Name, + text, + networkID, + name, + ).Scan(&id) + if err != nil { + return 0, err + } + return id, nil +} + +func (db *PostgresDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + parameters := []interface{}{ + networkID, + } + query := ` + SELECT t.target, MAX(m.time) AS latest + FROM "Message" m, "MessageTarget" t + WHERE m.target = t.id AND t.network = $1 + ` + if !options.Events { + query += `AND m.text IS NOT NULL ` + } + query += ` + GROUP BY t.target + HAVING true + ` + if !options.AfterTime.IsZero() { + // compares time strings by lexicographical order + parameters = append(parameters, options.AfterTime) + query += fmt.Sprintf(`AND MAX(m.time) > $%d `, len(parameters)) + } + if !options.BeforeTime.IsZero() { + // compares time strings by lexicographical order + parameters = append(parameters, options.BeforeTime) + query += fmt.Sprintf(`AND MAX(m.time) < $%d `, len(parameters)) + } + if options.TakeLast { + query += `ORDER BY latest DESC ` + } else { + query += `ORDER BY latest ASC ` + } + parameters = append(parameters, options.Limit) + query += fmt.Sprintf(`LIMIT $%d`, len(parameters)) + + rows, err := db.db.QueryContext(ctx, query, parameters...) + if err != nil { + return nil, err + } + defer rows.Close() + + var l []MessageTarget + for rows.Next() { + var mt MessageTarget + if err := rows.Scan(&mt.Name, &mt.LatestMessage); err != nil { + return nil, err + } + + l = append(l, mt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + if options.TakeLast { + // We ordered by DESC to limit to the last lines. + // Reverse the list to order by ASC these last lines. + for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 { + l[i], l[j] = l[j], l[i] + } + } + + return l, nil +} + +func (db *PostgresDB) ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + parameters := []interface{}{ + networkID, + name, + } + query := ` + SELECT m.raw + FROM "Message" AS m, "MessageTarget" AS t + WHERE m.target = t.id AND t.network = $1 AND t.target = $2 ` + if options.AfterID > 0 { + parameters = append(parameters, options.AfterID) + query += fmt.Sprintf(`AND m.id > $%d `, len(parameters)) + } + if !options.AfterTime.IsZero() { + // compares time strings by lexicographical order + parameters = append(parameters, options.AfterTime) + query += fmt.Sprintf(`AND m.time > $%d `, len(parameters)) + } + if !options.BeforeTime.IsZero() { + // compares time strings by lexicographical order + parameters = append(parameters, options.BeforeTime) + query += fmt.Sprintf(`AND m.time < $%d `, len(parameters)) + } + if options.Sender != "" { + parameters = append(parameters, options.Sender) + query += fmt.Sprintf(`AND m.sender = $%d `, len(parameters)) + } + if options.Text != "" { + parameters = append(parameters, options.Text) + query += fmt.Sprintf(`AND text_search @@ plainto_tsquery('search_simple', $%d) `, len(parameters)) + } + if !options.Events { + query += `AND m.text IS NOT NULL ` + } + if options.TakeLast { + query += `ORDER BY m.time DESC ` + } else { + query += `ORDER BY m.time ASC ` + } + parameters = append(parameters, options.Limit) + query += fmt.Sprintf(`LIMIT $%d`, len(parameters)) + + rows, err := db.db.QueryContext(ctx, query, parameters...) + if err != nil { + return nil, err + } + defer rows.Close() + + var l []*irc.Message + for rows.Next() { + var raw string + if err := rows.Scan(&raw); err != nil { + return nil, err + } + + msg, err := irc.ParseMessage(raw) + if err != nil { + return nil, err + } + + l = append(l, msg) + } + if err := rows.Err(); err != nil { + return nil, err + } + + if options.TakeLast { + // We ordered by DESC to limit to the last lines. + // Reverse the list to order by ASC these last lines. + for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 { + l[i], l[j] = l[j], l[i] + } + } + + return l, nil +} + var postgresNetworksTotalDesc = prometheus.NewDesc("soju_networks_total", "Number of networks", []string{"hostname"}, nil) type postgresMetricsCollector struct { diff --git a/database/sqlite.go b/database/sqlite.go index 58801b0..932e1ee 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -1,4 +1,5 @@ //go:build !nosqlite +// +build !nosqlite package database @@ -11,8 +12,10 @@ import ( "strings" "time" + "git.sr.ht/~emersion/soju/xirc" "github.com/prometheus/client_golang/prometheus" promcollectors "github.com/prometheus/client_golang/prometheus/collectors" + "gopkg.in/irc.v4" ) const SqliteEnabled = true @@ -146,6 +149,41 @@ CREATE TABLE WebPushSubscription ( FOREIGN KEY(network) REFERENCES Network(id), UNIQUE(network, endpoint) ); + +CREATE TABLE Message ( + id INTEGER PRIMARY KEY, + target INTEGER NOT NULL, + raw TEXT NOT NULL, + time TEXT NOT NULL, + sender TEXT NOT NULL, + text TEXT, + FOREIGN KEY(target) REFERENCES MessageTarget(id) +); +CREATE INDEX MessageIndex ON Message(target, time); + +CREATE TABLE MessageTarget ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target) +); + +CREATE VIRTUAL TABLE MessageFTS USING fts5 ( + text, + content=Message, + content_rowid=id +); +CREATE TRIGGER MessageFTSInsert AFTER INSERT ON Message BEGIN + INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text); +END; +CREATE TRIGGER MessageFTSDelete AFTER DELETE ON Message BEGIN + INSERT INTO MessageFTS(MessageFTS, rowid, text) VALUES ('delete', old.id, old.text); +END; +CREATE TRIGGER MessageFTSUpdate AFTER UPDATE ON Message BEGIN + INSERT INTO MessageFTS(MessageFTS, rowid, text) VALUES ('delete', old.id, old.text); + INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text); +END; ` var sqliteMigrations = []string{ @@ -293,6 +331,42 @@ var sqliteMigrations = []string{ `, "ALTER TABLE User ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1", "ALTER TABLE User ADD COLUMN downstream_interacted_at TEXT;", + ` + CREATE TABLE Message ( + id INTEGER PRIMARY KEY, + target INTEGER NOT NULL, + raw TEXT NOT NULL, + time TEXT NOT NULL, + sender TEXT NOT NULL, + text TEXT, + FOREIGN KEY(target) REFERENCES MessageTarget(id) + ); + CREATE INDEX MessageIndex ON Message(target, time); + + CREATE TABLE MessageTarget ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target) + ); + + CREATE VIRTUAL TABLE MessageFTS USING fts5 ( + text, + content=Message, + content_rowid=id + ); + CREATE TRIGGER MessageFTSInsert AFTER INSERT ON Message BEGIN + INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text); + END; + CREATE TRIGGER MessageFTSDelete AFTER DELETE ON Message BEGIN + INSERT INTO MessageFTS(MessageFTS, rowid, text) VALUES ('delete', old.id, old.text); + END; + CREATE TRIGGER MessageFTSUpdate AFTER UPDATE ON Message BEGIN + INSERT INTO MessageFTS(MessageFTS, rowid, text) VALUES ('delete', old.id, old.text); + INSERT INTO MessageFTS(rowid, text) VALUES (new.id, new.text); + END; + `, } type SqliteDB struct { @@ -697,6 +771,16 @@ func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error { } defer tx.Rollback() + _, err = tx.ExecContext(ctx, "DELETE FROM Message WHERE target IN (SELECT id FROM MessageTarget WHERE network = ?)", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM MessageTarget WHERE network = ?", id) + if err != nil { + return err + } + _, err = tx.ExecContext(ctx, "DELETE FROM WebPushSubscription WHERE network = ?", id) if err != nil { return err @@ -1054,3 +1138,232 @@ func (db *SqliteDB) DeleteWebPushSubscription(ctx context.Context, id int64) err _, err := db.db.ExecContext(ctx, "DELETE FROM WebPushSubscription WHERE id = ?", id) return err } + +func (db *SqliteDB) GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error) { + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + var msgID int64 + row := db.db.QueryRowContext(ctx, ` + SELECT m.id FROM Message AS m, MessageTarget AS t + WHERE t.network = :network AND t.target = :target AND m.target = t.id + ORDER BY m.time DESC LIMIT 1`, + sql.Named("network", networkID), + sql.Named("target", name), + ) + if err := row.Scan(&msgID); err != nil { + if err == sql.ErrNoRows { + return 0, nil + } + return 0, err + } + return msgID, nil +} + +func (db *SqliteDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) { + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + var t time.Time + if tag, ok := msg.Tags["time"]; ok { + var err error + t, err = time.Parse(xirc.ServerTimeLayout, tag) + if err != nil { + return 0, fmt.Errorf("failed to parse message time tag: %v", err) + } + } else { + t = time.Now() + } + + var text sql.NullString + switch msg.Command { + case "PRIVMSG", "NOTICE": + if len(msg.Params) > 1 { + text.Valid = true + text.String = msg.Params[1] + } + } + + res, err := db.db.ExecContext(ctx, ` + INSERT INTO MessageTarget(network, target) + VALUES (:network, :target) + ON CONFLICT DO NOTHING`, + sql.Named("network", networkID), + sql.Named("target", name), + ) + if err != nil { + return 0, err + } + + res, err = db.db.ExecContext(ctx, ` + INSERT INTO Message(target, raw, time, sender, text) + SELECT id, :raw, :time, :sender, :text + FROM MessageTarget as t + WHERE network = :network AND target = :target`, + sql.Named("network", networkID), + sql.Named("target", name), + sql.Named("raw", msg.String()), + sql.Named("time", sqliteTime{t}), + sql.Named("sender", msg.Name), + sql.Named("text", text), + ) + if err != nil { + return 0, err + } + id, err := res.LastInsertId() + if err != nil { + return 0, err + } + return id, nil +} + +func (db *SqliteDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) { + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + innerQuery := ` + SELECT time + FROM Message + WHERE target = MessageTarget.id ` + if !options.Events { + innerQuery += `AND text IS NOT NULL ` + } + innerQuery += ` + ORDER BY time DESC + LIMIT 1 + ` + + query := ` + SELECT target, (` + innerQuery + `) latest + FROM MessageTarget + WHERE network = :network ` + if !options.AfterTime.IsZero() { + // compares time strings by lexicographical order + query += `AND latest > :after ` + } + if !options.BeforeTime.IsZero() { + // compares time strings by lexicographical order + query += `AND latest < :before ` + } + if options.TakeLast { + query += `ORDER BY latest DESC ` + } else { + query += `ORDER BY latest ASC ` + } + query += `LIMIT :limit` + + rows, err := db.db.QueryContext(ctx, query, + sql.Named("network", networkID), + sql.Named("after", sqliteTime{options.AfterTime}), + sql.Named("before", sqliteTime{options.BeforeTime}), + sql.Named("limit", options.Limit), + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var l []MessageTarget + for rows.Next() { + var mt MessageTarget + var ts sqliteTime + if err := rows.Scan(&mt.Name, &ts); err != nil { + return nil, err + } + + mt.LatestMessage = ts.Time + l = append(l, mt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + if options.TakeLast { + // We ordered by DESC to limit to the last lines. + // Reverse the list to order by ASC these last lines. + for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 { + l[i], l[j] = l[j], l[i] + } + } + + return l, nil +} + +func (db *SqliteDB) ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error) { + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + query := ` + SELECT m.raw + FROM Message AS m, MessageTarget AS t + WHERE m.target = t.id AND t.network = :network AND t.target = :target ` + if options.AfterID > 0 { + query += `AND m.id > :afterID ` + } + if !options.AfterTime.IsZero() { + // compares time strings by lexicographical order + query += `AND m.time > :after ` + } + if !options.BeforeTime.IsZero() { + // compares time strings by lexicographical order + query += `AND m.time < :before ` + } + if options.Sender != "" { + query += `AND m.sender = :sender ` + } + if options.Text != "" { + query += `AND m.id IN (SELECT ROWID FROM MessageFTS WHERE MessageFTS MATCH :text) ` + } + if !options.Events { + query += `AND m.text IS NOT NULL ` + } + if options.TakeLast { + query += `ORDER BY m.time DESC ` + } else { + query += `ORDER BY m.time ASC ` + } + query += `LIMIT :limit` + + rows, err := db.db.QueryContext(ctx, query, + sql.Named("network", networkID), + sql.Named("target", name), + sql.Named("afterID", options.AfterID), + sql.Named("after", sqliteTime{options.AfterTime}), + sql.Named("before", sqliteTime{options.BeforeTime}), + sql.Named("sender", options.Sender), + sql.Named("text", options.Text), + sql.Named("limit", options.Limit), + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var l []*irc.Message + for rows.Next() { + var raw string + if err := rows.Scan(&raw); err != nil { + return nil, err + } + + msg, err := irc.ParseMessage(raw) + if err != nil { + return nil, err + } + + l = append(l, msg) + } + if err := rows.Err(); err != nil { + return nil, err + } + + if options.TakeLast { + // We ordered by DESC to limit to the last lines. + // Reverse the list to order by ASC these last lines. + for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 { + l[i], l[j] = l[j], l[i] + } + } + + return l, nil +} diff --git a/database/sqlite_mattn.go b/database/sqlite_mattn.go index bf5244a..a95a2bf 100644 --- a/database/sqlite_mattn.go +++ b/database/sqlite_mattn.go @@ -3,6 +3,7 @@ package database import ( + _ "git.sr.ht/~emersion/go-sqlite3-fts5" _ "github.com/mattn/go-sqlite3" ) diff --git a/doc/soju.1.scd b/doc/soju.1.scd index 8bc7d11..6996a41 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -137,6 +137,7 @@ The following directives are supported: - _memory_ stores messages in memory. - _fs_ stores messages on disk, in the same format as ZNC. _source_ is required and is the root directory path for the database. + - _db_ stores messages in the database. (_log_ is a deprecated alias for this directive.) diff --git a/downstream.go b/downstream.go index 7d8e132..2d2acab 100644 --- a/downstream.go +++ b/downstream.go @@ -401,7 +401,8 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { // TODO: this is racy, we should only enable chathistory after // authentication and then check that user.msgStore implements // chatHistoryMessageStore - if srv.Config().LogPath != "" { + switch srv.Config().LogDriver { + case "fs", "db": dc.caps.Available["draft/chathistory"] = "" dc.caps.Available["soju.im/search"] = "" } diff --git a/go.mod b/go.mod index de29298..c549fb3 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.15 require ( git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 + git.sr.ht/~emersion/go-sqlite3-fts5 v0.0.0-20230217131031-f2c8767594fc git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 github.com/SherClockHolmes/webpush-go v1.2.0 github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead diff --git a/go.sum b/go.sum index bff3e20..97f5b6b 100644 --- a/go.sum +++ b/go.sum @@ -33,6 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao= git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U= +git.sr.ht/~emersion/go-sqlite3-fts5 v0.0.0-20230217131031-f2c8767594fc h1:+y3OijpLl4rgbFsqMBmYUTCsGCkxQUWpWaqfS8j9Ygc= +git.sr.ht/~emersion/go-sqlite3-fts5 v0.0.0-20230217131031-f2c8767594fc/go.mod h1:PCl1xjl7iC6x35TKKubKRyo/3TT0dGI66jyNI6vmYnU= git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw= git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE= git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA= diff --git a/msgstore/db.go b/msgstore/db.go new file mode 100644 index 0000000..289253a --- /dev/null +++ b/msgstore/db.go @@ -0,0 +1,151 @@ +package msgstore + +import ( + "context" + "time" + + "git.sr.ht/~emersion/soju/database" + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v4" +) + +type dbMsgID struct { + ID bare.Uint +} + +func (dbMsgID) msgIDType() msgIDType { + return msgIDDB +} + +func parseDBMsgID(s string) (msgID int64, err error) { + var id dbMsgID + _, _, err = ParseMsgID(s, &id) + if err != nil { + return 0, err + } + return int64(id.ID), nil +} + +func formatDBMsgID(netID int64, target string, msgID int64) string { + id := dbMsgID{bare.Uint(msgID)} + return formatMsgID(netID, target, &id) +} + +// dbMessageStore is a persistent store for IRC messages, that +// stores messages in the soju database. +type dbMessageStore struct { + db database.Database +} + +var ( + _ Store = (*dbMessageStore)(nil) + _ ChatHistoryStore = (*dbMessageStore)(nil) + _ SearchStore = (*dbMessageStore)(nil) +) + +func NewDBStore(db database.Database) *dbMessageStore { + return &dbMessageStore{ + db: db, + } +} + +func (ms *dbMessageStore) Close() error { + return nil +} + +func (ms *dbMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) { + // TODO: what should we do with t? + + id, err := ms.db.GetMessageLastID(context.TODO(), network.ID, entity) + if err != nil { + return "", err + } + return formatDBMsgID(network.ID, entity, id), nil +} + +func (ms *dbMessageStore) LoadLatestID(ctx context.Context, id string, options *LoadMessageOptions) ([]*irc.Message, error) { + msgID, err := parseDBMsgID(id) + if err != nil { + return nil, err + } + + l, err := ms.db.ListMessages(ctx, options.Network.ID, options.Entity, &database.MessageOptions{ + AfterID: msgID, + Limit: options.Limit, + TakeLast: true, + }) + if err != nil { + return nil, err + } + return l, nil +} + +func (ms *dbMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) { + id, err := ms.db.StoreMessage(context.TODO(), network.ID, entity, msg) + if err != nil { + return "", err + } + return formatDBMsgID(network.ID, entity, id), nil +} + +func (ms *dbMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) { + l, err := ms.db.ListMessageLastPerTarget(ctx, network.ID, &database.MessageOptions{ + AfterTime: start, + BeforeTime: end, + Limit: limit, + Events: events, + }) + if err != nil { + return nil, err + } + targets := make([]ChatHistoryTarget, len(l)) + for i, v := range l { + targets[i] = ChatHistoryTarget{ + Name: v.Name, + LatestMessage: v.LatestMessage, + } + } + return targets, nil +} + +func (ms *dbMessageStore) LoadBeforeTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error) { + l, err := ms.db.ListMessages(ctx, options.Network.ID, options.Entity, &database.MessageOptions{ + AfterTime: end, + BeforeTime: start, + Limit: options.Limit, + Events: options.Events, + TakeLast: true, + }) + if err != nil { + return nil, err + } + return l, nil +} + +func (ms *dbMessageStore) LoadAfterTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error) { + l, err := ms.db.ListMessages(ctx, options.Network.ID, options.Entity, &database.MessageOptions{ + AfterTime: start, + BeforeTime: end, + Limit: options.Limit, + Events: options.Events, + }) + if err != nil { + return nil, err + } + return l, nil +} + +func (ms *dbMessageStore) Search(ctx context.Context, network *database.Network, options *SearchMessageOptions) ([]*irc.Message, error) { + l, err := ms.db.ListMessages(ctx, network.ID, options.In, &database.MessageOptions{ + AfterTime: options.Start, + BeforeTime: options.End, + Limit: options.Limit, + Sender: options.From, + Text: options.Text, + TakeLast: true, + }) + if err != nil { + return nil, err + } + return l, nil +} diff --git a/msgstore/fs.go b/msgstore/fs.go index 81ec58e..2d2899c 100644 --- a/msgstore/fs.go +++ b/msgstore/fs.go @@ -23,7 +23,7 @@ const ( fsMessageStoreMaxTries = 100 ) -func escapeFilename(unsafe string) (safe string) { +func EscapeFilename(unsafe string) (safe string) { if unsafe == "." { return "-" } else if unsafe == ".." { @@ -103,7 +103,7 @@ func IsFSStore(store Store) bool { func NewFSStore(root string, user *database.User) *fsMessageStore { return &fsMessageStore{ - root: filepath.Join(root, escapeFilename(user.Username)), + root: filepath.Join(root, EscapeFilename(user.Username)), user: user, files: make(map[string]*fsMessageStoreFile), } @@ -112,7 +112,7 @@ func NewFSStore(root string, user *database.User) *fsMessageStore { func (ms *fsMessageStore) logPath(network *database.Network, entity string, t time.Time) string { year, month, day := t.Date() filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) - return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename) + return filepath.Join(ms.root, EscapeFilename(network.GetName()), EscapeFilename(entity), filename) } // nextMsgID queries the message ID for the next message to be written to f. @@ -265,6 +265,10 @@ func formatMessage(msg *irc.Message) string { } func (ms *fsMessageStore) parseMessage(line string, network *database.Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) { + return FSParseMessage(line, ms.user, network, entity, ref, events) +} + +func FSParseMessage(line string, user *database.User, network *database.Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) { var hour, minute, second int _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second) if err != nil { @@ -391,7 +395,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *database.Network, e // our nickname in the logs, so grab it from the network settings. // Not very accurate since this may not match our nick at the time // the message was received, but we can't do a lot better. - entity = database.GetNick(ms.user, network) + entity = database.GetNick(user, network) } params = []string{entity, text} } @@ -634,7 +638,7 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options * func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) { start = start.In(time.Local) end = end.In(time.Local) - rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) + rootPath := filepath.Join(ms.root, EscapeFilename(network.GetName())) root, err := os.Open(rootPath) if os.IsNotExist(err) { return nil, nil @@ -713,7 +717,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Net func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts *SearchMessageOptions) ([]*irc.Message, error) { text := strings.ToLower(opts.Text) selector := func(m *irc.Message) bool { - if opts.From != "" && m.User != opts.From { + if opts.From != "" && m.Name != opts.From { return false } if text != "" && !strings.Contains(strings.ToLower(m.Params[1]), text) { @@ -734,8 +738,8 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, } func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *database.Network) error { - oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName())) - newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName())) + oldDir := filepath.Join(ms.root, EscapeFilename(oldNet.GetName())) + newDir := filepath.Join(ms.root, EscapeFilename(newNet.GetName())) // Avoid loosing data by overwriting an existing directory if _, err := os.Stat(newDir); err == nil { return fmt.Errorf("destination %q already exists", newDir) diff --git a/msgstore/msgstore.go b/msgstore/msgstore.go index 4585388..fb32ddd 100644 --- a/msgstore/msgstore.go +++ b/msgstore/msgstore.go @@ -52,7 +52,7 @@ type ChatHistoryStore interface { // end is before start. // If events is false, only PRIVMSG/NOTICE messages are considered. LoadBeforeTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error) - // LoadBeforeTime loads up to limit messages after start up to end. The + // LoadAfterTime loads up to limit messages after start up to end. The // returned messages must be between and excluding the provided bounds. // end is after start. // If events is false, only PRIVMSG/NOTICE messages are considered. @@ -90,6 +90,7 @@ const ( msgIDNone msgIDType = iota msgIDMemory msgIDFS + msgIDDB ) const msgIDVersion uint = 0 diff --git a/server.go b/server.go index 41d5c19..d7bf856 100644 --- a/server.go +++ b/server.go @@ -138,6 +138,7 @@ func (ln *retryListener) Accept() (net.Conn, error) { type Config struct { Hostname string Title string + LogDriver string LogPath string HTTPOrigins []string AcceptProxyIPs config.IPSet diff --git a/user.go b/user.go index f0e82ea..1e14c46 100644 --- a/user.go +++ b/user.go @@ -516,9 +516,12 @@ func newUser(srv *Server, record *database.User) *user { logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} var msgStore msgstore.Store - if logPath := srv.Config().LogPath; logPath != "" { - msgStore = msgstore.NewFSStore(logPath, record) - } else { + switch srv.Config().LogDriver { + case "fs": + msgStore = msgstore.NewFSStore(srv.Config().LogPath, record) + case "db": + msgStore = msgstore.NewDBStore(srv.db) + case "memory": msgStore = msgstore.NewMemoryStore() }