diff --git a/irc.go b/irc.go index 179a089..4b5f4d3 100644 --- a/irc.go +++ b/irc.go @@ -2,6 +2,7 @@ package soju import ( "fmt" + "sort" "strings" "gopkg.in/irc.v3" @@ -18,6 +19,9 @@ const ( const maxMessageLength = 512 +// The server-time layout, as defined in the IRCv3 spec. +const serverTimeLayout = "2006-01-02T15:04:05.000Z" + type userModes string func (ms userModes) Has(c byte) bool { @@ -293,5 +297,75 @@ type batch struct { Label string } -// The server-time layout, as defined in the IRCv3 spec. -const serverTimeLayout = "2006-01-02T15:04:05.000Z" +func join(channels, keys []string) []*irc.Message { + // Put channels with a key first + js := joinSorter{channels, keys} + sort.Sort(&js) + + // Two spaces because there are three words (JOIN, channels and keys) + maxLength := maxMessageLength - (len("JOIN") + 2) + + var msgs []*irc.Message + var channelsBuf, keysBuf strings.Builder + for i, channel := range channels { + key := keys[i] + + n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel) + if key != "" { + n += 1 + len(key) + } + + if channelsBuf.Len() > 0 && n > maxLength { + // No room for the new channel in this message + params := []string{channelsBuf.String()} + if keysBuf.Len() > 0 { + params = append(params, keysBuf.String()) + } + msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params}) + channelsBuf.Reset() + keysBuf.Reset() + } + + if channelsBuf.Len() > 0 { + channelsBuf.WriteByte(',') + } + channelsBuf.WriteString(channel) + if key != "" { + if keysBuf.Len() > 0 { + keysBuf.WriteByte(',') + } + keysBuf.WriteString(key) + } + } + if channelsBuf.Len() > 0 { + params := []string{channelsBuf.String()} + if keysBuf.Len() > 0 { + params = append(params, keysBuf.String()) + } + msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params}) + } + + return msgs +} + +type joinSorter struct { + channels []string + keys []string +} + +func (js *joinSorter) Len() int { + return len(js.channels) +} + +func (js *joinSorter) Less(i, j int) bool { + if (js.keys[i] != "") != (js.keys[j] != "") { + // Only one of the channels has a key + return js.keys[i] != "" + } + return js.channels[i] < js.channels[j] +} + +func (js *joinSorter) Swap(i, j int) { + js.channels[i], js.channels[j] = js.channels[j], js.channels[i] + js.keys[i], js.keys[j] = js.keys[j], js.keys[i] +} diff --git a/upstream.go b/upstream.go index 22b63fe..76479a8 100644 --- a/upstream.go +++ b/upstream.go @@ -553,19 +553,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { }) if len(uc.network.channels) > 0 { - // TODO: split this into multiple messages if need be - var names, keys []string + var channels, keys []string for _, ch := range uc.network.channels { - names = append(names, ch.Name) + channels = append(channels, ch.Name) keys = append(keys, ch.Key) } - uc.SendMessage(&irc.Message{ - Command: "JOIN", - Params: []string{ - strings.Join(names, ","), - strings.Join(keys, ","), - }, - }) + + for _, msg := range join(channels, keys) { + uc.SendMessage(msg) + } } case irc.RPL_MYINFO: if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {