soju/msgstore_memory.go
2021-01-10 22:48:58 +01:00

145 lines
3.3 KiB
Go

package soju
import (
"fmt"
"strconv"
"time"
"gopkg.in/irc.v3"
)
const messageRingBufferCap = 4096
func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) {
netID, entity, extra, err := parseMsgID(s)
if err != nil {
return 0, "", 0, err
}
seq, err = strconv.ParseUint(extra, 10, 64)
if err != nil {
return 0, "", 0, fmt.Errorf("failed to parse message ID %q: %v", s, err)
}
return netID, entity, seq, nil
}
func formatMemoryMsgID(netID int64, entity string, seq uint64) string {
extra := strconv.FormatUint(seq, 10)
return formatMsgID(netID, entity, extra)
}
type ringBufferKey struct {
networkID int64
entity string
}
type memoryMessageStore struct {
buffers map[ringBufferKey]*messageRingBuffer
}
func newMemoryMessageStore() *memoryMessageStore {
return &memoryMessageStore{
buffers: make(map[ringBufferKey]*messageRingBuffer),
}
}
func (ms *memoryMessageStore) Close() error {
ms.buffers = nil
return nil
}
func (ms *memoryMessageStore) get(network *network, entity string) *messageRingBuffer {
k := ringBufferKey{networkID: network.ID, entity: entity}
if rb, ok := ms.buffers[k]; ok {
return rb
}
rb := newMessageRingBuffer(messageRingBufferCap)
ms.buffers[k] = rb
return rb
}
func (ms *memoryMessageStore) LastMsgID(network *network, entity string, t time.Time) (string, error) {
var seq uint64
k := ringBufferKey{networkID: network.ID, entity: entity}
if rb, ok := ms.buffers[k]; ok {
seq = rb.cur
}
return formatMemoryMsgID(network.ID, entity, seq), nil
}
func (ms *memoryMessageStore) Append(network *network, entity string, msg *irc.Message) (string, error) {
k := ringBufferKey{networkID: network.ID, entity: entity}
rb, ok := ms.buffers[k]
if !ok {
rb = newMessageRingBuffer(messageRingBufferCap)
ms.buffers[k] = rb
}
seq := rb.Append(msg)
return formatMemoryMsgID(network.ID, entity, seq), nil
}
func (ms *memoryMessageStore) LoadLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) {
_, _, seq, err := parseMemoryMsgID(id)
if err != nil {
return nil, err
}
k := ringBufferKey{networkID: network.ID, entity: entity}
rb, ok := ms.buffers[k]
if !ok {
return nil, nil
}
return rb.LoadLatestSeq(seq, limit)
}
type messageRingBuffer struct {
buf []*irc.Message
cur uint64
}
func newMessageRingBuffer(capacity int) *messageRingBuffer {
return &messageRingBuffer{
buf: make([]*irc.Message, capacity),
cur: 1,
}
}
func (rb *messageRingBuffer) cap() uint64 {
return uint64(len(rb.buf))
}
func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 {
seq := rb.cur
i := int(seq % rb.cap())
rb.buf[i] = msg
rb.cur++
return seq
}
func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) {
if seq > rb.cur {
return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur)
} else if seq == rb.cur {
return nil, nil
}
// The query excludes the message with the sequence number seq
diff := rb.cur - seq - 1
if diff > rb.cap() {
// We dropped diff - cap entries
diff = rb.cap()
}
if int(diff) > limit {
diff = uint64(limit)
}
l := make([]*irc.Message, int(diff))
for i := 0; i < int(diff); i++ {
j := int((rb.cur - diff + uint64(i)) % rb.cap())
l[i] = rb.buf[j]
}
return l, nil
}