Simplify router, remove common and authentication.

Removed common package, since it was largely useless and instead
integrated what I could into the router package. Should simplify things
going forward, but we can always split this out later. Also completely
removed authentication for the time being - want to consider other
options for this.
This commit is contained in:
Gabriel Simmer 2020-04-02 20:50:39 +01:00
parent 207676d4a7
commit 641e015ca6
No known key found for this signature in database
GPG key ID: 33BA4D83B160A0A9
5 changed files with 70 additions and 552 deletions

View file

@ -1,277 +0,0 @@
package auth
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"os"
"strings"
_ "github.com/mattn/go-sqlite3"
"github.com/gmemstr/nas/common"
)
const (
enc = "cookie_session_encryption"
// This is the key with which each cookie is encrypted, I'll recommend moving it to a env file
cookieName = "NAS_SESSION"
cookieExpiry = 60 * 60 * 24 * 30 // 30 days in seconds
)
func UserPermissions(username string, permission int) (bool, error) {
db, err := sql.Open("sqlite3", "assets/config/users.db")
defer db.Close()
isAllowed := false
if err != nil {
return isAllowed, err
}
statement, err := db.Prepare("SELECT permissions FROM users WHERE username=?")
if err != nil {
return isAllowed, err
}
rows, err := statement.Query(username)
if err != nil {
return isAllowed, err
}
var level int
for rows.Next() {
err = rows.Scan(&level)
if err != nil {
return isAllowed, err
}
if level >= permission {
isAllowed = true
}
}
return isAllowed, nil
}
func RequireAuthorization(permission int) common.Handler {
return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError {
usr, err := DecryptCookie(r)
if err != nil {
fmt.Println(err.Error())
if strings.Contains(r.Header.Get("Accept"), "html") || r.Method == "GET" {
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
return &common.HTTPError{
Message: "Unauthorized! Redirecting to /login",
StatusCode: http.StatusTemporaryRedirect,
}
}
return &common.HTTPError{
Message: "Unauthorized!",
StatusCode: http.StatusUnauthorized,
}
}
rc.User = usr
username := rc.User.Username
hasPermission, err := UserPermissions(string(username), permission)
if !hasPermission {
return &common.HTTPError{
Message: "Unauthorized! Redirecting to /admin",
StatusCode: http.StatusUnauthorized,
}
}
return nil
}
}
func CreateSession(u *common.User) (*http.Cookie, error) {
secret := os.Getenv("POGO_SECRET")
iv, err := generateRandomString(16)
if err != nil {
return nil, err
}
userJSON, err := json.Marshal(u)
if err != nil {
return nil, err
}
hexedJSON := hex.EncodeToString(userJSON)
encKey := deriveKey(enc, secret)
block, err := aes.NewCipher(encKey)
if err != nil {
return nil, err
}
// Fill the block with 0x0e
if remBytes := len(hexedJSON) % aes.BlockSize; remBytes != 0 {
t := []byte(hexedJSON)
for i := 0; i < aes.BlockSize-remBytes; i++ {
t = append(t, 0x0e)
}
hexedJSON = string(t)
}
mode := cipher.NewCBCEncrypter(block, iv)
encCipher := make([]byte, len(hexedJSON)+aes.BlockSize)
mode.CryptBlocks(encCipher, []byte(hexedJSON))
cipherbase64 := base64urlencode(encCipher)
ivbase64 := base64urlencode(iv)
// Cookie format: iv.cipher.created_on.expire_on.HMAC
cookieStr := fmt.Sprintf("%s.%s", ivbase64, cipherbase64)
c := &http.Cookie{
Name: cookieName,
Value: cookieStr,
MaxAge: cookieExpiry,
}
// Insert token into database.
db, err := sql.Open("sqlite3", "assets/config/users.db")
defer db.Close()
if err != nil {
return nil, err
}
statement, err := db.Prepare("UPDATE users SET token=? WHERE username=?")
if err != nil {
return nil, err
}
_, err = statement.Exec(cookieStr, u.Username)
if err != nil {
return nil, err
}
return c, nil
}
func DecryptCookie(r *http.Request) (*common.User, error) {
secret := os.Getenv("POGO_SECRET")
c, err := r.Cookie(cookieName)
if err != nil {
if err != http.ErrNoCookie {
log.Printf("error in reading Cookie: %v", err)
}
return nil, err
}
csplit := strings.Split(c.Value, ".")
if len(csplit) != 2 {
return nil, errors.New("Invalid number of values in cookie")
}
ivb, cipherb := csplit[0], csplit[1]
iv, err := base64urldecode(ivb)
if err != nil {
return nil, err
}
dcipher, err := base64urldecode(cipherb)
if err != nil {
return nil, err
}
if len(iv) != 16 {
return nil, errors.New("IV length is not 16")
}
encKey := deriveKey(enc, secret)
if len(dcipher)%aes.BlockSize != 0 {
return nil, errors.New("ciphertext not multiple of blocksize")
}
block, err := aes.NewCipher(encKey)
if err != nil {
return nil, err
}
buf := make([]byte, len(dcipher))
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(buf, []byte(dcipher))
tstr := fmt.Sprintf("%x", buf)
// Remove aes padding, 0e is used because it was used in encryption to mark padding
padIndex := strings.Index(tstr, "0e")
if padIndex == -1 {
return nil, errors.New("Padding Index is -1")
}
tstr = tstr[:padIndex]
data, err := hex.DecodeString(tstr)
if err != nil {
return nil, err
}
data, err = hex.DecodeString(string(data))
if err != nil {
return nil, err
}
u := &common.User{}
err = json.Unmarshal(data, u)
if err != nil {
return nil, err
}
return u, nil
}
func deriveKey(msg, secret string) []byte {
key := []byte(secret)
sha256hash := hmac.New(sha256.New, key)
sha256hash.Write([]byte(msg))
return sha256hash.Sum(nil)
}
func generateRandomString(l int) ([]byte, error) {
rBytes := make([]byte, l)
_, err := rand.Read(rBytes)
if err != nil {
return nil, err
}
return rBytes, nil
}
func base64urldecode(str string) ([]byte, error) {
base64str := strings.Replace(string(str), "-", "+", -1)
base64str = strings.Replace(base64str, "_", "/", -1)
s, err := base64.RawStdEncoding.DecodeString(base64str)
if err != nil {
return nil, err
}
return s, nil
}
func base64urlencode(str []byte) string {
base64str := strings.Replace(string(str), "+", "-", -1)
base64str = strings.Replace(base64str, "/", "_", -1)
return base64.RawStdEncoding.EncodeToString([]byte(base64str))
}

