diff --git a/service.go b/service.go index 6ba7942..0947e44 100644 --- a/service.go +++ b/service.go @@ -775,19 +775,22 @@ func handleUserUpdate(dc *downstreamConn, params []string) error { return err } + // copy the user record because we'll mutate it + record := dc.user.User + if password != nil { hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("failed to hash password: %v", err) } - if err := dc.user.updatePassword(string(hashed)); err != nil { - return err - } + record.Password = string(hashed) } if realname != nil { - if err := dc.user.updateRealname(*realname); err != nil { - return err - } + record.Realname = *realname + } + + if err := dc.user.updateUser(&record); err != nil { + return err } sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username)) diff --git a/user.go b/user.go index 46aa140..f220490 100644 --- a/user.go +++ b/user.go @@ -856,33 +856,38 @@ func (u *user) deleteNetwork(id int64) error { return nil } -func (u *user) updatePassword(hashed string) error { - u.User.Password = hashed - return u.srv.db.StoreUser(&u.User) -} +func (u *user) updateUser(record *User) error { + if u.ID != record.ID { + panic("ID mismatch when updating user") + } -func (u *user) updateRealname(realname string) error { - u.User.Realname = realname - if err := u.srv.db.StoreUser(&u.User); err != nil { + realnameUpdated := u.Realname != record.Realname + if err := u.srv.db.StoreUser(record); err != nil { return fmt.Errorf("failed to update user %q: %v", u.Username, err) } + u.User = *record - // Re-connect to networks which use the default realname - var needUpdate []Network - u.forEachNetwork(func(net *network) { - if net.Realname == "" { - needUpdate = append(needUpdate, net.Network) + if realnameUpdated { + // Re-connect to networks which use the default realname + var needUpdate []Network + u.forEachNetwork(func(net *network) { + if net.Realname == "" { + needUpdate = append(needUpdate, net.Network) + } + }) + + var netErr error + for _, net := range needUpdate { + if _, err := u.updateNetwork(&net); err != nil { + netErr = err + } } - }) - - var netErr error - for _, net := range needUpdate { - if _, err := u.updateNetwork(&net); err != nil { - netErr = err + if netErr != nil { + return netErr } } - return netErr + return nil } func (u *user) stop() {