From f6043e5b982cdfd7e156a30776d426162b3a0e55 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 5 Apr 2023 16:54:55 +0200 Subject: [PATCH] Stop setting *user in downstreamConn.register Set it in downstreamConn.welcome instead. Makes it clearer that it must not be accessed before welcome is called (because it can only be accessed from the user goroutine). --- downstream.go | 57 +++++++++++++++++++-------------------------------- server.go | 41 +++++++++++++++++++++++++++++++++--- user.go | 2 +- 3 files changed, 60 insertions(+), 40 deletions(-) diff --git a/downstream.go b/downstream.go index 50b1769..7c0a690 100644 --- a/downstream.go +++ b/downstream.go @@ -309,6 +309,8 @@ type downstreamRegistration struct { networkID int64 negotiatingCaps bool + + authUsername string } func serverSASLMechanisms(srv *Server) []string { @@ -686,13 +688,6 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism)) } - if err == nil { - if username == "" { - panic(fmt.Errorf("username unset after SASL authentication")) - } - err = dc.setUser(username, clientName, networkName) - } - if err != nil { dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err) dc.endSASL(&irc.Message{ @@ -703,6 +698,11 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir break } + if username == "" { + panic(fmt.Errorf("username unset after SASL authentication")) + } + dc.setAuthUsername(username, clientName, networkName) + // Technically we should send RPL_LOGGEDIN here. However we use // RPL_LOGGEDIN to mirror the upstream connection status. Let's // see how many clients that breaks. See: @@ -721,7 +721,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir return err } - if dc.user == nil { + if dc.registration.authUsername == "" { return ircError{&irc.Message{ Command: "FAIL", Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"}, @@ -1247,28 +1247,10 @@ func unmarshalUsername(rawUsername string) (username, client, network string) { return username, client, network } -func (dc *downstreamConn) setUser(username, clientName, networkName string) error { - dc.user = dc.srv.getUser(username) - if dc.user == nil && dc.srv.Config().EnableUsersOnAuth { - ctx := context.TODO() - if _, err := dc.srv.db.GetUser(ctx, username); err != nil { - // Can't find the user in the DB -- try to create it - record := database.User{ - Username: username, - Enabled: true, - } - dc.user, err = dc.srv.createUser(ctx, &record) - if err != nil { - return fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err) - } - } - } - if dc.user == nil { - return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help") - } +func (dc *downstreamConn) setAuthUsername(username, clientName, networkName string) { dc.clientName = clientName + dc.registration.authUsername = username dc.registration.networkName = networkName - return nil } func (dc *downstreamConn) register(ctx context.Context) error { @@ -1286,7 +1268,7 @@ func (dc *downstreamConn) register(ctx context.Context) error { password := dc.registration.password dc.registration.password = "" - if dc.user == nil { + if dc.registration.authUsername == "" { if password == "" { if dc.caps.IsEnabled("sasl") { return ircError{&irc.Message{ @@ -1318,9 +1300,7 @@ func (dc *downstreamConn) register(ctx context.Context) error { }} } - if err := dc.setUser(username, clientName, networkName); err != nil { - return err - } + dc.setAuthUsername(username, clientName, networkName) } _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.registration.username) @@ -1343,8 +1323,8 @@ func (dc *downstreamConn) register(ctx context.Context) error { } dc.registered = true - dc.username = dc.user.Username - dc.logger.Printf("registration complete for user %q", dc.user.Username) + dc.username = dc.registration.authUsername + dc.logger.Printf("registration complete for user %q", dc.username) return nil } @@ -1421,10 +1401,15 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error { return nil } -func (dc *downstreamConn) welcome(ctx context.Context) error { - if dc.user == nil || !dc.registered { +func (dc *downstreamConn) welcome(ctx context.Context, user *user) error { + if !dc.registered { panic("tried to welcome an unregistered connection") } + if dc.user != nil { + panic("tried to welcome the same connection twice") + } + + dc.user = user remoteAddr := dc.conn.RemoteAddr().String() dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)} diff --git a/server.go b/server.go index d26fab1..092493b 100644 --- a/server.go +++ b/server.go @@ -476,11 +476,46 @@ func (s *Server) Handle(ic ircConn) { return } - dc.user.events <- eventDownstreamConnected{dc} - if err := dc.readMessages(dc.user.events); err != nil { + user, err := s.getOrCreateUser(context.TODO(), dc.registration.authUsername) + if err != nil { + dc.SendMessage(&irc.Message{ + Command: "ERROR", + Params: []string{"Internal server error"}, + }) + return + } + + user.events <- eventDownstreamConnected{dc} + if err := dc.readMessages(user.events); err != nil { dc.logger.Printf("%v", err) } - dc.user.events <- eventDownstreamDisconnected{dc} + user.events <- eventDownstreamDisconnected{dc} +} + +func (s *Server) getOrCreateUser(ctx context.Context, username string) (*user, error) { + user := s.getUser(username) + if user != nil { + return user, nil + } + + if _, err := s.db.GetUser(ctx, username); err == nil { + return nil, fmt.Errorf("user %q exists in the DB but hasn't been loaded by the bouncer -- a restart may help", username) + } + + if !s.Config().EnableUsersOnAuth { + return nil, fmt.Errorf("cannot find user %q in the DB", username) + } + + // Can't find the user in the DB -- try to create it + record := database.User{ + Username: username, + Enabled: true, + } + user, err := s.createUser(ctx, &record) + if err != nil { + return nil, fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err) + } + return user, nil } func (s *Server) HandleAdmin(ic ircConn) { diff --git a/user.go b/user.go index 771a4ad..b9997cf 100644 --- a/user.go +++ b/user.go @@ -737,7 +737,7 @@ func (u *user) run() { break } - if err := dc.welcome(ctx); err != nil { + if err := dc.welcome(ctx, u); err != nil { if ircErr, ok := err.(ircError); ok { msg := ircErr.Message.Copy() msg.Prefix = dc.srv.prefix()