diff --git a/config/config.go b/config/config.go index 6e3707b..8f60791 100644 --- a/config/config.go +++ b/config/config.go @@ -109,145 +109,150 @@ func Defaults() *Server { } func Load(path string) (*Server, error) { - cfg, err := scfg.Load(path) + var raw struct { + Listen []struct { + Addr string `scfg:",param"` + } `scfg:"listen"` + Hostname string `scfg:"hostname"` + Title string `scfg:"title"` + MOTD string `scfg:"motd"` + TLS *[2]string `scfg:"tls"` + DB *[2]string `scfg:"db"` + MessageStore []string `scfg:"message-store"` + Log []string `scfg:"log"` + Auth []string `scfg:"auth"` + HTTPOrigin []string `scfg:"http-origin"` + AcceptProxyIP []string `scfg:"accept-proxy-ip"` + MaxUserNetworks int `scfg:"max-user-networks"` + UpstreamUserIP []string `scfg:"upstream-user-ip"` + DisableInactiveUser string `scfg:"disable-inactive-user"` + EnableUserOnAuth string `scfg:"enable-user-on-auth"` + } + + f, err := os.Open(path) if err != nil { return nil, err } - return parse(cfg) -} + defer f.Close() + + if err := scfg.NewDecoder(f).Decode(&raw); err != nil { + return nil, err + } -func parse(cfg scfg.Block) (*Server, error) { srv := Defaults() - for _, d := range cfg { - switch d.Name { - case "listen": - var uri string - if err := d.ParseParams(&uri); err != nil { - return nil, err - } - srv.Listen = append(srv.Listen, uri) - case "hostname": - if err := d.ParseParams(&srv.Hostname); err != nil { - return nil, err - } - case "title": - if err := d.ParseParams(&srv.Title); err != nil { - return nil, err - } - case "motd": - if err := d.ParseParams(&srv.MOTDPath); err != nil { - return nil, err - } - case "tls": - tls := &TLS{} - if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil { - return nil, err - } - srv.TLS = tls - case "db": - if err := d.ParseParams(&srv.DB.Driver, &srv.DB.Source); err != nil { - return nil, err - } - case "message-store", "log": - if err := d.ParseParams(&srv.MsgStore.Driver); err != nil { - return nil, err - } - switch srv.MsgStore.Driver { - case "memory", "db": - case "fs": - if err := d.ParseParams(nil, &srv.MsgStore.Source); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, srv.MsgStore.Driver) - } - case "auth": - if err := d.ParseParams(&srv.Auth.Driver); err != nil { - return nil, err - } - switch srv.Auth.Driver { - case "internal", "pam": - srv.Auth.Source = "" - case "oauth2": - if err := d.ParseParams(nil, &srv.Auth.Source); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, srv.Auth.Driver) - } - case "http-origin": - srv.HTTPOrigins = d.Params - case "accept-proxy-ip": - srv.AcceptProxyIPs = nil - for _, s := range d.Params { - if s == "localhost" { - srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, loopbackIPs...) - continue - } - _, n, err := net.ParseCIDR(s) - if err != nil { - return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err) - } - srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, n) - } - case "max-user-networks": - var max string - if err := d.ParseParams(&max); err != nil { - return nil, err - } - var err error - if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil { - return nil, fmt.Errorf("directive %q: %v", d.Name, err) - } - case "upstream-user-ip": - if len(srv.UpstreamUserIPs) > 0 { - return nil, fmt.Errorf("directive %q: can only be specified once", d.Name) - } - var hasIPv4, hasIPv6 bool - for _, s := range d.Params { - _, n, err := net.ParseCIDR(s) - if err != nil { - return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err) - } - if n.IP.To4() == nil { - if hasIPv6 { - return nil, fmt.Errorf("directive %q: found two IPv6 CIDRs", d.Name) - } - hasIPv6 = true - } else { - if hasIPv4 { - return nil, fmt.Errorf("directive %q: found two IPv4 CIDRs", d.Name) - } - hasIPv4 = true - } - srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n) - } - case "disable-inactive-user": - var durStr string - if err := d.ParseParams(&durStr); err != nil { - return nil, err - } - dur, err := parseDuration(durStr) - if err != nil { - return nil, fmt.Errorf("directive %q: %v", d.Name, err) - } else if dur < 0 { - return nil, fmt.Errorf("directive %q: duration must be positive", d.Name) - } - srv.DisableInactiveUsersDelay = dur - case "enable-user-on-auth": - var s string - if err := d.ParseParams(&s); err != nil { - return nil, err - } - b, err := strconv.ParseBool(s) - if err != nil { - return nil, fmt.Errorf("directive %q: %v", d.Name, err) - } - srv.EnableUsersOnAuth = b - default: - return nil, fmt.Errorf("unknown directive %q", d.Name) + + for _, listen := range raw.Listen { + srv.Listen = append(srv.Listen, listen.Addr) + } + if raw.Hostname != "" { + srv.Hostname = raw.Hostname + } + srv.Title = raw.Title + srv.MOTDPath = raw.MOTD + if raw.TLS != nil { + srv.TLS = &TLS{CertPath: raw.TLS[0], KeyPath: raw.TLS[1]} + } + if raw.DB != nil { + srv.DB = DB{Driver: raw.DB[0], Source: raw.DB[1]} + } + if raw.MessageStore == nil { + raw.MessageStore = raw.Log + } + if raw.MessageStore != nil { + driver, source, err := parseDriverSource("message-store", raw.MessageStore) + if err != nil { + return nil, err } + switch driver { + case "memory", "db": + // nothing to do + case "fs": + if source == "" { + return nil, fmt.Errorf("directive message-store: driver %q requires a source", driver) + } + default: + return nil, fmt.Errorf("directive message-store: unknown driver %q", driver) + } + srv.MsgStore = MsgStore{driver, source} + } + if raw.Auth != nil { + driver, source, err := parseDriverSource("auth", raw.Auth) + if err != nil { + return nil, err + } + switch driver { + case "internal", "pam": + // nothing to do + case "oauth2": + if source == "" { + return nil, fmt.Errorf("directive auth: driver %q requires a source", driver) + } + default: + return nil, fmt.Errorf("directive auth: unknown driver %q", driver) + } + srv.Auth = Auth{driver, source} + } + srv.HTTPOrigins = raw.HTTPOrigin + for _, s := range raw.AcceptProxyIP { + if s == "localhost" { + srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, loopbackIPs...) + continue + } + _, n, err := net.ParseCIDR(s) + if err != nil { + return nil, fmt.Errorf("directive accept-proxy-ip: failed to parse CIDR: %v", err) + } + srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, n) + } + srv.MaxUserNetworks = raw.MaxUserNetworks + var hasIPv4, hasIPv6 bool + for _, s := range raw.UpstreamUserIP { + _, n, err := net.ParseCIDR(s) + if err != nil { + return nil, fmt.Errorf("directive upstream-user-ip: failed to parse CIDR: %v", err) + } + if n.IP.To4() == nil { + if hasIPv6 { + return nil, fmt.Errorf("directive upstream-user-ip: found two IPv6 CIDRs") + } + hasIPv6 = true + } else { + if hasIPv4 { + return nil, fmt.Errorf("directive upstream-user-ip: found two IPv4 CIDRs") + } + hasIPv4 = true + } + srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n) + } + if raw.DisableInactiveUser != "" { + dur, err := parseDuration(raw.DisableInactiveUser) + if err != nil { + return nil, fmt.Errorf("directive disable-inactive-user: %v", err) + } else if dur < 0 { + return nil, fmt.Errorf("directive disable-inactive-user: duration must be positive") + } + srv.DisableInactiveUsersDelay = dur + } + if raw.EnableUserOnAuth != "" { + b, err := strconv.ParseBool(raw.EnableUserOnAuth) + if err != nil { + return nil, fmt.Errorf("directive enable-user-on-auth: %v", err) + } + srv.EnableUsersOnAuth = b } return srv, nil } + +func parseDriverSource(name string, params []string) (driver, source string, err error) { + switch len(params) { + case 2: + source = params[1] + fallthrough + case 1: + driver = params[0] + default: + err = fmt.Errorf("directive %v requires exactly 1 or 2 parameters", name) + } + return driver, source, err +} diff --git a/go.mod b/go.mod index ef04e9f..1f046ae 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module git.sr.ht/~emersion/soju go 1.19 require ( - git.sr.ht/~emersion/go-scfg v0.0.0-20231004133111-9dce55c8d63b + git.sr.ht/~emersion/go-scfg v0.0.0-20231211181832-0b4e72d8ec3c git.sr.ht/~emersion/go-sqlite3-fts5 v0.0.0-20230217131031-f2c8767594fc git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 github.com/SherClockHolmes/webpush-go v1.3.0 diff --git a/go.sum b/go.sum index fb97db7..f554b05 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -git.sr.ht/~emersion/go-scfg v0.0.0-20231004133111-9dce55c8d63b h1:Lf4oYBOJVmbYzrfqWfXUvSpXQPNMgnbN0efn5A7bH3M= -git.sr.ht/~emersion/go-scfg v0.0.0-20231004133111-9dce55c8d63b/go.mod h1:ybgvEJTIx5XbaspSviB3KNa6OdPmAZqDoSud7z8fFlw= +git.sr.ht/~emersion/go-scfg v0.0.0-20231211181832-0b4e72d8ec3c h1:Cjy9/qASF8hogbKbWXgEQZxbYHrM9ksl76sGzsP8Zqo= +git.sr.ht/~emersion/go-scfg v0.0.0-20231211181832-0b4e72d8ec3c/go.mod h1:ybgvEJTIx5XbaspSviB3KNa6OdPmAZqDoSud7z8fFlw= git.sr.ht/~emersion/go-sqlite3-fts5 v0.0.0-20230217131031-f2c8767594fc h1:+y3OijpLl4rgbFsqMBmYUTCsGCkxQUWpWaqfS8j9Ygc= git.sr.ht/~emersion/go-sqlite3-fts5 v0.0.0-20230217131031-f2c8767594fc/go.mod h1:PCl1xjl7iC6x35TKKubKRyo/3TT0dGI66jyNI6vmYnU= git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw=