Add context to connectToUpstream

This commit is contained in:
Simon Ser 2021-12-02 10:53:43 +01:00
parent 33a639ecf0
commit 73287f242e
2 changed files with 9 additions and 9 deletions

View file

@ -123,7 +123,7 @@ type upstreamConn struct {
gotMotd bool
}
func connectToUpstream(network *network) (*upstreamConn, error) {
func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) {
logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())}
dialer := net.Dialer{Timeout: connectTimeout}
@ -143,7 +143,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
addr = u.Host + ":6697"
}
dialer.LocalAddr, err = network.user.localTCPAddrForHost(host)
dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
if err != nil {
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
}
@ -171,7 +171,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
}
netConn, err = dialer.Dial("tcp", addr)
netConn, err = dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
}
@ -188,19 +188,19 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
addr = u.Host + ":6667"
}
dialer.LocalAddr, err = network.user.localTCPAddrForHost(host)
dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
if err != nil {
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
}
logger.Printf("connecting to plain-text server at address %q", addr)
netConn, err = dialer.Dial("tcp", addr)
netConn, err = dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
}
case "irc+unix", "unix":
logger.Printf("connecting to Unix socket at path %q", u.Path)
netConn, err = dialer.Dial("unix", u.Path)
netConn, err = dialer.DialContext(ctx, "unix", u.Path)
if err != nil {
return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err)
}

View file

@ -202,7 +202,7 @@ func (net *network) run() {
}
lastTry = time.Now()
uc, err := connectToUpstream(net)
uc, err := connectToUpstream(context.TODO(), net)
if err != nil {
net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
@ -1015,13 +1015,13 @@ func (u *user) hasPersistentMsgStore() bool {
// localAddrForHost returns the local address to use when connecting to host.
// A nil address is returned when the OS should automatically pick one.
func (u *user) localTCPAddrForHost(host string) (*net.TCPAddr, error) {
func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) {
upstreamUserIPs := u.srv.Config().UpstreamUserIPs
if len(upstreamUserIPs) == 0 {
return nil, nil
}
ips, err := net.LookupIP(host)
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil {
return nil, err
}