Introduce UserUpdateFunc

References: https://todo.sr.ht/~emersion/soju/206
This commit is contained in:
Simon Ser 2023-03-01 14:16:33 +01:00
parent 67335130b1
commit aecff32103
3 changed files with 56 additions and 49 deletions

View file

@ -1773,9 +1773,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
record.Nick = nick
err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &record)
} else {
record := dc.user.User
record.Nick = nick
err = dc.user.updateUser(ctx, &record)
err = dc.user.updateUser(ctx, func(record *database.User) error {
record.Nick = nick
return nil
})
}
if err != nil {
dc.logger.Printf("failed to update nick: %v", err)
@ -1840,9 +1841,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
_, err = dc.user.updateNetwork(ctx, &record)
}
} else {
record := dc.user.User
record.Realname = realname
err = dc.user.updateUser(ctx, &record)
err = dc.user.updateUser(ctx, func(record *database.User) error {
record.Realname = realname
return nil
})
}
if err != nil {

View file

@ -1066,23 +1066,6 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
ctx.print(fmt.Sprintf("updated user %q", username))
} else {
// copy the user record because we'll mutate it
record := ctx.user.User
if password != nil {
if err := record.SetPassword(*password); err != nil {
return err
}
}
if disablePassword {
record.Password = ""
}
if nick != nil {
record.Nick = *nick
}
if realname != nil {
record.Realname = *realname
}
if admin != nil {
return fmt.Errorf("cannot update -admin of own user")
}
@ -1090,7 +1073,24 @@ func handleUserUpdate(ctx *serviceContext, params []string) error {
return fmt.Errorf("cannot update -enabled of own user")
}
if err := ctx.user.updateUser(ctx, &record); err != nil {
err := ctx.user.updateUser(ctx, func(record *database.User) error {
if password != nil {
if err := record.SetPassword(*password); err != nil {
return err
}
}
if disablePassword {
record.Password = ""
}
if nick != nil {
record.Nick = *nick
}
if realname != nil {
record.Realname = *realname
}
return nil
})
if err != nil {
return err
}

55
user.go
View file

@ -23,6 +23,8 @@ import (
"git.sr.ht/~emersion/soju/msgstore"
)
type UserUpdateFunc func(record *database.User) error
type event interface{}
type eventUpstreamMessage struct {
@ -702,9 +704,11 @@ func (u *user) run() {
}
if !u.Enabled && u.srv.Config().EnableUsersOnAuth {
record := u.User
record.Enabled = true
if err := u.updateUser(ctx, &record); err != nil {
err := u.updateUser(ctx, func(record *database.User) error {
record.Enabled = true
return nil
})
if err != nil {
dc.logger.Printf("failed to enable user after successful authentication: %v", err)
}
}
@ -791,20 +795,18 @@ func (u *user) run() {
dc.SendMessage(msg)
}
case eventUserUpdate:
// copy the user record because we'll mutate it
record := u.User
if e.password != nil {
record.Password = *e.password
}
if e.admin != nil {
record.Admin = *e.admin
}
if e.enabled != nil {
record.Enabled = *e.enabled
}
e.done <- u.updateUser(context.TODO(), &record)
e.done <- u.updateUser(context.TODO(), func(record *database.User) error {
if e.password != nil {
record.Password = *e.password
}
if e.admin != nil {
record.Admin = *e.admin
}
if e.enabled != nil {
record.Enabled = *e.enabled
}
return nil
})
// If the password was updated, kill all downstream connections to
// force them to re-authenticate with the new credentials.
@ -1110,18 +1112,19 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
return nil
}
func (u *user) updateUser(ctx context.Context, record *database.User) error {
if u.ID != record.ID {
panic("ID mismatch when updating user")
func (u *user) updateUser(ctx context.Context, update UserUpdateFunc) error {
record := u.User // copy
if err := update(&record); err != nil {
return err
}
nickUpdated := u.Nick != record.Nick
realnameUpdated := u.Realname != record.Realname
enabledUpdated := u.Enabled != record.Enabled
if err := u.srv.db.StoreUser(ctx, record); err != nil {
if err := u.srv.db.StoreUser(ctx, &record); err != nil {
return fmt.Errorf("failed to update user %q: %v", u.Username, err)
}
u.User = *record
u.User = record
if nickUpdated {
for _, net := range u.networks {
@ -1264,9 +1267,11 @@ func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAd
}
func (u *user) bumpDownstreamInteractionTime(ctx context.Context) {
record := u.User
record.DownstreamInteractedAt = time.Now()
if err := u.updateUser(ctx, &record); err != nil {
err := u.updateUser(ctx, func(record *database.User) error {
record.DownstreamInteractedAt = time.Now()
return nil
})
if err != nil {
u.logger.Printf("failed to bump downstream interaction time: %v", err)
}
}