From aecff32103976fb6c70c501056eee33aaa85505d Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 1 Mar 2023 14:16:33 +0100 Subject: [PATCH] Introduce UserUpdateFunc References: https://todo.sr.ht/~emersion/soju/206 --- downstream.go | 14 +++++++------ service.go | 36 ++++++++++++++++----------------- user.go | 55 ++++++++++++++++++++++++++++----------------------- 3 files changed, 56 insertions(+), 49 deletions(-) diff --git a/downstream.go b/downstream.go index 515a901..baf5542 100644 --- a/downstream.go +++ b/downstream.go @@ -1773,9 +1773,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. record.Nick = nick err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &record) } else { - record := dc.user.User - record.Nick = nick - err = dc.user.updateUser(ctx, &record) + err = dc.user.updateUser(ctx, func(record *database.User) error { + record.Nick = nick + return nil + }) } if err != nil { dc.logger.Printf("failed to update nick: %v", err) @@ -1840,9 +1841,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. _, err = dc.user.updateNetwork(ctx, &record) } } else { - record := dc.user.User - record.Realname = realname - err = dc.user.updateUser(ctx, &record) + err = dc.user.updateUser(ctx, func(record *database.User) error { + record.Realname = realname + return nil + }) } if err != nil { diff --git a/service.go b/service.go index 45e8e2c..a1f01f1 100644 --- a/service.go +++ b/service.go @@ -1066,23 +1066,6 @@ func handleUserUpdate(ctx *serviceContext, params []string) error { ctx.print(fmt.Sprintf("updated user %q", username)) } else { - // copy the user record because we'll mutate it - record := ctx.user.User - - if password != nil { - if err := record.SetPassword(*password); err != nil { - return err - } - } - if disablePassword { - record.Password = "" - } - if nick != nil { - record.Nick = *nick - } - if realname != nil { - record.Realname = *realname - } if admin != nil { return fmt.Errorf("cannot update -admin of own user") } @@ -1090,7 +1073,24 @@ func handleUserUpdate(ctx *serviceContext, params []string) error { return fmt.Errorf("cannot update -enabled of own user") } - if err := ctx.user.updateUser(ctx, &record); err != nil { + err := ctx.user.updateUser(ctx, func(record *database.User) error { + if password != nil { + if err := record.SetPassword(*password); err != nil { + return err + } + } + if disablePassword { + record.Password = "" + } + if nick != nil { + record.Nick = *nick + } + if realname != nil { + record.Realname = *realname + } + return nil + }) + if err != nil { return err } diff --git a/user.go b/user.go index 5459aab..a619760 100644 --- a/user.go +++ b/user.go @@ -23,6 +23,8 @@ import ( "git.sr.ht/~emersion/soju/msgstore" ) +type UserUpdateFunc func(record *database.User) error + type event interface{} type eventUpstreamMessage struct { @@ -702,9 +704,11 @@ func (u *user) run() { } if !u.Enabled && u.srv.Config().EnableUsersOnAuth { - record := u.User - record.Enabled = true - if err := u.updateUser(ctx, &record); err != nil { + err := u.updateUser(ctx, func(record *database.User) error { + record.Enabled = true + return nil + }) + if err != nil { dc.logger.Printf("failed to enable user after successful authentication: %v", err) } } @@ -791,20 +795,18 @@ func (u *user) run() { dc.SendMessage(msg) } case eventUserUpdate: - // copy the user record because we'll mutate it - record := u.User - - if e.password != nil { - record.Password = *e.password - } - if e.admin != nil { - record.Admin = *e.admin - } - if e.enabled != nil { - record.Enabled = *e.enabled - } - - e.done <- u.updateUser(context.TODO(), &record) + e.done <- u.updateUser(context.TODO(), func(record *database.User) error { + if e.password != nil { + record.Password = *e.password + } + if e.admin != nil { + record.Admin = *e.admin + } + if e.enabled != nil { + record.Enabled = *e.enabled + } + return nil + }) // If the password was updated, kill all downstream connections to // force them to re-authenticate with the new credentials. @@ -1110,18 +1112,19 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error { return nil } -func (u *user) updateUser(ctx context.Context, record *database.User) error { - if u.ID != record.ID { - panic("ID mismatch when updating user") +func (u *user) updateUser(ctx context.Context, update UserUpdateFunc) error { + record := u.User // copy + if err := update(&record); err != nil { + return err } nickUpdated := u.Nick != record.Nick realnameUpdated := u.Realname != record.Realname enabledUpdated := u.Enabled != record.Enabled - if err := u.srv.db.StoreUser(ctx, record); err != nil { + if err := u.srv.db.StoreUser(ctx, &record); err != nil { return fmt.Errorf("failed to update user %q: %v", u.Username, err) } - u.User = *record + u.User = record if nickUpdated { for _, net := range u.networks { @@ -1264,9 +1267,11 @@ func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAd } func (u *user) bumpDownstreamInteractionTime(ctx context.Context) { - record := u.User - record.DownstreamInteractedAt = time.Now() - if err := u.updateUser(ctx, &record); err != nil { + err := u.updateUser(ctx, func(record *database.User) error { + record.DownstreamInteractedAt = time.Now() + return nil + }) + if err != nil { u.logger.Printf("failed to bump downstream interaction time: %v", err) } }