Add context to connectToUpstream

This commit is contained in:
Simon Ser 2021-12-02 10:53:43 +01:00
parent 33a639ecf0
commit 73287f242e
2 changed files with 9 additions and 9 deletions

View file

@ -123,7 +123,7 @@ type upstreamConn struct {
gotMotd bool gotMotd bool
} }
func connectToUpstream(network *network) (*upstreamConn, error) { func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) {
logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())} logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())}
dialer := net.Dialer{Timeout: connectTimeout} dialer := net.Dialer{Timeout: connectTimeout}
@ -143,7 +143,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
addr = u.Host + ":6697" addr = u.Host + ":6697"
} }
dialer.LocalAddr, err = network.user.localTCPAddrForHost(host) dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
} }
@ -171,7 +171,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob)) logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
} }
netConn, err = dialer.Dial("tcp", addr) netConn, err = dialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err) return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
} }
@ -188,19 +188,19 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
addr = u.Host + ":6667" addr = u.Host + ":6667"
} }
dialer.LocalAddr, err = network.user.localTCPAddrForHost(host) dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
} }
logger.Printf("connecting to plain-text server at address %q", addr) logger.Printf("connecting to plain-text server at address %q", addr)
netConn, err = dialer.Dial("tcp", addr) netConn, err = dialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err) return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
} }
case "irc+unix", "unix": case "irc+unix", "unix":
logger.Printf("connecting to Unix socket at path %q", u.Path) logger.Printf("connecting to Unix socket at path %q", u.Path)
netConn, err = dialer.Dial("unix", u.Path) netConn, err = dialer.DialContext(ctx, "unix", u.Path)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err) return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err)
} }

View file

@ -202,7 +202,7 @@ func (net *network) run() {
} }
lastTry = time.Now() lastTry = time.Now()
uc, err := connectToUpstream(net) uc, err := connectToUpstream(context.TODO(), net)
if err != nil { if err != nil {
net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err) net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)} net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
@ -1015,13 +1015,13 @@ func (u *user) hasPersistentMsgStore() bool {
// localAddrForHost returns the local address to use when connecting to host. // localAddrForHost returns the local address to use when connecting to host.
// A nil address is returned when the OS should automatically pick one. // A nil address is returned when the OS should automatically pick one.
func (u *user) localTCPAddrForHost(host string) (*net.TCPAddr, error) { func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) {
upstreamUserIPs := u.srv.Config().UpstreamUserIPs upstreamUserIPs := u.srv.Config().UpstreamUserIPs
if len(upstreamUserIPs) == 0 { if len(upstreamUserIPs) == 0 {
return nil, nil return nil, nil
} }
ips, err := net.LookupIP(host) ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil { if err != nil {
return nil, err return nil, err
} }