diff --git a/auth/auth.go b/auth/auth.go deleted file mode 100644 index c178d6e..0000000 --- a/auth/auth.go +++ /dev/null @@ -1,277 +0,0 @@ -package auth - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "database/sql" - "encoding/base64" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "log" - "net/http" - "os" - "strings" - - _ "github.com/mattn/go-sqlite3" - - "github.com/gmemstr/nas/common" -) - -const ( - enc = "cookie_session_encryption" - - // This is the key with which each cookie is encrypted, I'll recommend moving it to a env file - cookieName = "NAS_SESSION" - cookieExpiry = 60 * 60 * 24 * 30 // 30 days in seconds -) - -func UserPermissions(username string, permission int) (bool, error) { - - db, err := sql.Open("sqlite3", "assets/config/users.db") - defer db.Close() - isAllowed := false - if err != nil { - return isAllowed, err - } - - statement, err := db.Prepare("SELECT permissions FROM users WHERE username=?") - if err != nil { - return isAllowed, err - } - - rows, err := statement.Query(username) - if err != nil { - return isAllowed, err - } - - var level int - for rows.Next() { - err = rows.Scan(&level) - if err != nil { - return isAllowed, err - } - if level >= permission { - isAllowed = true - } - } - return isAllowed, nil -} - -func RequireAuthorization(permission int) common.Handler { - return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError { - usr, err := DecryptCookie(r) - if err != nil { - fmt.Println(err.Error()) - if strings.Contains(r.Header.Get("Accept"), "html") || r.Method == "GET" { - http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) - return &common.HTTPError{ - Message: "Unauthorized! Redirecting to /login", - StatusCode: http.StatusTemporaryRedirect, - } - } - return &common.HTTPError{ - Message: "Unauthorized!", - StatusCode: http.StatusUnauthorized, - } - } - - rc.User = usr - - username := rc.User.Username - - hasPermission, err := UserPermissions(string(username), permission) - - if !hasPermission { - return &common.HTTPError{ - Message: "Unauthorized! Redirecting to /admin", - StatusCode: http.StatusUnauthorized, - } - } - return nil - } -} - -func CreateSession(u *common.User) (*http.Cookie, error) { - secret := os.Getenv("POGO_SECRET") - - iv, err := generateRandomString(16) - if err != nil { - return nil, err - } - userJSON, err := json.Marshal(u) - if err != nil { - return nil, err - } - - hexedJSON := hex.EncodeToString(userJSON) - - encKey := deriveKey(enc, secret) - - block, err := aes.NewCipher(encKey) - if err != nil { - return nil, err - } - - // Fill the block with 0x0e - if remBytes := len(hexedJSON) % aes.BlockSize; remBytes != 0 { - t := []byte(hexedJSON) - - for i := 0; i < aes.BlockSize-remBytes; i++ { - t = append(t, 0x0e) - } - hexedJSON = string(t) - } - - mode := cipher.NewCBCEncrypter(block, iv) - encCipher := make([]byte, len(hexedJSON)+aes.BlockSize) - - mode.CryptBlocks(encCipher, []byte(hexedJSON)) - - cipherbase64 := base64urlencode(encCipher) - ivbase64 := base64urlencode(iv) - - // Cookie format: iv.cipher.created_on.expire_on.HMAC - cookieStr := fmt.Sprintf("%s.%s", ivbase64, cipherbase64) - - c := &http.Cookie{ - Name: cookieName, - Value: cookieStr, - MaxAge: cookieExpiry, - } - - // Insert token into database. - db, err := sql.Open("sqlite3", "assets/config/users.db") - defer db.Close() - if err != nil { - return nil, err - } - - statement, err := db.Prepare("UPDATE users SET token=? WHERE username=?") - if err != nil { - return nil, err - } - - _, err = statement.Exec(cookieStr, u.Username) - if err != nil { - return nil, err - } - - return c, nil -} - -func DecryptCookie(r *http.Request) (*common.User, error) { - secret := os.Getenv("POGO_SECRET") - - c, err := r.Cookie(cookieName) - if err != nil { - if err != http.ErrNoCookie { - log.Printf("error in reading Cookie: %v", err) - } - return nil, err - } - - csplit := strings.Split(c.Value, ".") - if len(csplit) != 2 { - return nil, errors.New("Invalid number of values in cookie") - } - - ivb, cipherb := csplit[0], csplit[1] - - iv, err := base64urldecode(ivb) - if err != nil { - return nil, err - } - dcipher, err := base64urldecode(cipherb) - if err != nil { - return nil, err - } - - if len(iv) != 16 { - return nil, errors.New("IV length is not 16") - } - - encKey := deriveKey(enc, secret) - - if len(dcipher)%aes.BlockSize != 0 { - return nil, errors.New("ciphertext not multiple of blocksize") - } - - block, err := aes.NewCipher(encKey) - if err != nil { - return nil, err - } - buf := make([]byte, len(dcipher)) - - mode := cipher.NewCBCDecrypter(block, iv) - - mode.CryptBlocks(buf, []byte(dcipher)) - - tstr := fmt.Sprintf("%x", buf) - - // Remove aes padding, 0e is used because it was used in encryption to mark padding - padIndex := strings.Index(tstr, "0e") - if padIndex == -1 { - return nil, errors.New("Padding Index is -1") - } - tstr = tstr[:padIndex] - - data, err := hex.DecodeString(tstr) - if err != nil { - return nil, err - } - - data, err = hex.DecodeString(string(data)) - if err != nil { - return nil, err - } - - u := &common.User{} - err = json.Unmarshal(data, u) - if err != nil { - return nil, err - } - - return u, nil -} - -func deriveKey(msg, secret string) []byte { - key := []byte(secret) - sha256hash := hmac.New(sha256.New, key) - sha256hash.Write([]byte(msg)) - - return sha256hash.Sum(nil) -} - -func generateRandomString(l int) ([]byte, error) { - rBytes := make([]byte, l) - - _, err := rand.Read(rBytes) - if err != nil { - return nil, err - } - return rBytes, nil -} - -func base64urldecode(str string) ([]byte, error) { - base64str := strings.Replace(string(str), "-", "+", -1) - base64str = strings.Replace(base64str, "_", "/", -1) - - s, err := base64.RawStdEncoding.DecodeString(base64str) - if err != nil { - return nil, err - } - - return s, nil -} - -func base64urlencode(str []byte) string { - base64str := strings.Replace(string(str), "+", "-", -1) - base64str = strings.Replace(base64str, "/", "_", -1) - - return base64.RawStdEncoding.EncodeToString([]byte(base64str)) -} diff --git a/common/common.go b/common/common.go deleted file mode 100644 index 5cf7d91..0000000 --- a/common/common.go +++ /dev/null @@ -1,66 +0,0 @@ -package common - -import ( - "fmt" - "io" - "log" - "net/http" - "os" - "strconv" -) - -// Handler is the signature of HTTP Handler that is passed to Handle function -type Handler func(rc *RouterContext, w http.ResponseWriter, r *http.Request) *HTTPError - -// HTTPError is any error that occurs in middlewares or the code that handles HTTP Frontend of application -// Message is logged to console and Status Code is sent in response -// If a Middleware sends an HTTPError, No middlewares further up in chain are executed -type HTTPError struct { - // Message to log in console - Message string - // Status code that'll be sent in response - StatusCode int -} - -// RouterContext contains any information to be shared with middlewares. -type RouterContext struct { - User *User -} - -// User struct denotes the data is stored in the cookie -type User struct { - Username string `json:"username"` -} - -// ReadAndServeFile reads the file from specified location and sends it in response -func ReadAndServeFile(name string, w http.ResponseWriter) *HTTPError { - f, err := os.Open(name) - if err != nil { - - if os.IsNotExist(err) { - return &HTTPError{ - Message: fmt.Sprintf("%s not found", name), - StatusCode: http.StatusNotFound, - } - } - - return &HTTPError{ - Message: fmt.Sprintf("error in reading %s: %v\n", name, err), - StatusCode: http.StatusInternalServerError, - } - } - - defer f.Close() - stats, err := f.Stat() - if err != nil { - log.Printf("error in fetching %s's stats: %v\n", name, err) - } else { - w.Header().Add("Content-Length", strconv.FormatInt(stats.Size(), 10)) - } - - _, err = io.Copy(w, f) - if err != nil { - log.Printf("error in copying %s to response: %v\n", name, err) - } - return nil -} \ No newline at end of file diff --git a/router/filerouter.go b/router/filerouter.go index a7b6bcb..55ff8b4 100644 --- a/router/filerouter.go +++ b/router/filerouter.go @@ -3,7 +3,6 @@ package router import ( "encoding/json" "fmt" - "github.com/gmemstr/nas/common" "github.com/gmemstr/nas/files" "github.com/gorilla/mux" "net/http" @@ -11,20 +10,22 @@ import ( "strings" ) -func HandleProvider() common.Handler { - return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError { +func HandleProvider() Handler { + return func(context *Context, w http.ResponseWriter, r *http.Request) *HTTPError { vars := mux.Vars(r) - if r.Method == "GET" { - providerCodename := vars["provider"] - providerCodename = strings.Replace(providerCodename, "/", "", -1) - provider := *files.Providers[providerCodename] + providerCodename := vars["provider"] + providerCodename = strings.Replace(providerCodename, "/", "", -1) + provider := *files.Providers[providerCodename] + if r.Method == "GET" { fileList := provider.GetDirectory("") if vars["file"] != "" { fileType := provider.DetermineType(vars["file"]) if fileType == "" { - w.Write([]byte("file not found")) - return nil + return &HTTPError{ + Message: fmt.Sprintf("error determining filetype for %s\n", vars["file"]), + StatusCode: http.StatusInternalServerError, + } } if fileType == "file" { provider.ViewFile(vars["file"], w) @@ -34,28 +35,31 @@ func HandleProvider() common.Handler { } data, err := json.Marshal(fileList) if err != nil { - w.Write([]byte("An error occurred")) - return nil + return &HTTPError{ + Message: fmt.Sprintf("error fetching filelisting for %s\n", vars["file"]), + StatusCode: http.StatusInternalServerError, + } } w.Write(data) } + if r.Method == "POST" { - providerCodename := vars["provider"] - providerCodename = strings.Replace(providerCodename, "/", "", -1) - provider := *files.Providers[providerCodename] err := r.ParseMultipartForm(32 << 20) if err != nil { - w.Write([]byte("unable to parse form")) - fmt.Println(err.Error()) - return nil + return &HTTPError{ + Message: fmt.Sprintf("error parsing form for %s\n", vars["file"]), + StatusCode: http.StatusInternalServerError, + } } file, handler, err := r.FormFile("file") defer file.Close() success := provider.SaveFile(file, handler, vars["file"]) if !success { - w.Write([]byte("unable to save file")) - return nil + return &HTTPError{ + Message: fmt.Sprintf("error saving file %s\n", vars["file"]), + StatusCode: http.StatusInternalServerError, + } } w.Write([]byte("saved file")) } @@ -64,8 +68,8 @@ func HandleProvider() common.Handler { } } -func ListProviders() common.Handler { - return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError { +func ListProviders() Handler { + return func(context *Context, w http.ResponseWriter, r *http.Request) *HTTPError { var providers []string for v, _ := range files.ProviderConfig { providers = append(providers, v) @@ -73,7 +77,10 @@ func ListProviders() common.Handler { sort.Strings(providers) data, err := json.Marshal(providers) if err != nil { - return nil + return &HTTPError{ + Message: fmt.Sprintf("error provider listing"), + StatusCode: http.StatusInternalServerError, + } } w.Write(data) return nil diff --git a/router/router.go b/router/router.go index 6b3690e..765f38f 100644 --- a/router/router.go +++ b/router/router.go @@ -1,34 +1,40 @@ package router import ( - "database/sql" "fmt" - "github.com/gmemstr/nas/auth" - "golang.org/x/crypto/bcrypt" + "github.com/gorilla/mux" + "io" "log" "net/http" - - "github.com/gmemstr/nas/common" - "github.com/gorilla/mux" + "os" + "strconv" ) -func Handle(handlers ...common.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +type Handler func(context *Context, w http.ResponseWriter, r *http.Request) *HTTPError - rc := &common.RouterContext{} +type HTTPError struct { + Message string + StatusCode int +} + +// Context contains any information to be shared with middlewares. +type Context struct {} + +func Handle(handlers ...Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + context := &Context{} for _, handler := range handlers { - err := handler(rc, w, r) + err := handler(context, w, r) if err != nil { log.Printf("%v", err) - w.Write([]byte(http.StatusText(err.StatusCode))) - return } } }) } + // Actual router, define endpoints here. func Init() *mux.Router { @@ -40,134 +46,53 @@ func Init() *mux.Router { // Paths that require specific handlers r.Handle("/", Handle( - //auth.RequireAuthorization(1), rootHandler(), )).Methods("GET") - r.Handle(`/login`, Handle( - loginHandler(), - )).Methods("POST", "GET") - r.Handle("/api/providers", Handle( - //auth.RequireAuthorization(1), ListProviders(), )).Methods("GET") r.Handle(`/api/files/{provider:[a-zA-Z0-9]+\/*}`, Handle( - //auth.RequireAuthorization(1), HandleProvider(), )).Methods("GET", "POST") r.Handle(`/api/files/{provider}/{file:[a-zA-Z0-9=\-\/\s.,&_+]+}`, Handle( - //auth.RequireAuthorization(1), HandleProvider(), )).Methods("GET", "POST") return r } - -func loginHandler() common.Handler { - return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError { - if r.Method == "GET" { - w.Header().Set("Content-Type", "text/html") - file := "assets/web/index.html" - - return common.ReadAndServeFile(file, w) - } - db, err := sql.Open("sqlite3", "assets/config/users.db") - +// Handles serving index page. +func rootHandler() Handler { + return func(context *Context, w http.ResponseWriter, r *http.Request) *HTTPError { + f, err := os.Open("assets/web/index.html") if err != nil { - return &common.HTTPError{ - Message: fmt.Sprintf("error in reading user database: %v", err), + return &HTTPError{ + Message: fmt.Sprintf("error serving index page from assets/web"), StatusCode: http.StatusInternalServerError, } } - statement, err := db.Prepare("SELECT * FROM users WHERE username=?") - - if _, err := auth.DecryptCookie(r); err == nil { - http.Redirect(w, r, "/admin", http.StatusTemporaryRedirect) - return nil - } - - err = r.ParseForm() + defer f.Close() + stats, err := f.Stat() if err != nil { - return &common.HTTPError{ - Message: fmt.Sprintf("error in parsing form: %v", err), - StatusCode: http.StatusBadRequest, + return &HTTPError{ + Message: fmt.Sprintf("error serving index page from assets/web"), + StatusCode: http.StatusInternalServerError, + } + } else { + w.Header().Add("Content-Length", strconv.FormatInt(stats.Size(), 10)) + } + + _, err = io.Copy(w, f) + if err != nil { + return &HTTPError{ + Message: fmt.Sprintf("error serving index page from assets/web"), + StatusCode: http.StatusInternalServerError, } } - - username := r.Form.Get("username") - password := r.Form.Get("password") - rows, err := statement.Query(username) - - if username == "" || password == "" || err != nil { - return &common.HTTPError{ - Message: "username or password is invalid", - StatusCode: http.StatusBadRequest, - } - } - var id int - var dbun string - var dbhsh string - var dbtoken sql.NullString - var dbperm int - for rows.Next() { - err := rows.Scan(&id, &dbun, &dbhsh, &dbtoken, &dbperm) - if err != nil { - return &common.HTTPError{ - Message: fmt.Sprintf("error in decoding sql data", err), - StatusCode: http.StatusBadRequest, - } - } - - } - // Create a cookie here because the credentials are correct - if bcrypt.CompareHashAndPassword([]byte(dbhsh), []byte(password)) == nil { - c, err := auth.CreateSession(&common.User{ - Username: username, - }) - if err != nil { - return &common.HTTPError{ - Message: err.Error(), - StatusCode: http.StatusInternalServerError, - } - } - - // r.AddCookie(c) - w.Header().Add("Set-Cookie", c.String()) - // And now redirect the user to admin page - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) - db.Close() - return nil - } - - return &common.HTTPError{ - Message: "Invalid credentials!", - StatusCode: http.StatusUnauthorized, - } - } -} - -// Handles /. -func rootHandler() common.Handler { - return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError { - - var file string - switch r.URL.Path { - case "/": - w.Header().Set("Content-Type", "text/html") - file = "assets/web/index.html" - - default: - return &common.HTTPError{ - Message: fmt.Sprintf("%s: Not Found", r.URL.Path), - StatusCode: http.StatusNotFound, - } - } - - return common.ReadAndServeFile(file, w) + return nil } } \ No newline at end of file diff --git a/webserver.go b/webserver.go index 94145e7..65caa62 100644 --- a/webserver.go +++ b/webserver.go @@ -1,28 +1,17 @@ package main import ( - "crypto/rand" - "database/sql" - "encoding/base64" "fmt" "github.com/gmemstr/nas/files" + "github.com/gmemstr/nas/router" "github.com/go-yaml/yaml" - "golang.org/x/crypto/bcrypt" "io/ioutil" "log" "net/http" - "os" - - "github.com/gmemstr/nas/router" ) // Main function that defines routes func main() { - if _, err := os.Stat(".lock"); os.IsNotExist(err) { - createDatabase() - createLockFile() - } - // Initialize file providers. file, err := ioutil.ReadFile("providers.yml") if err != nil { @@ -38,63 +27,3 @@ func main() { fmt.Println("Your NAS instance is live on port :3000") log.Fatal(http.ListenAndServe(":3000", r)) } - -func createDatabase() { - fmt.Println("Initializing the database") - os.Create("assets/config/users.db") - - db, err := sql.Open("sqlite3", "assets/config/users.db") - if err != nil { - fmt.Println("Problem opening database file! %v", err) - } - - _, err = db.Exec("CREATE TABLE IF NOT EXISTS `users` ( `id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT UNIQUE, `username` TEXT UNIQUE, `hash` TEXT, `token` TEXT, `permissions` INTEGER )") - if err != nil { - fmt.Println("Problem creating database! %v", err) - } - - text, err := GenerateRandomString(12) - if err != nil { - fmt.Println("Error randomly generating password", err) - } - fmt.Println("Admin password: ", text) - hash, err := bcrypt.GenerateFromPassword([]byte(text), 4) - if err != nil { - fmt.Println("Error generating hash", err) - } - if bcrypt.CompareHashAndPassword(hash, []byte(text)) == nil { - fmt.Println("Password hashed") - } - _, err = db.Exec("INSERT INTO users(id,username,hash,permissions) VALUES (0,'admin','" + string(hash) + "',2)") - if err != nil { - fmt.Println("Problem creating database! %v", err) - } - defer db.Close() -} - -func createLockFile() { - lock, err := os.Create(".lock") - if err != nil { - fmt.Println("Error: %v", err) - } - lock.Write([]byte("This file left intentionally empty")) - defer lock.Close() -} - -func GenerateRandomBytes(n int) ([]byte, error) { - b := make([]byte, n) - _, err := rand.Read(b) - if err != nil { - return nil, err - } - - return b, nil -} - - -// GenerateRandomString returns a URL-safe, base64 encoded -// securely generated random string. -func GenerateRandomString(s int) (string, error) { - b, err := GenerateRandomBytes(s) - return base64.URLEncoding.EncodeToString(b), err -} \ No newline at end of file