304 lines
7.1 KiB
Go
304 lines
7.1 KiB
Go
|
package store
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"embed"
|
||
|
"errors"
|
||
|
"github.com/google/uuid"
|
||
|
_ "github.com/mattn/go-sqlite3"
|
||
|
"log"
|
||
|
"os"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
//go:embed database.sql
|
||
|
var migrations embed.FS
|
||
|
|
||
|
const SessionExpired = "session expired"
|
||
|
|
||
|
type Store struct {
|
||
|
database *sql.DB
|
||
|
}
|
||
|
|
||
|
type Storer interface {
|
||
|
SaveInvite(invite Invite) error
|
||
|
GetInvite(token string) (Invite, error)
|
||
|
GetServer(id string) (Server, error)
|
||
|
SaveServer(server Server) error
|
||
|
GetUserServers(user User) ([]Server, error)
|
||
|
LogInviteUse(user string, invite Invite) error
|
||
|
InviteLog(invite Invite) ([]InviteLog, error)
|
||
|
GetUser(uid string) (User, error)
|
||
|
SaveUser(user User) error
|
||
|
SaveSession(token string, user User) error
|
||
|
SessionUser(token string) (User, error)
|
||
|
}
|
||
|
|
||
|
type Invite struct {
|
||
|
Token string `json:"token"`
|
||
|
Creator User `json:"creator"`
|
||
|
Server Server `json:"server"`
|
||
|
Uses int `json:"uses"`
|
||
|
Unlimited bool `json:"unlimited"`
|
||
|
}
|
||
|
|
||
|
type Server struct {
|
||
|
Id string `json:"id"`
|
||
|
Name string `json:"name"`
|
||
|
Address string `json:"address"`
|
||
|
Rcon Rcon `json:"rcon,omitempty"`
|
||
|
Owner User `json:"-"`
|
||
|
}
|
||
|
|
||
|
type User struct {
|
||
|
Id string `json:"id,omitempty"`
|
||
|
Token string `json:"-"`
|
||
|
DisplayName string `json:"display_name"`
|
||
|
RefreshToken string `json:"-"`
|
||
|
TokenExpiry time.Time `json:"-"`
|
||
|
}
|
||
|
|
||
|
type Rcon struct {
|
||
|
Address string `json:"address,omitempty"`
|
||
|
Password string `json:"password,omitempty"`
|
||
|
}
|
||
|
|
||
|
type Session struct {
|
||
|
Token string
|
||
|
UID string
|
||
|
Expiry time.Time
|
||
|
}
|
||
|
|
||
|
func Open() (*Store, error) {
|
||
|
database := os.Getenv("WLM_DATABASE_PATH")
|
||
|
if database == "" {
|
||
|
database = "db.sqlite3"
|
||
|
}
|
||
|
|
||
|
if _, err := os.Stat(database); errors.Is(err, os.ErrNotExist) {
|
||
|
log.Printf("No database found at %s, creating", database)
|
||
|
_, err := os.Create(database)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
db, err := sql.Open("sqlite3", database)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
initialSetup, err := migrations.ReadFile("database.sql")
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
_, err = db.Exec(string(initialSetup))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
log.Printf("Database created at %s", database)
|
||
|
db.Close()
|
||
|
}
|
||
|
|
||
|
db, err := sql.Open("sqlite3", database)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return &Store{database: db}, nil
|
||
|
}
|
||
|
|
||
|
func (s *Store) SaveInvite(invite Invite) error {
|
||
|
q, err := s.database.Prepare("INSERT INTO invites (token, creator, server, uses, unlimited) VALUES ($1, $2, $3, $4, $5)")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
_, err = q.Exec(invite.Token, invite.Creator.Id, invite.Server.Id, invite.Uses, invite.Unlimited)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *Store) GetInvite(token string) (Invite, error) {
|
||
|
if token == "" {
|
||
|
return Invite{}, sql.ErrNoRows
|
||
|
}
|
||
|
q := s.database.QueryRow("SELECT * FROM invites WHERE token=$1", token)
|
||
|
|
||
|
var in Invite
|
||
|
err := q.Scan(&in.Token, &in.Creator.Id, &in.Server.Id, &in.Uses, &in.Unlimited)
|
||
|
if err != nil {
|
||
|
return Invite{}, err
|
||
|
}
|
||
|
|
||
|
in.Server, err = s.GetServer(in.Server.Id)
|
||
|
if err != nil {
|
||
|
return Invite{}, err
|
||
|
}
|
||
|
in.Server.Rcon = Rcon{}
|
||
|
|
||
|
q = s.database.QueryRow("SELECT display_name FROM users WHERE id=$1", in.Creator.Id)
|
||
|
err = q.Scan(&in.Creator.DisplayName)
|
||
|
if err != nil {
|
||
|
return Invite{}, err
|
||
|
}
|
||
|
|
||
|
return in, nil
|
||
|
}
|
||
|
|
||
|
func (s *Store) LogInviteUse(user string, invite Invite) error {
|
||
|
q, err := s.database.Prepare("INSERT INTO invite_log (entry_id, invite, user) VALUES ($1, $2, $3)")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
entryId := uuid.New().String()
|
||
|
|
||
|
_, err = q.Exec(entryId, invite.Token, user)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
type InviteLog struct {
|
||
|
EntryID string
|
||
|
Invite string
|
||
|
User string
|
||
|
}
|
||
|
|
||
|
func (s *Store) InviteLog(invite Invite) ([]InviteLog, error) {
|
||
|
q, err := s.database.Query("SELECT * FROM invite_log WHERE invite=$1", invite.Token)
|
||
|
if err != nil {
|
||
|
return []InviteLog{}, nil
|
||
|
}
|
||
|
|
||
|
var log []InviteLog
|
||
|
for q.Next() {
|
||
|
var logEntry InviteLog
|
||
|
err := q.Scan(&logEntry.EntryID, &logEntry.Invite, &logEntry.User)
|
||
|
if err != nil {
|
||
|
continue
|
||
|
}
|
||
|
log = append(log, logEntry)
|
||
|
}
|
||
|
|
||
|
return log, nil
|
||
|
}
|
||
|
|
||
|
func (s *Store) GetServer(id string) (Server, error) {
|
||
|
if id == "" {
|
||
|
return Server{}, sql.ErrNoRows
|
||
|
}
|
||
|
q := s.database.QueryRow("SELECT * FROM servers WHERE id=$1", id)
|
||
|
|
||
|
var serv Server
|
||
|
err := q.Scan(&serv.Id, &serv.Name, &serv.Address, &serv.Rcon.Address, &serv.Rcon.Password, &serv.Owner.Id)
|
||
|
if err != nil {
|
||
|
return Server{}, err
|
||
|
}
|
||
|
|
||
|
return serv, nil
|
||
|
}
|
||
|
|
||
|
func (s *Store) GetUserServers(user User) ([]Server, error) {
|
||
|
q, err := s.database.Query("SELECT id,address,name,rcon_address FROM servers WHERE owner=$1", user.Id)
|
||
|
if err != nil {
|
||
|
return []Server{}, nil
|
||
|
}
|
||
|
|
||
|
var servers []Server
|
||
|
for q.Next() {
|
||
|
var server Server
|
||
|
err := q.Scan(&server.Id, &server.Address, &server.Name, &server.Rcon.Address)
|
||
|
if err != nil {
|
||
|
continue
|
||
|
}
|
||
|
servers = append(servers, server)
|
||
|
}
|
||
|
|
||
|
return servers, nil
|
||
|
}
|
||
|
|
||
|
func (s *Store) SaveServer(server Server) error {
|
||
|
q, err := s.database.Prepare("INSERT INTO servers (id, address, name, rcon_address, rcon_password, owner) VALUES ($1, $2, $3, $4, $5, $6)")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
_, err = q.Exec(server.Id, server.Address, server.Name, server.Rcon.Address, server.Rcon.Password, server.Owner.Id)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *Store) GetUser(uid string) (User, error) {
|
||
|
q := s.database.QueryRow("SELECT * FROM users WHERE id=$1", uid)
|
||
|
|
||
|
var user User
|
||
|
var tokenExpiry string
|
||
|
err := q.Scan(&user.Id, &user.Token, &user.DisplayName, &user.RefreshToken, &tokenExpiry)
|
||
|
if err != nil {
|
||
|
return User{}, err
|
||
|
}
|
||
|
user.TokenExpiry, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", tokenExpiry)
|
||
|
if err != nil {
|
||
|
return User{}, err
|
||
|
}
|
||
|
|
||
|
return user, nil
|
||
|
}
|
||
|
|
||
|
func (s *Store) SaveUser(user User) error {
|
||
|
existingUser, err := s.GetUser(user.Id)
|
||
|
if err != nil {
|
||
|
if errors.Is(err, sql.ErrNoRows) {
|
||
|
q, err := s.database.Prepare("INSERT INTO users (id, token, display_name, refresh_token, token_expiry) VALUES ($1, $2, $3, $4, $5)")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
_, err = q.Exec(user.Id, user.Token, user.DisplayName, user.RefreshToken, user.TokenExpiry)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
} else {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
q, err := s.database.Prepare("UPDATE users SET token=$2, display_name=$3, refresh_token=$4, token_expiry=$5 WHERE id=$1")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
_, err = q.Exec(existingUser.Id, user.Token, user.DisplayName, user.RefreshToken, user.TokenExpiry)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (s *Store) SaveSession(token string, user User) error {
|
||
|
q, err := s.database.Prepare("INSERT INTO sessions (token, uid, expiry) VALUES ($1, $2, $3)")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
// Expire in 30 days.
|
||
|
_, err = q.Exec(token, user.Id, time.Now().Add(720*time.Hour))
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (s *Store) SessionUser(token string) (User, error) {
|
||
|
q := s.database.QueryRow("SELECT * FROM sessions WHERE token=$1", token)
|
||
|
var sess Session
|
||
|
var sessExpiry string
|
||
|
err := q.Scan(&sess.Token, &sess.UID, &sessExpiry)
|
||
|
if err != nil {
|
||
|
return User{}, err
|
||
|
}
|
||
|
sess.Expiry, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", sessExpiry)
|
||
|
if err != nil {
|
||
|
return User{}, err
|
||
|
}
|
||
|
if sess.Expiry.Before(time.Now()) {
|
||
|
return User{}, errors.New(SessionExpired)
|
||
|
}
|
||
|
|
||
|
return s.GetUser(sess.UID)
|
||
|
}
|