diff --git a/downstream.go b/downstream.go index f7e3808..4420b40 100644 --- a/downstream.go +++ b/downstream.go @@ -156,16 +156,15 @@ func (c *downstreamConn) handleMessage(msg *irc.Message) error { func (c *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { switch msg.Command { case "NICK": - if len(msg.Params) != 1 { - return newNeedMoreParamsError(msg.Command) + if err := parseMessageParams(msg, &c.nick); err != nil { + return err } - c.nick = msg.Params[0] case "USER": - if len(msg.Params) != 4 { - return newNeedMoreParamsError(msg.Command) + var username string + if err := parseMessageParams(msg, &username, nil, nil, &c.realname); err != nil { + return err } - c.username = "~" + msg.Params[0] - c.realname = msg.Params[3] + c.username = "~" + username default: c.logger.Printf("unhandled message: %v", msg) return newUnknownCommandError(msg.Command) diff --git a/irc.go b/irc.go index 728f973..d377ada 100644 --- a/irc.go +++ b/irc.go @@ -3,6 +3,8 @@ package jounce import ( "fmt" "strings" + + "gopkg.in/irc.v3" ) const ( @@ -90,3 +92,15 @@ func parseMembershipPrefix(s string) (prefix membership, nick string) { return 0, s } } + +func parseMessageParams(msg *irc.Message, out ...*string) error { + if len(msg.Params) < len(out) { + return newNeedMoreParamsError(msg.Command) + } + for i := range out { + if out[i] != nil { + *out[i] = msg.Params[i] + } + } + return nil +} diff --git a/upstream.go b/upstream.go index 4baca85..42f9554 100644 --- a/upstream.go +++ b/upstream.go @@ -106,11 +106,10 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { } return nil case "MODE": - if len(msg.Params) < 2 { - return newNeedMoreParamsError(msg.Command) + var name, modeStr string + if err := parseMessageParams(msg, &name, &modeStr); err != nil { + return err } - name := msg.Params[0] - modeStr := msg.Params[1] if name == msg.Prefix.Name { // user mode change if name != c.nick { @@ -143,20 +142,17 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { } } case irc.RPL_MYINFO: - if len(msg.Params) < 5 { - return newNeedMoreParamsError(msg.Command) + if err := parseMessageParams(msg, nil, &c.serverName, nil, &c.availableUserModes, &c.availableChannelModes); err != nil { + return err } - c.serverName = msg.Params[1] - c.availableUserModes = msg.Params[3] - c.availableChannelModes = msg.Params[4] if len(msg.Params) > 5 { c.channelModesWithParam = msg.Params[5] } case "NICK": - if len(msg.Params) < 1 { - return newNeedMoreParamsError(msg.Command) + var newNick string + if err := parseMessageParams(msg, &newNick); err != nil { + return err } - newNick := msg.Params[0] if msg.Prefix.Name == c.nick { c.logger.Printf("changed nick from %q to %q", c.nick, newNick) @@ -174,11 +170,12 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { dc.messages <- msg }) case "JOIN": - if len(msg.Params) < 1 { - return newNeedMoreParamsError(msg.Command) + var channels string + if err := parseMessageParams(msg, &channels); err != nil { + return err } - for _, ch := range strings.Split(msg.Params[0], ",") { + for _, ch := range strings.Split(channels, ",") { if msg.Prefix.Name == c.nick { c.logger.Printf("joined channel %q", ch) c.channels[ch] = &upstreamChannel{ @@ -198,11 +195,12 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { dc.messages <- msg }) case "PART": - if len(msg.Params) < 1 { - return newNeedMoreParamsError(msg.Command) + var channels string + if err := parseMessageParams(msg, &channels); err != nil { + return err } - for _, ch := range strings.Split(msg.Params[0], ",") { + for _, ch := range strings.Split(channels, ",") { if msg.Prefix.Name == c.nick { c.logger.Printf("parted channel %q", ch) delete(c.channels, ch) @@ -219,23 +217,25 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { dc.messages <- msg }) case irc.RPL_TOPIC, irc.RPL_NOTOPIC: - if len(msg.Params) < 3 { - return newNeedMoreParamsError(msg.Command) + var name, topic string + if err := parseMessageParams(msg, nil, &name, &topic); err != nil { + return err } - ch, err := c.getChannel(msg.Params[1]) + ch, err := c.getChannel(name) if err != nil { return err } if msg.Command == irc.RPL_TOPIC { - ch.Topic = msg.Params[2] + ch.Topic = topic } else { ch.Topic = "" } case "TOPIC": - if len(msg.Params) < 1 { - return newNeedMoreParamsError(msg.Command) + var name string + if err := parseMessageParams(msg, nil, &name); err != nil { + return err } - ch, err := c.getChannel(msg.Params[0]) + ch, err := c.getChannel(name) if err != nil { return err } @@ -245,43 +245,46 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { ch.Topic = "" } case rpl_topicwhotime: - if len(msg.Params) < 4 { - return newNeedMoreParamsError(msg.Command) + var name, who, timeStr string + if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil { + return err } - ch, err := c.getChannel(msg.Params[1]) + ch, err := c.getChannel(name) if err != nil { return err } - ch.TopicWho = msg.Params[2] - sec, err := strconv.ParseInt(msg.Params[3], 10, 64) + ch.TopicWho = who + sec, err := strconv.ParseInt(timeStr, 10, 64) if err != nil { return fmt.Errorf("failed to parse topic time: %v", err) } ch.TopicTime = time.Unix(sec, 0) case irc.RPL_NAMREPLY: - if len(msg.Params) < 4 { - return newNeedMoreParamsError(msg.Command) + var name, statusStr, members string + if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil { + return err } - ch, err := c.getChannel(msg.Params[2]) + ch, err := c.getChannel(name) if err != nil { return err } - status, err := parseChannelStatus(msg.Params[1]) + status, err := parseChannelStatus(statusStr) if err != nil { return err } ch.Status = status - for _, s := range strings.Split(msg.Params[3], " ") { + for _, s := range strings.Split(members, " ") { membership, nick := parseMembershipPrefix(s) ch.Members[nick] = membership } case irc.RPL_ENDOFNAMES: - if len(msg.Params) < 2 { - return newNeedMoreParamsError(msg.Command) + var name string + if err := parseMessageParams(msg, nil, &name); err != nil { + return err } - ch, err := c.getChannel(msg.Params[1]) + ch, err := c.getChannel(name) if err != nil { return err } @@ -322,7 +325,6 @@ func (c *upstreamConn) readMessages() error { Command: "NICK", Params: []string{c.upstream.Nick}, } - c.messages <- &irc.Message{ Command: "USER", Params: []string{c.upstream.Username, "0", "*", c.upstream.Realname},