From 6b89b9978443a6b669599b7692c1e776b9ec1ac0 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Fri, 9 Feb 2024 14:10:19 +0000 Subject: [PATCH] fileupload: add CORS header fields --- cmd/soju/main.go | 7 +++--- fileupload/fileupload.go | 48 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/cmd/soju/main.go b/cmd/soju/main.go index 7dfed44..04de78b 100644 --- a/cmd/soju/main.go +++ b/cmd/soju/main.go @@ -159,9 +159,10 @@ func main() { fileUploadHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cfg := srv.Config() h := fileupload.Handler{ - Uploader: cfg.FileUploader, - DB: db, - Auth: cfg.Auth, + Uploader: cfg.FileUploader, + DB: db, + Auth: cfg.Auth, + HTTPOrigins: cfg.HTTPOrigins, } h.ServeHTTP(w, r) }) diff --git a/fileupload/fileupload.go b/fileupload/fileupload.go index cceb4a5..5c72608 100644 --- a/fileupload/fileupload.go +++ b/fileupload/fileupload.go @@ -7,6 +7,7 @@ import ( "io" "mime" "net/http" + "net/url" "path" "strings" "time" @@ -53,14 +54,55 @@ func New(driver, source string) (Uploader, error) { } type Handler struct { - Uploader Uploader - Auth auth.Authenticator - DB database.Database + Uploader Uploader + Auth auth.Authenticator + DB database.Database + HTTPOrigins []string +} + +func (h *Handler) checkOrigin(reqOrigin string) bool { + for _, origin := range h.HTTPOrigins { + match, err := path.Match(origin, reqOrigin) + if err != nil { + panic(err) // patterns are checked at config load time + } else if match { + return true + } + } + return false +} + +func (h *Handler) setCORS(resp http.ResponseWriter, req *http.Request) error { + resp.Header().Set("Access-Control-Allow-Credentials", "true") + resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Content-Disposition") + resp.Header().Set("Access-Control-Expose-Headers", "Location, Content-Disposition") + + reqOrigin := req.Header.Get("Origin") + if reqOrigin == "" { + return nil + } + u, err := url.Parse(reqOrigin) + if err != nil { + return fmt.Errorf("invalid Origin header field: %v", err) + } + + if !strings.EqualFold(u.Host, req.Host) && !h.checkOrigin(reqOrigin) { + return fmt.Errorf("unauthorized Origin") + } + + resp.Header().Set("Access-Control-Allow-Origin", reqOrigin) + resp.Header().Set("Vary", "Origin") + return nil } func (h *Handler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { resp.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none';") + if err := h.setCORS(resp, req); err != nil { + http.Error(resp, err.Error(), http.StatusForbidden) + return + } + if h.Uploader == nil { http.NotFound(resp, req) return