diff --git a/auth/auth.go b/auth/auth.go index 2ccf3d9..86137cd 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -7,11 +7,17 @@ import ( "git.sr.ht/~emersion/soju/database" ) +type Authenticator interface{} + type PlainAuthenticator interface { AuthPlain(ctx context.Context, db database.Database, username, password string) error } -func New(driver, source string) (PlainAuthenticator, error) { +type OAuthBearerAuthenticator interface { + AuthOAuthBearer(ctx context.Context, db database.Database, token string) (username string, err error) +} + +func New(driver, source string) (Authenticator, error) { switch driver { case "internal": return NewInternal(), nil diff --git a/auth/oauth2.go b/auth/oauth2.go index 29722b2..4485ebe 100644 --- a/auth/oauth2.go +++ b/auth/oauth2.go @@ -19,7 +19,12 @@ type oauth2 struct { clientSecret string } -func newOAuth2(authURL string) (PlainAuthenticator, error) { +var ( + _ PlainAuthenticator = (*oauth2)(nil) + _ OAuthBearerAuthenticator = (*oauth2)(nil) +) + +func newOAuth2(authURL string) (Authenticator, error) { ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) defer cancel() @@ -77,14 +82,27 @@ func newOAuth2(authURL string) (PlainAuthenticator, error) { } func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, username, password string) error { + effectiveUsername, err := auth.AuthOAuthBearer(ctx, db, password) + if err != nil { + return err + } + + if username != effectiveUsername { + return fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", effectiveUsername) + } + + return nil +} + +func (auth *oauth2) AuthOAuthBearer(ctx context.Context, db database.Database, token string) (username string, err error) { reqValues := make(url.Values) - reqValues.Set("token", password) + reqValues.Set("token", token) reqBody := strings.NewReader(reqValues.Encode()) req, err := http.NewRequestWithContext(ctx, http.MethodPost, auth.introspectionURL.String(), reqBody) if err != nil { - return fmt.Errorf("failed to create OAuth 2.0 introspection request: %v", err) + return "", fmt.Errorf("failed to create OAuth 2.0 introspection request: %v", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") @@ -95,32 +113,29 @@ func (auth *oauth2) AuthPlain(ctx context.Context, db database.Database, usernam resp, err := http.DefaultClient.Do(req) if err != nil { - return fmt.Errorf("failed to send OAuth 2.0 introspection request: %v", err) + return "", fmt.Errorf("failed to send OAuth 2.0 introspection request: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("OAuth 2.0 introspection error: %v", resp.Status) + return "", fmt.Errorf("OAuth 2.0 introspection error: %v", resp.Status) } var data oauth2Introspection if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { - return fmt.Errorf("failed to decode OAuth 2.0 introspection response: %v", err) + return "", fmt.Errorf("failed to decode OAuth 2.0 introspection response: %v", err) } if !data.Active { - return fmt.Errorf("invalid access token") + return "", fmt.Errorf("invalid access token") } if data.Username == "" { // We really need the username here, otherwise an OAuth 2.0 user can // impersonate any other user. - return fmt.Errorf("missing username in OAuth 2.0 introspection response") - } - if username != data.Username { - return fmt.Errorf("username mismatch (OAuth 2.0 server returned %q)", data.Username) + return "", fmt.Errorf("missing username in OAuth 2.0 introspection response") } - return nil + return data.Username, nil } type oauth2Introspection struct { diff --git a/downstream.go b/downstream.go index 579b056..d6983c6 100644 --- a/downstream.go +++ b/downstream.go @@ -17,6 +17,7 @@ import ( "github.com/emersion/go-sasl" "gopkg.in/irc.v4" + "git.sr.ht/~emersion/soju/auth" "git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/msgstore" "git.sr.ht/~emersion/soju/xirc" @@ -310,10 +311,16 @@ var passthroughIsupport = map[string]bool{ "WHOX": true, } +type saslPlain struct { + Username, Password string +} + type downstreamSASL struct { - server sasl.Server - plainUsername, plainPassword string - pendingResp bytes.Buffer + server sasl.Server + mechanism string + plain *saslPlain + oauthBearer *sasl.OAuthBearerOptions + pendingResp bytes.Buffer } type downstreamRegistration struct { @@ -327,6 +334,17 @@ type downstreamRegistration struct { negotiatingCaps bool } +func serverSASLMechanisms(srv *Server) []string { + var l []string + if _, ok := srv.Config().Auth.(auth.PlainAuthenticator); ok { + l = append(l, "PLAIN") + } + if _, ok := srv.Config().Auth.(auth.OAuthBearerAuthenticator); ok { + l = append(l, "OAUTHBEARER") + } + return l +} + type downstreamConn struct { conn @@ -379,7 +397,7 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { for k, v := range permanentDownstreamCaps { dc.caps.Available[k] = v } - dc.caps.Available["sasl"] = "PLAIN" + dc.caps.Available["sasl"] = strings.Join(serverSASLMechanisms(dc.srv), ",") // TODO: this is racy, we should only enable chathistory after // authentication and then check that user.msgStore implements // chatHistoryMessageStore @@ -659,8 +677,52 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir break } - if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil { - dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err) + var username, clientName, networkName string + switch credentials.mechanism { + case "PLAIN": + username, clientName, networkName = unmarshalUsername(credentials.plain.Username) + password := credentials.plain.Password + + auth, ok := dc.srv.Config().Auth.(auth.PlainAuthenticator) + if !ok { + err = fmt.Errorf("SASL PLAIN not supported") + break + } + + if authErr := auth.AuthPlain(ctx, dc.srv.db, username, password); authErr != nil { + err = newInvalidUsernameOrPasswordError(authErr) + break + } + case "OAUTHBEARER": + auth, ok := dc.srv.Config().Auth.(auth.OAuthBearerAuthenticator) + if !ok { + err = fmt.Errorf("SASL OAUTHBEARER not supported") + break + } + + var authErr error + username, authErr = auth.AuthOAuthBearer(ctx, dc.srv.db, credentials.oauthBearer.Token) + if authErr != nil { + err = newInvalidUsernameOrPasswordError(authErr) + break + } + + if credentials.oauthBearer.Username != "" && credentials.oauthBearer.Username != username { + err = newInvalidUsernameOrPasswordError(fmt.Errorf("username mismatch (server returned %q)", username)) + } + default: + panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism)) + } + + if err == nil { + if username == "" { + panic(fmt.Errorf("username unset after SASL authentication")) + } + err = dc.setUser(username, clientName, networkName) + } + + if err != nil { + dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err) dc.endSASL(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_SASLFAIL, @@ -878,8 +940,15 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d switch mech { case "PLAIN": server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { - dc.sasl.plainUsername = username - dc.sasl.plainPassword = password + dc.sasl.plain = &saslPlain{ + Username: username, + Password: password, + } + return nil + })) + case "OAUTHBEARER": + server = sasl.NewOAuthBearerServer(sasl.OAuthBearerAuthenticator(func(options sasl.OAuthBearerOptions) *sasl.OAuthBearerError { + dc.sasl.oauthBearer = &options return nil })) default: @@ -890,7 +959,7 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d }} } - dc.sasl = &downstreamSASL{server: server} + dc.sasl = &downstreamSASL{server: server, mechanism: mech} } else { chunk := msg.Params[0] if chunk == "+" { @@ -1189,13 +1258,7 @@ func unmarshalUsername(rawUsername string) (username, client, network string) { return username, client, network } -func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error { - username, clientName, networkName := unmarshalUsername(username) - - if err := dc.srv.Config().Auth.AuthPlain(ctx, dc.srv.db, username, password); err != nil { - return newInvalidUsernameOrPasswordError(err) - } - +func (dc *downstreamConn) setUser(username, clientName, networkName string) error { dc.user = dc.srv.getUser(username) if dc.user == nil { return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help") @@ -1205,6 +1268,21 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s return nil } +func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error { + username, clientName, networkName := unmarshalUsername(username) + + plainAuth, ok := dc.srv.Config().Auth.(auth.PlainAuthenticator) + if !ok { + return fmt.Errorf("PLAIN authentication unsupported") + } + + if err := plainAuth.AuthPlain(ctx, dc.srv.db, username, password); err != nil { + return newInvalidUsernameOrPasswordError(err) + } + + return dc.setUser(username, clientName, networkName) +} + func (dc *downstreamConn) register(ctx context.Context) error { if dc.registered { panic("tried to register twice") @@ -2420,6 +2498,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } if credentials != nil { + if credentials.mechanism != "PLAIN" { + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Unsupported SASL authentication mechanism"}, + }) + return nil + } + if uc.saslClient != nil { dc.endSASL(&irc.Message{ Prefix: dc.srv.prefix(), @@ -2429,8 +2516,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return nil } - uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername) - uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword) + uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plain.Username) + uc.saslClient = sasl.NewPlainClient("", credentials.plain.Username, credentials.plain.Password) uc.enqueueCommand(dc, &irc.Message{ Command: "AUTHENTICATE", Params: []string{"PLAIN"}, diff --git a/server.go b/server.go index 03b0489..3fac69b 100644 --- a/server.go +++ b/server.go @@ -144,7 +144,7 @@ type Config struct { UpstreamUserIPs []*net.IPNet DisableInactiveUsersDelay time.Duration EnableUsersOnAuth bool - Auth auth.PlainAuthenticator + Auth auth.Authenticator } type Server struct { diff --git a/upstream.go b/upstream.go index 19e024a..4cb3540 100644 --- a/upstream.go +++ b/upstream.go @@ -862,7 +862,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil { if msg.Command == irc.RPL_SASLSUCCESS { - uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword) + uc.network.autoSaveSASLPlain(ctx, dc.sasl.plain.Username, dc.sasl.plain.Password) } dc.endSASL(msg)