View file

@ -1,66 +0,0 @@
package common
import (
"fmt"
"io"
"log"
"net/http"
"os"
"strconv"
)
// Handler is the signature of HTTP Handler that is passed to Handle function
type Handler func(rc *RouterContext, w http.ResponseWriter, r *http.Request) *HTTPError
// HTTPError is any error that occurs in middlewares or the code that handles HTTP Frontend of application
// Message is logged to console and Status Code is sent in response
// If a Middleware sends an HTTPError, No middlewares further up in chain are executed
type HTTPError struct {
// Message to log in console
Message string
// Status code that'll be sent in response
StatusCode int
}
// RouterContext contains any information to be shared with middlewares.
type RouterContext struct {
User *User
}
// User struct denotes the data is stored in the cookie
type User struct {
Username string `json:"username"`
}
// ReadAndServeFile reads the file from specified location and sends it in response
func ReadAndServeFile(name string, w http.ResponseWriter) *HTTPError {
f, err := os.Open(name)
if err != nil {
if os.IsNotExist(err) {
return &HTTPError{
Message: fmt.Sprintf("%s not found", name),
StatusCode: http.StatusNotFound,
}
}
return &HTTPError{
Message: fmt.Sprintf("error in reading %s: %v\n", name, err),
StatusCode: http.StatusInternalServerError,
}
}
defer f.Close()
stats, err := f.Stat()
if err != nil {
log.Printf("error in fetching %s's stats: %v\n", name, err)
} else {
w.Header().Add("Content-Length", strconv.FormatInt(stats.Size(), 10))
}
_, err = io.Copy(w, f)
if err != nil {
log.Printf("error in copying %s to response: %v\n", name, err)
}
return nil
}

View file

