From a9a066faac661fe17135ac33d1a0bea90b950ac8 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 13 Oct 2021 10:58:34 +0200 Subject: [PATCH] Add bouncer MOTD Closes: https://todo.sr.ht/~emersion/soju/137 --- cmd/soju/main.go | 23 ++++++++++++++++++++++- config/config.go | 5 +++++ doc/soju.1.scd | 7 ++++++- downstream.go | 23 ++++++++++++++++------- irc.go | 25 +++++++++++++++++++++++++ server.go | 14 +++++++++++++- 6 files changed, 87 insertions(+), 10 deletions(-) diff --git a/cmd/soju/main.go b/cmd/soju/main.go index 952ea78..950974d 100644 --- a/cmd/soju/main.go +++ b/cmd/soju/main.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "flag" "fmt" + "io/ioutil" "log" "net" "net/http" @@ -36,6 +37,19 @@ func (v *stringSliceFlag) Set(s string) error { return nil } +func loadMOTD(srv *soju.Server, filename string) error { + if filename == "" { + return nil + } + + b, err := ioutil.ReadFile(filename) + if err != nil { + return err + } + srv.SetMOTD(strings.TrimSpace(string(b))) + return nil +} + func main() { var listen []string var configPath string @@ -91,6 +105,10 @@ func main() { srv.MaxUserNetworks = cfg.MaxUserNetworks srv.Debug = debug + if err := loadMOTD(srv, cfg.MOTDPath); err != nil { + log.Fatalf("failed to load MOTD: %v", err) + } + for _, listen := range cfg.Listen { listenURI := listen if !strings.Contains(listenURI, ":/") { @@ -224,8 +242,8 @@ func main() { for sig := range sigCh { switch sig { case syscall.SIGHUP: + log.Print("reloading TLS certificate and MOTD") if cfg.TLS != nil { - log.Print("reloading TLS certificate") cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath) if err != nil { log.Printf("failed to reload TLS certificate and key: %v", err) @@ -233,6 +251,9 @@ func main() { } tlsCert.Store(&cert) } + if err := loadMOTD(srv, cfg.MOTDPath); err != nil { + log.Printf("failed to reload MOTD: %v", err) + } case syscall.SIGINT, syscall.SIGTERM: log.Print("shutting down server") srv.Shutdown() diff --git a/config/config.go b/config/config.go index 0c6870a..b8dc1fa 100644 --- a/config/config.go +++ b/config/config.go @@ -40,6 +40,7 @@ type Server struct { Listen []string Hostname string TLS *TLS + MOTDPath string SQLDriver string SQLSource string @@ -128,6 +129,10 @@ func parse(cfg scfg.Block) (*Server, error) { if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil { return nil, fmt.Errorf("directive %q: %v", d.Name, err) } + case "motd": + if err := d.ParseParams(&srv.MOTDPath); err != nil { + return nil, err + } default: return nil, fmt.Errorf("unknown directive %q", d.Name) } diff --git a/doc/soju.1.scd b/doc/soju.1.scd index ceba2e7..6c4a447 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -44,7 +44,8 @@ soju supports two connection modes: For per-client history to work, clients need to indicate their name. This can be done by adding a "@" suffix to the username. -soju will reload the TLS certificate and key when it receives the HUP signal. +soju will reload the TLS certificate/key and the MOTD file when it receives the +HUP signal. Administrators can broadcast a message to all bouncer users via _/notice $ _, or via _/notice $\* _ in multi-upstream mode. All @@ -142,6 +143,10 @@ The following directives are supported: *max-user-networks* Maximum number of networks per user. By default, there is no limit. +*motd* + Path to the MOTD file. The bouncer MOTD is sent to clients which aren't + bound to a specific network. By default, no MOTD is sent. + # IRC SERVICE soju exposes an IRC service called *BouncerServ* to manage the bouncer. diff --git a/downstream.go b/downstream.go index 631fa5c..5493add 100644 --- a/downstream.go +++ b/downstream.go @@ -1119,20 +1119,29 @@ func (dc *downstreamConn) welcome() error { for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) { dc.SendMessage(msg) } - motdHint := "No MOTD" if uc := dc.upstream(); uc != nil { - motdHint = "Use /motd to read the message of the day" dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_UMODEIS, Params: []string{dc.nick, string(uc.modes)}, }) } - dc.SendMessage(&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.ERR_NOMOTD, - Params: []string{dc.nick, motdHint}, - }) + + if motd := dc.user.srv.MOTD(); motd != "" && dc.network == nil { + for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) { + dc.SendMessage(msg) + } + } else { + motdHint := "No MOTD" + if dc.network != nil { + motdHint = "Use /motd to read the message of the day" + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_NOMOTD, + Params: []string{dc.nick, motdHint}, + }) + } dc.updateNick() dc.updateRealname() diff --git a/irc.go b/irc.go index e78d48f..3c1d5cf 100644 --- a/irc.go +++ b/irc.go @@ -379,6 +379,31 @@ func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.M return msgs } +func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message { + var msgs []*irc.Message + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_MOTDSTART, + Params: []string{nick, fmt.Sprintf("- Message of the Day -")}, + }) + + for _, l := range strings.Split(motd, "\n") { + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_MOTD, + Params: []string{nick, l}, + }) + } + + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_ENDOFMOTD, + Params: []string{nick, "End of /MOTD command."}, + }) + + return msgs +} + type joinSorter struct { channels []string keys []string diff --git a/server.go b/server.go index 9e7760c..bda5cae 100644 --- a/server.go +++ b/server.go @@ -63,10 +63,12 @@ type Server struct { lock sync.Mutex listeners map[net.Listener]struct{} users map[string]*user + + motd atomic.Value // string } func NewServer(db Database) *Server { - return &Server{ + srv := &Server{ Logger: log.New(log.Writer(), "", log.LstdFlags), HistoryLimit: 1000, MaxUserNetworks: -1, @@ -74,6 +76,8 @@ func NewServer(db Database) *Server { listeners: make(map[net.Listener]struct{}), users: make(map[string]*user), } + srv.motd.Store("") + return srv } func (s *Server) prefix() *irc.Prefix { @@ -268,3 +272,11 @@ func (s *Server) Stats() *ServerStats { stats.Downstreams = atomic.LoadInt64(&s.connCount) return &stats } + +func (s *Server) SetMOTD(motd string) { + s.motd.Store(motd) +} + +func (s *Server) MOTD() string { + return s.motd.Load().(string) +}