package store import ( "database/sql" "embed" "errors" "log" "os" "strings" "time" "github.com/google/uuid" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) //go:embed database.sql var migrations embed.FS const SessionExpired = "session expired" type Store struct { database *sql.DB } type Storer interface { SaveInvite(invite Invite) error GetInvite(token string) (Invite, error) GetServer(id string) (Server, error) SaveServer(server Server) error GetUserServers(user User) ([]Server, error) DeleteServer(server Server, user User) error LogInviteUse(user User, invite Invite) error InviteLog(invite Invite) ([]InviteLog, error) DeleteInvite(invite Invite) error GetUser(uid string) (User, error) SaveUser(user User) error SaveSession(token string, user User) error SessionUser(token string) (User, error) SaveOauthState(state OauthState) error OauthState(id string) (OauthState, error) Close() error ServerInvites(server Server) ([]Invite, error) } type OauthState struct { Id string Origin string State string } type Invite struct { Token string `json:"token"` Creator User `json:"creator"` Server Server `json:"server"` Uses int `json:"uses"` Unlimited bool `json:"unlimited"` } type Server struct { Id string `json:"id"` Name string `json:"name"` Address string `json:"address"` Rcon Rcon `json:"rcon,omitempty"` Owner User `json:"-"` } type User struct { Id string `json:"id,omitempty"` Token string `json:"-"` DisplayName string `json:"display_name"` RefreshToken string `json:"-"` TokenExpiry time.Time `json:"-"` } type Rcon struct { Address string `json:"address,omitempty"` Password string `json:"password,omitempty"` } type Session struct { Token string UID string Expiry time.Time } type InviteLog struct { EntryID string `json:"entry_id"` Invite Invite `json:"invite"` User User `json:"user"` } func Open() (*Store, error) { database := os.Getenv("WLM_DATABASE_PATH") dbType := "sqlite3" if database == "" { database = "db.sqlite3" } if strings.Contains(database, "postgresql://" ) { dbType = "postgres" } if _, err := os.Stat(database); errors.Is(err, os.ErrNotExist) && !strings.Contains(database, "postgresql://") { log.Printf("No database found at %s, creating", database) _, err := os.Create(database) if err != nil { return nil, err } log.Printf("Database created at %s", database) } db, err := sql.Open(dbType, database) if err != nil { return nil, err } initialSetup, err := migrations.ReadFile("database.sql") if err != nil { return nil, err } _, err = db.Exec(string(initialSetup)) if err != nil { return nil, err } return &Store{database: db}, nil } func (s *Store) SaveInvite(invite Invite) error { q, err := s.database.Prepare("INSERT INTO invites (token, creator, server, uses, unlimited) VALUES ($1, $2, $3, $4, $5)") if err != nil { return err } _, err = q.Exec(invite.Token, invite.Creator.Id, invite.Server.Id, invite.Uses, invite.Unlimited) if err != nil { return err } return nil } func (s *Store) GetInvite(token string) (Invite, error) { if token == "" { return Invite{}, sql.ErrNoRows } q := s.database.QueryRow("SELECT * FROM invites WHERE token=$1", token) var in Invite err := q.Scan(&in.Token, &in.Creator.Id, &in.Server.Id, &in.Uses, &in.Unlimited) if err != nil { return Invite{}, err } in.Server, err = s.GetServer(in.Server.Id) if err != nil { return Invite{}, err } in.Server.Rcon = Rcon{} q = s.database.QueryRow("SELECT display_name FROM users WHERE id=$1", in.Creator.Id) err = q.Scan(&in.Creator.DisplayName) if err != nil { return Invite{}, err } return in, nil } func (s *Store) LogInviteUse(user User, invite Invite) error { q, err := s.database.Prepare("INSERT INTO invite_log (entry_id, invite, uid) VALUES ($1, $2, $3)") if err != nil { return err } entryId := uuid.New().String() _, err = q.Exec(entryId, invite.Token, user.Id) return err } func (s *Store) InviteLog(invite Invite) ([]InviteLog, error) { q, err := s.database.Query("SELECT * FROM invite_log WHERE invite=$1", invite.Token) if err != nil { return []InviteLog{}, nil } var logs []InviteLog for q.Next() { var logEntry InviteLog err := q.Scan(&logEntry.EntryID, &logEntry.Invite.Token, &logEntry.User.Id) if err != nil { log.Println(err.Error()) } user, err := s.GetUser(logEntry.User.Id) if err != nil { log.Println(err.Error()) } logEntry.User = user logs = append(logs, logEntry) } return logs, nil } func (s *Store) GetServer(id string) (Server, error) { if id == "" { return Server{}, sql.ErrNoRows } q := s.database.QueryRow("SELECT * FROM servers WHERE id=$1", id) var serv Server err := q.Scan(&serv.Id, &serv.Name, &serv.Address, &serv.Rcon.Address, &serv.Rcon.Password, &serv.Owner.Id) if err != nil { return Server{}, err } return serv, nil } func (s *Store) GetUserServers(user User) ([]Server, error) { q, err := s.database.Query("SELECT id,address,name,rcon_address FROM servers WHERE owner=$1", user.Id) if err != nil { return []Server{}, nil } var servers []Server for q.Next() { var server Server err := q.Scan(&server.Id, &server.Address, &server.Name, &server.Rcon.Address) if err != nil { continue } servers = append(servers, server) } return servers, nil } func (s *Store) SaveServer(server Server) error { q, err := s.database.Prepare("INSERT INTO servers (id, address, name, rcon_address, rcon_password, owner) VALUES ($1, $2, $3, $4, $5, $6)") if err != nil { return err } _, err = q.Exec(server.Id, server.Address, server.Name, server.Rcon.Address, server.Rcon.Password, server.Owner.Id) if err != nil { return err } return nil } func (s *Store) GetUser(uid string) (User, error) { q := s.database.QueryRow("SELECT * FROM users WHERE id=$1", uid) var user User var tokenExpiry string err := q.Scan(&user.Id, &user.Token, &user.DisplayName, &user.RefreshToken, &tokenExpiry) if err != nil { return User{}, err } user.TokenExpiry, err = time.Parse("2006-01-02 15:04:05.999999999Z", tokenExpiry) if err != nil { return User{}, err } return user, nil } func (s *Store) SaveUser(user User) error { existingUser, err := s.GetUser(user.Id) if err != nil { if errors.Is(err, sql.ErrNoRows) { q, err := s.database.Prepare("INSERT INTO users (id, token, display_name, refresh_token, token_expiry) VALUES ($1, $2, $3, $4, $5)") if err != nil { return err } _, err = q.Exec(user.Id, user.Token, user.DisplayName, user.RefreshToken, user.TokenExpiry) if err != nil { return err } return nil } else { return err } } q, err := s.database.Prepare("UPDATE users SET token=$2, display_name=$3, refresh_token=$4, token_expiry=$5 WHERE id=$1") if err != nil { return err } _, err = q.Exec(existingUser.Id, user.Token, user.DisplayName, user.RefreshToken, user.TokenExpiry) return err } func (s *Store) SaveSession(token string, user User) error { q, err := s.database.Prepare("INSERT INTO sessions (token, uid, expiry) VALUES ($1, $2, $3)") if err != nil { return err } // Expire in 30 days. _, err = q.Exec(token, user.Id, time.Now().Add(720*time.Hour)) return err } func (s *Store) SessionUser(token string) (User, error) { q := s.database.QueryRow("SELECT * FROM sessions WHERE token=$1", token) var sess Session var sessExpiry string err := q.Scan(&sess.Token, &sess.UID, &sessExpiry) if err != nil { return User{}, err } sess.Expiry, err = time.Parse("2006-01-02 15:04:05.999999999Z", sessExpiry) if err != nil { return User{}, err } if sess.Expiry.Before(time.Now()) { return User{}, errors.New(SessionExpired) } return s.GetUser(sess.UID) } func (s *Store) SaveOauthState(state OauthState) error { _, err := s.OauthState(state.Id) if err != nil { if errors.Is(err, sql.ErrNoRows) { q, err := s.database.Prepare("INSERT INTO oauth_states (state_id, origin, state) VALUES ($1, $2, $3)") if err != nil { return err } _, err = q.Exec(state.Id, state.Origin, state.State) if err != nil { return err } return nil } else { return err } } q, err := s.database.Prepare("UPDATE oauth_states SET state=$2 WHERE state_id=$1") if err != nil { return err } _, err = q.Exec(state.Id, state.State) if err != nil { return err } return err } func (s *Store) OauthState(id string) (OauthState, error) { q := s.database.QueryRow("SELECT * FROM oauth_states WHERE state_id=$1", id) var state OauthState err := q.Scan(&state.Id, &state.Origin, &state.State) if err != nil { return OauthState{}, err } return state, nil } func (s *Store) Close() error { return s.database.Close() } func (s *Store) ServerInvites(server Server) ([]Invite, error) { q, err := s.database.Query("SELECT * FROM invites WHERE server=$1", server.Id) if err != nil { return []Invite{}, err } var invites []Invite for q.Next() { var invite Invite err = q.Scan(&invite.Token, &invite.Creator.Id, &invite.Server.Id, &invite.Uses, &invite.Unlimited) if err != nil { continue } invites = append(invites, invite) } return invites, nil } func (s *Store) DeleteInvite(invite Invite) error { inviteDeleteQuery, err := s.database.Prepare("DELETE FROM invites WHERE token=$1") if err != nil { return err } _, err = s.database.Exec("DELETE FROM invite_log WHERE invite=$1", invite.Token) if err != nil { return err } _, err = inviteDeleteQuery.Exec(invite.Token) if err != nil { return err } return nil } func (s *Store) DeleteServer(server Server, user User) error { serverDeleteQuery, err := s.database.Prepare("DELETE FROM servers WHERE id=$1 AND owner=$2") if err != nil { return err } serverInvites, err := s.ServerInvites(server) if err != nil { return err } for _, invite := range serverInvites { err = s.DeleteInvite(invite) if err != nil { continue } } _, err = serverDeleteQuery.Exec(server.Id, user.Id) if err != nil { return err } return nil }