diff --git a/downstream.go b/downstream.go index cce1305..54772df 100644 --- a/downstream.go +++ b/downstream.go @@ -614,8 +614,8 @@ func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Me return msg } -func (dc *downstreamConn) handleMessage(msg *irc.Message) error { - ctx, cancel := dc.conn.NewContext(context.TODO()) +func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error { + ctx, cancel := dc.conn.NewContext(ctx) defer cancel() ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout) @@ -1401,13 +1401,29 @@ func (dc *downstreamConn) relayDetachedMessage(net *network, msg *irc.Message) { } func (dc *downstreamConn) runUntilRegistered() error { + ctx, cancel := context.WithTimeout(context.TODO(), downstreamRegisterTimeout) + defer cancel() + + // Close the connection with an error if the deadline is exceeded + go func() { + <-ctx.Done() + if err := ctx.Err(); err == context.DeadlineExceeded { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "ERROR", + Params: []string{"Connection registration timed out"}, + }) + dc.Close() + } + }() + for !dc.registered { msg, err := dc.ReadMessage() if err != nil { return fmt.Errorf("failed to read IRC command: %w", err) } - err = dc.handleMessage(msg) + err = dc.handleMessage(ctx, msg) if ircErr, ok := err.(ircError); ok { ircErr.Message.Prefix = dc.srv.prefix() dc.SendMessage(ircErr.Message) diff --git a/server.go b/server.go index 1a83208..b125def 100644 --- a/server.go +++ b/server.go @@ -28,6 +28,7 @@ var upstreamMessageDelay = 2 * time.Second var upstreamMessageBurst = 10 var backlogTimeout = 10 * time.Second var handleDownstreamMessageTimeout = 10 * time.Second +var downstreamRegisterTimeout = 30 * time.Second var chatHistoryLimit = 1000 var backlogLimit = 4000 diff --git a/user.go b/user.go index 3422300..b573ff1 100644 --- a/user.go +++ b/user.go @@ -637,7 +637,7 @@ func (u *user) run() { dc.logger.Printf("ignoring message on closed connection: %v", msg) break } - err := dc.handleMessage(msg) + err := dc.handleMessage(context.TODO(), msg) if ircErr, ok := err.(ircError); ok { ircErr.Message.Prefix = dc.srv.prefix() dc.SendMessage(ircErr.Message)