soju/config/config.go
Simon Ser 706b6e33fb
Make db and log config options more future-proof
Rename the "sql" directive to "db". Rename the "log" directive to
"log fs".

In the future, we'll maybe support more databases and more message
stores. Make it so it's easy to integrate these new festures to the
config file format.
2021-04-21 18:15:04 +02:00

123 lines
2.4 KiB
Go

package config
import (
"fmt"
"net"
"os"
"git.sr.ht/~emersion/go-scfg"
)
type IPSet []*net.IPNet
func (set IPSet) Contains(ip net.IP) bool {
for _, n := range set {
if n.Contains(ip) {
return true
}
}
return false
}
// loopbackIPs contains the loopback networks 127.0.0.0/8 and ::1/128.
var loopbackIPs = IPSet{
&net.IPNet{
IP: net.IP{127, 0, 0, 0},
Mask: net.CIDRMask(8, 32),
},
&net.IPNet{
IP: net.IPv6loopback,
Mask: net.CIDRMask(128, 128),
},
}
type TLS struct {
CertPath, KeyPath string
}
type Server struct {
Listen []string
Hostname string
TLS *TLS
SQLDriver string
SQLSource string
LogPath string
HTTPOrigins []string
AcceptProxyIPs IPSet
}
func Defaults() *Server {
hostname, err := os.Hostname()
if err != nil {
hostname = "localhost"
}
return &Server{
Hostname: hostname,
SQLDriver: "sqlite3",
SQLSource: "soju.db",
}
}
func Load(path string) (*Server, error) {
cfg, err := scfg.Load(path)
if err != nil {
return nil, err
}
return parse(cfg)
}
func parse(cfg scfg.Block) (*Server, error) {
srv := Defaults()
for _, d := range cfg {
switch d.Name {
case "listen":
var uri string
if err := d.ParseParams(&uri); err != nil {
return nil, err
}
srv.Listen = append(srv.Listen, uri)
case "hostname":
if err := d.ParseParams(&srv.Hostname); err != nil {
return nil, err
}
case "tls":
tls := &TLS{}
if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil {
return nil, err
}
srv.TLS = tls
case "db":
if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
return nil, err
}
case "log":
var driver string
if err := d.ParseParams(&driver, &srv.LogPath); err != nil {
return nil, err
}
if driver != "fs" {
return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver)
}
case "http-origin":
srv.HTTPOrigins = d.Params
case "accept-proxy-ip":
srv.AcceptProxyIPs = nil
for _, s := range d.Params {
if s == "localhost" {
srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, loopbackIPs...)
continue
}
_, n, err := net.ParseCIDR(s)
if err != nil {
return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err)
}
srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, n)
}
default:
return nil, fmt.Errorf("unknown directive %q", d.Name)
}
}
return srv, nil
}