Add disable-inactive-user config option

This can be used to automatically disable users if they don't
actively use the bouncer for a while.
This commit is contained in:
Simon Ser 2023-01-26 16:57:07 +01:00
parent 57f5ee8d6f
commit 9df9880301
8 changed files with 211 additions and 18 deletions

View file

@ -84,14 +84,15 @@ func loadConfig() (*config.Server, *soju.Config, error) {
}
cfg := &soju.Config{
Hostname: raw.Hostname,
Title: raw.Title,
LogPath: raw.MsgStore.Source,
HTTPOrigins: raw.HTTPOrigins,
AcceptProxyIPs: raw.AcceptProxyIPs,
MaxUserNetworks: raw.MaxUserNetworks,
UpstreamUserIPs: raw.UpstreamUserIPs,
MOTD: motd,
Hostname: raw.Hostname,
Title: raw.Title,
LogPath: raw.MsgStore.Source,
HTTPOrigins: raw.HTTPOrigins,
AcceptProxyIPs: raw.AcceptProxyIPs,
MaxUserNetworks: raw.MaxUserNetworks,
UpstreamUserIPs: raw.UpstreamUserIPs,
DisableInactiveUsersDelay: raw.DisableInactiveUsersDelay,
MOTD: motd,
}
return raw, cfg, nil
}

View file

@ -5,6 +5,8 @@ import (
"net"
"os"
"strconv"
"strings"
"time"
"git.sr.ht/~emersion/go-scfg"
)
@ -32,6 +34,18 @@ var loopbackIPs = IPSet{
},
}
func parseDuration(s string) (time.Duration, error) {
if !strings.HasSuffix(s, "d") {
return 0, fmt.Errorf("missing 'd' suffix in duration")
}
s = strings.TrimSuffix(s, "d")
v, err := strconv.ParseFloat(s, 64)
if err != nil {
return 0, fmt.Errorf("invalid duration: %v", err)
}
return time.Duration(v * 24 * float64(time.Hour)), nil
}
type TLS struct {
CertPath, KeyPath string
}
@ -57,8 +71,9 @@ type Server struct {
HTTPOrigins []string
AcceptProxyIPs IPSet
MaxUserNetworks int
UpstreamUserIPs []*net.IPNet
MaxUserNetworks int
UpstreamUserIPs []*net.IPNet
DisableInactiveUsersDelay time.Duration
}
func Defaults() *Server {
@ -180,6 +195,18 @@ func parse(cfg scfg.Block) (*Server, error) {
}
srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n)
}
case "disable-inactive-user":
var durStr string
if err := d.ParseParams(&durStr); err != nil {
return nil, err
}
dur, err := parseDuration(durStr)
if err != nil {
return nil, fmt.Errorf("directive %q: %v", d.Name, err)
} else if dur < 0 {
return nil, fmt.Errorf("directive %q: duration must be positive", d.Name)
}
srv.DisableInactiveUsersDelay = dur
default:
return nil, fmt.Errorf("unknown directive %q", d.Name)
}

View file

@ -20,6 +20,7 @@ type Database interface {
GetUser(ctx context.Context, username string) (*User, error)
StoreUser(ctx context.Context, user *User) error
DeleteUser(ctx context.Context, id int64) error
ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error)
ListNetworks(ctx context.Context, userID int64) ([]Network, error)
StoreNetwork(ctx context.Context, userID int64, network *Network) error

View file

@ -354,6 +354,33 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro
return user, nil
}
func (db *PostgresDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx,
`SELECT username FROM "User" WHERE COALESCE(downstream_interacted_at, created_at) < $1`,
limit)
if err != nil {
return nil, err
}
defer rows.Close()
var usernames []string
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return nil, err
}
usernames = append(usernames, username)
}
if err := rows.Err(); err != nil {
return nil, err
}
return usernames, nil
}
func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()

View file

@ -445,6 +445,33 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
return user, nil
}
func (db *SqliteDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx,
"SELECT username FROM User WHERE coalesce(downstream_interacted_at, created_at) < ?",
sqliteTime{limit})
if err != nil {
return nil, err
}
defer rows.Close()
var usernames []string
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return nil, err
}
usernames = append(usernames, username)
}
if err := rows.Err(); err != nil {
return nil, err
}
return usernames, nil
}
func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()

View file

