diff --git a/downstream.go b/downstream.go index 50b1769..7c0a690 100644 --- a/downstream.go +++ b/downstream.go @@ -309,6 +309,8 @@ type downstreamRegistration struct { networkID int64 negotiatingCaps bool + + authUsername string } func serverSASLMechanisms(srv *Server) []string { @@ -686,13 +688,6 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism)) } - if err == nil { - if username == "" { - panic(fmt.Errorf("username unset after SASL authentication")) - } - err = dc.setUser(username, clientName, networkName) - } - if err != nil { dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err) dc.endSASL(&irc.Message{ @@ -703,6 +698,11 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir break } + if username == "" { + panic(fmt.Errorf("username unset after SASL authentication")) + } + dc.setAuthUsername(username, clientName, networkName) + // Technically we should send RPL_LOGGEDIN here. However we use // RPL_LOGGEDIN to mirror the upstream connection status. Let's // see how many clients that breaks. See: @@ -721,7 +721,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir return err } - if dc.user == nil { + if dc.registration.authUsername == "" { return ircError{&irc.Message{ Command: "FAIL", Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"}, @@ -1247,28 +1247,10 @@ func unmarshalUsername(rawUsername string) (username, client, network string) { return username, client, network } -func (dc *downstreamConn) setUser(username, clientName, networkName string) error { - dc.user = dc.srv.getUser(username) - if dc.user == nil && dc.srv.Config().EnableUsersOnAuth { - ctx := context.TODO() - if _, err := dc.srv.db.GetUser(ctx, username); err != nil { - // Can't find the user in the DB -- try to create it - record := database.User{ - Username: username, - Enabled: true, - } - dc.user, err = dc.srv.createUser(ctx, &record) - if err != nil { - return fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err) - } - } - } - if dc.user == nil { - return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help") - } +func (dc *downstreamConn) setAuthUsername(username, clientName, networkName string) { dc.clientName = clientName + dc.registration.authUsername = username dc.registration.networkName = networkName - return nil } func (dc *downstreamConn) register(ctx context.Context) error { @@ -1286,7 +1268,7 @@ func (dc *downstreamConn) register(ctx context.Context) error { password := dc.registration.password dc.registration.password = "" - if dc.user == nil { + if dc.registration.authUsername == "" { if password == "" { if dc.caps.IsEnabled("sasl") { return ircError{&irc.Message{ @@ -1318,9 +1300,7 @@ func (dc *downstreamConn) register(ctx context.Context) error { }} } - if err := dc.setUser(username, clientName, networkName); err != nil { - return err - } + dc.setAuthUsername(username, clientName, networkName) } _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.registration.username) @@ -1343,8 +1323,8 @@ func (dc *downstreamConn) register(ctx context.Context) error { } dc.registered = true - dc.username = dc.user.Username - dc.logger.Printf("registration complete for user %q", dc.user.Username) + dc.username = dc.registration.authUsername + dc.logger.Printf("registration complete for user %q", dc.username) return nil } @@ -1421,10 +1401,15 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error { return nil } -func (dc *downstreamConn) welcome(ctx context.Context) error { - if dc.user == nil || !dc.registered { +func (dc *downstreamConn) welcome(ctx context.Context, user *user) error { + if !dc.registered { panic("tried to welcome an unregistered connection") } + if dc.user != nil { + panic("tried to welcome the same connection twice") + } + + dc.user = user remoteAddr := dc.conn.RemoteAddr().String() dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)} diff --git a/server.go b/server.go index d26fab1..092493b 100644 --- a/server.go +++ b/server.go @@ -476,11 +476,46 @@ func (s *Server) Handle(ic ircConn) { return } - dc.user.events <- eventDownstreamConnected{dc} - if err := dc.readMessages(dc.user.events); err != nil { + user, err := s.getOrCreateUser(context.TODO(), dc.registration.authUsername) + if err != nil { + dc.SendMessage(&irc.Message{ + Command: "ERROR", + Params: []string{"Internal server error"}, + }) + return + } + + user.events <- eventDownstreamConnected{dc} + if err := dc.readMessages(user.events); err != nil { dc.logger.Printf("%v", err) } - dc.user.events <- eventDownstreamDisconnected{dc} + user.events <- eventDownstreamDisconnected{dc} +} + +func (s *Server) getOrCreateUser(ctx context.Context, username string) (*user, error) { + user := s.getUser(username) + if user != nil { + return user, nil + } + + if _, err := s.db.GetUser(ctx, username); err == nil { + return nil, fmt.Errorf("user %q exists in the DB but hasn't been loaded by the bouncer -- a restart may help", username) + } + + if !s.Config().EnableUsersOnAuth { + return nil, fmt.Errorf("cannot find user %q in the DB", username) + } + + // Can't find the user in the DB -- try to create it + record := database.User{ + Username: username, + Enabled: true, + } + user, err := s.createUser(ctx, &record) + if err != nil { + return nil, fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err) + } + return user, nil } func (s *Server) HandleAdmin(ic ircConn) { diff --git a/user.go b/user.go index 771a4ad..b9997cf 100644 --- a/user.go +++ b/user.go @@ -737,7 +737,7 @@ func (u *user) run() { break } - if err := dc.welcome(ctx); err != nil { + if err := dc.welcome(ctx, u); err != nil { if ircErr, ok := err.(ircError); ok { msg := ircErr.Message.Copy() msg.Prefix = dc.srv.prefix()