From bbb5e79f59d930d813abe01a5e2cab892bbd890f Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 30 Apr 2020 16:10:39 +0200 Subject: [PATCH] Introduce permanentUpstreamCaps --- upstream.go | 139 ++++++++++++++++++++++++++++------------------------ 1 file changed, 76 insertions(+), 63 deletions(-) diff --git a/upstream.go b/upstream.go index cf49c57..3fd367c 100644 --- a/upstream.go +++ b/upstream.go @@ -15,6 +15,16 @@ import ( "gopkg.in/irc.v3" ) +// permanentUpstreamCaps is the static list of upstream capabilities always +// requested when supported. +var permanentUpstreamCaps = map[string]bool{ + "away-notify": true, + "batch": true, + "labeled-response": true, + "message-tags": true, + "server-time": true, +} + type upstreamChannel struct { Name string conn *upstreamConn @@ -1209,7 +1219,7 @@ func (uc *upstreamConn) handleSupportedCaps(capsStr string) { func (uc *upstreamConn) requestCaps() { var requestCaps []string - for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time", "away-notify"} { + for c := range permanentUpstreamCaps { if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] { requestCaps = append(requestCaps, c) } @@ -1219,12 +1229,72 @@ func (uc *upstreamConn) requestCaps() { requestCaps = append(requestCaps, "sasl") } - if len(requestCaps) > 0 { - uc.SendMessage(&irc.Message{ - Command: "CAP", - Params: []string{"REQ", strings.Join(requestCaps, " ")}, - }) + if len(requestCaps) == 0 { + return } + + uc.SendMessage(&irc.Message{ + Command: "CAP", + Params: []string{"REQ", strings.Join(requestCaps, " ")}, + }) +} + +func (uc *upstreamConn) requestSASL() bool { + if uc.network.SASL.Mechanism == "" { + return false + } + + v, ok := uc.supportedCaps["sasl"] + if !ok { + return false + } + if v != "" { + mechanisms := strings.Split(v, ",") + found := false + for _, mech := range mechanisms { + if strings.EqualFold(mech, uc.network.SASL.Mechanism) { + found = true + break + } + } + if !found { + return false + } + } + + return true +} + +func (uc *upstreamConn) handleCapAck(name string, ok bool) error { + uc.caps[name] = ok + + switch name { + case "sasl": + if !ok { + uc.logger.Printf("server refused to acknowledge the SASL capability") + return nil + } + + auth := &uc.network.SASL + switch auth.Mechanism { + case "PLAIN": + uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username) + uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password) + default: + return fmt.Errorf("unsupported SASL mechanism %q", name) + } + + uc.SendMessage(&irc.Message{ + Command: "AUTHENTICATE", + Params: []string{auth.Mechanism}, + }) + default: + if permanentUpstreamCaps[name] { + break + } + uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name) + } + return nil } func splitSpace(s string) []string { @@ -1290,63 +1360,6 @@ func (uc *upstreamConn) runUntilRegistered() error { return nil } -func (uc *upstreamConn) requestSASL() bool { - if uc.network.SASL.Mechanism == "" { - return false - } - - v, ok := uc.supportedCaps["sasl"] - if !ok { - return false - } - if v != "" { - mechanisms := strings.Split(v, ",") - found := false - for _, mech := range mechanisms { - if strings.EqualFold(mech, uc.network.SASL.Mechanism) { - found = true - break - } - } - if !found { - return false - } - } - - return true -} - -func (uc *upstreamConn) handleCapAck(name string, ok bool) error { - uc.caps[name] = ok - - switch name { - case "sasl": - if !ok { - uc.logger.Printf("server refused to acknowledge the SASL capability") - return nil - } - - auth := &uc.network.SASL - switch auth.Mechanism { - case "PLAIN": - uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username) - uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password) - default: - return fmt.Errorf("unsupported SASL mechanism %q", name) - } - - uc.SendMessage(&irc.Message{ - Command: "AUTHENTICATE", - Params: []string{auth.Mechanism}, - }) - case "message-tags", "labeled-response", "away-notify", "batch", "server-time": - // Nothing to do - default: - uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name) - } - return nil -} - func (uc *upstreamConn) readMessages(ch chan<- event) error { for { msg, err := uc.ReadMessage()