diff --git a/upstream.go b/upstream.go index eba7dc8..8d76cc7 100644 --- a/upstream.go +++ b/upstream.go @@ -421,6 +421,12 @@ func (uc *upstreamConn) isOurNick(nick string) bool { return uc.network.equalCasemap(uc.nick, nick) } +func (uc *upstreamConn) forwardMsgByID(id uint64, msg *irc.Message) { + uc.forEachDownstreamByID(id, func(dc *downstreamConn) { + dc.SendMessage(msg) + }) +} + func (uc *upstreamConn) abortPendingCommands() { for _, l := range uc.pendingCmds { for _, pendingCmd := range l { @@ -1005,13 +1011,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - dc.SendMessage(&irc.Message{ - Prefix: uc.srv.prefix(), - Command: msg.Command, - Params: msg.Params, - }) - }) + uc.forwardMsgByID(downstreamID, msg) case "BATCH": var tag string if err := parseMessageParams(msg, &tag); err != nil { @@ -1471,9 +1471,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err ch := uc.channels.Get(name) if ch == nil { // NAMES on a channel we have not joined, forward to downstream - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - dc.SendMessage(msg) - }) + uc.forwardMsgByID(downstreamID, msg) return nil } @@ -1496,9 +1494,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err ch := uc.channels.Get(name) if ch == nil { // NAMES on a channel we have not joined, forward to downstream - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - dc.SendMessage(msg) - }) + uc.forwardMsgByID(downstreamID, msg) return nil } @@ -1628,9 +1624,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return err } - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - dc.SendMessage(msg) - }) + uc.forwardMsgByID(downstreamID, msg) case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE: var targetsStr string if err := parseMessageParams(msg, nil, &targetsStr); err != nil { @@ -1695,9 +1689,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } }) case irc.RPL_AWAY: - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - dc.SendMessage(msg) - }) + uc.forwardMsgByID(downstreamID, msg) case "AWAY": // Update user flags, if we already have the flags cached uu := uc.users.Get(msg.Prefix.Name) @@ -1728,9 +1720,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err dc.SendMessage(msg) }) case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST, irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST: - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - dc.SendMessage(msg) - }) + uc.forwardMsgByID(downstreamID, msg) case irc.ERR_NOSUCHNICK: var nick, reason string if err := parseMessageParams(msg, nil, &nick, &reason); err != nil { @@ -1760,13 +1750,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.saslStarted = false } - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - dc.SendMessage(&irc.Message{ - Prefix: uc.srv.prefix(), - Command: msg.Command, - Params: []string{dc.nick, command, reason}, - }) - }) + uc.forwardMsgByID(downstreamID, msg) case "FAIL": var command, code string if err := parseMessageParams(msg, &command, &code); err != nil { @@ -1781,9 +1765,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err downstreamID = dc.id } - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - dc.SendMessage(msg) - }) + uc.forwardMsgByID(downstreamID, msg) case "ACK": // Ignore case irc.RPL_NOWAWAY, irc.RPL_UNAWAY: @@ -1803,13 +1785,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - dc.SendMessage(&irc.Message{ - Prefix: uc.srv.prefix(), - Command: msg.Command, - Params: msg.Params, - }) - }) + uc.forwardMsgByID(downstreamID, msg) case irc.RPL_LISTSTART: // Ignore case "ERROR": @@ -1852,10 +1828,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err fallthrough default: uc.logger.Printf("unhandled message: %v", msg) - - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - dc.SendMessage(msg) - }) + uc.forwardMsgByID(downstreamID, msg) } return nil }