From d7d9d45b45ee08451052a244ed69514b3d9862a8 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 26 Jan 2023 18:33:55 +0100 Subject: [PATCH] Add a flag to disable users Add a new flag to disable users. This can be useful to temporarily deactivate an account without erasing data. The user goroutine is kept alive for simplicity's sake. Most of the infrastructure assumes that each user always has a running goroutine. A disabled user's goroutine is responsible for sending back an error to downstream connections, and listening for potential events to re-enable the account. --- cmd/sojuctl/main.go | 1 + contrib/znc-import/main.go | 2 +- database/database.go | 1 + database/postgres.go | 27 +++++++++++-------- database/sqlite.go | 28 +++++++++++++------- doc/soju.1.scd | 8 +++++- server_test.go | 6 ++++- service.go | 9 ++++++- user.go | 53 ++++++++++++++++++++++++++++---------- 9 files changed, 98 insertions(+), 37 deletions(-) diff --git a/cmd/sojuctl/main.go b/cmd/sojuctl/main.go index bde1d86..e727c72 100644 --- a/cmd/sojuctl/main.go +++ b/cmd/sojuctl/main.go @@ -78,6 +78,7 @@ func main() { Username: username, Password: string(hashed), Admin: *admin, + Enabled: true, } if err := db.StoreUser(ctx, &user); err != nil { log.Fatalf("failed to create user: %v", err) diff --git a/contrib/znc-import/main.go b/contrib/znc-import/main.go index 51284d6..b5f4adc 100644 --- a/contrib/znc-import/main.go +++ b/contrib/znc-import/main.go @@ -107,7 +107,7 @@ func main() { log.Printf("user %q: updating existing user", username) } else { // "!!" is an invalid crypt format, thus disables password auth - u = &database.User{Username: username, Password: "!!"} + u = &database.User{Username: username, Password: "!!", Enabled: true} usersCreated++ log.Printf("user %q: creating new user", username) } diff --git a/database/database.go b/database/database.go index 061d529..d26d1a7 100644 --- a/database/database.go +++ b/database/database.go @@ -71,6 +71,7 @@ type User struct { Nick string Realname string Admin bool + Enabled bool } func (u *User) CheckPassword(password string) (upgraded bool, err error) { diff --git a/database/postgres.go b/database/postgres.go index 2387229..f209bad 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -32,7 +32,8 @@ CREATE TABLE "User" ( admin BOOLEAN NOT NULL DEFAULT FALSE, nick VARCHAR(255), realname VARCHAR(255), - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now() + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), + enabled BOOLEAN NOT NULL DEFAULT TRUE ); CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); @@ -169,6 +170,7 @@ var postgresMigrations = []string{ `ALTER TABLE "Network" ADD COLUMN auto_away BOOLEAN NOT NULL DEFAULT TRUE`, `ALTER TABLE "Network" ADD COLUMN certfp TEXT`, `ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`, + `ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`, } type PostgresDB struct { @@ -302,7 +304,8 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { defer cancel() rows, err := db.db.QueryContext(ctx, - `SELECT id, username, password, admin, nick, realname FROM "User"`) + `SELECT id, username, password, admin, nick, realname, enabled + FROM "User"`) if err != nil { return nil, err } @@ -312,7 +315,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { for rows.Next() { var user User var password, nick, realname sql.NullString - if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname); err != nil { + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { return nil, err } user.Password = password.String @@ -335,9 +338,11 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro var password, nick, realname sql.NullString row := db.db.QueryRowContext(ctx, - `SELECT id, password, admin, nick, realname FROM "User" WHERE username = $1`, + `SELECT id, password, admin, nick, realname, enabled + FROM "User" + WHERE username = $1`, username) - if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname); err != nil { + if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { return nil, err } user.Password = password.String @@ -357,16 +362,16 @@ func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { var err error if user.ID == 0 { err = db.db.QueryRowContext(ctx, ` - INSERT INTO "User" (username, password, admin, nick, realname) - VALUES ($1, $2, $3, $4, $5) + INSERT INTO "User" (username, password, admin, nick, realname, enabled) + VALUES ($1, $2, $3, $4, $5, $6) RETURNING id`, - user.Username, password, user.Admin, nick, realname).Scan(&user.ID) + user.Username, password, user.Admin, nick, realname, user.Enabled).Scan(&user.ID) } else { _, err = db.db.ExecContext(ctx, ` UPDATE "User" - SET password = $1, admin = $2, nick = $3, realname = $4 - WHERE id = $5`, - password, user.Admin, nick, realname, user.ID) + SET password = $1, admin = $2, nick = $3, realname = $4, enabled = $5 + WHERE id = $6`, + password, user.Admin, nick, realname, user.Enabled, user.ID) } return err } diff --git a/database/sqlite.go b/database/sqlite.go index a2dbbc7..85a647c 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -62,7 +62,8 @@ CREATE TABLE User ( admin INTEGER NOT NULL DEFAULT 0, realname TEXT, nick TEXT, - created_at TEXT NOT NULL + created_at TEXT NOT NULL, + enabled INTEGER NOT NULL DEFAULT 1 ); CREATE TABLE Network ( @@ -289,6 +290,7 @@ var sqliteMigrations = []string{ ALTER TABLE User ADD COLUMN created_at TEXT NOT NULL DEFAULT ''; UPDATE User SET created_at = strftime('` + sqliteTimeFormat + `', 'now'); `, + "ALTER TABLE User ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1", } type SqliteDB struct { @@ -388,7 +390,8 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { defer cancel() rows, err := db.db.QueryContext(ctx, - "SELECT id, username, password, admin, nick, realname FROM User") + `SELECT id, username, password, admin, nick, realname, enabled + FROM User`) if err != nil { return nil, err } @@ -398,7 +401,7 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { for rows.Next() { var user User var password, nick, realname sql.NullString - if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname); err != nil { + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { return nil, err } user.Password = password.String @@ -421,9 +424,11 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) var password, nick, realname sql.NullString row := db.db.QueryRowContext(ctx, - "SELECT id, password, admin, nick, realname FROM User WHERE username = ?", + `SELECT id, password, admin, nick, realname, enabled + FROM User + WHERE username = ?`, username) - if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname); err != nil { + if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { return nil, err } user.Password = password.String @@ -442,21 +447,26 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { sql.Named("admin", user.Admin), sql.Named("nick", toNullString(user.Nick)), sql.Named("realname", toNullString(user.Realname)), + sql.Named("enabled", user.Enabled), sql.Named("now", sqliteTime{time.Now()}), } var err error if user.ID != 0 { _, err = db.db.ExecContext(ctx, ` - UPDATE User SET password = :password, admin = :admin, nick = :nick, - realname = :realname WHERE username = :username`, + UPDATE User + SET password = :password, admin = :admin, nick = :nick, + realname = :realname, enabled = :enabled + WHERE username = :username`, args...) } else { var res sql.Result res, err = db.db.ExecContext(ctx, ` INSERT INTO - User(username, password, admin, nick, realname, created_at) - VALUES (:username, :password, :admin, :nick, :realname, :now)`, + User(username, password, admin, nick, realname, created_at, + enabled) + VALUES (:username, :password, :admin, :nick, :realname, :now, + :enabled)`, args...) if err != nil { return err diff --git a/doc/soju.1.scd b/doc/soju.1.scd index 94c0e1e..add5527 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -434,6 +434,11 @@ character. Set the user's realname. This is used as a fallback if there is no realname set for a network. + *-enabled* true|false + Enable or disable the user. If the user is disabled, the bouncer will + not connect to any of their networks, and downstream connections will + be immediately closed. By default, users are enabled. + *user update* [username] [options...] Update a user. The options are the same as the _user create_ command. @@ -445,7 +450,8 @@ character. - The _-username_ flag is never valid, usernames are immutable. - The _-nick_ and _-realname_ flag are only valid when updating the current user. - - The _-admin_ flag is only valid when updating another user. + - The _-admin_ and _-enabled_ flags are only valid when updating another + user. *user delete* [confirmation token] Delete a soju user. diff --git a/server_test.go b/server_test.go index d92eab8..f2d17d0 100644 --- a/server_test.go +++ b/server_test.go @@ -51,7 +51,11 @@ func createTestUser(t *testing.T, db database.Database) *database.User { t.Fatalf("failed to generate bcrypt hash: %v", err) } - record := &database.User{Username: testUsername, Password: string(hashed)} + record := &database.User{ + Username: testUsername, + Password: string(hashed), + Enabled: true, + } if err := db.StoreUser(context.Background(), record); err != nil { t.Fatalf("failed to store test user: %v", err) } diff --git a/service.go b/service.go index 966a02b..115638e 100644 --- a/service.go +++ b/service.go @@ -920,6 +920,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error { nick := fs.String("nick", "", "") realname := fs.String("realname", "", "") admin := fs.Bool("admin", false, "") + enabled := fs.Bool("enabled", true, "") if err := fs.Parse(params); err != nil { return err @@ -939,6 +940,7 @@ func handleUserCreate(ctx *serviceContext, params []string) error { Nick: *nick, Realname: *realname, Admin: *admin, + Enabled: *enabled, } if err := user.SetPassword(*password); err != nil { return err @@ -960,12 +962,13 @@ func popArg(params []string) (string, []string) { func handleUserUpdate(ctx *serviceContext, params []string) error { var password, nick, realname *string - var admin *bool + var admin, enabled *bool fs := newFlagSet() fs.Var(stringPtrFlag{&password}, "password", "") fs.Var(stringPtrFlag{&nick}, "nick", "") fs.Var(stringPtrFlag{&realname}, "realname", "") fs.Var(boolPtrFlag{&admin}, "admin", "") + fs.Var(boolPtrFlag{&enabled}, "enabled", "") username, params := popArg(params) if err := fs.Parse(params); err != nil { @@ -1005,6 +1008,7 @@ func handleUserUpdate(ctx *serviceContext, params []string) error { event := eventUserUpdate{ password: hashed, admin: admin, + enabled: enabled, done: done, } select { @@ -1036,6 +1040,9 @@ func handleUserUpdate(ctx *serviceContext, params []string) error { if admin != nil { return fmt.Errorf("cannot update -admin of own user") } + if enabled != nil { + return fmt.Errorf("cannot update -enabled of own user") + } if err := ctx.user.updateUser(ctx, &record); err != nil { return err diff --git a/user.go b/user.go index 9cd699d..2b651d7 100644 --- a/user.go +++ b/user.go @@ -74,6 +74,7 @@ type eventStop struct{} type eventUserUpdate struct { password *string admin *bool + enabled *bool done chan error } @@ -246,7 +247,7 @@ func (net *network) runConn(ctx context.Context) error { } func (net *network) run() { - if !net.Enabled { + if !net.user.Enabled || !net.Enabled { return } @@ -687,6 +688,15 @@ func (u *user) run() { dc.monitored.SetCasemapping(dc.network.casemap) } + if !u.Enabled { + dc.SendMessage(&irc.Message{ + Command: "ERROR", + Params: []string{"This bouncer account is disabled"}, + }) + // TODO: close dc after the error message is sent + break + } + if err := dc.welcome(context.TODO()); err != nil { if ircErr, ok := err.(ircError); ok { msg := ircErr.Message.Copy() @@ -762,6 +772,9 @@ func (u *user) run() { if e.admin != nil { record.Admin = *e.admin } + if e.enabled != nil { + record.Enabled = *e.enabled + } e.done <- u.updateUser(context.TODO(), &record) @@ -1071,6 +1084,7 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error { nickUpdated := u.Nick != record.Nick realnameUpdated := u.Realname != record.Realname + enabledUpdated := u.Enabled != record.Enabled if err := u.srv.db.StoreUser(ctx, record); err != nil { return fmt.Errorf("failed to update user %q: %v", u.Username, err) } @@ -1091,22 +1105,28 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error { } } - if realnameUpdated { + if realnameUpdated || enabledUpdated { // Re-connect to networks which use the default realname var needUpdate []database.Network for _, net := range u.networks { - if net.Realname != "" { - continue - } + // If only the realname was updated, maybe we can skip the + // re-connect + if realnameUpdated && !enabledUpdated { + // If this network has a custom realname set, no need to + // re-connect: the user-wide realname remains unused + if net.Realname != "" { + continue + } - // We only need to call updateNetwork for upstreams that don't - // support setname - if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") { - uc.SendMessage(ctx, &irc.Message{ - Command: "SETNAME", - Params: []string{database.GetRealname(&u.User, &net.Network)}, - }) - continue + // We only need to call updateNetwork for upstreams that don't + // support setname + if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") { + uc.SendMessage(ctx, &irc.Message{ + Command: "SETNAME", + Params: []string{database.GetRealname(&u.User, &net.Network)}, + }) + continue + } } needUpdate = append(needUpdate, net.Network) @@ -1123,6 +1143,13 @@ func (u *user) updateUser(ctx context.Context, record *database.User) error { } } + if !u.Enabled { + // TODO: send an error message before disconnecting + for _, dc := range u.downstreamConns { + dc.Close() + } + } + return nil }