Save delivery receipts in DB

This avoids loosing history on restart for clients that don't
support chathistory.

Closes: https://todo.sr.ht/~emersion/soju/80
This commit is contained in:
Simon Ser 2021-02-10 18:16:08 +01:00
parent 5b4469fcb7
commit 1e4ff49472
3 changed files with 161 additions and 11 deletions

90
db.go
View file

@ -120,6 +120,13 @@ type Channel struct {
DetachOn MessageFilter DetachOn MessageFilter
} }
type DeliveryReceipt struct {
ID int64
Target string // channel or nick
Client string
InternalMsgID string
}
const schema = ` const schema = `
CREATE TABLE User ( CREATE TABLE User (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
@ -161,6 +168,16 @@ CREATE TABLE Channel (
FOREIGN KEY(network) REFERENCES Network(id), FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, name) UNIQUE(network, name)
); );
CREATE TABLE DeliveryReceipt (
id INTEGER PRIMARY KEY,
network INTEGER NOT NULL,
target VARCHAR(255) NOT NULL,
client VARCHAR(255),
internal_msgid VARCHAR(255) NOT NULL,
FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, target, client)
);
` `
var migrations = []string{ var migrations = []string{
@ -217,6 +234,17 @@ var migrations = []string{
ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0; ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0; ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
`, `,
`
CREATE TABLE DeliveryReceipt (
id INTEGER PRIMARY KEY,
network INTEGER NOT NULL,
target VARCHAR(255) NOT NULL,
client VARCHAR(255),
internal_msgid VARCHAR(255) NOT NULL,
FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, target, client)
);
`,
} }
type DB struct { type DB struct {
@ -578,3 +606,65 @@ func (db *DB) DeleteChannel(id int64) error {
_, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id) _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
return err return err
} }
func (db *DB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
db.lock.RLock()
defer db.lock.RUnlock()
rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
FROM DeliveryReceipt
WHERE network = ?`, networkID)
if err != nil {
return nil, err
}
defer rows.Close()
var receipts []DeliveryReceipt
for rows.Next() {
var rcpt DeliveryReceipt
var client sql.NullString
if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
return nil, err
}
rcpt.Client = client.String
receipts = append(receipts, rcpt)
}
if err := rows.Err(); err != nil {
return nil, err
}
return receipts, nil
}
func (db *DB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
db.lock.Lock()
defer db.lock.Unlock()
tx, err := db.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client = ?",
networkID, toNullString(client))
if err != nil {
return err
}
for i := range receipts {
rcpt := &receipts[i]
res, err := tx.Exec("INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (?, ?, ?, ?)",
networkID, rcpt.Target, toNullString(client), rcpt.InternalMsgID)
if err != nil {
return err
}
rcpt.ID, err = res.LastInsertId()
if err != nil {
return err
}
}
return tx.Commit()
}

View file

@ -1752,9 +1752,9 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string
return "" return ""
} }
for clientName, _ := range uc.user.clientNames { uc.network.delivered.ForEachClient(func(clientName string) {
uc.network.delivered.StoreID(entity, clientName, lastID) uc.network.delivered.StoreID(entity, clientName, lastID)
} })
} }
msgID, err := uc.user.msgStore.Append(uc.network, entityCM, msg) msgID, err := uc.user.msgStore.Append(uc.network, entityCM, msg)

78
user.go
View file

@ -92,6 +92,20 @@ func (ds deliveredStore) ForEachTarget(f func(target string)) {
} }
} }
func (ds deliveredStore) ForEachClient(f func(clientName string)) {
clients := make(map[string]struct{})
for _, entry := range ds.m.innerMap {
delivered := entry.value.(deliveredClientMap)
for clientName := range delivered {
clients[clientName] = struct{}{}
}
}
for clientName := range clients {
f(clientName)
}
}
type network struct { type network struct {
Network Network
user *user user *user
@ -298,6 +312,28 @@ func (net *network) updateCasemapping(newCasemap casemapping) {
} }
} }
func (net *network) storeClientDeliveryReceipts(clientName string) {
if !net.user.hasPersistentMsgStore() {
return
}
var receipts []DeliveryReceipt
net.delivered.ForEachTarget(func(target string) {
msgID := net.delivered.LoadID(target, clientName)
if msgID == "" {
return
}
receipts = append(receipts, DeliveryReceipt{
Target: target,
InternalMsgID: msgID,
})
})
if err := net.user.srv.db.StoreClientDeliveryReceipts(net.ID, clientName, receipts); err != nil {
net.user.srv.Logger.Printf("failed to store delivery receipts for user %q, client %q, network %q: %v", net.user.Username, clientName, net.GetName(), err)
}
}
type user struct { type user struct {
User User
srv *Server srv *Server
@ -308,7 +344,6 @@ type user struct {
networks []*network networks []*network
downstreamConns []*downstreamConn downstreamConns []*downstreamConn
msgStore messageStore msgStore messageStore
clientNames map[string]struct{}
// LIST commands in progress // LIST commands in progress
pendingLISTs []pendingLIST pendingLISTs []pendingLIST
@ -329,12 +364,11 @@ func newUser(srv *Server, record *User) *user {
} }
return &user{ return &user{
User: *record, User: *record,
srv: srv, srv: srv,
events: make(chan event, 64), events: make(chan event, 64),
done: make(chan struct{}), done: make(chan struct{}),
msgStore: msgStore, msgStore: msgStore,
clientNames: make(map[string]struct{}),
} }
} }
@ -407,6 +441,18 @@ func (u *user) run() {
network := newNetwork(u, &record, channels) network := newNetwork(u, &record, channels)
u.networks = append(u.networks, network) u.networks = append(u.networks, network)
if u.hasPersistentMsgStore() {
receipts, err := u.srv.db.ListDeliveryReceipts(record.ID)
if err != nil {
u.srv.Logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
return
}
for _, rcpt := range receipts {
network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
}
}
go network.run() go network.run()
} }
@ -489,8 +535,6 @@ func (u *user) run() {
u.forEachUpstream(func(uc *upstreamConn) { u.forEachUpstream(func(uc *upstreamConn) {
uc.updateAway() uc.updateAway()
}) })
u.clientNames[dc.clientName] = struct{}{}
case eventDownstreamDisconnected: case eventDownstreamDisconnected:
dc := e.dc dc := e.dc
@ -501,6 +545,10 @@ func (u *user) run() {
} }
} }
dc.forEachNetwork(func(net *network) {
net.storeClientDeliveryReceipts(dc.clientName)
})
u.forEachUpstream(func(uc *upstreamConn) { u.forEachUpstream(func(uc *upstreamConn) {
uc.updateAway() uc.updateAway()
}) })
@ -524,6 +572,10 @@ func (u *user) run() {
}) })
for _, n := range u.networks { for _, n := range u.networks {
n.stop() n.stop()
n.delivered.ForEachClient(func(clientName string) {
n.storeClientDeliveryReceipts(clientName)
})
} }
return return
default: default:
@ -665,3 +717,11 @@ func (u *user) stop() {
u.events <- eventStop{} u.events <- eventStop{}
<-u.done <-u.done
} }
func (u *user) hasPersistentMsgStore() bool {
if u.msgStore == nil {
return false
}
_, isMem := u.msgStore.(*memoryMessageStore)
return !isMem
}