diff --git a/Dockerfile b/Dockerfile index 06467cd..70a0aa3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM docker.io/golang:alpine as builder +FROM docker.io/golang:1.22-alpine as builder WORKDIR /build/wlm COPY go.mod go.sum ./ @@ -8,7 +8,7 @@ COPY . . RUN apk add --update gcc musl-dev WORKDIR /build/wlm -RUN go build -o wlm -ldflags "-s -w" cmd/wlm/*.go +RUN CGO_ENABLED=1 go build -o wlm -ldflags "-s -w" cmd/wlm/*.go FROM docker.io/alpine diff --git a/cmd/wlm/main.go b/cmd/wlm/main.go index f1da3b8..feea7a1 100644 --- a/cmd/wlm/main.go +++ b/cmd/wlm/main.go @@ -3,15 +3,24 @@ package main import ( "log" "net/http" + "os" "whitelistmanager/internal/invite" "whitelistmanager/internal/minecraft" "whitelistmanager/internal/store" "whitelistmanager/internal/transport" "github.com/alexedwards/flow" + "github.com/joho/godotenv" ) func main() { + if _, err := os.Stat(".env"); err == nil { + err := godotenv.Load() + if err != nil { + log.Fatal("Error loading .env file") + } + } + mux := flow.New() db, err := store.Open() if err != nil { @@ -47,7 +56,11 @@ func main() { mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { r.Header.Set("Content-Type", "application/json") - w.Write([]byte(`{"duck": "quacks"}`)) + _, err = w.Write([]byte(`{"duck": "quacks"}`)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } }) log.Println("Http listening on 0.0.0.0:8080") diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..e6c5d2d --- /dev/null +++ b/flake.lock @@ -0,0 +1,60 @@ +{ + "nodes": { + "nixpkgs": { + "locked": { + "lastModified": 1716948383, + "narHash": "sha256-SzDKxseEcHR5KzPXLwsemyTR/kaM9whxeiJohbL04rs=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "ad57eef4ef0659193044870c731987a6df5cf56b", + "type": "github" + }, + "original": { + "id": "nixpkgs", + "ref": "nixos-unstable", + "type": "indirect" + } + }, + "root": { + "inputs": { + "nixpkgs": "nixpkgs", + "utils": "utils" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..d158d80 --- /dev/null +++ b/flake.nix @@ -0,0 +1,71 @@ +{ + description = "Discord-style invites for your Minecraft server"; + + inputs = { + nixpkgs.url = "nixpkgs/nixos-unstable"; + utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, utils }: + utils.lib.eachSystem [ + "x86_64-linux" + "aarch64-linux" + "x86_64-darwin" + "aarch64-darwin" + ] (system: + let + version = builtins.substring 0 8 self.lastModifiedDate; + pkgs = import nixpkgs { inherit system; }; + in { + packages = rec { + whitelistmanager = pkgs.buildGo122Module { + pname = "wlm"; + version = "0.1.0-${version}"; + go = pkgs.go; + src = ./.; + subPackages = "cmd/wlm"; + vendorHash = "sha256-Jww3hNOGpwXSCAvD2THmTlIVf4HL7FHITjjEUbcLRao="; + }; + + docker = pkgs.dockerTools.buildLayeredImage { + name = "git.gmem.ca/arch/whitelistmanager"; + tag = "latest"; + config.Cmd = [ "${whitelistmanager}/bin/wlm" ]; + contents = [ pkgs.cacert ]; + }; + + portable-service = let + web-service = pkgs.substituteAll { + name = "whitelistmanager.service"; + src = ./run/portable-service/whitelistmanager.service.in; + inherit whitelistmanager; + }; + in pkgs.portableService { + inherit (whitelistmanager) version; + pname = "whitelistmanager"; + description = "The whitelistmanager service"; + homepage = "https://git.gmem.ca/arch/whitelistmanager"; + units = [ web-service ]; + symlinks = [{ + object = "${pkgs.cacert}/etc/ssl"; + symlink = "/etc/ssl"; + }]; + }; + + default = docker; + }; + + apps.default = + utils.lib.mkApp { drv = self.packages.${system}.default; }; + + devShells.default = pkgs.mkShell { + buildInputs = with pkgs; [ + go + gopls + gotools + go-tools + sqlite-interactive + ]; + }; + }) // {}; +} diff --git a/go.mod b/go.mod index 95fd48e..e150525 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,15 @@ module whitelistmanager -go 1.18 +go 1.22 require ( github.com/Kelwing/mc-rcon v0.0.0-20220214194105-bec8dcbccc3f github.com/alexedwards/flow v0.0.0-20220607190737-c48a87f2b4c4 github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 - github.com/mattn/go-sqlite3 v1.14.14 + github.com/joho/godotenv v1.5.1 + github.com/lib/pq v1.10.9 + github.com/mattn/go-sqlite3 v1.14.22 golang.org/x/oauth2 v0.0.0-20220630143837-2104d58473e0 ) diff --git a/go.sum b/go.sum index 3ec2260..7364bbb 100644 --- a/go.sum +++ b/go.sum @@ -10,10 +10,15 @@ github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw= -github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 9398c50..eab4de8 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -7,7 +7,7 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" + "io" "net/http" "os" "time" @@ -119,7 +119,7 @@ func xblTokenExchange(token *oauth2.Token, client *http.Client) (string, string, if err != nil { return "", "", err } - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { return "", "", err } @@ -162,7 +162,7 @@ func xstsTokenExchange(xblToken string, client *http.Client) (string, error) { if err != nil { return "", err } - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { return "", err } diff --git a/internal/invite/invite.go b/internal/invite/invite.go index 09b6244..2e33603 100644 --- a/internal/invite/invite.go +++ b/internal/invite/invite.go @@ -25,6 +25,9 @@ func NewManager(db store.Storer) *Manager { func (i *Manager) Create(in store.Invite, user store.User) (string, error) { server, err := i.store.GetServer(in.Server.Id) + if err != nil { + return "", err + } if server.Owner.Id != user.Id { return "", errors.New(NotOwnerofServer) } diff --git a/internal/store/database.go b/internal/store/database.go index 38574d6..1214ed7 100644 --- a/internal/store/database.go +++ b/internal/store/database.go @@ -6,9 +6,11 @@ import ( "errors" "log" "os" + "strings" "time" "github.com/google/uuid" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) @@ -90,36 +92,35 @@ type InviteLog struct { func Open() (*Store, error) { database := os.Getenv("WLM_DATABASE_PATH") + dbType := "sqlite3" if database == "" { database = "db.sqlite3" } + if strings.Contains(database, "postgresql://" ) { + dbType = "postgres" + } - if _, err := os.Stat(database); errors.Is(err, os.ErrNotExist) { + if _, err := os.Stat(database); errors.Is(err, os.ErrNotExist) && !strings.Contains(database, "postgresql://") { log.Printf("No database found at %s, creating", database) _, err := os.Create(database) if err != nil { return nil, err } - db, err := sql.Open("sqlite3", database) - if err != nil { - return nil, err - } - initialSetup, err := migrations.ReadFile("database.sql") - if err != nil { - return nil, err - } - _, err = db.Exec(string(initialSetup)) - if err != nil { - return nil, err - } log.Printf("Database created at %s", database) - db.Close() } - - db, err := sql.Open("sqlite3", database) + db, err := sql.Open(dbType, database) if err != nil { return nil, err } + initialSetup, err := migrations.ReadFile("database.sql") + if err != nil { + return nil, err + } + _, err = db.Exec(string(initialSetup)) + if err != nil { + return nil, err + } + return &Store{database: db}, nil } @@ -165,7 +166,7 @@ func (s *Store) GetInvite(token string) (Invite, error) { } func (s *Store) LogInviteUse(user User, invite Invite) error { - q, err := s.database.Prepare("INSERT INTO invite_log (entry_id, invite, user) VALUES ($1, $2, $3)") + q, err := s.database.Prepare("INSERT INTO invite_log (entry_id, invite, uid) VALUES ($1, $2, $3)") if err != nil { return err } @@ -255,7 +256,7 @@ func (s *Store) GetUser(uid string) (User, error) { if err != nil { return User{}, err } - user.TokenExpiry, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", tokenExpiry) + user.TokenExpiry, err = time.Parse("2006-01-02 15:04:05.999999999Z", tokenExpiry) if err != nil { return User{}, err } @@ -309,7 +310,7 @@ func (s *Store) SessionUser(token string) (User, error) { if err != nil { return User{}, err } - sess.Expiry, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", sessExpiry) + sess.Expiry, err = time.Parse("2006-01-02 15:04:05.999999999Z", sessExpiry) if err != nil { return User{}, err } diff --git a/internal/store/database.sql b/internal/store/database.sql index d5039a6..406abe2 100644 --- a/internal/store/database.sql +++ b/internal/store/database.sql @@ -1,54 +1,54 @@ -- noinspection SqlNoDataSourceInspectionForFile +CREATE TABLE IF NOT EXISTS servers ( + id TEXT NOT NULL, + name TEXT NOT NULL, + address TEXT NOT NULL, + rcon_address TEXT NOT NULL, + rcon_password TEXT NOT NULL, + owner TEXT NOT NULL, + PRIMARY KEY (id) +); -CREATE TABLE `invites` ( - `token` TEXT NOT NULL, - `creator` TEXT NOT NULL, - `server` TEXT NOT NULL, - `uses` INT NOT NULL, - `unlimited` INTEGER NOT NULL, - PRIMARY KEY (`token`), - FOREIGN KEY (`server`) REFERENCES servers(`id`), - FOREIGN KEY (`creator`) REFERENCES users(`id`) -) STRICT; +CREATE TABLE IF NOT EXISTS users ( + id TEXT NOT NULL, + token TEXT NOT NULL, + display_name TEXT NOT NULL, + refresh_token TEXT NOT NULL, + token_expiry TEXT NOT NULL, + PRIMARY KEY (id) +); -CREATE TABLE `servers` ( - `id` TEXT NOT NULL, - `name` TEXT NOT NULL, - `address` TEXT NOT NULL, - `rcon_address` TEXT NOT NULL, - `rcon_password` TEXT NOT NULL, - `owner` TEXT NOT NULL, - PRIMARY KEY (`id`) -) STRICT; +CREATE TABLE IF NOT EXISTS sessions ( + token TEXT NOT NULL, + uid TEXT NOT NULL, + expiry TEXT NOT NULL, + PRIMARY KEY (token), + FOREIGN KEY (uid) REFERENCES users(id) +); -CREATE TABLE `users` ( - `id` TEXT NOT NULL, - `token` TEXT NOT NULL, - `display_name` TEXT NOT NULL, - `refresh_token` TEXT NOT NULL, - `token_expiry` TEXT NOT NULL, - PRIMARY KEY (`id`) -) STRICT; +CREATE TABLE IF NOT EXISTS oauth_states ( + state_id TEXT NOT NULL, + origin TEXT NOT NULL, + state TEXT NOT NULL, + PRIMARY KEY (state_id) +); -CREATE TABLE `sessions` ( - `token` TEXT NOT NULL, - `uid` TEXT NOT NULL, - `expiry` TEXT NOT NULL, - PRIMARY KEY (`token`), - FOREIGN KEY (`uid`) REFERENCES users(`id`) -) STRICT; +CREATE TABLE IF NOT EXISTS invites ( + token TEXT NOT NULL, + creator TEXT NOT NULL, + server TEXT NOT NULL, + uses INT NOT NULL, + unlimited BOOLEAN NOT NULL, + PRIMARY KEY (token), + FOREIGN KEY (server) REFERENCES servers(id), + FOREIGN KEY (creator) REFERENCES users(id) +); -CREATE TABLE `invite_log` ( - `entry_id` TEXT NOT NULL, - `invite` TEXT NOT NULL, - `user` TEXT NOT NULL, - PRIMARY KEY (`entry_id`), - FOREIGN KEY (`invite`) REFERENCES invites(`token`) -) STRICT; - -CREATE TABLE `oauth_states` ( - `state_id` TEXT NOT NULL, - `origin` TEXT NOT NULL, - `state` TEXT NOT NULL, - PRIMARY KEY (`state_id`) -) STRICT; \ No newline at end of file +CREATE TABLE IF NOT EXISTS invite_log ( + entry_id TEXT NOT NULL, + invite TEXT NOT NULL, + uid TEXT NOT NULL, + PRIMARY KEY (entry_id), + FOREIGN KEY (invite) REFERENCES invites(token), + FOREIGN KEY (uid) REFERENCES users(id) +); diff --git a/internal/transport/http.go b/internal/transport/http.go index 9ed3e52..c7e2feb 100644 --- a/internal/transport/http.go +++ b/internal/transport/http.go @@ -19,6 +19,8 @@ import ( "github.com/google/uuid" ) +type CONTEXT_USER struct {} + type Handler struct { store store.Storer manager invite.InviteManager @@ -74,7 +76,7 @@ func (h *Handler) SessionAuth(next http.Handler) http.Handler { http.Error(w, err.Error(), http.StatusForbidden) return } - ctx := context.WithValue(r.Context(), "user", user) + ctx := context.WithValue(r.Context(), CONTEXT_USER{}, user) r = r.WithContext(ctx) next.ServeHTTP(w, r) }) @@ -93,7 +95,7 @@ func (h *Handler) CreateInvite(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - user := r.Context().Value("user").(store.User) + user := r.Context().Value(CONTEXT_USER{}).(store.User) in, err := h.manager.Create(store.Invite{ Server: store.Server{Id: i.Server}, Unlimited: i.Unlimited, @@ -108,7 +110,11 @@ func (h *Handler) CreateInvite(w http.ResponseWriter, r *http.Request) { return } - w.Write([]byte(in)) + _, err = w.Write([]byte(in)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } func (h *Handler) GetInvite(w http.ResponseWriter, r *http.Request) { @@ -158,7 +164,7 @@ func (h *Handler) AcceptInvite(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - user := r.Context().Value("user").(store.User) + user := r.Context().Value(CONTEXT_USER{}).(store.User) log.Println(user.DisplayName) resp, err := h.mc.Whitelist(user.DisplayName, server) if err != nil { @@ -174,7 +180,11 @@ func (h *Handler) AcceptInvite(w http.ResponseWriter, r *http.Request) { } } - w.Write([]byte(resp)) + _, err = w.Write([]byte(resp)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } func (h *Handler) AuthRedirect(w http.ResponseWriter, r *http.Request) { @@ -195,10 +205,13 @@ func (h *Handler) AuthCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query()["code"][0] stateId := r.URL.Query()["state"][0] state, err := h.store.OauthState(stateId) - if err != nil || state.State == "completed" { + if err != nil { http.Error(w, err.Error(), http.StatusForbidden) return } + if state.State == "completed" { + return + } user, err := auth.Authenticate(code) if err != nil { @@ -248,7 +261,7 @@ func generateSessionToken() (string, error) { } func (h *Handler) Servers(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(store.User) + user := r.Context().Value(CONTEXT_USER{}).(store.User) if r.Method == http.MethodGet { servers, err := h.store.GetUserServers(user) if err != nil { @@ -257,10 +270,18 @@ func (h *Handler) Servers(w http.ResponseWriter, r *http.Request) { } if len(servers) == 0 { w.WriteHeader(http.StatusNotFound) - w.Write([]byte("[]")) + _, err = w.Write([]byte("[]")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + return + } + err = json.NewEncoder(w).Encode(servers) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - json.NewEncoder(w).Encode(servers) return } @@ -282,7 +303,7 @@ func (h *Handler) Servers(w http.ResponseWriter, r *http.Request) { } func (h *Handler) Server(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(store.User) + user := r.Context().Value(CONTEXT_USER{}).(store.User) serverId := flow.Param(r.Context(), "id") if r.Method == http.MethodGet { server, err := h.store.GetServer(serverId) @@ -296,13 +317,17 @@ func (h *Handler) Server(w http.ResponseWriter, r *http.Request) { } server.Rcon.Password = "" - json.NewEncoder(w).Encode(server) + err = json.NewEncoder(w).Encode(server) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } return } if r.Method == http.MethodDelete { serverId := flow.Param(r.Context(), "id") - user := r.Context().Value("user").(store.User) + user := r.Context().Value(CONTEXT_USER{}).(store.User) server, err := h.store.GetServer(serverId) if err != nil { http.Error(w, "no such server", http.StatusNotFound) @@ -320,23 +345,31 @@ func (h *Handler) Server(w http.ResponseWriter, r *http.Request) { return } - w.Write([]byte("deleted")) + _, err = w.Write([]byte("deleted")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } } func (h *Handler) CurrentUser(w http.ResponseWriter, r *http.Request) { - value := r.Context().Value("user") + value := r.Context().Value(CONTEXT_USER{}) if value == nil { http.Error(w, "", http.StatusForbidden) return } user := value.(store.User) - w.Write([]byte(user.DisplayName)) + _, err := w.Write([]byte(user.DisplayName)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } func (h *Handler) ServerInvites(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(store.User) + user := r.Context().Value(CONTEXT_USER{}).(store.User) serverId := flow.Param(r.Context(), "id") server, err := h.store.GetServer(serverId) if err != nil { @@ -360,16 +393,23 @@ func (h *Handler) ServerInvites(w http.ResponseWriter, r *http.Request) { if len(serverInvites) == 0 { w.WriteHeader(http.StatusNotFound) - w.Write([]byte("[]")) + _, err := w.Write([]byte("[]")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } return } - json.NewEncoder(w).Encode(serverInvites) - return + err = json.NewEncoder(w).Encode(serverInvites) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } func (h *Handler) InviteLog(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(store.User) + user := r.Context().Value(CONTEXT_USER{}).(store.User) inviteToken := flow.Param(r.Context(), "id") invite, err := h.store.GetInvite(inviteToken) if err != nil { @@ -387,15 +427,23 @@ func (h *Handler) InviteLog(w http.ResponseWriter, r *http.Request) { } if len(logs) == 0 { w.WriteHeader(http.StatusNotFound) - w.Write([]byte("[]")) + _, err = w.Write([]byte("[]")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } return } - json.NewEncoder(w).Encode(logs) + err = json.NewEncoder(w).Encode(logs) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } func (h *Handler) DeleteInvite(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(store.User) + user := r.Context().Value(CONTEXT_USER{}).(store.User) inviteToken := flow.Param(r.Context(), "id") invite, err := h.store.GetInvite(inviteToken) if err != nil { @@ -412,6 +460,10 @@ func (h *Handler) DeleteInvite(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - w.Write([]byte("deleted")) + _, err = w.Write([]byte("deleted")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } diff --git a/internal/transport/http_test.go b/internal/transport/http_test.go index 23c933e..47b461f 100644 --- a/internal/transport/http_test.go +++ b/internal/transport/http_test.go @@ -98,7 +98,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() h := http.HandlerFunc(handler.CreateInvite) @@ -133,7 +133,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() h := http.HandlerFunc(handler.CreateInvite) @@ -164,7 +164,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id", handler.GetInvite) @@ -197,7 +197,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id", handler.GetInvite) @@ -230,7 +230,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id", handler.GetInvite) @@ -285,7 +285,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id/accept", handler.AcceptInvite, "POST") @@ -314,7 +314,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id/accept", handler.AcceptInvite, "POST") @@ -338,7 +338,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id/accept", handler.AcceptInvite, "POST") @@ -362,7 +362,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id/log", handler.InviteLog) @@ -397,7 +397,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id/log", handler.InviteLog) @@ -434,7 +434,7 @@ func TestInvites(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/invite/:id", handler.DeleteInvite) @@ -476,7 +476,7 @@ func TestUser(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/me", handler.CurrentUser) @@ -532,7 +532,7 @@ func TestServers(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/servers", handler.Servers) @@ -562,7 +562,7 @@ func TestServers(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/server/:id", handler.Server) @@ -605,7 +605,7 @@ func TestServers(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/servers", handler.Servers) @@ -640,7 +640,7 @@ func TestServers(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/server/:id", handler.Server, "DELETE") @@ -675,7 +675,7 @@ func TestServers(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/server/:id", handler.ServerInvites) @@ -710,7 +710,7 @@ func TestServers(t *testing.T) { } ctx := req.Context() - ctx = context.WithValue(ctx, "user", user) + ctx = context.WithValue(ctx, transport.CONTEXT_USER{}, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() m.HandleFunc("/api/v1/server/:id/invites", handler.ServerInvites) @@ -761,7 +761,10 @@ func TestMiddlewares(t *testing.T) { rr := httptest.NewRecorder() m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { r.Header.Set("Content-Type", "application/json") - w.Write([]byte(`{"duck": "quacks"}`)) + _, err = w.Write([]byte(`{"duck": "quacks"}`)) + if err != nil { + t.Fatal(err) + } }) m.ServeHTTP(rr, req) @@ -793,7 +796,10 @@ func TestMiddlewares(t *testing.T) { rr := httptest.NewRecorder() m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { r.Header.Set("Content-Type", "application/json") - w.Write([]byte(`{"duck": "quacks"}`)) + _, err = w.Write([]byte(`{"duck": "quacks"}`)) + if err != nil { + t.Fatal(err) + } }) m.ServeHTTP(rr, req) @@ -822,7 +828,10 @@ func TestMiddlewares(t *testing.T) { rr := httptest.NewRecorder() m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { r.Header.Set("Content-Type", "application/json") - w.Write([]byte(`{"duck": "quacks"}`)) + _, err = w.Write([]byte(`{"duck": "quacks"}`)) + if err != nil { + t.Fatal(err) + } }) m.ServeHTTP(rr, req) @@ -850,7 +859,10 @@ func TestMiddlewares(t *testing.T) { rr := httptest.NewRecorder() m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { r.Header.Set("Content-Type", "application/json") - w.Write([]byte(`{"duck": "quacks"}`)) + _, err = w.Write([]byte(`{"duck": "quacks"}`)) + if err != nil { + t.Fatal(err) + } }) m.ServeHTTP(rr, req)