diff --git a/transport/http_test.go b/transport/http_test.go index e8553f1..b492a4b 100644 --- a/transport/http_test.go +++ b/transport/http_test.go @@ -286,4 +286,52 @@ func TestInvites(t *testing.T) { rr.Body.String(), expected) } }) + + t.Run("user can't accept unknown invite", func(t *testing.T) { + m := flow.New() + st.EXPECT().GetInvite(inv.Token).Return(store.Invite{}, sql.ErrNoRows) + handler := transport.New(st, im, mc) + + req, err := http.NewRequest("POST", "/api/v1/invite/foo/accept", nil) + if err != nil { + t.Fatal(err) + } + + ctx := req.Context() + ctx = context.WithValue(ctx, "user", user) + req = req.WithContext(ctx) + rr := httptest.NewRecorder() + m.HandleFunc("/api/v1/invite/:id/accept", handler.AcceptInvite, "POST") + m.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusNotFound { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusNotFound) + } + }) + + t.Run("user can't accept invite with no remaining uses", func(t *testing.T) { + m := flow.New() + st.EXPECT().GetInvite(inv.Token).Return(inv, nil) + im.EXPECT().RemainingUses(inv).Return(0, nil) + handler := transport.New(st, im, mc) + + req, err := http.NewRequest("POST", "/api/v1/invite/foo/accept", nil) + if err != nil { + t.Fatal(err) + } + + ctx := req.Context() + ctx = context.WithValue(ctx, "user", user) + req = req.WithContext(ctx) + rr := httptest.NewRecorder() + m.HandleFunc("/api/v1/invite/:id/accept", handler.AcceptInvite, "POST") + m.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusForbidden { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusForbidden) + } + }) + }