@ -3,7 +3,6 @@ package router
import (
"encoding/json"
"fmt"
"github.com/gmemstr/nas/common"
"github.com/gmemstr/nas/files"
"github.com/gorilla/mux"
"net/http"
@ -11,20 +10,22 @@ import (
"strings"
)
func HandleProvider() common.Handler {
return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError {
func HandleProvider() Handler {
return func(context *Context, w http.ResponseWriter, r *http.Request) *HTTPError {
vars := mux.Vars(r)
if r.Method == "GET" {
providerCodename := vars["provider"]
providerCodename = strings.Replace(providerCodename, "/", "", -1)
provider := *files.Providers[providerCodename]
providerCodename := vars["provider"]
providerCodename = strings.Replace(providerCodename, "/", "", -1)
provider := *files.Providers[providerCodename]
if r.Method == "GET" {
fileList := provider.GetDirectory("")
if vars["file"] != "" {
fileType := provider.DetermineType(vars["file"])
if fileType == "" {
w.Write([]byte("file not found"))
return nil
return &HTTPError{
Message: fmt.Sprintf("error determining filetype for %s\n", vars["file"]),
StatusCode: http.StatusInternalServerError,
}
}
if fileType == "file" {
provider.ViewFile(vars["file"], w)
@ -34,28 +35,31 @@ func HandleProvider() common.Handler {
}
data, err := json.Marshal(fileList)
if err != nil {
w.Write([]byte("An error occurred"))
return nil
return &HTTPError{
Message: fmt.Sprintf("error fetching filelisting for %s\n", vars["file"]),
StatusCode: http.StatusInternalServerError,
}
}
w.Write(data)
}
if r.Method == "POST" {
providerCodename := vars["provider"]
providerCodename = strings.Replace(providerCodename, "/", "", -1)
provider := *files.Providers[providerCodename]
err := r.ParseMultipartForm(32 << 20)
if err != nil {
w.Write([]byte("unable to parse form"))
fmt.Println(err.Error())
return nil
return &HTTPError{
Message: fmt.Sprintf("error parsing form for %s\n", vars["file"]),
StatusCode: http.StatusInternalServerError,
}
}
file, handler, err := r.FormFile("file")
defer file.Close()
success := provider.SaveFile(file, handler, vars["file"])
if !success {
w.Write([]byte("unable to save file"))
return nil
return &HTTPError{
Message: fmt.Sprintf("error saving file %s\n", vars["file"]),
StatusCode: http.StatusInternalServerError,
}
}
w.Write([]byte("saved file"))
}
@ -64,8 +68,8 @@ func HandleProvider() common.Handler {
}
}
func ListProviders() common.Handler {
return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError {
func ListProviders() Handler {
return func(context *Context, w http.ResponseWriter, r *http.Request) *HTTPError {
var providers []string
for v, _ := range files.ProviderConfig {
providers = append(providers, v)
@ -73,7 +77,10 @@ func ListProviders() common.Handler {
sort.Strings(providers)
data, err := json.Marshal(providers)
if err != nil {
return nil
return &HTTPError{
Message: fmt.Sprintf("error provider listing"),
StatusCode: http.StatusInternalServerError,
}
}
w.Write(data)
return nil

View file

@ -1,34 +1,40 @@
package router
import (
"database/sql"
"fmt"
"github.com/gmemstr/nas/auth"
"golang.org/x/crypto/bcrypt"
"github.com/gorilla/mux"
"io"
"log"
"net/http"
"github.com/gmemstr/nas/common"
"github.com/gorilla/mux"
"os"
"strconv"
)
func Handle(handlers ...common.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
type Handler func(context *Context, w http.ResponseWriter, r *http.Request) *HTTPError
rc := &common.RouterContext{}
type HTTPError struct {
Message string
StatusCode int
}
// Context contains any information to be shared with middlewares.
type Context struct {}
func Handle(handlers ...Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
context := &Context{}
for _, handler := range handlers {
err := handler(rc, w, r)
err := handler(context, w, r)
if err != nil {
log.Printf("%v", err)
w.Write([]byte(http.StatusText(err.StatusCode)))
return
}
}
})
}
// Actual router, define endpoints here.
func Init() *mux.Router {
@ -40,134 +46,53 @@ func Init() *mux.Router {
// Paths that require specific handlers
r.Handle("/", Handle(
//auth.RequireAuthorization(1),
rootHandler(),
)).Methods("GET")
r.Handle(`/login`, Handle(
loginHandler(),
)).Methods("POST", "GET")
r.Handle("/api/providers", Handle(
//auth.RequireAuthorization(1),
ListProviders(),
)).Methods("GET")
r.Handle(`/api/files/{provider:[a-zA-Z0-9]+\/*}`, Handle(
//auth.RequireAuthorization(1),
HandleProvider(),
)).Methods("GET", "POST")
r.Handle(`/api/files/{provider}/{file:[a-zA-Z0-9=\-\/\s.,&_+]+}`, Handle(
//auth.RequireAuthorization(1),
HandleProvider(),
)).Methods("GET", "POST")
return r
}
func loginHandler() common.Handler {
return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError {
if r.Method == "GET" {
w.Header().Set("Content-Type", "text/html")
file := "assets/web/index.html"
return common.ReadAndServeFile(file, w)
}
db, err := sql.Open("sqlite3", "assets/config/users.db")
// Handles serving index page.
func rootHandler() Handler {
return func(context *Context, w http.ResponseWriter, r *http.Request) *HTTPError {
f, err := os.Open("assets/web/index.html")
if err != nil {
return &common.HTTPError{
Message: fmt.Sprintf("error in reading user database: %v", err),
return &HTTPError{
Message: fmt.Sprintf("error serving index page from assets/web"),
StatusCode: http.StatusInternalServerError,
}
}
statement, err := db.Prepare("SELECT * FROM users WHERE username=?")
if _, err := auth.DecryptCookie(r); err == nil {
http.Redirect(w, r, "/admin", http.StatusTemporaryRedirect)
return nil
}
err = r.ParseForm()
defer f.Close()
stats, err := f.Stat()
if err != nil {
return &common.HTTPError{
Message: fmt.Sprintf("error in parsing form: %v", err),
StatusCode: http.StatusBadRequest,
return &HTTPError{
Message: fmt.Sprintf("error serving index page from assets/web"),
StatusCode: http.StatusInternalServerError,
}
} else {
w.Header().Add("Content-Length", strconv.FormatInt(stats.Size(), 10))
}
_, err = io.Copy(w, f)
if err != nil {
return &HTTPError{
Message: fmt.Sprintf("error serving index page from assets/web"),
StatusCode: http.StatusInternalServerError,
}
}
username := r.Form.Get("username")
password := r.Form.Get("password")
rows, err := statement.Query(username)
if username == "" || password == "" || err != nil {
return &common.HTTPError{
Message: "username or password is invalid",
StatusCode: http.StatusBadRequest,
}
}
var id int
var dbun string
var dbhsh string
var dbtoken sql.NullString
var dbperm int
for rows.Next() {
err := rows.Scan(&id, &dbun, &dbhsh, &dbtoken, &dbperm)
if err != nil {
return &common.HTTPError{
Message: fmt.Sprintf("error in decoding sql data", err),
StatusCode: http.StatusBadRequest,
}
}
}
// Create a cookie here because the credentials are correct
if bcrypt.CompareHashAndPassword([]byte(dbhsh), []byte(password)) == nil {
c, err := auth.CreateSession(&common.User{
Username: username,
})
if err != nil {
return &common.HTTPError{
Message: err.Error(),
StatusCode: http.StatusInternalServerError,
}
}
// r.AddCookie(c)
w.Header().Add("Set-Cookie", c.String())
// And now redirect the user to admin page
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
db.Close()
return nil
}
return &common.HTTPError{
Message: "Invalid credentials!",
StatusCode: http.StatusUnauthorized,
}
}
}
// Handles /.
func rootHandler() common.Handler {
return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError {
var file string
switch r.URL.Path {
case "/":
w.Header().Set("Content-Type", "text/html")
file = "assets/web/index.html"
default:
return &common.HTTPError{
Message: fmt.Sprintf("%s: Not Found", r.URL.Path),
StatusCode: http.StatusNotFound,
}
}
return common.ReadAndServeFile(file, w)
return nil
}
}

View file

@ -1,28 +1,17 @@
package main
import (
"crypto/rand"
"database/sql"
"encoding/base64"
"fmt"
"github.com/gmemstr/nas/files"
"github.com/gmemstr/nas/router"
"github.com/go-yaml/yaml"
"golang.org/x/crypto/bcrypt"
"io/ioutil"
"log"
"net/http"
"os"
"github.com/gmemstr/nas/router"
)
// Main function that defines routes
func main() {
if _, err := os.Stat(".lock"); os.IsNotExist(err) {
createDatabase()
createLockFile()
}
// Initialize file providers.
file, err := ioutil.ReadFile("providers.yml")
if err != nil {
@ -38,63 +27,3 @@ func main() {
fmt.Println("Your NAS instance is live on port :3000")
log.Fatal(http.ListenAndServe(":3000", r))
}
func createDatabase() {
fmt.Println("Initializing the database")
os.Create("assets/config/users.db")
db, err := sql.Open("sqlite3", "assets/config/users.db")
if err != nil {
fmt.Println("Problem opening database file! %v", err)
}
_, err = db.Exec("CREATE TABLE IF NOT EXISTS `users` ( `id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT UNIQUE, `username` TEXT UNIQUE, `hash` TEXT, `token` TEXT, `permissions` INTEGER )")
if err != nil {
fmt.Println("Problem creating database! %v", err)
}
text, err := GenerateRandomString(12)
if err != nil {
fmt.Println("Error randomly generating password", err)
}
fmt.Println("Admin password: ", text)
hash, err := bcrypt.GenerateFromPassword([]byte(text), 4)
if err != nil {
fmt.Println("Error generating hash", err)
}
if bcrypt.CompareHashAndPassword(hash, []byte(text)) == nil {
fmt.Println("Password hashed")
}
_, err = db.Exec("INSERT INTO users(id,username,hash,permissions) VALUES (0,'admin','" + string(hash) + "',2)")
if err != nil {
fmt.Println("Problem creating database! %v", err)
}
defer db.Close()
}
func createLockFile() {
lock, err := os.Create(".lock")
if err != nil {
fmt.Println("Error: %v", err)
}
lock.Write([]byte("This file left intentionally empty"))
defer lock.Close()
}
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateRandomString returns a URL-safe, base64 encoded
// securely generated random string.
func GenerateRandomString(s int) (string, error) {
b, err := GenerateRandomBytes(s)
return base64.URLEncoding.EncodeToString(b), err
}