Stop setting *user in downstreamConn.register

Set it in downstreamConn.welcome instead. Makes it clearer that it
must not be accessed before welcome is called (because it can only
be accessed from the user goroutine).
This commit is contained in:
Simon Ser 2023-04-05 16:54:55 +02:00
parent c5079f7ac3
commit f6043e5b98
3 changed files with 60 additions and 40 deletions

View file

@ -309,6 +309,8 @@ type downstreamRegistration struct {
networkID int64 networkID int64
negotiatingCaps bool negotiatingCaps bool
authUsername string
} }
func serverSASLMechanisms(srv *Server) []string { func serverSASLMechanisms(srv *Server) []string {
@ -686,13 +688,6 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism)) panic(fmt.Errorf("unexpected SASL mechanism %q", credentials.mechanism))
} }
if err == nil {
if username == "" {
panic(fmt.Errorf("username unset after SASL authentication"))
}
err = dc.setUser(username, clientName, networkName)
}
if err != nil { if err != nil {
dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err) dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err)
dc.endSASL(&irc.Message{ dc.endSASL(&irc.Message{
@ -703,6 +698,11 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
break break
} }
if username == "" {
panic(fmt.Errorf("username unset after SASL authentication"))
}
dc.setAuthUsername(username, clientName, networkName)
// Technically we should send RPL_LOGGEDIN here. However we use // Technically we should send RPL_LOGGEDIN here. However we use
// RPL_LOGGEDIN to mirror the upstream connection status. Let's // RPL_LOGGEDIN to mirror the upstream connection status. Let's
// see how many clients that breaks. See: // see how many clients that breaks. See:
@ -721,7 +721,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
return err return err
} }
if dc.user == nil { if dc.registration.authUsername == "" {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"}, Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"},
@ -1247,28 +1247,10 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
return username, client, network return username, client, network
} }
func (dc *downstreamConn) setUser(username, clientName, networkName string) error { func (dc *downstreamConn) setAuthUsername(username, clientName, networkName string) {
dc.user = dc.srv.getUser(username)
if dc.user == nil && dc.srv.Config().EnableUsersOnAuth {
ctx := context.TODO()
if _, err := dc.srv.db.GetUser(ctx, username); err != nil {
// Can't find the user in the DB -- try to create it
record := database.User{
Username: username,
Enabled: true,
}
dc.user, err = dc.srv.createUser(ctx, &record)
if err != nil {
return fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err)
}
}
}
if dc.user == nil {
return fmt.Errorf("user exists in the DB but hasn't been loaded by the bouncer -- a restart may help")
}
dc.clientName = clientName dc.clientName = clientName
dc.registration.authUsername = username
dc.registration.networkName = networkName dc.registration.networkName = networkName
return nil
} }
func (dc *downstreamConn) register(ctx context.Context) error { func (dc *downstreamConn) register(ctx context.Context) error {
@ -1286,7 +1268,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
password := dc.registration.password password := dc.registration.password
dc.registration.password = "" dc.registration.password = ""
if dc.user == nil { if dc.registration.authUsername == "" {
if password == "" { if password == "" {
if dc.caps.IsEnabled("sasl") { if dc.caps.IsEnabled("sasl") {
return ircError{&irc.Message{ return ircError{&irc.Message{
@ -1318,9 +1300,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
}} }}
} }
if err := dc.setUser(username, clientName, networkName); err != nil { dc.setAuthUsername(username, clientName, networkName)
return err
}
} }
_, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.registration.username) _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.registration.username)
@ -1343,8 +1323,8 @@ func (dc *downstreamConn) register(ctx context.Context) error {
} }
dc.registered = true dc.registered = true
dc.username = dc.user.Username dc.username = dc.registration.authUsername
dc.logger.Printf("registration complete for user %q", dc.user.Username) dc.logger.Printf("registration complete for user %q", dc.username)
return nil return nil
} }
@ -1421,10 +1401,15 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
return nil return nil
} }
func (dc *downstreamConn) welcome(ctx context.Context) error { func (dc *downstreamConn) welcome(ctx context.Context, user *user) error {
if dc.user == nil || !dc.registered { if !dc.registered {
panic("tried to welcome an unregistered connection") panic("tried to welcome an unregistered connection")
} }
if dc.user != nil {
panic("tried to welcome the same connection twice")
}
dc.user = user
remoteAddr := dc.conn.RemoteAddr().String() remoteAddr := dc.conn.RemoteAddr().String()
dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)} dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)}

View file

@ -476,11 +476,46 @@ func (s *Server) Handle(ic ircConn) {
return return
} }
dc.user.events <- eventDownstreamConnected{dc} user, err := s.getOrCreateUser(context.TODO(), dc.registration.authUsername)
if err := dc.readMessages(dc.user.events); err != nil { if err != nil {
dc.SendMessage(&irc.Message{
Command: "ERROR",
Params: []string{"Internal server error"},
})
return
}
user.events <- eventDownstreamConnected{dc}
if err := dc.readMessages(user.events); err != nil {
dc.logger.Printf("%v", err) dc.logger.Printf("%v", err)
} }
dc.user.events <- eventDownstreamDisconnected{dc} user.events <- eventDownstreamDisconnected{dc}
}
func (s *Server) getOrCreateUser(ctx context.Context, username string) (*user, error) {
user := s.getUser(username)
if user != nil {
return user, nil
}
if _, err := s.db.GetUser(ctx, username); err == nil {
return nil, fmt.Errorf("user %q exists in the DB but hasn't been loaded by the bouncer -- a restart may help", username)
}
if !s.Config().EnableUsersOnAuth {
return nil, fmt.Errorf("cannot find user %q in the DB", username)
}
// Can't find the user in the DB -- try to create it
record := database.User{
Username: username,
Enabled: true,
}
user, err := s.createUser(ctx, &record)
if err != nil {
return nil, fmt.Errorf("failed to automatically create user %q after successful authentication: %v", username, err)
}
return user, nil
} }
func (s *Server) HandleAdmin(ic ircConn) { func (s *Server) HandleAdmin(ic ircConn) {

View file

@ -737,7 +737,7 @@ func (u *user) run() {
break break
} }
if err := dc.welcome(ctx); err != nil { if err := dc.welcome(ctx, u); err != nil {
if ircErr, ok := err.(ircError); ok { if ircErr, ok := err.(ircError); ok {
msg := ircErr.Message.Copy() msg := ircErr.Message.Copy()
msg.Prefix = dc.srv.prefix() msg.Prefix = dc.srv.prefix()