downstream: pass context to SendMessage

Just like upstream does.
This commit is contained in:
Simon Ser 2023-04-06 13:23:20 +02:00
parent 51768c256a
commit 6f01bd86c3
5 changed files with 226 additions and 223 deletions

File diff suppressed because it is too large Load diff

View file

@ -462,7 +462,7 @@ func (s *Server) Handle(ic ircConn) {
defer dc.Close() defer dc.Close()
if shutdown { if shutdown {
dc.SendMessage(&irc.Message{ dc.SendMessage(context.TODO(), &irc.Message{
Command: "ERROR", Command: "ERROR",
Params: []string{"Server is shutting down"}, Params: []string{"Server is shutting down"},
}) })
@ -478,7 +478,7 @@ func (s *Server) Handle(ic ircConn) {
user, err := s.getOrCreateUser(context.TODO(), dc.registration.authUsername) user, err := s.getOrCreateUser(context.TODO(), dc.registration.authUsername)
if err != nil { if err != nil {
dc.SendMessage(&irc.Message{ dc.SendMessage(context.TODO(), &irc.Message{
Command: "ERROR", Command: "ERROR",
Params: []string{"Internal server error"}, Params: []string{"Internal server error"},
}) })

View file

@ -56,7 +56,7 @@ type serviceCommand struct {
} }
func sendServiceNOTICE(dc *downstreamConn, text string) { func sendServiceNOTICE(dc *downstreamConn, text string) {
dc.SendMessage(&irc.Message{ dc.SendMessage(context.TODO(), &irc.Message{
Prefix: servicePrefix, Prefix: servicePrefix,
Command: "NOTICE", Command: "NOTICE",
Params: []string{dc.nick, text}, Params: []string{dc.nick, text},
@ -64,7 +64,7 @@ func sendServiceNOTICE(dc *downstreamConn, text string) {
} }
func sendServicePRIVMSG(dc *downstreamConn, text string) { func sendServicePRIVMSG(dc *downstreamConn, text string) {
dc.SendMessage(&irc.Message{ dc.SendMessage(context.TODO(), &irc.Message{
Prefix: servicePrefix, Prefix: servicePrefix,
Command: "PRIVMSG", Command: "PRIVMSG",
Params: []string{dc.nick, text}, Params: []string{dc.nick, text},

View file

@ -421,19 +421,20 @@ func (uc *upstreamConn) isOurNick(nick string) bool {
return uc.network.equalCasemap(uc.nick, nick) return uc.network.equalCasemap(uc.nick, nick)
} }
func (uc *upstreamConn) forwardMessage(msg *irc.Message) { func (uc *upstreamConn) forwardMessage(ctx context.Context, msg *irc.Message) {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
}) })
} }
func (uc *upstreamConn) forwardMsgByID(id uint64, msg *irc.Message) { func (uc *upstreamConn) forwardMsgByID(ctx context.Context, id uint64, msg *irc.Message) {
uc.forEachDownstreamByID(id, func(dc *downstreamConn) { uc.forEachDownstreamByID(id, func(dc *downstreamConn) {
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
}) })
} }
func (uc *upstreamConn) abortPendingCommands() { func (uc *upstreamConn) abortPendingCommands() {
ctx := context.TODO()
for _, l := range uc.pendingCmds { for _, l := range uc.pendingCmds {
for _, pendingCmd := range l { for _, pendingCmd := range l {
dc := uc.downstreamByID(pendingCmd.downstreamID) dc := uc.downstreamByID(pendingCmd.downstreamID)
@ -443,7 +444,7 @@ func (uc *upstreamConn) abortPendingCommands() {
switch pendingCmd.msg.Command { switch pendingCmd.msg.Command {
case "LIST": case "LIST":
dc.SendMessage(&irc.Message{ dc.SendMessage(ctx, &irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_LISTEND, Command: irc.RPL_LISTEND,
Params: []string{dc.nick, "Command aborted"}, Params: []string{dc.nick, "Command aborted"},
@ -453,26 +454,26 @@ func (uc *upstreamConn) abortPendingCommands() {
if len(pendingCmd.msg.Params) > 0 { if len(pendingCmd.msg.Params) > 0 {
mask = pendingCmd.msg.Params[0] mask = pendingCmd.msg.Params[0]
} }
dc.SendMessage(&irc.Message{ dc.SendMessage(ctx, &irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_ENDOFWHO, Command: irc.RPL_ENDOFWHO,
Params: []string{dc.nick, mask, "Command aborted"}, Params: []string{dc.nick, mask, "Command aborted"},
}) })
case "WHOIS": case "WHOIS":
nick := pendingCmd.msg.Params[len(pendingCmd.msg.Params)-1] nick := pendingCmd.msg.Params[len(pendingCmd.msg.Params)-1]
dc.SendMessage(&irc.Message{ dc.SendMessage(ctx, &irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_ENDOFWHOIS, Command: irc.RPL_ENDOFWHOIS,
Params: []string{dc.nick, nick, "Command aborted"}, Params: []string{dc.nick, nick, "Command aborted"},
}) })
case "AUTHENTICATE": case "AUTHENTICATE":
dc.endSASL(&irc.Message{ dc.endSASL(ctx, &irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.ERR_SASLABORTED, Command: irc.ERR_SASLABORTED,
Params: []string{dc.nick, "SASL authentication aborted"}, Params: []string{dc.nick, "SASL authentication aborted"},
}) })
case "REGISTER", "VERIFY": case "REGISTER", "VERIFY":
dc.SendMessage(&irc.Message{ dc.SendMessage(ctx, &irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "FAIL", Command: "FAIL",
Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"}, Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"},
@ -732,7 +733,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if uc.registered { if uc.registered {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps() dc.updateSupportedCaps(ctx)
}) })
} }
case "NEW": case "NEW":
@ -753,7 +754,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if uc.registered { if uc.registered {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps() dc.updateSupportedCaps(ctx)
}) })
} }
default: default:
@ -818,8 +819,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.logger.Printf("logged in with account %q", uc.account) uc.logger.Printf("logged in with account %q", uc.account)
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateAccount() dc.updateAccount(ctx)
dc.updateHost() dc.updateHost(ctx)
}) })
case irc.RPL_LOGGEDOUT: case irc.RPL_LOGGEDOUT:
var rawPrefix string var rawPrefix string
@ -835,8 +836,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.logger.Printf("logged out") uc.logger.Printf("logged out")
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateAccount() dc.updateAccount(ctx)
dc.updateHost() dc.updateHost(ctx)
}) })
case xirc.RPL_VISIBLEHOST: case xirc.RPL_VISIBLEHOST:
var rawHost string var rawHost string
@ -852,7 +853,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateHost() dc.updateHost(ctx)
}) })
case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED: case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
var info string var info string
@ -876,7 +877,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.network.autoSaveSASLPlain(ctx, dc.sasl.plain.Username, dc.sasl.plain.Password) uc.network.autoSaveSASLPlain(ctx, dc.sasl.plain.Username, dc.sasl.plain.Password)
} }
dc.endSASL(msg) dc.endSASL(ctx, msg)
} }
if !uc.registered { if !uc.registered {
@ -898,7 +899,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.network.autoSaveSASLPlain(ctx, account, password) uc.network.autoSaveSASLPlain(ctx, account, password)
} }
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
} }
case irc.RPL_WELCOME: case irc.RPL_WELCOME:
if err := parseMessageParams(msg, &uc.nick); err != nil { if err := parseMessageParams(msg, &uc.nick); err != nil {
@ -993,7 +994,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
msgs := xirc.GenerateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport) msgs := xirc.GenerateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport)
for _, msg := range msgs { for _, msg := range msgs {
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
} }
}) })
case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD: case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD:
@ -1017,7 +1018,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return nil return nil
} }
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
case "BATCH": case "BATCH":
var tag string var tag string
if err := parseMessageParams(msg, &tag); err != nil { if err := parseMessageParams(msg, &tag); err != nil {
@ -1088,10 +1089,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}) })
if !me { if !me {
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
} else { } else {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateNick() dc.updateNick(ctx)
}) })
uc.updateMonitor() uc.updateMonitor()
} }
@ -1112,10 +1113,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.realname = newRealname uc.realname = newRealname
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateRealname() dc.updateRealname(ctx)
}) })
} else { } else {
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
} }
case "CHGHOST": case "CHGHOST":
var newUsername, newHostname string var newUsername, newHostname string
@ -1135,11 +1136,11 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.hostname = newHostname uc.hostname = newHostname
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateHost() dc.updateHost(ctx)
}) })
} else { } else {
// TODO: add fallback with QUIT/JOIN/MODE messages // TODO: add fallback with QUIT/JOIN/MODE messages
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
} }
case "JOIN": case "JOIN":
var channels string var channels string
@ -1274,7 +1275,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.users.Del(msg.Prefix.Name) uc.users.Del(msg.Prefix.Name)
if msg.Prefix.Name != uc.nick { if msg.Prefix.Name != uc.nick {
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
} }
case irc.RPL_TOPIC, irc.RPL_NOTOPIC: case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
var name, topic string var name, topic string
@ -1322,7 +1323,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return err return err
} }
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
} else { // channel mode change } else { // channel mode change
ch, err := uc.getChannel(name) ch, err := uc.getChannel(name)
if err != nil { if err != nil {
@ -1338,7 +1339,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
c := uc.network.channels.Get(name) c := uc.network.channels.Get(name)
if c == nil || !c.Detached { if c == nil || !c.Detached {
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
} }
} }
case irc.RPL_UMODEIS: case irc.RPL_UMODEIS:
@ -1355,7 +1356,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return err return err
} }
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
case irc.RPL_CHANNELMODEIS: case irc.RPL_CHANNELMODEIS:
var channel string var channel string
if err := parseMessageParams(msg, nil, &channel); err != nil { if err := parseMessageParams(msg, nil, &channel); err != nil {
@ -1379,7 +1380,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
c := uc.network.channels.Get(channel) c := uc.network.channels.Get(channel)
if firstMode && (c == nil || !c.Detached) { if firstMode && (c == nil || !c.Detached) {
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
} }
case xirc.RPL_CREATIONTIME: case xirc.RPL_CREATIONTIME:
var channel, creationTime string var channel, creationTime string
@ -1397,7 +1398,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
c := uc.network.channels.Get(channel) c := uc.network.channels.Get(channel)
if firstCreationTime && (c == nil || !c.Detached) { if firstCreationTime && (c == nil || !c.Detached) {
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
} }
case xirc.RPL_TOPICWHOTIME: case xirc.RPL_TOPICWHOTIME:
var channel, who, timeStr string var channel, who, timeStr string
@ -1420,7 +1421,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
c := uc.network.channels.Get(channel) c := uc.network.channels.Get(channel)
if firstTopicWhoTime && (c == nil || !c.Detached) { if firstTopicWhoTime && (c == nil || !c.Detached) {
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
} }
case irc.RPL_LIST: case irc.RPL_LIST:
dc, cmd := uc.currentPendingCommand("LIST") dc, cmd := uc.currentPendingCommand("LIST")
@ -1430,7 +1431,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return nil return nil
} }
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
case irc.RPL_LISTEND: case irc.RPL_LISTEND:
dc, cmd := uc.dequeueCommand("LIST") dc, cmd := uc.dequeueCommand("LIST")
if cmd == nil { if cmd == nil {
@ -1439,7 +1440,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return nil return nil
} }
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
case irc.RPL_NAMREPLY: case irc.RPL_NAMREPLY:
var name, statusStr, members string var name, statusStr, members string
if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil { if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
@ -1449,7 +1450,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
ch := uc.channels.Get(name) ch := uc.channels.Get(name)
if ch == nil { if ch == nil {
// NAMES on a channel we have not joined, forward to downstream // NAMES on a channel we have not joined, forward to downstream
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
return nil return nil
} }
@ -1472,7 +1473,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
ch := uc.channels.Get(name) ch := uc.channels.Get(name)
if ch == nil { if ch == nil {
// NAMES on a channel we have not joined, forward to downstream // NAMES on a channel we have not joined, forward to downstream
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
return nil return nil
} }
@ -1506,7 +1507,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
realname := parts[1] realname := parts[1]
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
if uc.shouldCacheUserInfo(nick) { if uc.shouldCacheUserInfo(nick) {
uc.cacheUserInfo(nick, &upstreamUser{ uc.cacheUserInfo(nick, &upstreamUser{
@ -1526,7 +1527,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return nil return nil
} }
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
if len(cmd.Params) > 1 { if len(cmd.Params) > 1 {
fields, _ := xirc.ParseWHOXOptions(cmd.Params[1]) fields, _ := xirc.ParseWHOXOptions(cmd.Params[1])
@ -1559,7 +1560,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return nil return nil
} }
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
case xirc.RPL_WHOISCERTFP, xirc.RPL_WHOISREGNICK, irc.RPL_WHOISUSER, irc.RPL_WHOISSERVER, irc.RPL_WHOISCHANNELS, irc.RPL_WHOISOPERATOR, irc.RPL_WHOISIDLE, xirc.RPL_WHOISSPECIAL, xirc.RPL_WHOISACCOUNT, xirc.RPL_WHOISACTUALLY, xirc.RPL_WHOISHOST, xirc.RPL_WHOISMODES, xirc.RPL_WHOISSECURE: case xirc.RPL_WHOISCERTFP, xirc.RPL_WHOISREGNICK, irc.RPL_WHOISUSER, irc.RPL_WHOISSERVER, irc.RPL_WHOISCHANNELS, irc.RPL_WHOISOPERATOR, irc.RPL_WHOISIDLE, xirc.RPL_WHOISSPECIAL, xirc.RPL_WHOISACCOUNT, xirc.RPL_WHOISACTUALLY, xirc.RPL_WHOISHOST, xirc.RPL_WHOISMODES, xirc.RPL_WHOISSECURE:
dc, cmd := uc.currentPendingCommand("WHOIS") dc, cmd := uc.currentPendingCommand("WHOIS")
if cmd == nil { if cmd == nil {
@ -1568,7 +1569,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return nil return nil
} }
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
case irc.RPL_ENDOFWHOIS: case irc.RPL_ENDOFWHOIS:
dc, cmd := uc.dequeueCommand("WHOIS") dc, cmd := uc.dequeueCommand("WHOIS")
if cmd == nil { if cmd == nil {
@ -1577,7 +1578,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return nil return nil
} }
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
case "INVITE": case "INVITE":
var nick, channel string var nick, channel string
if err := parseMessageParams(msg, &nick, &channel); err != nil { if err := parseMessageParams(msg, &nick, &channel); err != nil {
@ -1590,7 +1591,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if !weAreInvited && !dc.caps.IsEnabled("invite-notify") { if !weAreInvited && !dc.caps.IsEnabled("invite-notify") {
return return
} }
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
}) })
if weAreInvited { if weAreInvited {
@ -1602,7 +1603,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return err return err
} }
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE: case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE:
var targetsStr string var targetsStr string
if err := parseMessageParams(msg, nil, &targetsStr); err != nil { if err := parseMessageParams(msg, nil, &targetsStr); err != nil {
@ -1640,7 +1641,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
for _, target := range targets { for _, target := range targets {
prefix := irc.ParsePrefix(target) prefix := irc.ParsePrefix(target)
if dc.monitored.Has(prefix.Name) { if dc.monitored.Has(prefix.Name) {
dc.SendMessage(&irc.Message{ dc.SendMessage(ctx, &irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: msg.Command, Command: msg.Command,
Params: []string{dc.nick, target}, Params: []string{dc.nick, target},
@ -1658,7 +1659,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
for _, target := range targets { for _, target := range targets {
if dc.monitored.Has(target) { if dc.monitored.Has(target) {
dc.SendMessage(&irc.Message{ dc.SendMessage(ctx, &irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: msg.Command, Command: msg.Command,
Params: []string{dc.nick, limit, target}, Params: []string{dc.nick, limit, target},
@ -1667,7 +1668,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
}) })
case irc.RPL_AWAY: case irc.RPL_AWAY:
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
case "AWAY": case "AWAY":
// Update user flags, if we already have the flags cached // Update user flags, if we already have the flags cached
uu := uc.users.Get(msg.Prefix.Name) uu := uc.users.Get(msg.Prefix.Name)
@ -1683,7 +1684,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}) })
} }
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
case "ACCOUNT": case "ACCOUNT":
var account string var account string
if err := parseMessageParams(msg, &account); err != nil { if err := parseMessageParams(msg, &account); err != nil {
@ -1692,9 +1693,9 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.cacheUserInfo(msg.Prefix.Name, &upstreamUser{ uc.cacheUserInfo(msg.Prefix.Name, &upstreamUser{
Account: account, Account: account,
}) })
uc.forwardMessage(msg) uc.forwardMessage(ctx, msg)
case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST, irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST: case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST, irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST:
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
case irc.ERR_NOSUCHNICK: case irc.ERR_NOSUCHNICK:
var nick, reason string var nick, reason string
if err := parseMessageParams(msg, nil, &nick, &reason); err != nil { if err := parseMessageParams(msg, nil, &nick, &reason); err != nil {
@ -1706,10 +1707,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if cmd != nil && cm(cmd.Params[len(cmd.Params)-1]) == cm(nick) { if cmd != nil && cm(cmd.Params[len(cmd.Params)-1]) == cm(nick) {
uc.dequeueCommand("WHOIS") uc.dequeueCommand("WHOIS")
if dc != nil { if dc != nil {
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
} }
} else { } else {
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
} }
case xirc.ERR_UNKNOWNERROR, irc.ERR_UNKNOWNCOMMAND, irc.ERR_NEEDMOREPARAMS, irc.RPL_TRYAGAIN: case xirc.ERR_UNKNOWNERROR, irc.ERR_UNKNOWNCOMMAND, irc.ERR_NEEDMOREPARAMS, irc.RPL_TRYAGAIN:
var command, reason string var command, reason string
@ -1726,7 +1727,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.saslStarted = false uc.saslStarted = false
} }
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
case "FAIL": case "FAIL":
var command, code string var command, code string
if err := parseMessageParams(msg, &command, &code); err != nil { if err := parseMessageParams(msg, &command, &code); err != nil {
@ -1741,7 +1742,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
downstreamID = dc.id downstreamID = dc.id
} }
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
case "ACK": case "ACK":
// Ignore // Ignore
case irc.RPL_NOWAWAY, irc.RPL_UNAWAY: case irc.RPL_NOWAWAY, irc.RPL_UNAWAY:
@ -1761,7 +1762,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return nil return nil
} }
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
case irc.RPL_LISTSTART: case irc.RPL_LISTSTART:
// Ignore // Ignore
case "ERROR": case "ERROR":
@ -1801,10 +1802,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if !uc.registered { if !uc.registered {
return registrationError{msg} return registrationError{msg}
} }
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
default: default:
uc.logger.Printf("unhandled message: %v", msg) uc.logger.Printf("unhandled message: %v", msg)
uc.forwardMsgByID(downstreamID, msg) uc.forwardMsgByID(ctx, downstreamID, msg)
} }
return nil return nil
} }
@ -2146,12 +2147,13 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, originID uint64
ch := uc.network.channels.Get(target) ch := uc.network.channels.Get(target)
detached := ch != nil && ch.Detached detached := ch != nil && ch.Detached
ctx := context.TODO()
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
echo := dc.id == originID && msg.Prefix != nil && uc.isOurNick(msg.Prefix.Name) echo := dc.id == originID && msg.Prefix != nil && uc.isOurNick(msg.Prefix.Name)
if !detached && (!echo || dc.caps.IsEnabled("echo-message")) { if !detached && (!echo || dc.caps.IsEnabled("echo-message")) {
dc.sendMessageWithID(msg, msgID) dc.sendMessageWithID(ctx, msg, msgID)
} else { } else {
dc.advanceMessageWithID(msg, msgID) dc.advanceMessageWithID(ctx, msg, msgID)
} }
}) })
} }

