diff --git a/upstream.go b/upstream.go index 8dc71d5..dd3afce 100644 --- a/upstream.go +++ b/upstream.go @@ -153,6 +153,10 @@ type upstreamConn struct { // been sent yet. pendingCmds map[string][]pendingUpstreamCommand + pendingRegainNick string + regainNickTimer *time.Timer + regainNickBackoff *backoffer + gotMotd bool } @@ -869,6 +873,13 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err // Ignore the initial MOTD upon connection, but forward // subsequent MOTD messages downstream uc.gotMotd = true + + // If the server doesn't support MONITOR, periodically try to + // regain our desired nick + if _, ok := uc.isupport["MONITOR"]; !ok { + uc.startRegainNickTimer() + } + return nil } @@ -924,6 +935,11 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick) me = true uc.nick = newNick + + if uc.network.equalCasemap(uc.pendingRegainNick, newNick) { + uc.pendingRegainNick = "" + uc.stopRegainNickTimer() + } } uc.channels.ForEach(func(ch *upstreamChannel) { @@ -1773,6 +1789,18 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err }) return nil } + + var failedNick string + if err := parseMessageParams(msg, nil, &failedNick); err != nil { + return err + } + if uc.network.equalCasemap(uc.pendingRegainNick, failedNick) { + // This message comes from our own logic to try to regain our + // desired nick, don't relay to downstream connections + uc.pendingRegainNick = "" + return nil + } + fallthrough case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME, irc.ERR_NICKCOLLISION, irc.ERR_UNAVAILRESOURCE, irc.ERR_NOPERMFORHOST, irc.ERR_YOUREBANNEDCREEP: if !uc.registered { @@ -2238,3 +2266,67 @@ func (uc *upstreamConn) updateMonitor() { uc.monitored.Del(target) } } + +func (uc *upstreamConn) stopRegainNickTimer() { + if uc.regainNickTimer != nil { + uc.regainNickTimer.Stop() + // Maybe we're racing with the timer goroutine, so maybe we'll receive + // an eventTryRegainNick later on, but tryRegainNick handles that case + } + uc.regainNickTimer = nil + uc.regainNickBackoff = nil +} + +func (uc *upstreamConn) startRegainNickTimer() { + if uc.regainNickBackoff != nil || uc.regainNickTimer != nil { + panic("startRegainNickTimer called twice") + } + + wantNick := database.GetNick(&uc.user.User, &uc.network.Network) + if uc.isOurNick(wantNick) { + return + } + + const ( + min = 15 * time.Second + max = 10 * time.Minute + jitter = 10 * time.Second + ) + uc.regainNickBackoff = newBackoffer(min, max, jitter) + uc.regainNickTimer = time.AfterFunc(uc.regainNickBackoff.Next(), func() { + e := eventTryRegainNick{uc: uc, nick: wantNick} + select { + case uc.network.user.events <- e: + // ok + default: + uc.logger.Printf("skipping nick regain attempt: event queue is full") + } + }) +} + +func (uc *upstreamConn) tryRegainNick(nick string) { + ctx := context.TODO() + + if uc.regainNickTimer == nil { + return + } + + // Maybe the user has updated their desired nick + wantNick := database.GetNick(&uc.user.User, &uc.network.Network) + if wantNick != nick || uc.isOurNick(wantNick) { + uc.stopRegainNickTimer() + return + } + + uc.regainNickTimer.Reset(uc.regainNickBackoff.Next()) + + if uc.pendingRegainNick != "" { + return + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{wantNick}, + }) + uc.pendingRegainNick = wantNick +} diff --git a/user.go b/user.go index 7e22c45..8c072f5 100644 --- a/user.go +++ b/user.go @@ -75,6 +75,11 @@ type eventUserUpdate struct { done chan error } +type eventTryRegainNick struct { + uc *upstreamConn + nick string +} + type deliveredClientMap map[string]string // client name -> msg ID type deliveredStore struct { @@ -755,6 +760,8 @@ func (u *user) run() { dc.Close() } } + case eventTryRegainNick: + e.uc.tryRegainNick(e.nick) case eventStop: for _, dc := range u.downstreamConns { dc.Close() @@ -776,6 +783,7 @@ func (u *user) run() { func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { uc.network.conn = nil + uc.stopRegainNickTimer() uc.abortPendingCommands() uc.channels.ForEach(func(uch *upstreamChannel) {