211 lines
5.1 KiB
Go
211 lines
5.1 KiB
Go
|
package fileupload
|
||
|
|
||
|
import (
|
||
|
"crypto/rand"
|
||
|
"encoding/hex"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"mime"
|
||
|
"net/http"
|
||
|
"path"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"git.sr.ht/~emersion/soju/auth"
|
||
|
"git.sr.ht/~emersion/soju/database"
|
||
|
)
|
||
|
|
||
|
const maxSize = 50 * 1024 * 1024 // 50 MiB
|
||
|
|
||
|
type Uploader interface {
|
||
|
load(filename string) (basename string, modTime time.Time, content io.ReadSeekCloser, err error)
|
||
|
store(r io.Reader, username, mimeType, basename string) (outFilename string, err error)
|
||
|
}
|
||
|
|
||
|
func New(driver, source string) (Uploader, error) {
|
||
|
switch driver {
|
||
|
case "fs":
|
||
|
return &fs{source}, nil
|
||
|
default:
|
||
|
return nil, fmt.Errorf("unknown file upload driver %q", driver)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type Handler struct {
|
||
|
Uploader Uploader
|
||
|
Auth auth.Authenticator
|
||
|
DB database.Database
|
||
|
}
|
||
|
|
||
|
func (h *Handler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||
|
resp.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none';")
|
||
|
|
||
|
if h.Uploader == nil {
|
||
|
http.NotFound(resp, req)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
switch req.Method {
|
||
|
case http.MethodOptions:
|
||
|
resp.WriteHeader(http.StatusNoContent)
|
||
|
case http.MethodHead, http.MethodGet:
|
||
|
h.fetch(resp, req)
|
||
|
case http.MethodPost:
|
||
|
h.store(resp, req)
|
||
|
default:
|
||
|
http.Error(resp, "only OPTIONS, HEAD, GET and POST are allowed", http.StatusMethodNotAllowed)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (h *Handler) fetch(resp http.ResponseWriter, req *http.Request) {
|
||
|
prefix := "/uploads/"
|
||
|
if !strings.HasPrefix(req.URL.Path, prefix) {
|
||
|
http.Error(resp, "invalid path", http.StatusNotFound)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
filename := strings.TrimPrefix(req.URL.Path, prefix)
|
||
|
filename = path.Join("/", filename)[1:] // prevent directory traversal
|
||
|
if filename == "" {
|
||
|
http.Error(resp, "invalid path", http.StatusNotFound)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
basename, modTime, content, err := h.Uploader.load(filename)
|
||
|
if err != nil {
|
||
|
http.Error(resp, "failed to open file", http.StatusNotFound)
|
||
|
return
|
||
|
}
|
||
|
defer content.Close()
|
||
|
|
||
|
contentDisp := mime.FormatMediaType("attachment", map[string]string{
|
||
|
"filename": basename,
|
||
|
})
|
||
|
resp.Header().Set("Content-Disposition", contentDisp)
|
||
|
|
||
|
http.ServeContent(resp, req, basename, modTime, content)
|
||
|
}
|
||
|
|
||
|
func (h *Handler) store(resp http.ResponseWriter, req *http.Request) {
|
||
|
if req.URL.Path != "/uploads" {
|
||
|
http.Error(resp, "invalid path", http.StatusNotFound)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
authz := req.Header.Get("Authorization")
|
||
|
if authz == "" {
|
||
|
http.Error(resp, "missing Authorization header", http.StatusUnauthorized)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
username string
|
||
|
err error
|
||
|
)
|
||
|
scheme, param, _ := strings.Cut(authz, " ")
|
||
|
switch strings.ToLower(scheme) {
|
||
|
case "basic":
|
||
|
plainAuth, ok := h.Auth.(auth.PlainAuthenticator)
|
||
|
if !ok {
|
||
|
http.Error(resp, "Basic scheme in Authorization header not supported", http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
var password string
|
||
|
username, password, ok = req.BasicAuth()
|
||
|
if !ok {
|
||
|
http.Error(resp, "invalid Authorization header", http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
err = plainAuth.AuthPlain(req.Context(), h.DB, username, password)
|
||
|
case "bearer":
|
||
|
oauthAuth, ok := h.Auth.(auth.OAuthBearerAuthenticator)
|
||
|
if !ok {
|
||
|
http.Error(resp, "Bearer scheme in Authorization header not supported", http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
username, err = oauthAuth.AuthOAuthBearer(req.Context(), h.DB, param)
|
||
|
default:
|
||
|
http.Error(resp, "unsupported Authorization header scheme", http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
if err != nil {
|
||
|
var msg string
|
||
|
if authErr, ok := err.(*auth.Error); ok {
|
||
|
msg = authErr.ExternalMsg
|
||
|
} else {
|
||
|
msg = "authentication failed"
|
||
|
}
|
||
|
http.Error(resp, msg, http.StatusForbidden)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var mimeType string
|
||
|
if contentType := req.Header.Get("Content-Type"); contentType != "" {
|
||
|
var (
|
||
|
params map[string]string
|
||
|
err error
|
||
|
)
|
||
|
mimeType, params, err = mime.ParseMediaType(contentType)
|
||
|
if err != nil {
|
||
|
http.Error(resp, "failed to parse Content-Type", http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
if mimeType == "application/octet-stream" {
|
||
|
mimeType = ""
|
||
|
}
|
||
|
|
||
|
switch strings.ToLower(params["charset"]) {
|
||
|
case "", "utf-8", "us-ascii":
|
||
|
// OK
|
||
|
default:
|
||
|
http.Error(resp, "unsupported charset", http.StatusUnsupportedMediaType)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var basename string
|
||
|
if contentDisp := req.Header.Get("Content-Disposition"); contentDisp != "" {
|
||
|
_, params, err := mime.ParseMediaType(contentDisp)
|
||
|
if err != nil {
|
||
|
http.Error(resp, "failed to parse Content-Disposition", http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
basename = path.Base(params["filename"])
|
||
|
}
|
||
|
|
||
|
r := &limitedReader{r: req.Body, n: maxSize}
|
||
|
outFilename, err := h.Uploader.store(r, username, mimeType, basename)
|
||
|
if err != nil {
|
||
|
http.Error(resp, "failed to write file", http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
resp.Header().Set("Location", "/uploads/"+outFilename)
|
||
|
resp.WriteHeader(http.StatusCreated)
|
||
|
}
|
||
|
|
||
|
type limitedReader struct {
|
||
|
r io.Reader
|
||
|
n int64
|
||
|
}
|
||
|
|
||
|
func (lr *limitedReader) Read(p []byte) (n int, err error) {
|
||
|
if lr.n <= 0 {
|
||
|
return 0, fmt.Errorf("file too large")
|
||
|
}
|
||
|
if int64(len(p)) > lr.n {
|
||
|
p = p[0:lr.n]
|
||
|
}
|
||
|
n, err = lr.r.Read(p)
|
||
|
lr.n -= int64(n)
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
func generateToken(n int) (string, error) {
|
||
|
b := make([]byte, n)
|
||
|
if _, err := rand.Read(b); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return hex.EncodeToString(b), nil
|
||
|
}
|