From ce69f00e3fe5c2eb6b4b10880a3c11ab2754bb2f Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 3 Nov 2021 18:18:04 +0100 Subject: [PATCH] msgstore: add context to messageStore methods This allows setting a hard timeout. --- downstream.go | 18 ++++++++++++------ msgstore.go | 9 +++++---- msgstore_fs.go | 9 +++++---- msgstore_memory.go | 3 ++- server.go | 1 + 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/downstream.go b/downstream.go index 6831d6c..569831f 100644 --- a/downstream.go +++ b/downstream.go @@ -1328,9 +1328,12 @@ func (dc *downstreamConn) sendTargetBacklog(net *network, target, msgID string) ch := net.channels.Value(target) + ctx, cancel := context.WithTimeout(context.TODO(), messageStoreTimeout) + defer cancel() + limit := 4000 targetCM := net.casemap(target) - history, err := dc.user.msgStore.LoadLatestID(&net.Network, targetCM, msgID, limit) + history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, limit) if err != nil { dc.logger.Printf("failed to send backlog for %q: %v", target, err) return @@ -2334,21 +2337,24 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { eventPlayback := dc.caps["draft/event-playback"] + ctx, cancel := context.WithTimeout(context.TODO(), messageStoreTimeout) + defer cancel() + var history []*irc.Message switch subcommand { case "BEFORE": - history, err = store.LoadBeforeTime(&network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback) + history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback) case "AFTER": - history, err = store.LoadAfterTime(&network.Network, entity, bounds[0], time.Now(), limit, eventPlayback) + history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback) case "BETWEEN": if bounds[0].Before(bounds[1]) { - history, err = store.LoadAfterTime(&network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) } else { - history, err = store.LoadBeforeTime(&network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) } case "TARGETS": // TODO: support TARGETS in multi-upstream mode - targets, err := store.ListTargets(&network.Network, bounds[0], bounds[1], limit, eventPlayback) + targets, err := store.ListTargets(ctx, &network.Network, bounds[0], bounds[1], limit, eventPlayback) if err != nil { dc.logger.Printf("failed fetching targets for chathistory: %v", err) return ircError{&irc.Message{ diff --git a/msgstore.go b/msgstore.go index 79f8e9e..deb63b4 100644 --- a/msgstore.go +++ b/msgstore.go @@ -2,6 +2,7 @@ package soju import ( "bytes" + "context" "encoding/base64" "fmt" "time" @@ -19,7 +20,7 @@ type messageStore interface { LastMsgID(network *Network, entity string, t time.Time) (string, error) // LoadLatestID queries the latest non-event messages for the given network, // entity and date, up to a count of limit messages, sorted from oldest to newest. - LoadLatestID(network *Network, entity, id string, limit int) ([]*irc.Message, error) + LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) Append(network *Network, entity string, msg *irc.Message) (id string, err error) } @@ -37,17 +38,17 @@ type chatHistoryMessageStore interface { // It returns up to limit targets, starting from start and ending on end, // both excluded. end may be before or after start. // If events is false, only PRIVMSG/NOTICE messages are considered. - ListTargets(network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) + ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) // LoadBeforeTime loads up to limit messages before start down to end. The // returned messages must be between and excluding the provided bounds. // end is before start. // If events is false, only PRIVMSG/NOTICE messages are considered. - LoadBeforeTime(network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) + LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) // LoadBeforeTime loads up to limit messages after start up to end. The // returned messages must be between and excluding the provided bounds. // end is after start. // If events is false, only PRIVMSG/NOTICE messages are considered. - LoadAfterTime(network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) + LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) } type msgIDType uint diff --git a/msgstore_fs.go b/msgstore_fs.go index 58603c5..c6547b6 100644 --- a/msgstore_fs.go +++ b/msgstore_fs.go @@ -2,6 +2,7 @@ package soju import ( "bufio" + "context" "fmt" "io" "os" @@ -476,7 +477,7 @@ func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, re return history, nil } -func (ms *fsMessageStore) LoadBeforeTime(network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { +func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { start = start.In(time.Local) end = end.In(time.Local) history := make([]*irc.Message, limit) @@ -501,7 +502,7 @@ func (ms *fsMessageStore) LoadBeforeTime(network *Network, entity string, start return history[remaining:], nil } -func (ms *fsMessageStore) LoadAfterTime(network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { +func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { start = start.In(time.Local) end = end.In(time.Local) var history []*irc.Message @@ -525,7 +526,7 @@ func (ms *fsMessageStore) LoadAfterTime(network *Network, entity string, start t return history, nil } -func (ms *fsMessageStore) LoadLatestID(network *Network, entity, id string, limit int) ([]*irc.Message, error) { +func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { var afterTime time.Time var afterOffset int64 if id != "" { @@ -569,7 +570,7 @@ func (ms *fsMessageStore) LoadLatestID(network *Network, entity, id string, limi return history[remaining:], nil } -func (ms *fsMessageStore) ListTargets(network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { +func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { start = start.In(time.Local) end = end.In(time.Local) rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) diff --git a/msgstore_memory.go b/msgstore_memory.go index 73cc42a..677a684 100644 --- a/msgstore_memory.go +++ b/msgstore_memory.go @@ -1,6 +1,7 @@ package soju import ( + "context" "fmt" "time" @@ -91,7 +92,7 @@ func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.M return formatMemoryMsgID(network.ID, entity, seq), nil } -func (ms *memoryMessageStore) LoadLatestID(network *Network, entity, id string, limit int) ([]*irc.Message, error) { +func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { _, _, seq, err := parseMemoryMsgID(id) if err != nil { return nil, err diff --git a/server.go b/server.go index 5d5562c..bb51dfd 100644 --- a/server.go +++ b/server.go @@ -25,6 +25,7 @@ var connectTimeout = 15 * time.Second var writeTimeout = 10 * time.Second var upstreamMessageDelay = 2 * time.Second var upstreamMessageBurst = 10 +var messageStoreTimeout = 10 * time.Second type Logger interface { Print(v ...interface{})