From 394f2853ad472727e3a9ce9d0d7e02cedcd54872 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 29 Apr 2020 19:07:15 +0200 Subject: [PATCH] Add downstream support for cap-notify --- downstream.go | 100 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 88 insertions(+), 12 deletions(-) diff --git a/downstream.go b/downstream.go index 745db51..3221360 100644 --- a/downstream.go +++ b/downstream.go @@ -50,6 +50,17 @@ var errAuthFailed = ircError{&irc.Message{ Params: []string{"*", "Invalid username or password"}, }} +// permanentDownstreamCaps is the list of always-supported downstream +// capabilities. +var permanentDownstreamCaps = map[string]string{ + "batch": "", + "cap-notify": "", + "echo-message": "", + "message-tags": "", + "sasl": "PLAIN", + "server-time": "", +} + type downstreamConn struct { conn @@ -68,6 +79,7 @@ type downstreamConn struct { negociatingCaps bool capVersion int + supportedCaps map[string]string caps map[string]bool saslServer sasl.Server @@ -78,12 +90,16 @@ func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn dc := &downstreamConn{ conn: *newConn(srv, netConn, logger), id: id, + supportedCaps: make(map[string]string), caps: make(map[string]bool), } dc.hostname = netConn.RemoteAddr().String() if host, _, err := net.SplitHostPort(dc.hostname); err == nil { dc.hostname = host } + for k, v := range permanentDownstreamCaps { + dc.supportedCaps[k] = v + } return dc } @@ -439,12 +455,13 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { } } - caps := []string{"message-tags", "server-time", "echo-message", "batch"} - - if dc.capVersion >= 302 { - caps = append(caps, "sasl=PLAIN") - } else { - caps = append(caps, "sasl") + caps := make([]string, 0, len(dc.supportedCaps)) + for k, v := range dc.supportedCaps { + if dc.capVersion >= 302 && v != "" { + caps = append(caps, k + "=" + v) + } else { + caps = append(caps, k) + } } // TODO: multi-line replies @@ -454,6 +471,11 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { Params: []string{replyTo, "LS", strings.Join(caps, " ")}, }) + if dc.capVersion >= 302 { + // CAP version 302 implicitly enables cap-notify + dc.caps["cap-notify"] = true + } + if !dc.registered { dc.negociatingCaps = true } @@ -477,6 +499,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { }} } + // TODO: atomically ack/nak the whole capability set caps := strings.Fields(args[0]) ack := true for _, name := range caps { @@ -486,17 +509,23 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { name = strings.TrimPrefix(name, "-") } - enabled := dc.caps[name] - if enable == enabled { + if enable == dc.caps[name] { continue } - switch name { - case "sasl", "message-tags", "server-time", "echo-message", "batch": - dc.caps[name] = enable - default: + _, ok := dc.supportedCaps[name] + if !ok { ack = false + break } + + if name == "cap-notify" && dc.capVersion >= 302 && !enable { + // cap-notify cannot be disabled with CAP version 302 + ack = false + break + } + + dc.caps[name] = enable } reply := "NAK" @@ -519,6 +548,53 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { return nil } +func (dc *downstreamConn) setSupportedCap(name, value string) { + prevValue, hasPrev := dc.supportedCaps[name] + changed := !hasPrev || prevValue != value + dc.supportedCaps[name] = value + + if !dc.caps["cap-notify"] || !changed { + return + } + + replyTo := dc.nick + if !dc.registered { + replyTo = "*" + } + + cap := name + if value != "" && dc.capVersion >= 302 { + cap = name + "=" + value + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{replyTo, "NEW", cap}, + }) +} + +func (dc *downstreamConn) unsetSupportedCap(name string) { + _, hasPrev := dc.supportedCaps[name] + delete(dc.supportedCaps, name) + delete(dc.caps, name) + + if !dc.caps["cap-notify"] || !hasPrev { + return + } + + replyTo := dc.nick + if !dc.registered { + replyTo = "*" + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{replyTo, "DEL", name}, + }) +} + func sanityCheckServer(addr string) error { dialer := net.Dialer{Timeout: 30 * time.Second} conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)