From f784b42346e1012358b2b6500822dccd5c5f38e9 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 14 Mar 2024 15:44:09 +0100 Subject: [PATCH] upstream: use round-robin DNS resolution when per-user IPs are set up The standard library doesn't distribute connections to different hosts. This causes issues for large deployments: the bouncer always connects to the same IRC server, even if an IRC network has multiple servers. This is disabled when per-user IPs are disabled, because our resolver implementation is very bare-bones and e.g. doesn't fallback to IPv4 when IPv6 is unavailable. Per-user IPs indicate a larger deployment and thus a need to spread the load. Closes: https://todo.sr.ht/~emersion/soju/221 --- upstream.go | 58 ++++++++++++++++++++++++++++++++++++++++++++--------- user.go | 18 +++-------------- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/upstream.go b/upstream.go index e3bdf4e..7253453 100644 --- a/upstream.go +++ b/upstream.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net" "strconv" "strings" @@ -373,17 +374,28 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er } func dialTCP(ctx context.Context, user *user, addr string) (net.Conn, error) { - host, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, err + var dialer net.Dialer + upstreamUserIPs := user.srv.Config().UpstreamUserIPs + if len(upstreamUserIPs) > 0 || true { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + ipAddr, err := resolveIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to resolve host %q: %v", host, err) + } + + localAddr, err := user.localTCPAddr(ipAddr.IP) + if err != nil { + return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) + } + + addr = net.JoinHostPort(ipAddr.String(), port) + dialer.LocalAddr = localAddr } - localAddr, err := user.localTCPAddrForHost(ctx, host) - if err != nil { - return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) - } - - dialer := net.Dialer{LocalAddr: localAddr} return dialer.DialContext(ctx, "tcp", addr) } @@ -2436,3 +2448,31 @@ func (uc *upstreamConn) shouldCacheUserInfo(nick string) bool { }) return found } + +// resolveIPAddr replaces the standard library's DNS resolver to randomize the +// result order instead of always returning the same IP address. The bouncer +// will often have bursts of connections to the same host (e.g. on startup) so +// it's more important for our use-case to distribute the traffic among +// available IP addresses than to find the fastest link. +// +// See: https://todo.sr.ht/~emersion/soju/221 +func resolveIPAddr(ctx context.Context, host string) (*net.IPAddr, error) { + ipAddrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + + // Prefer IPv6 if available, for per-user local IP addresses + ip6Addrs := make([]net.IPAddr, 0, len(ipAddrs)) + for _, ipAddr := range ipAddrs { + if ipAddr.IP.To4() == nil { + ip6Addrs = append(ip6Addrs, ipAddr) + } + } + if len(ip6Addrs) > 0 { + ipAddrs = ip6Addrs + } + + i := rand.Intn(len(ipAddrs)) + return &ipAddrs[i], nil +} diff --git a/user.go b/user.go index 588e015..448c739 100644 --- a/user.go +++ b/user.go @@ -1251,27 +1251,15 @@ func (u *user) FormatServerTime(t time.Time) string { return xirc.FormatServerTime(t) } -// localAddrForHost returns the local address to use when connecting to host. +// localTCPAddr returns the local address to use when connecting to a host. // A nil address is returned when the OS should automatically pick one. -func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) { +func (u *user) localTCPAddr(remoteIP net.IP) (*net.TCPAddr, error) { upstreamUserIPs := u.srv.Config().UpstreamUserIPs if len(upstreamUserIPs) == 0 { return nil, nil } - ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) - if err != nil { - return nil, err - } - - wantIPv6 := false - for _, ip := range ips { - if ip.To4() == nil { - wantIPv6 = true - break - } - } - + wantIPv6 := remoteIP.To4() == nil var ipNet *net.IPNet for _, in := range upstreamUserIPs { if wantIPv6 == (in.IP.To4() == nil) {