soju/msgstore_memory.go
2022-05-09 15:08:04 +02:00

163 lines
3.6 KiB
Go

package soju
import (
"context"
"fmt"
"time"
"git.sr.ht/~sircmpwn/go-bare"
"gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
)
const messageRingBufferCap = 4096
type memoryMsgID struct {
Seq bare.Uint
}
func (memoryMsgID) msgIDType() msgIDType {
return msgIDMemory
}
func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) {
var id memoryMsgID
netID, entity, err = parseMsgID(s, &id)
if err != nil {
return 0, "", 0, err
}
return netID, entity, uint64(id.Seq), nil
}
func formatMemoryMsgID(netID int64, entity string, seq uint64) string {
id := memoryMsgID{bare.Uint(seq)}
return formatMsgID(netID, entity, &id)
}
type ringBufferKey struct {
networkID int64
entity string
}
type memoryMessageStore struct {
buffers map[ringBufferKey]*messageRingBuffer
}
var _ messageStore = (*memoryMessageStore)(nil)
func newMemoryMessageStore() *memoryMessageStore {
return &memoryMessageStore{
buffers: make(map[ringBufferKey]*messageRingBuffer),
}
}
func (ms *memoryMessageStore) Close() error {
ms.buffers = nil
return nil
}
func (ms *memoryMessageStore) get(network *database.Network, entity string) *messageRingBuffer {
k := ringBufferKey{networkID: network.ID, entity: entity}
if rb, ok := ms.buffers[k]; ok {
return rb
}
rb := newMessageRingBuffer(messageRingBufferCap)
ms.buffers[k] = rb
return rb
}
func (ms *memoryMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) {
var seq uint64
k := ringBufferKey{networkID: network.ID, entity: entity}
if rb, ok := ms.buffers[k]; ok {
seq = rb.cur
}
return formatMemoryMsgID(network.ID, entity, seq), nil
}
func (ms *memoryMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
switch msg.Command {
case "PRIVMSG", "NOTICE":
// Only append these messages, because LoadLatestID shouldn't return
// other kinds of message.
default:
return "", nil
}
k := ringBufferKey{networkID: network.ID, entity: entity}
rb, ok := ms.buffers[k]
if !ok {
rb = newMessageRingBuffer(messageRingBufferCap)
ms.buffers[k] = rb
}
seq := rb.Append(msg)
return formatMemoryMsgID(network.ID, entity, seq), nil
}
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
_, _, seq, err := parseMemoryMsgID(id)
if err != nil {
return nil, err
}
k := ringBufferKey{networkID: network.ID, entity: entity}
rb, ok := ms.buffers[k]
if !ok {
return nil, nil
}
return rb.LoadLatestSeq(seq, limit)
}
type messageRingBuffer struct {
buf []*irc.Message
cur uint64
}
func newMessageRingBuffer(capacity int) *messageRingBuffer {
return &messageRingBuffer{
buf: make([]*irc.Message, capacity),
cur: 1,
}
}
func (rb *messageRingBuffer) cap() uint64 {
return uint64(len(rb.buf))
}
func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 {
seq := rb.cur
i := int(seq % rb.cap())
rb.buf[i] = msg
rb.cur++
return seq
}
func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) {
if seq > rb.cur {
return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur)
} else if seq == rb.cur {
return nil, nil
}
// The query excludes the message with the sequence number seq
diff := rb.cur - seq - 1
if diff > rb.cap() {
// We dropped diff - cap entries
diff = rb.cap()
}
if int(diff) > limit {
diff = uint64(limit)
}
l := make([]*irc.Message, int(diff))
for i := 0; i < int(diff); i++ {
j := int((rb.cur - diff + uint64(i)) % rb.cap())
l[i] = rb.buf[j]
}
return l, nil
}