Add msgstore package

This commit is contained in:
Simon Ser 2022-05-09 16:25:57 +02:00
parent b92afa7cca
commit 620a8789b0
5 changed files with 75 additions and 60 deletions

View file

@ -18,6 +18,7 @@ import (
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/database"
"git.sr.ht/~emersion/soju/msgstore"
"git.sr.ht/~emersion/soju/xirc" "git.sr.ht/~emersion/soju/xirc"
) )
@ -650,7 +651,7 @@ func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) {
// ackMsgID acknowledges that a message has been received. // ackMsgID acknowledges that a message has been received.
func (dc *downstreamConn) ackMsgID(id string) { func (dc *downstreamConn) ackMsgID(id string) {
netID, entity, err := parseMsgID(id, nil) netID, entity, err := msgstore.ParseMsgID(id, nil)
if err != nil { if err != nil {
dc.logger.Printf("failed to ACK message ID %q: %v", id, err) dc.logger.Printf("failed to ACK message ID %q: %v", id, err)
return return
@ -1137,7 +1138,7 @@ func (dc *downstreamConn) updateSupportedCaps() {
dc.unsetSupportedCap("draft/account-registration") dc.unsetSupportedCap("draft/account-registration")
} }
if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil { if _, ok := dc.user.msgStore.(msgstore.ChatHistoryStore); ok && dc.network != nil {
dc.setSupportedCap("draft/event-playback", "") dc.setSupportedCap("draft/event-playback", "")
} else { } else {
dc.unsetSupportedCap("draft/event-playback") dc.unsetSupportedCap("draft/event-playback")
@ -1665,7 +1666,7 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t
defer cancel() defer cancel()
targetCM := net.casemap(target) targetCM := net.casemap(target)
loadOptions := loadMessageOptions{ loadOptions := msgstore.LoadMessageOptions{
Network: &net.Network, Network: &net.Network,
Entity: targetCM, Entity: targetCM,
Limit: backlogLimit, Limit: backlogLimit,
@ -2786,7 +2787,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil return nil
} }
store, ok := dc.user.msgStore.(chatHistoryMessageStore) store, ok := dc.user.msgStore.(msgstore.ChatHistoryStore)
if !ok { if !ok {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_UNKNOWNCOMMAND, Command: irc.ERR_UNKNOWNCOMMAND,
@ -2832,7 +2833,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
eventPlayback := dc.caps.IsEnabled("draft/event-playback") eventPlayback := dc.caps.IsEnabled("draft/event-playback")
options := loadMessageOptions{ options := msgstore.LoadMessageOptions{
Network: &network.Network, Network: &network.Network,
Entity: entity, Entity: entity,
Limit: limit, Limit: limit,
@ -2980,7 +2981,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
}) })
case "SEARCH": case "SEARCH":
store, ok := dc.user.msgStore.(searchMessageStore) store, ok := dc.user.msgStore.(msgstore.SearchStore)
if !ok { if !ok {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_UNKNOWNCOMMAND, Command: irc.ERR_UNKNOWNCOMMAND,
@ -2995,7 +2996,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
var uc *upstreamConn var uc *upstreamConn
const searchMaxLimit = 100 const searchMaxLimit = 100
opts := searchMessageOptions{ opts := msgstore.SearchMessageOptions{
Limit: searchMaxLimit, Limit: searchMaxLimit,
} }
for name, v := range attrs { for name, v := range attrs {

View file

@ -1,4 +1,4 @@
package soju package msgstore
import ( import (
"bufio" "bufio"
@ -57,7 +57,7 @@ func (fsMsgID) msgIDType() msgIDType {
func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) { func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) {
var id fsMsgID var id fsMsgID
netID, entity, err = parseMsgID(s, &id) netID, entity, err = ParseMsgID(s, &id)
if err != nil { if err != nil {
return 0, "", time.Time{}, 0, err return 0, "", time.Time{}, 0, err
} }
@ -89,11 +89,14 @@ type fsMessageStore struct {
files map[string]*fsMessageStoreFile // indexed by entity files map[string]*fsMessageStoreFile // indexed by entity
} }
var _ messageStore = (*fsMessageStore)(nil) var (
var _ chatHistoryMessageStore = (*fsMessageStore)(nil) _ Store = (*fsMessageStore)(nil)
var _ searchMessageStore = (*fsMessageStore)(nil) _ ChatHistoryStore = (*fsMessageStore)(nil)
_ SearchStore = (*fsMessageStore)(nil)
_ RenameNetworkStore = (*fsMessageStore)(nil)
)
func newFSMessageStore(root string, user *database.User) *fsMessageStore { func NewFSStore(root string, user *database.User) *fsMessageStore {
return &fsMessageStore{ return &fsMessageStore{
root: filepath.Join(root, escapeFilename(user.Username)), root: filepath.Join(root, escapeFilename(user.Username)),
user: user, user: user,
@ -402,7 +405,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *database.Network, e
return msg, t, nil return msg, t, nil
} }
func (ms *fsMessageStore) parseMessagesBefore(ref time.Time, end time.Time, options *loadMessageOptions, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) parseMessagesBefore(ref time.Time, end time.Time, options *LoadMessageOptions, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(options.Network, options.Entity, ref) path := ms.logPath(options.Network, options.Entity, ref)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
@ -461,7 +464,7 @@ func (ms *fsMessageStore) parseMessagesBefore(ref time.Time, end time.Time, opti
} }
} }
func (ms *fsMessageStore) parseMessagesAfter(ref time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) parseMessagesAfter(ref time.Time, end time.Time, options *LoadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(options.Network, options.Entity, ref) path := ms.logPath(options.Network, options.Entity, ref)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
@ -496,7 +499,7 @@ func (ms *fsMessageStore) parseMessagesAfter(ref time.Time, end time.Time, optio
return history, nil return history, nil
} }
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) getBeforeTime(ctx context.Context, start time.Time, end time.Time, options *LoadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
if start.IsZero() { if start.IsZero() {
start = time.Now() start = time.Now()
} else { } else {
@ -531,11 +534,11 @@ func (ms *fsMessageStore) getBeforeTime(ctx context.Context, start time.Time, en
return messages[remaining:], nil return messages[remaining:], nil
} }
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, start time.Time, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error) {
return ms.getBeforeTime(ctx, start, end, options, nil) return ms.getBeforeTime(ctx, start, end, options, nil)
} }
func (ms *fsMessageStore) getAfterTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) getAfterTime(ctx context.Context, start time.Time, end time.Time, options *LoadMessageOptions, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
start = start.In(time.Local) start = start.In(time.Local)
if end.IsZero() { if end.IsZero() {
end = time.Now() end = time.Now()
@ -569,11 +572,11 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, start time.Time, end
return messages, nil return messages, nil
} }
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, start time.Time, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, start time.Time, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error) {
return ms.getAfterTime(ctx, start, end, options, nil) return ms.getAfterTime(ctx, start, end, options, nil)
} }
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options *LoadMessageOptions) ([]*irc.Message, error) {
var afterTime time.Time var afterTime time.Time
var afterOffset int64 var afterOffset int64
if id != "" { if id != "" {
@ -623,7 +626,7 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, id string, options *
return history[remaining:], nil return history[remaining:], nil
} }
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) {
start = start.In(time.Local) start = start.In(time.Local)
end = end.In(time.Local) end = end.In(time.Local)
rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
@ -642,7 +645,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Net
return nil, err return nil, err
} }
var targets []chatHistoryTarget var targets []ChatHistoryTarget
for _, target := range targetNames { for _, target := range targetNames {
// target is already escaped here // target is already escaped here
targetPath := filepath.Join(rootPath, target) targetPath := filepath.Join(rootPath, target)
@ -673,7 +676,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Net
continue continue
} }
targets = append(targets, chatHistoryTarget{ targets = append(targets, ChatHistoryTarget{
Name: target, Name: target,
LatestMessage: t, LatestMessage: t,
}) })
@ -702,7 +705,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Net
return targets, nil return targets, nil
} }
func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts *searchMessageOptions) ([]*irc.Message, error) { func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts *SearchMessageOptions) ([]*irc.Message, error) {
text := strings.ToLower(opts.Text) text := strings.ToLower(opts.Text)
selector := func(m *irc.Message) bool { selector := func(m *irc.Message) bool {
if opts.From != "" && m.User != opts.From { if opts.From != "" && m.User != opts.From {
@ -713,7 +716,7 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network,
} }
return true return true
} }
loadOptions := loadMessageOptions{ loadOptions := LoadMessageOptions{
Network: network, Network: network,
Entity: opts.In, Entity: opts.In,
Limit: opts.Limit, Limit: opts.Limit,

View file

@ -1,4 +1,4 @@
package soju package msgstore
import ( import (
"context" "context"
@ -23,7 +23,7 @@ func (memoryMsgID) msgIDType() msgIDType {
func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) { func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) {
var id memoryMsgID var id memoryMsgID
netID, entity, err = parseMsgID(s, &id) netID, entity, err = ParseMsgID(s, &id)
if err != nil { if err != nil {
return 0, "", 0, err return 0, "", 0, err
} }
@ -40,13 +40,18 @@ type ringBufferKey struct {
entity string entity string
} }
func IsMemoryStore(store Store) bool {
_, ok := store.(*memoryMessageStore)
return ok
}
type memoryMessageStore struct { type memoryMessageStore struct {
buffers map[ringBufferKey]*messageRingBuffer buffers map[ringBufferKey]*messageRingBuffer
} }
var _ messageStore = (*memoryMessageStore)(nil) var _ Store = (*memoryMessageStore)(nil)
func newMemoryMessageStore() *memoryMessageStore { func NewMemoryStore() *memoryMessageStore {
return &memoryMessageStore{ return &memoryMessageStore{
buffers: make(map[ringBufferKey]*messageRingBuffer), buffers: make(map[ringBufferKey]*messageRingBuffer),
} }
@ -96,7 +101,7 @@ func (ms *memoryMessageStore) Append(network *database.Network, entity string, m
return formatMemoryMsgID(network.ID, entity, seq), nil return formatMemoryMsgID(network.ID, entity, seq), nil
} }
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) { func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, id string, options *LoadMessageOptions) ([]*irc.Message, error) {
if options.Events { if options.Events {
return nil, fmt.Errorf("events are unsupported for memory message store") return nil, fmt.Errorf("events are unsupported for memory message store")
} }

View file

@ -1,4 +1,4 @@
package soju package msgstore
import ( import (
"bytes" "bytes"
@ -13,15 +13,15 @@ import (
"git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/database"
) )
type loadMessageOptions struct { type LoadMessageOptions struct {
Network *database.Network Network *database.Network
Entity string Entity string
Limit int Limit int
Events bool Events bool
} }
// messageStore is a per-user store for IRC messages. // Store is a per-user store for IRC messages.
type messageStore interface { type Store interface {
Close() error Close() error
// LastMsgID queries the last message ID for the given network, entity and // LastMsgID queries the last message ID for the given network, entity and
// date. The message ID returned may not refer to a valid message, but can be // date. The message ID returned may not refer to a valid message, but can be
@ -29,38 +29,37 @@ type messageStore interface {
LastMsgID(network *database.Network, entity string, t time.Time) (string, error) LastMsgID(network *database.Network, entity string, t time.Time) (string, error)
// LoadLatestID queries the latest non-event messages for the given network, // LoadLatestID queries the latest non-event messages for the given network,
// entity and date, up to a count of limit messages, sorted from oldest to newest. // entity and date, up to a count of limit messages, sorted from oldest to newest.
LoadLatestID(ctx context.Context, id string, options *loadMessageOptions) ([]*irc.Message, error) LoadLatestID(ctx context.Context, id string, options *LoadMessageOptions) ([]*irc.Message, error)
Append(network *database.Network, entity string, msg *irc.Message) (id string, err error) Append(network *database.Network, entity string, msg *irc.Message) (id string, err error)
} }
type chatHistoryTarget struct { type ChatHistoryTarget struct {
Name string Name string
LatestMessage time.Time LatestMessage time.Time
} }
// chatHistoryMessageStore is a message store that supports chat history // ChatHistoryStore is a message store that supports chat history operations.
// operations. type ChatHistoryStore interface {
type chatHistoryMessageStore interface { Store
messageStore
// ListTargets lists channels and nicknames by time of the latest message. // ListTargets lists channels and nicknames by time of the latest message.
// It returns up to limit targets, starting from start and ending on end, // It returns up to limit targets, starting from start and ending on end,
// both excluded. end may be before or after start. // both excluded. end may be before or after start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error)
// LoadBeforeTime loads up to limit messages before start down to end. The // LoadBeforeTime loads up to limit messages before start down to end. The
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is before start. // end is before start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadBeforeTime(ctx context.Context, start, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) LoadBeforeTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error)
// LoadBeforeTime loads up to limit messages after start up to end. The // LoadBeforeTime loads up to limit messages after start up to end. The
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is after start. // end is after start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadAfterTime(ctx context.Context, start, end time.Time, options *loadMessageOptions) ([]*irc.Message, error) LoadAfterTime(ctx context.Context, start, end time.Time, options *LoadMessageOptions) ([]*irc.Message, error)
} }
type searchMessageOptions struct { type SearchMessageOptions struct {
Start time.Time Start time.Time
End time.Time End time.Time
Limit int Limit int
@ -69,13 +68,20 @@ type searchMessageOptions struct {
Text string Text string
} }
// searchMessageStore is a message store that supports server-side search // SearchStore is a message store that supports server-side search operations.
// operations. type SearchStore interface {
type searchMessageStore interface { Store
messageStore
// Search returns messages matching the specified options. // Search returns messages matching the specified options.
Search(ctx context.Context, network *database.Network, options *searchMessageOptions) ([]*irc.Message, error) Search(ctx context.Context, network *database.Network, options *SearchMessageOptions) ([]*irc.Message, error)
}
// RenameNetworkStore is a message store which needs to be notified of network
// name changes.
type RenameNetworkStore interface {
Store
RenameNetwork(oldNet, newNet *database.Network) error
} }
type msgIDType uint type msgIDType uint
@ -118,7 +124,7 @@ func formatMsgID(netID int64, target string, body msgIDBody) string {
return base64.RawURLEncoding.EncodeToString(buf.Bytes()) return base64.RawURLEncoding.EncodeToString(buf.Bytes())
} }
func parseMsgID(s string, body msgIDBody) (netID int64, target string, err error) { func ParseMsgID(s string, body msgIDBody) (netID int64, target string, err error) {
b, err := base64.RawURLEncoding.DecodeString(s) b, err := base64.RawURLEncoding.DecodeString(s)
if err != nil { if err != nil {
return 0, "", fmt.Errorf("invalid internal message ID: %v", err) return 0, "", fmt.Errorf("invalid internal message ID: %v", err)

20
user.go
View file

@ -16,6 +16,7 @@ import (
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/database"
"git.sr.ht/~emersion/soju/msgstore"
) )
type event interface{} type event interface{}
@ -454,17 +455,17 @@ type user struct {
networks []*network networks []*network
downstreamConns []*downstreamConn downstreamConns []*downstreamConn
msgStore messageStore msgStore msgstore.Store
} }
func newUser(srv *Server, record *database.User) *user { func newUser(srv *Server, record *database.User) *user {
logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
var msgStore messageStore var msgStore msgstore.Store
if logPath := srv.Config().LogPath; logPath != "" { if logPath := srv.Config().LogPath; logPath != "" {
msgStore = newFSMessageStore(logPath, record) msgStore = msgstore.NewFSStore(logPath, record)
} else { } else {
msgStore = newMemoryMessageStore() msgStore = msgstore.NewMemoryStore()
} }
return &user{ return &user{
@ -951,10 +952,10 @@ func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*ne
// The filesystem message store needs to be notified whenever the network // The filesystem message store needs to be notified whenever the network
// is renamed // is renamed
fsMsgStore, isFS := u.msgStore.(*fsMessageStore) renameNetMsgStore, ok := u.msgStore.(msgstore.RenameNetworkStore)
if isFS && updatedNetwork.GetName() != network.GetName() { if ok && updatedNetwork.GetName() != network.GetName() {
if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil { if err := renameNetMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err) network.logger.Printf("failed to update message store network name to %q: %v", updatedNetwork.GetName(), err)
} }
} }
@ -1049,8 +1050,7 @@ func (u *user) hasPersistentMsgStore() bool {
if u.msgStore == nil { if u.msgStore == nil {
return false return false
} }
_, isMem := u.msgStore.(*memoryMessageStore) return !msgstore.IsMemoryStore(u.msgStore)
return !isMem
} }
// localAddrForHost returns the local address to use when connecting to host. // localAddrForHost returns the local address to use when connecting to host.