diff --git a/db.go b/db.go index 4ad89d7..d37c79d 100644 --- a/db.go +++ b/db.go @@ -2,7 +2,6 @@ package jounce import ( "database/sql" - "errors" "sync" _ "github.com/mattn/go-sqlite3" @@ -27,7 +26,7 @@ type Channel struct { } type DB struct { - lock sync.Mutex + lock sync.RWMutex db *sql.DB } @@ -46,8 +45,8 @@ func (db *DB) Close() error { } func (db *DB) ListUsers() ([]User, error) { - db.lock.Lock() - defer db.lock.Unlock() + db.lock.RLock() + defer db.lock.RUnlock() rows, err := db.db.Query("SELECT username, password FROM User") if err != nil { @@ -75,8 +74,8 @@ func (db *DB) ListUsers() ([]User, error) { } func (db *DB) ListNetworks(username string) ([]Network, error) { - db.lock.Lock() - defer db.lock.Unlock() + db.lock.RLock() + defer db.lock.RUnlock() rows, err := db.db.Query("SELECT id, addr, nick, username, realname FROM Network WHERE user = ?", username) if err != nil { @@ -107,8 +106,8 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { } func (db *DB) ListChannels(networkID int64) ([]Channel, error) { - db.lock.Lock() - defer db.lock.Unlock() + db.lock.RLock() + defer db.lock.RUnlock() rows, err := db.db.Query("SELECT id, name FROM Channel WHERE network = ?", networkID) if err != nil {