diff --git a/database/database.go b/database/database.go index f98a6d1..b81d646 100644 --- a/database/database.go +++ b/database/database.go @@ -8,6 +8,7 @@ import ( "time" "github.com/prometheus/client_golang/prometheus" + "golang.org/x/crypto/bcrypt" ) type Database interface { @@ -63,6 +64,29 @@ type User struct { Admin bool } +func (u *User) CheckPassword(password string) error { + // Password auth disabled + if u.Password == "" { + return fmt.Errorf("password auth disabled") + } + + err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)) + if err != nil { + return fmt.Errorf("wrong password: %v", err) + } + + return nil +} + +func (u *User) SetPassword(password string) error { + hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %v", err) + } + u.Password = string(hashed) + return nil +} + type SASL struct { Mechanism string diff --git a/downstream.go b/downstream.go index 575c7a2..0819913 100644 --- a/downstream.go +++ b/downstream.go @@ -14,7 +14,6 @@ import ( "time" "github.com/emersion/go-sasl" - "golang.org/x/crypto/bcrypt" "gopkg.in/irc.v3" "git.sr.ht/~emersion/soju/database" @@ -1304,14 +1303,8 @@ func (dc *downstreamConn) authenticate(ctx context.Context, username, password s return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err)) } - // Password auth disabled - if u.Password == "" { - return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled")) - } - - err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)) - if err != nil { - return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password")) + if err := u.CheckPassword(password); err != nil { + return newInvalidUsernameOrPasswordError(err) } dc.user = dc.srv.getUser(username) diff --git a/service.go b/service.go index 4acf94d..feba0fa 100644 --- a/service.go +++ b/service.go @@ -830,17 +830,14 @@ func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) return fmt.Errorf("flag -password is required") } - hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) - if err != nil { - return fmt.Errorf("failed to hash password: %v", err) - } - user := &database.User{ Username: *username, - Password: string(hashed), Realname: *realname, Admin: *admin, } + if err := user.SetPassword(*password); err != nil { + return err + } if _, err := dc.srv.createUser(ctx, user); err != nil { return fmt.Errorf("could not create user: %v", err) } @@ -872,16 +869,6 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) return fmt.Errorf("unexpected argument") } - var hashed *string - if password != nil { - hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) - if err != nil { - return fmt.Errorf("failed to hash password: %v", err) - } - hashedStr := string(hashedBytes) - hashed = &hashedStr - } - if username != "" && username != dc.user.Username { if !dc.user.Admin { return fmt.Errorf("you must be an admin to update other users") @@ -890,6 +877,16 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) return fmt.Errorf("cannot update -realname of other user") } + var hashed *string + if password != nil { + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %v", err) + } + hashedStr := string(hashedBytes) + hashed = &hashedStr + } + u := dc.srv.getUser(username) if u == nil { return fmt.Errorf("unknown username %q", username) @@ -916,8 +913,10 @@ func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) // copy the user record because we'll mutate it record := dc.user.User - if hashed != nil { - record.Password = *hashed + if password != nil { + if err := record.SetPassword(*password); err != nil { + return err + } } if realname != nil { record.Realname = *realname