pogo/auth/auth.go

261 lines
5.4 KiB
Go
Raw Normal View History

2017-10-03 12:08:27 +01:00
package auth
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
2017-11-29 20:56:51 +00:00
"database/sql"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
2017-10-03 12:08:27 +01:00
"net/http"
"os"
2017-11-29 20:56:51 +00:00
"strings"
_ "github.com/mattn/go-sqlite3"
2017-10-03 12:08:27 +01:00
"github.com/gmemstr/pogo/common"
2017-10-03 12:08:27 +01:00
)
const (
enc = "cookie_session_encryption"
2017-10-06 05:22:18 +01:00
// This is the key with which each cookie is encrypted, I'll recommend moving it to a env file
cookieName = "POGO_SESSION"
cookieExpiry = 60 * 60 * 24 * 30 // 30 days in seconds
)
2017-11-29 20:56:51 +00:00
func UserPermissions(username string, permission int) (bool, error) {
db, err := sql.Open("sqlite3", "assets/config/users.db")
defer db.Close()
isAllowed := false
if err != nil {
return isAllowed, err
}
statement, err := db.Prepare("SELECT permissions FROM users WHERE username=?")
if err != nil {
return isAllowed, err
}
rows, err := statement.Query(username)
if err != nil {
return isAllowed, err
}
var level int
for rows.Next() {
err = rows.Scan(&level)
if err != nil {
return isAllowed, err
}
if level >= permission {
isAllowed = true
}
}
return isAllowed, nil
}
func RequireAuthorization(permission int) common.Handler {
2017-10-03 12:08:27 +01:00
return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError {
usr, err := DecryptCookie(r)
if err != nil {
fmt.Println(err.Error())
if strings.Contains(r.Header.Get("Accept"), "html") || r.Method == "GET" {
http.Redirect(w, r, "/#/login", http.StatusTemporaryRedirect)
return &common.HTTPError{
Message: "Unauthorized! Redirecting to /login",
StatusCode: http.StatusTemporaryRedirect,
}
}
return &common.HTTPError{
Message: "Unauthorized!",
StatusCode: http.StatusUnauthorized,
}
}
rc.User = usr
2017-11-29 20:56:51 +00:00
username := rc.User.Username
hasPermission, err := UserPermissions(string(username), permission)
if !hasPermission {
return &common.HTTPError{
Message: "Unauthorized! Redirecting to /admin",
StatusCode: http.StatusUnauthorized,
}
}
return nil
2017-10-03 12:08:27 +01:00
}
}
func CreateSession(u *common.User) (*http.Cookie, error) {
2017-10-06 05:22:18 +01:00
secret := os.Getenv("POGO_SECRET")
iv, err := generateRandomString(16)
if err != nil {
return nil, err
}
userJSON, err := json.Marshal(u)
if err != nil {
return nil, err
}
hexedJSON := hex.EncodeToString(userJSON)
encKey := deriveKey(enc, secret)
block, err := aes.NewCipher(encKey)
if err != nil {
return nil, err
}
// Fill the block with 0x0e
if remBytes := len(hexedJSON) % aes.BlockSize; remBytes != 0 {
t := []byte(hexedJSON)
for i := 0; i < aes.BlockSize-remBytes; i++ {
t = append(t, 0x0e)
}
hexedJSON = string(t)
2017-10-03 12:08:27 +01:00
}
mode := cipher.NewCBCEncrypter(block, iv)
encCipher := make([]byte, len(hexedJSON)+aes.BlockSize)
mode.CryptBlocks(encCipher, []byte(hexedJSON))
cipherbase64 := base64urlencode(encCipher)
ivbase64 := base64urlencode(iv)
// Cookie format: iv.cipher.created_on.expire_on.HMAC
cookieStr := fmt.Sprintf("%s.%s", ivbase64, cipherbase64)
c := &http.Cookie{
Name: cookieName,
Value: cookieStr,
MaxAge: cookieExpiry,
}
return c, nil
2017-10-03 12:08:27 +01:00
}
func DecryptCookie(r *http.Request) (*common.User, error) {
2017-10-06 05:22:18 +01:00
secret := os.Getenv("POGO_SECRET")
c, err := r.Cookie(cookieName)
if err != nil {
if err != http.ErrNoCookie {
log.Printf("error in reading Cookie: %v", err)
}
return nil, err
}
csplit := strings.Split(c.Value, ".")
if len(csplit) != 2 {
return nil, errors.New("Invalid number of values in cookie")
}
ivb, cipherb := csplit[0], csplit[1]
iv, err := base64urldecode(ivb)
if err != nil {
return nil, err
}
dcipher, err := base64urldecode(cipherb)
if err != nil {
return nil, err
}
if len(iv) != 16 {
return nil, errors.New("IV length is not 16")
}
encKey := deriveKey(enc, secret)
if len(dcipher)%aes.BlockSize != 0 {
return nil, errors.New("ciphertext not multiple of blocksize")
}
block, err := aes.NewCipher(encKey)
if err != nil {
return nil, err
}
buf := make([]byte, len(dcipher))
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(buf, []byte(dcipher))
tstr := fmt.Sprintf("%x", buf)
// Remove aes padding, 0e is used because it was used in encryption to mark padding
padIndex := strings.Index(tstr, "0e")
if padIndex == -1 {
return nil, errors.New("Padding Index is -1")
}
tstr = tstr[:padIndex]
data, err := hex.DecodeString(tstr)
if err != nil {
return nil, err
}
data, err = hex.DecodeString(string(data))
if err != nil {
return nil, err
}
u := &common.User{}
err = json.Unmarshal(data, u)
if err != nil {
return nil, err
}
return u, nil
}
func deriveKey(msg, secret string) []byte {
key := []byte(secret)
sha256hash := hmac.New(sha256.New, key)
sha256hash.Write([]byte(msg))
return sha256hash.Sum(nil)
2017-10-03 12:08:27 +01:00
}
func generateRandomString(l int) ([]byte, error) {
rBytes := make([]byte, l)
_, err := rand.Read(rBytes)
if err != nil {
return nil, err
}
return rBytes, nil
}
func base64urldecode(str string) ([]byte, error) {
base64str := strings.Replace(string(str), "-", "+", -1)
base64str = strings.Replace(base64str, "_", "/", -1)
s, err := base64.RawStdEncoding.DecodeString(base64str)
if err != nil {
return nil, err
}
return s, nil
}
func base64urlencode(str []byte) string {
base64str := strings.Replace(string(str), "+", "-", -1)
base64str = strings.Replace(base64str, "/", "_", -1)
return base64.RawStdEncoding.EncodeToString([]byte(base64str))
}