package transport_test import ( "bytes" "context" "database/sql" "encoding/json" "errors" "github.com/alexedwards/flow" "github.com/golang/mock/gomock" "log" "net/http" "net/http/httptest" "testing" "time" "whitelistmanager/invite" mock_invite "whitelistmanager/mocks/invite" mock_store "whitelistmanager/mocks/store" "whitelistmanager/store" "whitelistmanager/transport" ) type invitePayload struct { Server string `json:"server"` Unlimited bool `json:"unlimited"` Uses int `json:"uses"` } func TestInvites(t *testing.T) { ctrl := gomock.NewController(t) st := mock_store.NewMockStorer(ctrl) im := mock_invite.NewMockInviteManager(ctrl) user := store.User{ Id: "1", Token: "", DisplayName: "user", RefreshToken: "", TokenExpiry: time.Time{}, } server := store.Server{ Id: "1", } inv := store.Invite{ Token: "foo", Server: server, Uses: 0, Unlimited: false, } invNoToken := store.Invite{ Token: "", Server: store.Server{ Id: "1", }, Uses: 0, Unlimited: false, } invPayload := invitePayload{ Server: "1", Uses: 0, Unlimited: false, } t.Run("create invite correctly", func(t *testing.T) { im.EXPECT().Create(invNoToken, user).Return(inv.Token, nil) handler := transport.New(st, im) jsonData, err := json.Marshal(invPayload) if err != nil { log.Fatal(err) } req, err := http.NewRequest("POST", "/api/v1/invites", bytes.NewBuffer(jsonData)) if err != nil { t.Fatal(err) } ctx := req.Context() ctx = context.WithValue(ctx, "user", user) req = req.WithContext(ctx) rr := httptest.NewRecorder() h := http.HandlerFunc(handler.CreateInvite) h.ServeHTTP(rr, req) // Check the status code is what we expect. if status := rr.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) } expected := `foo` if rr.Body.String() != expected { t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected) } }) t.Run("user is not server owner", func(t *testing.T) { im.EXPECT().Create(invNoToken, user).Return("", errors.New(invite.NotOwnerofServer)) handler := transport.New(st, im) jsonData, err := json.Marshal(invPayload) if err != nil { log.Fatal(err) } req, err := http.NewRequest("POST", "/api/v1/invites", bytes.NewBuffer(jsonData)) if err != nil { t.Fatal(err) } ctx := req.Context() ctx = context.WithValue(ctx, "user", user) req = req.WithContext(ctx) rr := httptest.NewRecorder() h := http.HandlerFunc(handler.CreateInvite) h.ServeHTTP(rr, req) // Check the status code is what we expect. if status := rr.Code; status != http.StatusForbidden { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusForbidden) } expected := "user is not owner of server\n" if rr.Body.String() != expected { t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected) } }) t.Run("get existing invite", func(t *testing.T) { m := flow.New() st.EXPECT().GetInvite(inv.Token).Return(inv, nil) handler := transport.New(st, im) req, err := http.NewRequest("GET", "/api/v1/invite/foo", nil) if err != nil { t.Fatal(err) } ctx := req.Context() ctx = context.WithValue(ctx, "user", user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id", handler.GetInvite) m.ServeHTTP(rr, req) // Check the status code is what we expect. if status := rr.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) } expected, err := json.Marshal(inv) if err != nil { t.Fatal(err) } if rr.Body.String() != string(expected)+"\n" { t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), string(expected)) } }) t.Run("existing invite", func(t *testing.T) { m := flow.New() st.EXPECT().GetInvite(inv.Token).Return(inv, nil) handler := transport.New(st, im) req, err := http.NewRequest("GET", "/api/v1/invite/foo", nil) if err != nil { t.Fatal(err) } ctx := req.Context() ctx = context.WithValue(ctx, "user", user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id", handler.GetInvite) m.ServeHTTP(rr, req) // Check the status code is what we expect. if status := rr.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) } expected, err := json.Marshal(inv) if err != nil { t.Fatal(err) } if rr.Body.String() != string(expected)+"\n" { t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), string(expected)) } }) t.Run("non existent invite", func(t *testing.T) { m := flow.New() st.EXPECT().GetInvite(inv.Token).Return(store.Invite{}, sql.ErrNoRows) handler := transport.New(st, im) req, err := http.NewRequest("GET", "/api/v1/invite/foo", nil) if err != nil { t.Fatal(err) } ctx := req.Context() ctx = context.WithValue(ctx, "user", user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id", handler.GetInvite) m.ServeHTTP(rr, req) if status := rr.Code; status != http.StatusNotFound { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusNotFound) } }) t.Run("user not logged in when getting invite", func(t *testing.T) { m := flow.New() st.EXPECT().GetInvite(inv.Token).Return(inv, nil) handler := transport.New(st, im) req, err := http.NewRequest("GET", "/api/v1/invite/foo", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id", handler.GetInvite) m.ServeHTTP(rr, req) if status := rr.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusNotFound) } expected, err := json.Marshal(inv) if err != nil { t.Fatal(err) } if rr.Body.String() != string(expected)+"\n" { t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), string(expected)) } }) //t.Run("user can accept invite", func(t *testing.T) { // m := flow.New() // st.EXPECT().GetInvite(inv.Token).Return(inv, nil) // im.EXPECT().RemainingUses(inv).Return(1, nil) // st.EXPECT().GetServer(server.Id).Return(server, nil) // handler := transport.New(st, im) // // req, err := http.NewRequest("POST", "/api/v1/invite/foo/accept", nil) // if err != nil { // t.Fatal(err) // } // // ctx := req.Context() // ctx = context.WithValue(ctx, "user", user) // req = req.WithContext(ctx) // rr := httptest.NewRecorder() // m.HandleFunc("/api/v1/invite/:id/accept", handler.AcceptInvite, "POST") // m.ServeHTTP(rr, req) // // if status := rr.Code; status != http.StatusForbidden { // t.Errorf("handler returned wrong status code: got %v want %v", // status, http.StatusNotFound) // } // // expected, err := json.Marshal(inv) // if err != nil { // t.Fatal(err) // } // if rr.Body.String() != string(expected)+"\n" { // t.Errorf("handler returned unexpected body: got %v want %v", // rr.Body.String(), string(expected)) // } }) }