diff --git a/downstream.go b/downstream.go index d4e9564..8643fcd 100644 --- a/downstream.go +++ b/downstream.go @@ -1735,7 +1735,8 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { }} } - if dc.user.msgStore == nil { + store, ok := dc.user.msgStore.(chatHistoryMessageStore) + if !ok { return ircError{&irc.Message{ Command: irc.ERR_UNKNOWNCOMMAND, Params: []string{dc.nick, subcommand, "Unknown command"}, @@ -1775,9 +1776,9 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { var history []*irc.Message switch subcommand { case "BEFORE": - history, err = dc.user.msgStore.LoadBeforeTime(uc.network, entity, timestamp, limit) + history, err = store.LoadBeforeTime(uc.network, entity, timestamp, limit) case "AFTER": - history, err = dc.user.msgStore.LoadAfterTime(uc.network, entity, timestamp, limit) + history, err = store.LoadAfterTime(uc.network, entity, timestamp, limit) default: // TODO: support LATEST, BETWEEN return ircError{&irc.Message{ diff --git a/msgstore.go b/msgstore.go index 2be4f18..18cd0c9 100644 --- a/msgstore.go +++ b/msgstore.go @@ -16,12 +16,19 @@ type messageStore interface { // date. The message ID returned may not refer to a valid message, but can be // used in history queries. LastMsgID(network *network, entity string, t time.Time) (string, error) - LoadBeforeTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) - LoadAfterTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) LoadLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) Append(network *network, entity string, msg *irc.Message) (id string, err error) } +// chatHistoryMessageStore is a message store that supports chat history +// operations. +type chatHistoryMessageStore interface { + messageStore + + LoadBeforeTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) + LoadAfterTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) +} + func formatMsgID(netID int64, entity, extra string) string { return fmt.Sprintf("%v %v %v", netID, entity, extra) }