fileupload: add CORS header fields

This commit is contained in:
Alex McGrath 2024-02-09 14:10:19 +00:00 committed by Simon Ser
parent 2a78536eb9
commit 6b89b99784
2 changed files with 49 additions and 6 deletions

View file

@ -162,6 +162,7 @@ func main() {
Uploader: cfg.FileUploader, Uploader: cfg.FileUploader,
DB: db, DB: db,
Auth: cfg.Auth, Auth: cfg.Auth,
HTTPOrigins: cfg.HTTPOrigins,
} }
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
}) })

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"mime" "mime"
"net/http" "net/http"
"net/url"
"path" "path"
"strings" "strings"
"time" "time"
@ -56,11 +57,52 @@ type Handler struct {
Uploader Uploader Uploader Uploader
Auth auth.Authenticator Auth auth.Authenticator
DB database.Database DB database.Database
HTTPOrigins []string
}
func (h *Handler) checkOrigin(reqOrigin string) bool {
for _, origin := range h.HTTPOrigins {
match, err := path.Match(origin, reqOrigin)
if err != nil {
panic(err) // patterns are checked at config load time
} else if match {
return true
}
}
return false
}
func (h *Handler) setCORS(resp http.ResponseWriter, req *http.Request) error {
resp.Header().Set("Access-Control-Allow-Credentials", "true")
resp.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Content-Disposition")
resp.Header().Set("Access-Control-Expose-Headers", "Location, Content-Disposition")
reqOrigin := req.Header.Get("Origin")
if reqOrigin == "" {
return nil
}
u, err := url.Parse(reqOrigin)
if err != nil {
return fmt.Errorf("invalid Origin header field: %v", err)
}
if !strings.EqualFold(u.Host, req.Host) && !h.checkOrigin(reqOrigin) {
return fmt.Errorf("unauthorized Origin")
}
resp.Header().Set("Access-Control-Allow-Origin", reqOrigin)
resp.Header().Set("Vary", "Origin")
return nil
} }
func (h *Handler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { func (h *Handler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
resp.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none';") resp.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none';")
if err := h.setCORS(resp, req); err != nil {
http.Error(resp, err.Error(), http.StatusForbidden)
return
}
if h.Uploader == nil { if h.Uploader == nil {
http.NotFound(resp, req) http.NotFound(resp, req)
return return