31
user.go
View file

@ -337,7 +337,7 @@ func (net *network) detach(ch *database.Channel) {
} }
net.forEachDownstream(func(dc *downstreamConn) { net.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{ dc.SendMessage(context.TODO(), &irc.Message{
Prefix: dc.prefix(), Prefix: dc.prefix(),
Command: "PART", Command: "PART",
Params: []string{ch.Name, "Detach"}, Params: []string{ch.Name, "Detach"},
@ -364,7 +364,7 @@ func (net *network) attach(ctx context.Context, ch *database.Channel) {
} }
net.forEachDownstream(func(dc *downstreamConn) { net.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{ dc.SendMessage(ctx, &irc.Message{
Prefix: dc.prefix(), Prefix: dc.prefix(),
Command: "JOIN", Command: "JOIN",
Params: []string{ch.Name}, Params: []string{ch.Name},
@ -642,17 +642,18 @@ func (u *user) run() {
uc.updateAway() uc.updateAway()
uc.updateMonitor() uc.updateMonitor()
ctx := context.TODO()
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps() dc.updateSupportedCaps(ctx)
if !dc.caps.IsEnabled("soju.im/bouncer-networks") { if !dc.caps.IsEnabled("soju.im/bouncer-networks") {
sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName())) sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
} }
dc.updateNick() dc.updateNick(ctx)
dc.updateHost() dc.updateHost(ctx)
dc.updateRealname() dc.updateRealname(ctx)
dc.updateAccount() dc.updateAccount(ctx)
dc.updateCasemapping() dc.updateCasemapping()
}) })
u.notifyBouncerNetworkState(uc.network.ID, irc.Tags{ u.notifyBouncerNetworkState(uc.network.ID, irc.Tags{
@ -729,7 +730,7 @@ func (u *user) run() {
} }
if !u.Enabled { if !u.Enabled {
dc.SendMessage(&irc.Message{ dc.SendMessage(ctx, &irc.Message{
Command: "ERROR", Command: "ERROR",
Params: []string{"This bouncer account is disabled"}, Params: []string{"This bouncer account is disabled"},
}) })
@ -741,9 +742,9 @@ func (u *user) run() {
if ircErr, ok := err.(ircError); ok { if ircErr, ok := err.(ircError); ok {
msg := ircErr.Message.Copy() msg := ircErr.Message.Copy()
msg.Prefix = dc.srv.prefix() msg.Prefix = dc.srv.prefix()
dc.SendMessage(msg) dc.SendMessage(ctx, msg)
} else { } else {
dc.SendMessage(&irc.Message{ dc.SendMessage(ctx, &irc.Message{
Command: "ERROR", Command: "ERROR",
Params: []string{"Internal server error"}, Params: []string{"Internal server error"},
}) })
@ -799,7 +800,7 @@ func (u *user) run() {
err := dc.handleMessage(context.TODO(), msg) err := dc.handleMessage(context.TODO(), msg)
if ircErr, ok := err.(ircError); ok { if ircErr, ok := err.(ircError); ok {
ircErr.Message.Prefix = dc.srv.prefix() ircErr.Message.Prefix = dc.srv.prefix()
dc.SendMessage(ircErr.Message) dc.SendMessage(context.TODO(), ircErr.Message)
} else if err != nil { } else if err != nil {
dc.logger.Printf("failed to handle message %q: %v", msg, err) dc.logger.Printf("failed to handle message %q: %v", msg, err)
dc.Close() dc.Close()
@ -807,7 +808,7 @@ func (u *user) run() {
case eventBroadcast: case eventBroadcast:
msg := e.msg msg := e.msg
for _, dc := range u.downstreamConns { for _, dc := range u.downstreamConns {
dc.SendMessage(msg) dc.SendMessage(context.TODO(), msg)
} }
case eventUserUpdate: case eventUserUpdate:
e.done <- u.updateUser(context.TODO(), func(record *database.User) error { e.done <- u.updateUser(context.TODO(), func(record *database.User) error {
@ -882,7 +883,7 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
}) })
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps() dc.updateSupportedCaps(context.TODO())
}) })
// If the network has been removed, don't send a state change notification // If the network has been removed, don't send a state change notification
@ -912,7 +913,7 @@ func (u *user) notifyBouncerNetworkState(netID int64, attrs irc.Tags) {
netIDStr := fmt.Sprintf("%v", netID) netIDStr := fmt.Sprintf("%v", netID)
for _, dc := range u.downstreamConns { for _, dc := range u.downstreamConns {
if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{ dc.SendMessage(context.TODO(), &irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "BOUNCER", Command: "BOUNCER",
Params: []string{"NETWORK", netIDStr, attrs.String()}, Params: []string{"NETWORK", netIDStr, attrs.String()},
@ -1116,7 +1117,7 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
idStr := fmt.Sprintf("%v", network.ID) idStr := fmt.Sprintf("%v", network.ID)
for _, dc := range u.downstreamConns { for _, dc := range u.downstreamConns {
if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{ dc.SendMessage(ctx, &irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "BOUNCER", Command: "BOUNCER",
Params: []string{"NETWORK", idStr, "*"}, Params: []string{"NETWORK", idStr, "*"},