mirror of
https://github.com/gmemstr/pogo.git
synced 2024-09-20 01:31:09 +01:00
261 lines
5.3 KiB
Go
261 lines
5.3 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
|
|
"github.com/gmemstr/pogo/common"
|
|
)
|
|
|
|
const (
|
|
enc = "cookie_session_encryption"
|
|
|
|
// This is the key with which each cookie is encrypted, I'll recommend moving it to a env file
|
|
cookieName = "POGO_SESSION"
|
|
cookieExpiry = 60 * 60 * 24 * 30 // 30 days in seconds
|
|
)
|
|
|
|
func UserPermissions(username string, permission int) (bool, error) {
|
|
|
|
db, err := sql.Open("sqlite3", "assets/config/users.db")
|
|
defer db.Close()
|
|
isAllowed := false
|
|
if err != nil {
|
|
return isAllowed, err
|
|
}
|
|
|
|
statement, err := db.Prepare("SELECT permissions FROM users WHERE username=?")
|
|
if err != nil {
|
|
return isAllowed, err
|
|
}
|
|
|
|
rows, err := statement.Query(username)
|
|
if err != nil {
|
|
return isAllowed, err
|
|
}
|
|
|
|
var level int
|
|
for rows.Next() {
|
|
err = rows.Scan(&level)
|
|
if err != nil {
|
|
return isAllowed, err
|
|
}
|
|
if level >= permission {
|
|
isAllowed = true
|
|
}
|
|
}
|
|
return isAllowed, nil
|
|
}
|
|
|
|
func RequireAuthorization(permission int) common.Handler {
|
|
return func(rc *common.RouterContext, w http.ResponseWriter, r *http.Request) *common.HTTPError {
|
|
usr, err := DecryptCookie(r)
|
|
if err != nil {
|
|
fmt.Println(err.Error())
|
|
if strings.Contains(r.Header.Get("Accept"), "html") || r.Method == "GET" {
|
|
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
|
return &common.HTTPError{
|
|
Message: "Unauthorized! Redirecting to /login",
|
|
StatusCode: http.StatusTemporaryRedirect,
|
|
}
|
|
}
|
|
return &common.HTTPError{
|
|
Message: "Unauthorized!",
|
|
StatusCode: http.StatusUnauthorized,
|
|
}
|
|
}
|
|
|
|
rc.User = usr
|
|
|
|
username := rc.User.Username
|
|
|
|
hasPermission, err := UserPermissions(string(username), permission)
|
|
|
|
if !hasPermission {
|
|
return &common.HTTPError{
|
|
Message: "Unauthorized! Redirecting to /admin",
|
|
StatusCode: http.StatusUnauthorized,
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func CreateSession(u *common.User) (*http.Cookie, error) {
|
|
secret := os.Getenv("POGO_SECRET")
|
|
|
|
iv, err := generateRandomString(16)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
userJSON, err := json.Marshal(u)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hexedJSON := hex.EncodeToString(userJSON)
|
|
|
|
encKey := deriveKey(enc, secret)
|
|
|
|
block, err := aes.NewCipher(encKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Fill the block with 0x0e
|
|
if remBytes := len(hexedJSON) % aes.BlockSize; remBytes != 0 {
|
|
t := []byte(hexedJSON)
|
|
|
|
for i := 0; i < aes.BlockSize-remBytes; i++ {
|
|
t = append(t, 0x0e)
|
|
}
|
|
hexedJSON = string(t)
|
|
}
|
|
|
|
mode := cipher.NewCBCEncrypter(block, iv)
|
|
encCipher := make([]byte, len(hexedJSON)+aes.BlockSize)
|
|
|
|
mode.CryptBlocks(encCipher, []byte(hexedJSON))
|
|
|
|
cipherbase64 := base64urlencode(encCipher)
|
|
ivbase64 := base64urlencode(iv)
|
|
|
|
// Cookie format: iv.cipher.created_on.expire_on.HMAC
|
|
cookieStr := fmt.Sprintf("%s.%s", ivbase64, cipherbase64)
|
|
|
|
c := &http.Cookie{
|
|
Name: cookieName,
|
|
Value: cookieStr,
|
|
MaxAge: cookieExpiry,
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func DecryptCookie(r *http.Request) (*common.User, error) {
|
|
secret := os.Getenv("POGO_SECRET")
|
|
|
|
c, err := r.Cookie(cookieName)
|
|
if err != nil {
|
|
if err != http.ErrNoCookie {
|
|
log.Printf("error in reading Cookie: %v", err)
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
csplit := strings.Split(c.Value, ".")
|
|
if len(csplit) != 2 {
|
|
return nil, errors.New("Invalid number of values in cookie")
|
|
}
|
|
|
|
ivb, cipherb := csplit[0], csplit[1]
|
|
|
|
iv, err := base64urldecode(ivb)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dcipher, err := base64urldecode(cipherb)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(iv) != 16 {
|
|
return nil, errors.New("IV length is not 16")
|
|
}
|
|
|
|
encKey := deriveKey(enc, secret)
|
|
|
|
if len(dcipher)%aes.BlockSize != 0 {
|
|
return nil, errors.New("ciphertext not multiple of blocksize")
|
|
}
|
|
|
|
block, err := aes.NewCipher(encKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
buf := make([]byte, len(dcipher))
|
|
|
|
mode := cipher.NewCBCDecrypter(block, iv)
|
|
|
|
mode.CryptBlocks(buf, []byte(dcipher))
|
|
|
|
tstr := fmt.Sprintf("%x", buf)
|
|
|
|
// Remove aes padding, 0e is used because it was used in encryption to mark padding
|
|
padIndex := strings.Index(tstr, "0e")
|
|
if padIndex == -1 {
|
|
return nil, errors.New("Padding Index is -1")
|
|
}
|
|
tstr = tstr[:padIndex]
|
|
|
|
data, err := hex.DecodeString(tstr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
data, err = hex.DecodeString(string(data))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
u := &common.User{}
|
|
err = json.Unmarshal(data, u)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func deriveKey(msg, secret string) []byte {
|
|
key := []byte(secret)
|
|
sha256hash := hmac.New(sha256.New, key)
|
|
sha256hash.Write([]byte(msg))
|
|
|
|
return sha256hash.Sum(nil)
|
|
}
|
|
|
|
func generateRandomString(l int) ([]byte, error) {
|
|
rBytes := make([]byte, l)
|
|
|
|
_, err := rand.Read(rBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return rBytes, nil
|
|
}
|
|
|
|
func base64urldecode(str string) ([]byte, error) {
|
|
base64str := strings.Replace(string(str), "-", "+", -1)
|
|
base64str = strings.Replace(base64str, "_", "/", -1)
|
|
|
|
s, err := base64.RawStdEncoding.DecodeString(base64str)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func base64urlencode(str []byte) string {
|
|
base64str := strings.Replace(string(str), "+", "-", -1)
|
|
base64str = strings.Replace(base64str, "/", "_", -1)
|
|
|
|
return base64.RawStdEncoding.EncodeToString([]byte(base64str))
|
|
}
|