msgstore: take Network as arg instead of network

The message stores don't need to access the internal network
struct, they just need network metadata such as ID and name.

This can ease moving message stores into a separate package in the
future.
This commit is contained in:
Simon Ser 2021-11-03 16:37:01 +01:00
parent 03f8972305
commit 2b4f0a870f
6 changed files with 32 additions and 32 deletions

View file

@ -1295,7 +1295,7 @@ func (dc *downstreamConn) welcome() error {
// Fast-forward history to last message // Fast-forward history to last message
targetCM := net.casemap(target) targetCM := net.casemap(target)
lastID, err := dc.user.msgStore.LastMsgID(net, targetCM, time.Now()) lastID, err := dc.user.msgStore.LastMsgID(&net.Network, targetCM, time.Now())
if err != nil { if err != nil {
dc.logger.Printf("failed to get last message ID: %v", err) dc.logger.Printf("failed to get last message ID: %v", err)
return return
@ -1330,7 +1330,7 @@ func (dc *downstreamConn) sendTargetBacklog(net *network, target, msgID string)
limit := 4000 limit := 4000
targetCM := net.casemap(target) targetCM := net.casemap(target)
history, err := dc.user.msgStore.LoadLatestID(net, targetCM, msgID, limit) history, err := dc.user.msgStore.LoadLatestID(&net.Network, targetCM, msgID, limit)
if err != nil { if err != nil {
dc.logger.Printf("failed to send backlog for %q: %v", target, err) dc.logger.Printf("failed to send backlog for %q: %v", target, err)
return return
@ -2337,18 +2337,18 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
var history []*irc.Message var history []*irc.Message
switch subcommand { switch subcommand {
case "BEFORE": case "BEFORE":
history, err = store.LoadBeforeTime(network, entity, bounds[0], time.Time{}, limit, eventPlayback) history, err = store.LoadBeforeTime(&network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback)
case "AFTER": case "AFTER":
history, err = store.LoadAfterTime(network, entity, bounds[0], time.Now(), limit, eventPlayback) history, err = store.LoadAfterTime(&network.Network, entity, bounds[0], time.Now(), limit, eventPlayback)
case "BETWEEN": case "BETWEEN":
if bounds[0].Before(bounds[1]) { if bounds[0].Before(bounds[1]) {
history, err = store.LoadAfterTime(network, entity, bounds[0], bounds[1], limit, eventPlayback) history, err = store.LoadAfterTime(&network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
} else { } else {
history, err = store.LoadBeforeTime(network, entity, bounds[0], bounds[1], limit, eventPlayback) history, err = store.LoadBeforeTime(&network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
} }
case "TARGETS": case "TARGETS":
// TODO: support TARGETS in multi-upstream mode // TODO: support TARGETS in multi-upstream mode
targets, err := store.ListTargets(network, bounds[0], bounds[1], limit, eventPlayback) targets, err := store.ListTargets(&network.Network, bounds[0], bounds[1], limit, eventPlayback)
if err != nil { if err != nil {
dc.logger.Printf("failed fetching targets for chathistory: %v", err) dc.logger.Printf("failed fetching targets for chathistory: %v", err)
return ircError{&irc.Message{ return ircError{&irc.Message{

View file

@ -16,11 +16,11 @@ type messageStore interface {
// LastMsgID queries the last message ID for the given network, entity and // LastMsgID queries the last message ID for the given network, entity and
// date. The message ID returned may not refer to a valid message, but can be // date. The message ID returned may not refer to a valid message, but can be
// used in history queries. // used in history queries.
LastMsgID(network *network, entity string, t time.Time) (string, error) LastMsgID(network *Network, entity string, t time.Time) (string, error)
// LoadLatestID queries the latest non-event messages for the given network, // LoadLatestID queries the latest non-event messages for the given network,
// entity and date, up to a count of limit messages, sorted from oldest to newest. // entity and date, up to a count of limit messages, sorted from oldest to newest.
LoadLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) LoadLatestID(network *Network, entity, id string, limit int) ([]*irc.Message, error)
Append(network *network, entity string, msg *irc.Message) (id string, err error) Append(network *Network, entity string, msg *irc.Message) (id string, err error)
} }
type chatHistoryTarget struct { type chatHistoryTarget struct {
@ -37,17 +37,17 @@ type chatHistoryMessageStore interface {
// It returns up to limit targets, starting from start and ending on end, // It returns up to limit targets, starting from start and ending on end,
// both excluded. end may be before or after start. // both excluded. end may be before or after start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
ListTargets(network *network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) ListTargets(network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
// LoadBeforeTime loads up to limit messages before start down to end. The // LoadBeforeTime loads up to limit messages before start down to end. The
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is before start. // end is before start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadBeforeTime(network *network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) LoadBeforeTime(network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
// LoadBeforeTime loads up to limit messages after start up to end. The // LoadBeforeTime loads up to limit messages after start up to end. The
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is after start. // end is after start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadAfterTime(network *network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) LoadAfterTime(network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
} }
type msgIDType uint type msgIDType uint

View file

@ -94,14 +94,14 @@ func newFSMessageStore(root, username string) *fsMessageStore {
} }
} }
func (ms *fsMessageStore) logPath(network *network, entity string, t time.Time) string { func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string {
year, month, day := t.Date() year, month, day := t.Date()
filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename) return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename)
} }
// nextMsgID queries the message ID for the next message to be written to f. // nextMsgID queries the message ID for the next message to be written to f.
func nextFSMsgID(network *network, entity string, t time.Time, f *os.File) (string, error) { func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) {
offset, err := f.Seek(0, io.SeekEnd) offset, err := f.Seek(0, io.SeekEnd)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to query next FS message ID: %v", err) return "", fmt.Errorf("failed to query next FS message ID: %v", err)
@ -109,7 +109,7 @@ func nextFSMsgID(network *network, entity string, t time.Time, f *os.File) (stri
return formatFSMsgID(network.ID, entity, t, offset), nil return formatFSMsgID(network.ID, entity, t, offset), nil
} }
func (ms *fsMessageStore) LastMsgID(network *network, entity string, t time.Time) (string, error) { func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
p := ms.logPath(network, entity, t) p := ms.logPath(network, entity, t)
fi, err := os.Stat(p) fi, err := os.Stat(p)
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -120,7 +120,7 @@ func (ms *fsMessageStore) LastMsgID(network *network, entity string, t time.Time
return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil
} }
func (ms *fsMessageStore) Append(network *network, entity string, msg *irc.Message) (string, error) { func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
s := formatMessage(msg) s := formatMessage(msg)
if s == "" { if s == "" {
return "", nil return "", nil
@ -388,7 +388,7 @@ func parseMessage(line, entity string, ref time.Time, events bool) (*irc.Message
return msg, t, nil return msg, t, nil
} }
func (ms *fsMessageStore) parseMessagesBefore(network *network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64) ([]*irc.Message, error) { func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref) path := ms.logPath(network, entity, ref)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
@ -444,7 +444,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *network, entity string, r
} }
} }
func (ms *fsMessageStore) parseMessagesAfter(network *network, entity string, ref time.Time, end time.Time, events bool, limit int) ([]*irc.Message, error) { func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref) path := ms.logPath(network, entity, ref)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
@ -476,7 +476,7 @@ func (ms *fsMessageStore) parseMessagesAfter(network *network, entity string, re
return history, nil return history, nil
} }
func (ms *fsMessageStore) LoadBeforeTime(network *network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadBeforeTime(network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
start = start.In(time.Local) start = start.In(time.Local)
end = end.In(time.Local) end = end.In(time.Local)
history := make([]*irc.Message, limit) history := make([]*irc.Message, limit)
@ -501,7 +501,7 @@ func (ms *fsMessageStore) LoadBeforeTime(network *network, entity string, start
return history[remaining:], nil return history[remaining:], nil
} }
func (ms *fsMessageStore) LoadAfterTime(network *network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadAfterTime(network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
start = start.In(time.Local) start = start.In(time.Local)
end = end.In(time.Local) end = end.In(time.Local)
var history []*irc.Message var history []*irc.Message
@ -525,7 +525,7 @@ func (ms *fsMessageStore) LoadAfterTime(network *network, entity string, start t
return history, nil return history, nil
} }
func (ms *fsMessageStore) LoadLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadLatestID(network *Network, entity, id string, limit int) ([]*irc.Message, error) {
var afterTime time.Time var afterTime time.Time
var afterOffset int64 var afterOffset int64
if id != "" { if id != "" {
@ -569,7 +569,7 @@ func (ms *fsMessageStore) LoadLatestID(network *network, entity, id string, limi
return history[remaining:], nil return history[remaining:], nil
} }
func (ms *fsMessageStore) ListTargets(network *network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { func (ms *fsMessageStore) ListTargets(network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
start = start.In(time.Local) start = start.In(time.Local)
end = end.In(time.Local) end = end.In(time.Local)
rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
@ -644,7 +644,7 @@ func (ms *fsMessageStore) ListTargets(network *network, start, end time.Time, li
return targets, nil return targets, nil
} }
func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *network) error { func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error {
oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName())) oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName()))
newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName())) newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName()))
// Avoid loosing data by overwriting an existing directory // Avoid loosing data by overwriting an existing directory

View file

@ -54,7 +54,7 @@ func (ms *memoryMessageStore) Close() error {
return nil return nil
} }
func (ms *memoryMessageStore) get(network *network, entity string) *messageRingBuffer { func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer {
k := ringBufferKey{networkID: network.ID, entity: entity} k := ringBufferKey{networkID: network.ID, entity: entity}
if rb, ok := ms.buffers[k]; ok { if rb, ok := ms.buffers[k]; ok {
return rb return rb
@ -64,7 +64,7 @@ func (ms *memoryMessageStore) get(network *network, entity string) *messageRingB
return rb return rb
} }
func (ms *memoryMessageStore) LastMsgID(network *network, entity string, t time.Time) (string, error) { func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
var seq uint64 var seq uint64
k := ringBufferKey{networkID: network.ID, entity: entity} k := ringBufferKey{networkID: network.ID, entity: entity}
if rb, ok := ms.buffers[k]; ok { if rb, ok := ms.buffers[k]; ok {
@ -73,7 +73,7 @@ func (ms *memoryMessageStore) LastMsgID(network *network, entity string, t time.
return formatMemoryMsgID(network.ID, entity, seq), nil return formatMemoryMsgID(network.ID, entity, seq), nil
} }
func (ms *memoryMessageStore) Append(network *network, entity string, msg *irc.Message) (string, error) { func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
switch msg.Command { switch msg.Command {
case "PRIVMSG", "NOTICE": case "PRIVMSG", "NOTICE":
default: default:
@ -91,7 +91,7 @@ func (ms *memoryMessageStore) Append(network *network, entity string, msg *irc.M
return formatMemoryMsgID(network.ID, entity, seq), nil return formatMemoryMsgID(network.ID, entity, seq), nil
} }
func (ms *memoryMessageStore) LoadLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) { func (ms *memoryMessageStore) LoadLatestID(network *Network, entity, id string, limit int) ([]*irc.Message, error) {
_, _, seq, err := parseMemoryMsgID(id) _, _, seq, err := parseMemoryMsgID(id)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -1787,7 +1787,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string
// This is the first message we receive from this target. Save the last // This is the first message we receive from this target. Save the last
// message ID in delivery receipts, so that we can send the new message // message ID in delivery receipts, so that we can send the new message
// in the backlog if an offline client reconnects. // in the backlog if an offline client reconnects.
lastID, err := uc.user.msgStore.LastMsgID(uc.network, entityCM, time.Now()) lastID, err := uc.user.msgStore.LastMsgID(&uc.network.Network, entityCM, time.Now())
if err != nil { if err != nil {
uc.logger.Printf("failed to log message: failed to get last message ID: %v", err) uc.logger.Printf("failed to log message: failed to get last message ID: %v", err)
return "" return ""
@ -1798,7 +1798,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string
}) })
} }
msgID, err := uc.user.msgStore.Append(uc.network, entityCM, msg) msgID, err := uc.user.msgStore.Append(&uc.network.Network, entityCM, msg)
if err != nil { if err != nil {
uc.logger.Printf("failed to log message: %v", err) uc.logger.Printf("failed to log message: %v", err)
return "" return ""

View file

@ -261,7 +261,7 @@ func (net *network) detach(ch *Channel) {
if net.user.msgStore != nil { if net.user.msgStore != nil {
nameCM := net.casemap(ch.Name) nameCM := net.casemap(ch.Name)
lastID, err := net.user.msgStore.LastMsgID(net, nameCM, time.Now()) lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now())
if err != nil { if err != nil {
net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err) net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
} }
@ -859,7 +859,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
// is renamed // is renamed
fsMsgStore, isFS := u.msgStore.(*fsMessageStore) fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
if isFS && updatedNetwork.GetName() != network.GetName() { if isFS && updatedNetwork.GetName() != network.GetName() {
if err := fsMsgStore.RenameNetwork(network, updatedNetwork); err != nil { if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err) network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
} }
} }