diff --git a/downstream.go b/downstream.go index 26069cf..6352f23 100644 --- a/downstream.go +++ b/downstream.go @@ -863,7 +863,7 @@ func (dc *downstreamConn) welcome() error { continue } - lastID, err := lastMsgID(net, target, time.Now()) + lastID, err := dc.user.msgStore.LastMsgID(net, target, time.Now()) if err != nil { dc.logger.Printf("failed to get last message ID: %v", err) continue @@ -876,7 +876,7 @@ func (dc *downstreamConn) welcome() error { } func (dc *downstreamConn) sendNetworkHistory(net *network) { - if dc.caps["draft/chathistory"] || dc.srv.LogPath == "" { + if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { return } for target, history := range net.history { @@ -890,7 +890,7 @@ func (dc *downstreamConn) sendNetworkHistory(net *network) { } limit := 4000 - history, err := loadHistoryLatestID(net, target, lastDelivered, limit) + history, err := dc.user.msgStore.LoadLatestID(net, target, lastDelivered, limit) if err != nil { dc.logger.Printf("failed to send implicit history for %q: %v", target, err) continue @@ -1601,7 +1601,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { }} } - if dc.srv.LogPath == "" { + if dc.user.msgStore == nil { return ircError{&irc.Message{ Command: irc.ERR_UNKNOWNCOMMAND, Params: []string{dc.nick, subcommand, "Unknown command"}, @@ -1641,9 +1641,9 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { var history []*irc.Message switch subcommand { case "BEFORE": - history, err = loadHistoryBeforeTime(uc.network, entity, timestamp, limit) + history, err = dc.user.msgStore.LoadBeforeTime(uc.network, entity, timestamp, limit) case "AFTER": - history, err = loadHistoryAfterTime(uc.network, entity, timestamp, limit) + history, err = dc.user.msgStore.LoadAfterTime(uc.network, entity, timestamp, limit) default: // TODO: support LATEST, BETWEEN return ircError{&irc.Message{ diff --git a/logger.go b/msgstore.go similarity index 74% rename from logger.go rename to msgstore.go index 86381a3..e0bb779 100644 --- a/logger.go +++ b/msgstore.go @@ -12,32 +12,28 @@ import ( "gopkg.in/irc.v3" ) -const messageLoggerMaxTries = 100 - -type messageLogger struct { - network *network - entity string - - path string - file *os.File -} - -func newMessageLogger(network *network, entity string) *messageLogger { - return &messageLogger{ - network: network, - entity: entity, - } -} +const messageStoreMaxTries = 100 var escapeFilename = strings.NewReplacer("/", "-", "\\", "-") -func logPath(network *network, entity string, t time.Time) string { - user := network.user - srv := user.srv +// messageStore is a per-user store for IRC messages. +type messageStore struct { + root string + files map[string]*os.File // indexed by entity +} + +func newMessageStore(root, username string) *messageStore { + return &messageStore{ + root: filepath.Join(root, escapeFilename.Replace(username)), + files: make(map[string]*os.File), + } +} + +func (ms *messageStore) logPath(network *network, entity string, t time.Time) string { year, month, day := t.Date() filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) - return filepath.Join(srv.LogPath, escapeFilename.Replace(user.Username), escapeFilename.Replace(network.GetName()), escapeFilename.Replace(entity), filename) + return filepath.Join(ms.root, escapeFilename.Replace(network.GetName()), escapeFilename.Replace(entity), filename) } func parseMsgID(s string) (network, entity string, t time.Time, offset int64, err error) { @@ -64,11 +60,11 @@ func nextMsgID(network *network, entity string, t time.Time, f *os.File) (string return formatMsgID(network.GetName(), entity, t, offset), nil } -// lastMsgID queries the last message ID for the given network, entity and +// LastMsgID queries the last message ID for the given network, entity and // date. The message ID returned may not refer to a valid message, but can be // used in history queries. -func lastMsgID(network *network, entity string, t time.Time) (string, error) { - p := logPath(network, entity, t) +func (ms *messageStore) LastMsgID(network *network, entity string, t time.Time) (string, error) { + p := ms.logPath(network, entity, t) fi, err := os.Stat(p) if os.IsNotExist(err) { return formatMsgID(network.GetName(), entity, t, -1), nil @@ -78,7 +74,7 @@ func lastMsgID(network *network, entity string, t time.Time) (string, error) { return formatMsgID(network.GetName(), entity, t, fi.Size()-1), nil } -func (ml *messageLogger) Append(msg *irc.Message) (string, error) { +func (ms *messageStore) Append(network *network, entity string, msg *irc.Message) (string, error) { s := formatMessage(msg) if s == "" { return "", nil @@ -97,44 +93,50 @@ func (ml *messageLogger) Append(msg *irc.Message) (string, error) { } // TODO: enforce maximum open file handles (LRU cache of file handles) + f := ms.files[entity] + // TODO: handle non-monotonic clock behaviour - path := logPath(ml.network, ml.entity, t) - if ml.path != path { - if ml.file != nil { - ml.file.Close() + path := ms.logPath(network, entity, t) + if f == nil || f.Name() != path { + if f != nil { + f.Close() } dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0700); err != nil { - return "", fmt.Errorf("failed to create logs directory %q: %v", dir, err) + return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err) } - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) + var err error + f, err = os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) if err != nil { - return "", fmt.Errorf("failed to open log file %q: %v", path, err) + return "", fmt.Errorf("failed to open message log file %q: %v", path, err) } - ml.path = path - ml.file = f + ms.files[entity] = f } - msgID, err := nextMsgID(ml.network, ml.entity, t, ml.file) + msgID, err := nextMsgID(network, entity, t, f) if err != nil { return "", fmt.Errorf("failed to generate message ID: %v", err) } - _, err = fmt.Fprintf(ml.file, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s) + _, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s) if err != nil { - return "", fmt.Errorf("failed to log message to %q: %v", ml.path, err) + return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err) } + return msgID, nil } -func (ml *messageLogger) Close() error { - if ml.file == nil { - return nil +func (ms *messageStore) Close() error { + var closeErr error + for _, f := range ms.files { + if err := f.Close(); err != nil { + closeErr = fmt.Errorf("failed to close message store: %v", err) + } } - return ml.file.Close() + return closeErr } // formatMessage formats a message log line. It assumes a well-formed IRC @@ -233,8 +235,8 @@ func parseMessage(line, entity string, ref time.Time) (*irc.Message, time.Time, return msg, t, nil } -func parseMessagesBefore(network *network, entity string, ref time.Time, limit int, afterOffset int64) ([]*irc.Message, error) { - path := logPath(network, entity, ref) +func (ms *messageStore) parseMessagesBefore(network *network, entity string, ref time.Time, limit int, afterOffset int64) ([]*irc.Message, error) { + path := ms.logPath(network, entity, ref) f, err := os.Open(path) if err != nil { if os.IsNotExist(err) { @@ -289,8 +291,8 @@ func parseMessagesBefore(network *network, entity string, ref time.Time, limit i } } -func parseMessagesAfter(network *network, entity string, ref time.Time, limit int) ([]*irc.Message, error) { - path := logPath(network, entity, ref) +func (ms *messageStore) parseMessagesAfter(network *network, entity string, ref time.Time, limit int) ([]*irc.Message, error) { + path := ms.logPath(network, entity, ref) f, err := os.Open(path) if err != nil { if os.IsNotExist(err) { @@ -319,12 +321,12 @@ func parseMessagesAfter(network *network, entity string, ref time.Time, limit in return history, nil } -func loadHistoryBeforeTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) { +func (ms *messageStore) LoadBeforeTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) { history := make([]*irc.Message, limit) remaining := limit tries := 0 - for remaining > 0 && tries < messageLoggerMaxTries { - buf, err := parseMessagesBefore(network, entity, t, remaining, -1) + for remaining > 0 && tries < messageStoreMaxTries { + buf, err := ms.parseMessagesBefore(network, entity, t, remaining, -1) if err != nil { return nil, err } @@ -342,13 +344,13 @@ func loadHistoryBeforeTime(network *network, entity string, t time.Time, limit i return history[remaining:], nil } -func loadHistoryAfterTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) { +func (ms *messageStore) LoadAfterTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) { var history []*irc.Message remaining := limit tries := 0 now := time.Now() - for remaining > 0 && tries < messageLoggerMaxTries && t.Before(now) { - buf, err := parseMessagesAfter(network, entity, t, remaining) + for remaining > 0 && tries < messageStoreMaxTries && t.Before(now) { + buf, err := ms.parseMessagesAfter(network, entity, t, remaining) if err != nil { return nil, err } @@ -370,7 +372,7 @@ func truncateDay(t time.Time) time.Time { return time.Date(year, month, day, 0, 0, 0, 0, t.Location()) } -func loadHistoryLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) { +func (ms *messageStore) LoadLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) { var afterTime time.Time var afterOffset int64 if id != "" { @@ -389,13 +391,13 @@ func loadHistoryLatestID(network *network, entity, id string, limit int) ([]*irc t := time.Now() remaining := limit tries := 0 - for remaining > 0 && tries < messageLoggerMaxTries && !truncateDay(t).Before(afterTime) { + for remaining > 0 && tries < messageStoreMaxTries && !truncateDay(t).Before(afterTime) { var offset int64 = -1 if afterOffset >= 0 && truncateDay(t).Equal(afterTime) { offset = afterOffset } - buf, err := parseMessagesBefore(network, entity, t, remaining, offset) + buf, err := ms.parseMessagesBefore(network, entity, t, remaining, offset) if err != nil { return nil, err } diff --git a/upstream.go b/upstream.go index 60c2d2b..729471d 100644 --- a/upstream.go +++ b/upstream.go @@ -81,8 +81,6 @@ type upstreamConn struct { // set of LIST commands in progress, per downstream pendingLISTDownstreamSet map[uint64]struct{} - - messageLoggers map[string]*messageLogger } func connectToUpstream(network *network) (*upstreamConn, error) { @@ -182,7 +180,6 @@ func connectToUpstream(network *network) (*upstreamConn, error) { availableChannelModes: stdChannelModes, availableMemberships: stdMemberships, pendingLISTDownstreamSet: make(map[uint64]struct{}), - messageLoggers: make(map[string]*messageLogger), } return uc, nil } @@ -1611,16 +1608,10 @@ func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message } func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) { - if uc.srv.LogPath == "" { + if uc.user.msgStore == nil { return } - ml, ok := uc.messageLoggers[entity] - if !ok { - ml = newMessageLogger(uc.network, entity) - uc.messageLoggers[entity] = ml - } - detached := false if ch, ok := uc.network.channels[entity]; ok { detached = ch.Detached @@ -1628,7 +1619,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) { history, ok := uc.network.history[entity] if !ok { - lastID, err := lastMsgID(uc.network, entity, time.Now()) + lastID, err := uc.user.msgStore.LastMsgID(uc.network, entity, time.Now()) if err != nil { uc.logger.Printf("failed to log message: failed to get last message ID: %v", err) return @@ -1652,7 +1643,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) { } } - msgID, err := ml.Append(msg) + msgID, err := uc.user.msgStore.Append(uc.network, entity, msg) if err != nil { uc.logger.Printf("failed to log message: %v", err) return diff --git a/user.go b/user.go index 3c74022..b1d0b65 100644 --- a/user.go +++ b/user.go @@ -249,6 +249,7 @@ type user struct { networks []*network downstreamConns []*downstreamConn + msgStore *messageStore // LIST commands in progress pendingLISTs []pendingLIST @@ -261,11 +262,17 @@ type pendingLIST struct { } func newUser(srv *Server, record *User) *user { + var msgStore *messageStore + if srv.LogPath != "" { + msgStore = newMessageStore(srv.LogPath, record.Username) + } + return &user{ - User: *record, - srv: srv, - events: make(chan event, 64), - done: make(chan struct{}), + User: *record, + srv: srv, + events: make(chan event, 64), + done: make(chan struct{}), + msgStore: msgStore, } } @@ -312,7 +319,14 @@ func (u *user) getNetworkByID(id int64) *network { } func (u *user) run() { - defer close(u.done) + defer func() { + if u.msgStore != nil { + if err := u.msgStore.Close(); err != nil { + u.srv.Logger.Printf("failed to close message store for user %q: %v", u.Username, err) + } + } + close(u.done) + }() networks, err := u.srv.db.ListNetworks(u.ID) if err != nil { @@ -459,12 +473,6 @@ func (u *user) run() { func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { uc.network.conn = nil - for _, ml := range uc.messageLoggers { - if err := ml.Close(); err != nil { - uc.logger.Printf("failed to close message logger: %v", err) - } - } - uc.endPendingLISTs(true) uc.forEachDownstream(func(dc *downstreamConn) {