@ -170,6 +170,15 @@ The following directives are supported:
This can be useful to avoid having the whole bouncer banned from an upstream
network because of one malicious user.
*disable-inactive-user* <duration>
Disable inactive users after the specified duration.
A user is inactive when the last downstream connection is closed.
The duration is a positive decimal number followed by the unit "d" (days).
For instance, "30d" disables users 30 days after they last disconnect from
the bouncer.
# IRC SERVICE
soju exposes an IRC service called *BouncerServ* to manage the bouncer.

112
server.go
View file

@ -133,14 +133,15 @@ func (ln *retryListener) Accept() (net.Conn, error) {
}
type Config struct {
Hostname string
Title string
LogPath string
HTTPOrigins []string
AcceptProxyIPs config.IPSet
MaxUserNetworks int
MOTD string
UpstreamUserIPs []*net.IPNet
Hostname string
Title string
LogPath string
HTTPOrigins []string
AcceptProxyIPs config.IPSet
MaxUserNetworks int
MOTD string
UpstreamUserIPs []*net.IPNet
DisableInactiveUsersDelay time.Duration
}
type Server struct {
@ -151,6 +152,7 @@ type Server struct {
config atomic.Value // *Config
db database.Database
stopWG sync.WaitGroup
stopCh chan struct{}
lock sync.Mutex
listeners map[net.Listener]struct{}
@ -178,6 +180,7 @@ func NewServer(db database.Database) *Server {
db: db,
listeners: make(map[net.Listener]struct{}),
users: make(map[string]*user),
stopCh: make(chan struct{}),
}
srv.config.Store(&Config{
Hostname: "localhost",
@ -216,6 +219,12 @@ func (s *Server) Start() error {
}
s.lock.Unlock()
s.stopWG.Add(1)
go func() {
defer s.stopWG.Done()
s.disableInactiveUsersLoop()
}()
return nil
}
@ -343,6 +352,8 @@ func (s *Server) sendWebPush(ctx context.Context, sub *webpush.Subscription, vap
func (s *Server) Shutdown() {
s.Logger.Printf("shutting down server")
close(s.stopCh)
s.lock.Lock()
s.shutdown = true
for ln := range s.listeners {
@ -547,3 +558,88 @@ func (s *Server) Stats() *ServerStats {
stats.Upstreams = s.metrics.upstreams.Value()
return &stats
}
func (s *Server) disableInactiveUsersLoop() {
ticker := time.NewTicker(4 * time.Hour)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
}
if err := s.disableInactiveUsers(context.TODO()); err != nil {
s.Logger.Printf("failed to disable inactive users: %v", err)
}
}
}
func (s *Server) disableInactiveUsers(ctx context.Context) error {
delay := s.Config().DisableInactiveUsersDelay
if delay == 0 {
return nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
usernames, err := s.db.ListInactiveUsernames(ctx, time.Now().Add(-delay))
if err != nil {
return fmt.Errorf("failed to list inactive users: %v", err)
} else if len(usernames) == 0 {
return nil
}
// Filter out users with active downstream connections
var users []*user
s.lock.Lock()
for _, username := range usernames {
u := s.users[username]
if u == nil {
// TODO: disable the user in the DB
continue
}
if n := u.numDownstreamConns.Load(); n > 0 {
continue
}
users = append(users, u)
}
s.lock.Unlock()
if len(users) == 0 {
return nil
}
s.Logger.Printf("found %v inactive users", len(users))
for _, u := range users {
done := make(chan error, 1)
enabled := false
event := eventUserUpdate{
enabled: &enabled,
done: done,
}
select {
case <-ctx.Done():
return ctx.Err()
case u.events <- event:
// Event was sent, let's wait for the reply
}
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
if err != nil {
return err
} else {
s.Logger.Printf("deleted inactive user %q", u.Username)
}
}
}
return nil
}

View file

@ -11,6 +11,7 @@ import (
"net"
"sort"
"strings"
"sync/atomic"
"time"
"git.sr.ht/~emersion/soju/xirc"
@ -503,6 +504,8 @@ type user struct {
events chan event
done chan struct{}
numDownstreamConns atomic.Int64
networks []*network
downstreamConns []*downstreamConn
msgStore msgstore.Store
@ -715,6 +718,7 @@ func (u *user) run() {
}
u.downstreamConns = append(u.downstreamConns, dc)
u.numDownstreamConns.Add(1)
dc.forEachNetwork(func(network *network) {
if network.lastError != nil {
@ -734,6 +738,7 @@ func (u *user) run() {
for i := range u.downstreamConns {
if u.downstreamConns[i] == dc {
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
u.numDownstreamConns.Add(-1)
break
}
}