Allow most config options to be reloaded

Closes: https://todo.sr.ht/~emersion/soju/42
This commit is contained in:
Simon Ser 2021-11-16 00:38:04 +01:00
parent e44f4b2eee
commit 73295e4fa7
7 changed files with 111 additions and 98 deletions

View file

@ -37,19 +37,6 @@ func (v *stringSliceFlag) Set(s string) error {
return nil
}
func loadMOTD(srv *soju.Server, filename string) error {
if filename == "" {
return nil
}
b, err := ioutil.ReadFile(filename)
if err != nil {
return err
}
srv.SetMOTD(strings.TrimSuffix(string(b), "\n"))
return nil
}
func bumpOpenedFileLimit() error {
var rlimit syscall.Rlimit
if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
@ -62,24 +49,65 @@ func bumpOpenedFileLimit() error {
return nil
}
var (
configPath string
debug bool
tlsCert atomic.Value // *tls.Certificate
)
func loadConfig() (*config.Server, *soju.Config, error) {
var raw *config.Server
if configPath != "" {
var err error
raw, err = config.Load(configPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load config file: %v", err)
}
} else {
raw = config.Defaults()
}
var motd string
if raw.MOTDPath != "" {
b, err := ioutil.ReadFile(raw.MOTDPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load MOTD: %v", err)
}
motd = strings.TrimSuffix(string(b), "\n")
}
if raw.TLS != nil {
cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err)
}
tlsCert.Store(&cert)
}
cfg := &soju.Config{
Hostname: raw.Hostname,
Title: raw.Title,
LogPath: raw.LogPath,
HTTPOrigins: raw.HTTPOrigins,
AcceptProxyIPs: raw.AcceptProxyIPs,
MaxUserNetworks: raw.MaxUserNetworks,
Debug: debug,
MOTD: motd,
}
return raw, cfg, nil
}
func main() {
var listen []string
var configPath string
var debug bool
flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
flag.StringVar(&configPath, "config", "", "path to configuration file")
flag.BoolVar(&debug, "debug", false, "enable debug logging")
flag.Parse()
var cfg *config.Server
if configPath != "" {
var err error
cfg, err = config.Load(configPath)
cfg, serverCfg, err := loadConfig()
if err != nil {
log.Fatalf("failed to load config file: %v", err)
}
} else {
cfg = config.Defaults()
log.Fatal(err)
}
cfg.Listen = append(cfg.Listen, listen...)
@ -97,14 +125,7 @@ func main() {
}
var tlsCfg *tls.Config
var tlsCert atomic.Value
if cfg.TLS != nil {
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
if err != nil {
log.Fatalf("failed to load TLS certificate and key: %v", err)
}
tlsCert.Store(&cert)
tlsCfg = &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return tlsCert.Load().(*tls.Certificate), nil
@ -113,17 +134,7 @@ func main() {
}
srv := soju.NewServer(db)
srv.Hostname = cfg.Hostname
srv.Title = cfg.Title
srv.LogPath = cfg.LogPath
srv.HTTPOrigins = cfg.HTTPOrigins
srv.AcceptProxyIPs = cfg.AcceptProxyIPs
srv.MaxUserNetworks = cfg.MaxUserNetworks
srv.Debug = debug
if err := loadMOTD(srv, cfg.MOTDPath); err != nil {
log.Fatalf("failed to load MOTD: %v", err)
}
srv.SetConfig(serverCfg)
for _, listen := range cfg.Listen {
listenURI := listen
@ -258,17 +269,12 @@ func main() {
for sig := range sigCh {
switch sig {
case syscall.SIGHUP:
log.Print("reloading TLS certificate and MOTD")
if cfg.TLS != nil {
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
log.Print("reloading configuration")
_, serverCfg, err := loadConfig()
if err != nil {
log.Printf("failed to reload TLS certificate and key: %v", err)
break
}
tlsCert.Store(&cert)
}
if err := loadMOTD(srv, cfg.MOTDPath); err != nil {
log.Printf("failed to reload MOTD: %v", err)
log.Printf("failed to reloading configuration: %v", err)
} else {
srv.SetConfig(serverCfg)
}
case syscall.SIGINT, syscall.SIGTERM:
log.Print("shutting down server")
@ -286,7 +292,7 @@ func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
if !ok {
return proxyproto.IGNORE, nil
}
if srv.AcceptProxyIPs.Contains(tcpAddr.IP) {
if srv.Config().AcceptProxyIPs.Contains(tcpAddr.IP) {
return proxyproto.USE, nil
}
return proxyproto.IGNORE, nil

View file

@ -195,7 +195,7 @@ func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
<-rl.C
}
if c.srv.Debug {
if c.srv.Config().Debug {
c.logger.Printf("sent: %v", msg)
}
c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
@ -248,7 +248,7 @@ func (c *conn) ReadMessage() (*irc.Message, error) {
return nil, err
}
if c.srv.Debug {
if c.srv.Config().Debug {
c.logger.Printf("received: %v", msg)
}

View file

@ -44,8 +44,9 @@ soju supports two connection modes:
For per-client history to work, clients need to indicate their name. This can
be done by adding a "@<client>" suffix to the username.
soju will reload the TLS certificate/key and the MOTD file when it receives the
HUP signal.
soju will reload the configuration file, the TLS certificate/key and the MOTD
file when it receives the HUP signal. The configuration options _listen_, _db_
and _log_ cannot be reloaded.
Administrators can broadcast a message to all bouncer users via _/notice
$<hostname> <text>_, or via _/notice $\* <text>_ in multi-upstream mode. All

View file

@ -290,7 +290,10 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
for k, v := range permanentDownstreamCaps {
dc.supportedCaps[k] = v
}
if srv.LogPath != "" {
// TODO: this is racy, we should only enable chathistory after
// authentication and then check that user.msgStore implements
// chatHistoryMessageStore
if srv.Config().LogPath != "" {
dc.supportedCaps["draft/chathistory"] = ""
}
return dc
@ -996,7 +999,7 @@ func (dc *downstreamConn) updateSupportedCaps() {
}
}
if dc.srv.LogPath != "" && dc.network != nil {
if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil {
dc.setSupportedCap("draft/event-playback", "")
} else {
dc.unsetSupportedCap("draft/event-playback")
@ -1175,8 +1178,8 @@ func (dc *downstreamConn) welcome() error {
if dc.network != nil {
isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID))
}
if dc.network == nil && dc.srv.Title != "" {
isupport = append(isupport, "NETWORK="+encodeISUPPORT(dc.srv.Title))
if title := dc.srv.Config().Title; dc.network == nil && title != "" {
isupport = append(isupport, "NETWORK="+encodeISUPPORT(title))
}
if dc.network == nil && dc.caps["soju.im/bouncer-networks"] {
isupport = append(isupport, "WHOX")
@ -1204,12 +1207,12 @@ func (dc *downstreamConn) welcome() error {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_YOURHOST,
Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname},
})
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_MYINFO,
Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
Params: []string{dc.nick, dc.srv.Config().Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
})
for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) {
dc.SendMessage(msg)
@ -1229,7 +1232,7 @@ func (dc *downstreamConn) welcome() error {
})
}
if motd := dc.user.srv.MOTD(); motd != "" && dc.network == nil {
if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil {
for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) {
dc.SendMessage(msg)
}
@ -1420,7 +1423,8 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
if len(msg.Params) > 1 {
destination = msg.Params[1]
}
if destination != "" && destination != dc.srv.Hostname {
hostname := dc.srv.Config().Hostname
if destination != "" && destination != hostname {
return ircError{&irc.Message{
Command: irc.ERR_NOSUCHSERVER,
Params: []string{dc.nick, destination, "No such server"},
@ -1429,7 +1433,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: "PONG",
Params: []string{dc.srv.Hostname, source},
Params: []string{hostname, source},
})
return nil
case "PONG":
@ -1946,7 +1950,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Token: whoxToken,
Username: dc.user.Username,
Hostname: dc.hostname,
Server: dc.srv.Hostname,
Server: dc.srv.Config().Hostname,
Nickname: dc.nick,
Flags: flags,
Account: dc.user.Username,
@ -1965,7 +1969,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Token: whoxToken,
Username: servicePrefix.User,
Hostname: servicePrefix.Host,
Server: dc.srv.Hostname,
Server: dc.srv.Config().Hostname,
Nickname: serviceNick,
Flags: "H*",
Account: serviceNick,
@ -2025,7 +2029,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOISSERVER,
Params: []string{dc.nick, dc.nick, dc.srv.Hostname, "soju"},
Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "soju"},
})
if dc.user.Admin {
dc.SendMessage(&irc.Message{
@ -2055,7 +2059,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOISSERVER,
Params: []string{dc.nick, serviceNick, dc.srv.Hostname, "soju"},
Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "soju"},
})
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
@ -2104,7 +2108,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
tags := copyClientTags(msg.Tags)
for _, name := range strings.Split(targetsStr, ",") {
if name == "$"+dc.srv.Hostname || (name == "$*" && dc.network == nil) {
if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) {
// "$" means a server mask follows. If it's the bouncer's
// hostname, broadcast the message to all bouncer users.
if !dc.user.Admin {

View file

@ -53,17 +53,22 @@ func (l *prefixLogger) Printf(format string, v ...interface{}) {
l.logger.Printf("%v"+format, v...)
}
type Server struct {
type Config struct {
Hostname string
Title string
Logger Logger
LogPath string
Debug bool
HTTPOrigins []string
AcceptProxyIPs config.IPSet
MaxUserNetworks int
MOTD string
}
type Server struct {
Logger Logger
Identd *Identd // can be nil
config atomic.Value // *Config
db Database
stopWG sync.WaitGroup
connCount int64 // atomic
@ -71,24 +76,29 @@ type Server struct {
lock sync.Mutex
listeners map[net.Listener]struct{}
users map[string]*user
motd atomic.Value // string
}
func NewServer(db Database) *Server {
srv := &Server{
Logger: log.New(log.Writer(), "", log.LstdFlags),
MaxUserNetworks: -1,
db: db,
listeners: make(map[net.Listener]struct{}),
users: make(map[string]*user),
}
srv.motd.Store("")
srv.config.Store(&Config{Hostname: "localhost", MaxUserNetworks: -1})
return srv
}
func (s *Server) prefix() *irc.Prefix {
return &irc.Prefix{Name: s.Hostname}
return &irc.Prefix{Name: s.Config().Hostname}
}
func (s *Server) Config() *Config {
return s.config.Load().(*Config)
}
func (s *Server) SetConfig(cfg *Config) {
s.config.Store(cfg)
}
func (s *Server) Start() error {
@ -239,7 +249,7 @@ func (s *Server) Serve(ln net.Listener) error {
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
OriginPatterns: s.HTTPOrigins,
OriginPatterns: s.Config().HTTPOrigins,
})
if err != nil {
s.Logger.Printf("failed to serve HTTP connection: %v", err)
@ -249,7 +259,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
isProxy := false
if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if ip := net.ParseIP(host); ip != nil {
isProxy = s.AcceptProxyIPs.Contains(ip)
isProxy = s.Config().AcceptProxyIPs.Contains(ip)
}
}
@ -293,11 +303,3 @@ func (s *Server) Stats() *ServerStats {
stats.Downstreams = atomic.LoadInt64(&s.connCount)
return &stats
}
func (s *Server) SetMOTD(motd string) {
s.motd.Store(motd)
}
func (s *Server) MOTD() string {
return s.motd.Load().(string)
}

View file

@ -1050,7 +1050,7 @@ func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params [
broadcastMsg := &irc.Message{
Prefix: servicePrefix,
Command: "NOTICE",
Params: []string{"$" + dc.srv.Hostname, text},
Params: []string{"$" + dc.srv.Config().Hostname, text},
}
var err error
dc.srv.forEachUser(func(u *user) {

View file

@ -415,8 +415,8 @@ func newUser(srv *Server, record *User) *user {
logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
var msgStore messageStore
if srv.LogPath != "" {
msgStore = newFSMessageStore(srv.LogPath, record.Username)
if logPath := srv.Config().LogPath; logPath != "" {
msgStore = newFSMessageStore(logPath, record.Username)
} else {
msgStore = newMemoryMessageStore()
}
@ -776,7 +776,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er
return nil, err
}
if u.srv.MaxUserNetworks >= 0 && len(u.networks) >= u.srv.MaxUserNetworks {
if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
return nil, fmt.Errorf("maximum number of networks reached")
}