diff --git a/downstream.go b/downstream.go index 9b2e6b1..543559f 100644 --- a/downstream.go +++ b/downstream.go @@ -2,6 +2,7 @@ package soju import ( "crypto/tls" + "encoding/base64" "fmt" "io" "net" @@ -10,6 +11,7 @@ import ( "sync" "time" + "github.com/emersion/go-sasl" "golang.org/x/crypto/bcrypt" "gopkg.in/irc.v3" ) @@ -76,6 +78,8 @@ type downstreamConn struct { capVersion int caps map[string]bool + saslServer sasl.Server + lock sync.Mutex ourMessages map[*irc.Message]struct{} } @@ -342,6 +346,101 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil { return err } + case "AUTHENTICATE": + if !dc.caps["sasl"] { + return ircError{&irc.Message{ + Command: err_saslfail, + Params: []string{"*", "AUTHENTICATE requires the \"sasl\" capability to be enabled"}, + }} + } + if len(msg.Params) == 0 { + return ircError{&irc.Message{ + Command: err_saslfail, + Params: []string{"*", "Missing AUTHENTICATE argument"}, + }} + } + if dc.nick == "" { + return ircError{&irc.Message{ + Command: err_saslfail, + Params: []string{"*", "Expected NICK command before AUTHENTICATE"}, + }} + } + + var resp []byte + if dc.saslServer == nil { + mech := strings.ToUpper(msg.Params[0]) + switch mech { + case "PLAIN": + dc.saslServer = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { + return dc.authenticate(username, password) + })) + default: + return ircError{&irc.Message{ + Command: err_saslfail, + Params: []string{"*", fmt.Sprintf("Unsupported SASL mechanism %q", mech)}, + }} + } + } else if msg.Params[0] == "*" { + dc.saslServer = nil + return ircError{&irc.Message{ + Command: err_saslaborted, + Params: []string{"*", "SASL authentication aborted"}, + }} + } else if msg.Params[0] == "+" { + resp = nil + } else { + // TODO: multi-line messages + var err error + resp, err = base64.StdEncoding.DecodeString(msg.Params[0]) + if err != nil { + dc.saslServer = nil + return ircError{&irc.Message{ + Command: err_saslfail, + Params: []string{"*", "Invalid base64-encoded response"}, + }} + } + } + + challenge, done, err := dc.saslServer.Next(resp) + if err != nil { + dc.saslServer = nil + if ircErr, ok := err.(ircError); ok && ircErr.Message.Command == irc.ERR_PASSWDMISMATCH { + return ircError{&irc.Message{ + Command: err_saslfail, + Params: []string{"*", ircErr.Message.Params[1]}, + }} + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: err_saslfail, + Params: []string{"*", "SASL error"}, + }) + return fmt.Errorf("SASL authentication failed: %v", err) + } else if done { + dc.saslServer = nil + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_loggedin, + Params: []string{dc.nick, dc.nick, dc.user.Username, "You are now logged in"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_saslsuccess, + Params: []string{dc.nick, "SASL authentication successful"}, + }) + } else { + challengeStr := "+" + if challenge != nil { + challengeStr = base64.StdEncoding.EncodeToString(challenge) + } + + // TODO: multi-line messages + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "AUTHENTICATE", + Params: []string{challengeStr}, + }) + } default: dc.logger.Printf("unhandled message: %v", msg) return newUnknownCommandError(msg.Command) @@ -370,11 +469,11 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { } var caps []string - /*if dc.capVersion >= 302 { + if dc.capVersion >= 302 { caps = append(caps, "sasl=PLAIN") } else { caps = append(caps, "sasl") - }*/ + } // TODO: multi-line replies dc.SendMessage(&irc.Message{ @@ -421,8 +520,8 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { } switch name { - /*case "sasl": - dc.caps[name] = enable*/ + case "sasl": + dc.caps[name] = enable default: ack = false } @@ -457,19 +556,52 @@ func sanityCheckServer(addr string) error { return conn.Close() } -func (dc *downstreamConn) register() error { - username := dc.rawUsername - var networkName string +func unmarshalUsername(rawUsername string) (username, network string) { + username = rawUsername if i := strings.LastIndexAny(username, "/@"); i >= 0 { - networkName = username[i+1:] + network = username[i+1:] } if i := strings.IndexAny(username, "/@"); i >= 0 { username = username[:i] } - dc.username = "~" + username + return username, network +} - password := dc.password - dc.password = "" +func (dc *downstreamConn) setNetwork(networkName string) error { + if networkName == "" { + return nil + } + + network := dc.user.getNetwork(networkName) + if network == nil { + addr := networkName + if !strings.ContainsRune(addr, ':') { + addr = addr + ":6697" + } + + dc.logger.Printf("trying to connect to new network %q", addr) + if err := sanityCheckServer(addr); err != nil { + dc.logger.Printf("failed to connect to %q: %v", addr, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)}, + }} + } + + dc.logger.Printf("auto-saving network %q", networkName) + var err error + network, err = dc.user.createNetwork(networkName, dc.nick) + if err != nil { + return err + } + } + + dc.network = network + return nil +} + +func (dc *downstreamConn) authenticate(username, password string) error { + username, networkName := unmarshalUsername(username) u := dc.srv.getUser(username) if u == nil { @@ -483,40 +615,32 @@ func (dc *downstreamConn) register() error { return errAuthFailed } - var network *network - if networkName != "" { - network = u.getNetwork(networkName) - if network == nil { - addr := networkName - if !strings.ContainsRune(addr, ':') { - addr = addr + ":6697" - } + dc.user = u - dc.logger.Printf("trying to connect to new network %q", addr) - if err := sanityCheckServer(addr); err != nil { - dc.logger.Printf("failed to connect to %q: %v", addr, err) - return ircError{&irc.Message{ - Command: irc.ERR_PASSWDMISMATCH, - Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)}, - }} - } + return dc.setNetwork(networkName) +} - dc.logger.Printf("auto-saving network %q", networkName) - network, err = u.createNetwork(networkName, dc.nick) - if err != nil { - return err - } +func (dc *downstreamConn) register() error { + password := dc.password + dc.password = "" + if dc.user == nil { + if err := dc.authenticate(dc.rawUsername, password); err != nil { + return err + } + } else if dc.network == nil { + _, networkName := unmarshalUsername(dc.rawUsername) + if err := dc.setNetwork(networkName); err != nil { + return err } } dc.registered = true - dc.user = u - dc.network = network + dc.username = dc.user.Username - u.lock.Lock() - firstDownstream := len(u.downstreamConns) == 0 - u.downstreamConns = append(u.downstreamConns, dc) - u.lock.Unlock() + dc.user.lock.Lock() + firstDownstream := len(dc.user.downstreamConns) == 0 + dc.user.downstreamConns = append(dc.user.downstreamConns, dc) + dc.user.lock.Unlock() dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(),