diff --git a/downstream.go b/downstream.go index 12dcaff..6302a5c 100644 --- a/downstream.go +++ b/downstream.go @@ -9,7 +9,6 @@ import ( "fmt" "io" "net" - "net/http" "strconv" "strings" "time" @@ -3060,14 +3059,6 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return err } - if err := checkWebPushEndpoint(ctx, endpoint); err != nil { - dc.logger.Printf("failed to check Web push endpoint %q: %v", endpoint, err) - return ircError{&irc.Message{ - Command: "FAIL", - Params: []string{"WEBPUSH", "INVALID_PARAMS", subcommand, "Invalid endpoint"}, - }} - } - rawKeys := irc.ParseTags(keysStr) authKey, hasAuthKey := rawKeys["auth"] p256dhKey, hasP256dh := rawKeys["p256dh"] @@ -3112,16 +3103,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. networkID = dc.network.ID } - // TODO: limit max number of subscriptions, prune old ones - - if err := dc.user.srv.db.StoreWebPushSubscription(ctx, dc.user.ID, networkID, &newSub); err != nil { - dc.logger.Printf("failed to store Web push subscription: %v", err) - return ircError{&irc.Message{ - Command: "FAIL", - Params: []string{"WEBPUSH", "INTERNAL_ERROR", subcommand, "Internal error"}, - }} - } - + // Send a test Web Push message, to make sure the endpoint is valid err = dc.srv.sendWebPush(ctx, &webpush.Subscription{ Endpoint: newSub.Endpoint, Keys: webpush.Keys{ @@ -3134,6 +3116,20 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }) if err != nil { dc.logger.Printf("failed to send Web push notification to endpoint %q: %v", newSub.Endpoint, err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"WEBPUSH", "INVALID_PARAMS", subcommand, "Invalid endpoint"}, + }} + } + + // TODO: limit max number of subscriptions, prune old ones + + if err := dc.user.srv.db.StoreWebPushSubscription(ctx, dc.user.ID, networkID, &newSub); err != nil { + dc.logger.Printf("failed to store Web push subscription: %v", err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"WEBPUSH", "INTERNAL_ERROR", subcommand, "Internal error"}, + }} } dc.SendMessage(&irc.Message{ @@ -3336,35 +3332,3 @@ func sendNames(dc *downstreamConn, ch *upstreamChannel) { dc.SendMessage(msg) } } - -func checkWebPushEndpoint(ctx context.Context, endpoint string) error { - req, err := http.NewRequestWithContext(ctx, http.MethodOptions, endpoint, nil) - if err != nil { - return fmt.Errorf("failed to create HTTP request: %v", err) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("HTTP request failed: %v", err) - } - resp.Body.Close() - - if resp.StatusCode/100 != 2 { - return fmt.Errorf("HTTP request failed: %v", resp.Status) - } - - allow := strings.Split(resp.Header.Get("Allow"), ",") - found := false - for _, method := range allow { - if strings.EqualFold(strings.TrimSpace(method), http.MethodPost) { - found = true - break - } - } - - if !found { - return fmt.Errorf("POST missing from Allow header in OPTIONS response") - } - - return nil -}