From f508d36c38e24dbfb66774ce5edd549b9d1275d5 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 9 May 2022 15:36:39 +0200 Subject: [PATCH] msgstore: add loadMessageOptions A struct containing common parameters for all messageStore.Load* functions returning messages. --- downstream.go | 22 ++++++++++---- msgstore.go | 13 +++++++-- msgstore_fs.go | 73 ++++++++++++++++++++++++++-------------------- msgstore_memory.go | 10 +++++-- 4 files changed, 76 insertions(+), 42 deletions(-) diff --git a/downstream.go b/downstream.go index b818d4b..50e64ee 100644 --- a/downstream.go +++ b/downstream.go @@ -1664,7 +1664,12 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t defer cancel() targetCM := net.casemap(target) - history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit) + loadOptions := loadMessageOptions{ + Network: &net.Network, + Entity: targetCM, + Limit: backlogLimit, + } + history, err := dc.user.msgStore.LoadLatestID(ctx, msgID, &loadOptions) if err != nil { dc.logger.Printf("failed to send backlog for %q: %v", target, err) return @@ -2826,17 +2831,24 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. eventPlayback := dc.caps.IsEnabled("draft/event-playback") + options := loadMessageOptions{ + Network: &network.Network, + Entity: entity, + Limit: limit, + Events: eventPlayback, + } + var history []*irc.Message switch subcommand { case "BEFORE", "LATEST": - history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback) + history, err = store.LoadBeforeTime(ctx, bounds[0], time.Time{}, &options) case "AFTER": - history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback) + history, err = store.LoadAfterTime(ctx, bounds[0], time.Now(), &options) case "BETWEEN": if bounds[0].Before(bounds[1]) { - history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + history, err = store.LoadAfterTime(ctx, bounds[0], bounds[1], &options) } else { - history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + history, err = store.LoadBeforeTime(ctx, bounds[0], bounds[1], &options) } case "TARGETS": // TODO: support TARGETS in multi-upstream mode diff --git a/msgstore.go b/msgstore.go index d6c6379..0ae22fc 100644 --- a/msgstore.go +++ b/msgstore.go @@ -13,6 +13,13 @@ import ( "git.sr.ht/~emersion/soju/database" ) +type loadMessageOptions struct { + Network *database.Network + Entity string + Limit int + Events bool +} + // messageStore is a per-user store for IRC messages. type messageStore interface { Close() error @@ -22,7 +29,7 @@ type messageStore interface { LastMsgID(network *database.Network, entity string, t time.Time) (string, error) // LoadLatestID queries the latest non-event messages for the given network, // entity and date, up to a count of limit messages, sorted from oldest to newest. - LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) + LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) Append(network *database.Network, entity string, msg *irc.Message) (id string, err error) } @@ -45,12 +52,12 @@ type chatHistoryMessageStore interface { // returned messages must be between and excluding the provided bounds. // end is before start. // If events is false, only PRIVMSG/NOTICE messages are considered. - LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) + LoadBeforeTime(ctx context.Context, start, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) // LoadBeforeTime loads up to limit messages after start up to end. The // returned messages must be between and excluding the provided bounds. // end is after start. // If events is false, only PRIVMSG/NOTICE messages are considered. - LoadAfterTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) + LoadAfterTime(ctx context.Context, start, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) } type searchOptions struct { diff --git a/msgstore_fs.go b/msgstore_fs.go index 078a947..086fcaf 100644 --- a/msgstore_fs.go +++ b/msgstore_fs.go @@ -401,8 +401,8 @@ func (ms *fsMessageStore) parseMessage(line string, network *database.Network, e return msg, t, nil } -func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) { - path := ms.logPath(network, entity, ref) +func (ms *fsMessageStore) parseMessagesBefore(ref time.Time, end time.Time, options *loadMessageOptions, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) { + path := ms.logPath(options.Network, options.Entity, ref) f, err := os.Open(path) if err != nil { if os.IsNotExist(err) { @@ -412,7 +412,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity } defer f.Close() - historyRing := make([]*irc.Message, limit) + historyRing := make([]*irc.Message, options.Limit) cur := 0 sc := bufio.NewScanner(f) @@ -425,7 +425,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity } for sc.Scan() { - msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) + msg, t, err := ms.parseMessage(sc.Text(), options.Network, options.Entity, ref, options.Events) if err != nil { return nil, err } else if msg == nil || !t.After(end) { @@ -437,20 +437,20 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity continue } - historyRing[cur%limit] = msg + historyRing[cur%options.Limit] = msg cur++ } if sc.Err() != nil { return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err()) } - n := limit - if cur < limit { + n := options.Limit + if cur < options.Limit { n = cur } - start := (cur - n + limit) % limit + start := (cur - n + options.Limit) % options.Limit - if start+n <= limit { // ring doesnt wrap + if start+n <= options.Limit { // ring doesnt wrap return historyRing[start : start+n], nil } else { // ring wraps history := make([]*irc.Message, n) @@ -460,8 +460,8 @@ func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity } } -func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) { - path := ms.logPath(network, entity, ref) +func (ms *fsMessageStore) parseMessagesAfter(ref time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) { + path := ms.logPath(options.Network, options.Entity, ref) f, err := os.Open(path) if err != nil { if os.IsNotExist(err) { @@ -473,8 +473,8 @@ func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity s var history []*irc.Message sc := bufio.NewScanner(f) - for sc.Scan() && len(history) < limit { - msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) + for sc.Scan() && len(history) < options.Limit { + msg, t, err := ms.parseMessage(sc.Text(), options.Network, options.Entity, ref, options.Events) if err != nil { return nil, err } else if msg == nil || !t.After(ref) { @@ -495,18 +495,20 @@ func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity s return history, nil } -func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) { +func (ms *fsMessageStore) getBeforeTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) { if start.IsZero() { start = time.Now() } else { start = start.In(time.Local) } end = end.In(time.Local) - messages := make([]*irc.Message, limit) - remaining := limit + messages := make([]*irc.Message, options.Limit) + remaining := options.Limit tries := 0 for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) { - buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1, selector) + parseOptions := *options + parseOptions.Limit = remaining + buf, err := ms.parseMessagesBefore(start, end, &parseOptions, -1, selector) if err != nil { return nil, err } @@ -528,11 +530,11 @@ func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.N return messages[remaining:], nil } -func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { - return ms.getBeforeTime(ctx, network, entity, start, end, limit, events, nil) +func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) { + return ms.getBeforeTime(ctx, start, end, options, nil) } -func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) { +func (ms *fsMessageStore) getAfterTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) { start = start.In(time.Local) if end.IsZero() { end = time.Now() @@ -540,10 +542,12 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Ne end = end.In(time.Local) } var messages []*irc.Message - remaining := limit + remaining := options.Limit tries := 0 for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) { - buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining, selector) + parseOptions := *options + parseOptions.Limit = remaining + buf, err := ms.parseMessagesAfter(start, end, &parseOptions, selector) if err != nil { return nil, err } @@ -564,11 +568,11 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Ne return messages, nil } -func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { - return ms.getAfterTime(ctx, network, entity, start, end, limit, events, nil) +func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) { + return ms.getAfterTime(ctx, start, end, options, nil) } -func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) { +func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) { var afterTime time.Time var afterOffset int64 if id != "" { @@ -579,14 +583,14 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Ne if err != nil { return nil, err } - if idNet != network.ID || idEntity != entity { + if idNet != options.Network.ID || idEntity != options.Entity { return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity") } } - history := make([]*irc.Message, limit) + history := make([]*irc.Message, options.Limit) t := time.Now() - remaining := limit + remaining := options.Limit tries := 0 for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) { var offset int64 = -1 @@ -594,7 +598,9 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Ne offset = afterOffset } - buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset, nil) + parseOptions := *options + parseOptions.Limit = remaining + buf, err := ms.parseMessagesBefore(t, time.Time{}, &parseOptions, offset, nil) if err != nil { return nil, err } @@ -706,10 +712,15 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, } return true } + loadOptions := loadMessageOptions{ + Network: network, + Entity: opts.in, + Limit: opts.limit, + } if !opts.start.IsZero() { - return ms.getAfterTime(ctx, network, opts.in, opts.start, opts.end, opts.limit, false, selector) + return ms.getAfterTime(ctx, opts.start, opts.end, &loadOptions, selector) } else { - return ms.getBeforeTime(ctx, network, opts.in, opts.end, opts.start, opts.limit, false, selector) + return ms.getBeforeTime(ctx, opts.end, opts.start, &loadOptions, selector) } } diff --git a/msgstore_memory.go b/msgstore_memory.go index 02158c0..c7d05e8 100644 --- a/msgstore_memory.go +++ b/msgstore_memory.go @@ -96,19 +96,23 @@ func (ms *memoryMessageStore) Append(network *database.Network, entity string, m return formatMemoryMsgID(network.ID, entity, seq), nil } -func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) { +func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) { + if options.Events { + return nil, fmt.Errorf("events are unsupported for memory message store") + } + _, _, seq, err := parseMemoryMsgID(id) if err != nil { return nil, err } - k := ringBufferKey{networkID: network.ID, entity: entity} + k := ringBufferKey{networkID: options.Network.ID, entity: options.Entity} rb, ok := ms.buffers[k] if !ok { return nil, nil } - return rb.LoadLatestSeq(seq, limit) + return rb.LoadLatestSeq(seq, options.Limit) } type messageRingBuffer struct {