Add a queue for WHO commands

This has the following upsides:

- We can now routes WHO replies to the correct client, without
  broadcasting them to everybody.
- We are less likely to hit server rate limits when multiple downstreams
  are issuing WHO commands at the same time.
This commit is contained in:
Simon Ser 2021-11-09 22:09:17 +01:00
parent 0c360d24c5
commit 0b6ff2e61a
3 changed files with 134 additions and 86 deletions

View file

@ -1864,7 +1864,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
return nil return nil
} }
uc.enqueueLIST(dc, msg) uc.enqueueCommand(dc, msg)
case "NAMES": case "NAMES":
if len(msg.Params) == 0 { if len(msg.Params) == 0 {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -1986,7 +1986,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
params = append(params, options) params = append(params, options)
} }
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.enqueueCommand(dc, &irc.Message{
Command: "WHO", Command: "WHO",
Params: params, Params: params,
}) })

View file

@ -74,7 +74,7 @@ func (uc *upstreamChannel) updateAutoDetach(dur time.Duration) {
type pendingUpstreamCommand struct { type pendingUpstreamCommand struct {
downstreamID uint64 downstreamID uint64
cmd *irc.Message msg *irc.Message
} }
type upstreamConn struct { type upstreamConn struct {
@ -109,10 +109,10 @@ type upstreamConn struct {
casemapIsSet bool casemapIsSet bool
// Queue of LIST commands in progress. The first entry has been sent to the // Queue of commands in progress, indexed by type. The first entry has been
// server and is awaiting reply. The following entries have not been sent // sent to the server and is awaiting reply. The following entries have not
// yet. // been sent yet.
pendingLIST []pendingUpstreamCommand pendingCmds map[string][]pendingUpstreamCommand
gotMotd bool gotMotd bool
} }
@ -208,6 +208,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
availableChannelModes: stdChannelModes, availableChannelModes: stdChannelModes,
availableMemberships: stdMemberships, availableMemberships: stdMemberships,
isupport: make(map[string]*string), isupport: make(map[string]*string),
pendingCmds: make(map[string][]pendingUpstreamCommand),
} }
return uc, nil return uc, nil
} }
@ -225,6 +226,15 @@ func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)
}) })
} }
func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn {
for _, dc := range uc.user.downstreamConns {
if dc.id == id {
return dc
}
}
return nil
}
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
ch := uc.channels.Value(name) ch := uc.channels.Value(name)
if ch == nil { if ch == nil {
@ -241,63 +251,85 @@ func (uc *upstreamConn) isOurNick(nick string) bool {
return uc.nickCM == uc.network.casemap(nick) return uc.nickCM == uc.network.casemap(nick)
} }
func (uc *upstreamConn) endPendingLISTs() { func (uc *upstreamConn) endPendingCommands() {
for _, pendingCmd := range uc.pendingLIST { for _, l := range uc.pendingCmds {
uc.forEachDownstreamByID(pendingCmd.downstreamID, func(dc *downstreamConn) { for _, pendingCmd := range l {
dc.SendMessage(&irc.Message{ dc := uc.downstreamByID(pendingCmd.downstreamID)
Prefix: dc.srv.prefix(), if dc == nil {
Command: irc.RPL_LISTEND, continue
Params: []string{dc.nick, "End of /LIST"}, }
})
})
}
uc.pendingLIST = nil
}
func (uc *upstreamConn) sendNextPendingLIST() { switch pendingCmd.msg.Command {
if len(uc.pendingLIST) == 0 { case "LIST":
return dc.SendMessage(&irc.Message{
} Prefix: dc.srv.prefix(),
uc.SendMessage(uc.pendingLIST[0].cmd) Command: irc.RPL_LISTEND,
} Params: []string{dc.nick, "End of /LIST"},
})
func (uc *upstreamConn) enqueueLIST(dc *downstreamConn, cmd *irc.Message) { case "WHO":
uc.pendingLIST = append(uc.pendingLIST, pendingUpstreamCommand{ mask := "*"
downstreamID: dc.id, if len(pendingCmd.msg.Params) > 0 {
cmd: cmd, mask = pendingCmd.msg.Params[0]
}) }
dc.SendMessage(&irc.Message{
if len(uc.pendingLIST) == 1 { Prefix: dc.srv.prefix(),
uc.sendNextPendingLIST() Command: irc.RPL_ENDOFWHO,
} Params: []string{dc.nick, mask, "End of /WHO"},
} })
default:
func (uc *upstreamConn) currentPendingLIST() (*downstreamConn, *irc.Message) { panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command))
if len(uc.pendingLIST) == 0 { }
return nil, nil
}
pendingCmd := uc.pendingLIST[0]
for _, dc := range uc.user.downstreamConns {
if dc.id == pendingCmd.downstreamID {
return dc, pendingCmd.cmd
} }
} }
return nil, pendingCmd.cmd uc.pendingCmds = make(map[string][]pendingUpstreamCommand)
} }
func (uc *upstreamConn) dequeueLIST() (*downstreamConn, *irc.Message) { func (uc *upstreamConn) sendNextPendingCommand(cmd string) {
dc, cmd := uc.currentPendingLIST() if len(uc.pendingCmds[cmd]) == 0 {
return
}
uc.SendMessage(uc.pendingCmds[cmd][0].msg)
}
if len(uc.pendingLIST) > 0 { func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) {
copy(uc.pendingLIST, uc.pendingLIST[1:]) switch msg.Command {
uc.pendingLIST = uc.pendingLIST[:len(uc.pendingLIST)-1] case "LIST", "WHO":
// Supported
default:
panic(fmt.Errorf("Unsupported pending command %q", msg.Command))
} }
uc.sendNextPendingLIST() uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{
downstreamID: dc.id,
msg: msg,
})
return dc, cmd if len(uc.pendingCmds[msg.Command]) == 1 {
uc.sendNextPendingCommand(msg.Command)
}
}
func (uc *upstreamConn) currentPendingCommand(cmd string) (*downstreamConn, *irc.Message) {
if len(uc.pendingCmds[cmd]) == 0 {
return nil, nil
}
pendingCmd := uc.pendingCmds[cmd][0]
return uc.downstreamByID(pendingCmd.downstreamID), pendingCmd.msg
}
func (uc *upstreamConn) dequeueCommand(cmd string) (*downstreamConn, *irc.Message) {
dc, msg := uc.currentPendingCommand(cmd)
if len(uc.pendingCmds[cmd]) > 0 {
copy(uc.pendingCmds[cmd], uc.pendingCmds[cmd][1:])
uc.pendingCmds[cmd] = uc.pendingCmds[cmd][:len(uc.pendingCmds[cmd])-1]
}
uc.sendNextPendingCommand(cmd)
return dc, msg
} }
func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) { func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) {
@ -1095,7 +1127,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err return err
} }
dc, cmd := uc.currentPendingLIST() dc, cmd := uc.currentPendingCommand("LIST")
if cmd == nil { if cmd == nil {
return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST") return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST")
} else if dc == nil { } else if dc == nil {
@ -1108,7 +1140,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic}, Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic},
}) })
case irc.RPL_LISTEND: case irc.RPL_LISTEND:
dc, cmd := uc.dequeueLIST() dc, cmd := uc.dequeueCommand("LIST")
if cmd == nil { if cmd == nil {
return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST") return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST")
} else if dc == nil { } else if dc == nil {
@ -1195,6 +1227,13 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err return err
} }
dc, cmd := uc.currentPendingCommand("WHO")
if cmd == nil {
return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO")
} else if dc == nil {
return nil
}
parts := strings.SplitN(trailing, " ", 2) parts := strings.SplitN(trailing, " ", 2)
if len(parts) != 2 { if len(parts) != 2 {
return fmt.Errorf("received malformed RPL_WHOREPLY: wrong trailing parameter: %s", trailing) return fmt.Errorf("received malformed RPL_WHOREPLY: wrong trailing parameter: %s", trailing)
@ -1208,35 +1247,46 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
trailing = strconv.Itoa(hops) + " " + realname trailing = strconv.Itoa(hops) + " " + realname
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { if channel != "*" {
channel := channel channel = dc.marshalEntity(uc.network, channel)
if channel != "*" { }
channel = dc.marshalEntity(uc.network, channel) nick = dc.marshalEntity(uc.network, nick)
} dc.SendMessage(&irc.Message{
nick := dc.marshalEntity(uc.network, nick) Prefix: dc.srv.prefix(),
dc.SendMessage(&irc.Message{ Command: irc.RPL_WHOREPLY,
Prefix: dc.srv.prefix(), Params: []string{dc.nick, channel, username, host, server, nick, mode, trailing},
Command: irc.RPL_WHOREPLY,
Params: []string{dc.nick, channel, username, host, server, nick, mode, trailing},
})
}) })
case rpl_whospcrpl:
dc, cmd := uc.currentPendingCommand("WHO")
if cmd == nil {
return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO")
} else if dc == nil {
return nil
}
// Only supported in single-upstream mode, so forward as-is
dc.SendMessage(msg)
case irc.RPL_ENDOFWHO: case irc.RPL_ENDOFWHO:
var name string var name string
if err := parseMessageParams(msg, nil, &name); err != nil { if err := parseMessageParams(msg, nil, &name); err != nil {
return err return err
} }
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { dc, cmd := uc.dequeueCommand("WHO")
name := name if cmd == nil {
if name != "*" { return fmt.Errorf("unexpected RPL_ENDOFWHO: no matching pending WHO")
// TODO: support WHO masks } else if dc == nil {
name = dc.marshalEntity(uc.network, name) return nil
} }
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), mask := "*"
Command: irc.RPL_ENDOFWHO, if len(cmd.Params) > 0 {
Params: []string{dc.nick, name, "End of /WHO list"}, mask = cmd.Params[0]
}) }
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_ENDOFWHO,
Params: []string{dc.nick, mask, "End of /WHO list"},
}) })
case irc.RPL_WHOISUSER: case irc.RPL_WHOISUSER:
var nick, username, host, realname string var nick, username, host, realname string
@ -1436,8 +1486,11 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err return err
} }
if command == "LIST" { if command == "LIST" || command == "WHO" {
uc.endPendingLISTs() dc, _ := uc.dequeueCommand(command)
if dc != nil && downstreamID == 0 {
downstreamID = dc.id
}
} }
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
@ -1453,11 +1506,6 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
// Ignore // Ignore
case irc.RPL_YOURHOST, irc.RPL_CREATED: case irc.RPL_YOURHOST, irc.RPL_CREATED:
// Ignore // Ignore
case rpl_whospcrpl:
// Not supported in multi-upstream mode, forward as-is
uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(msg)
})
case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
fallthrough fallthrough
case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE: case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:

View file

@ -681,7 +681,7 @@ func (u *user) run() {
func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
uc.network.conn = nil uc.network.conn = nil
uc.endPendingLISTs() uc.endPendingCommands()
for _, entry := range uc.channels.innerMap { for _, entry := range uc.channels.innerMap {
uch := entry.value.(*upstreamChannel) uch := entry.value.(*upstreamChannel)