Update DB on JOIN and PART

This commit is contained in:
Simon Ser 2020-03-12 18:33:03 +01:00
parent 461de13ecc
commit 0c4e9b539c
No known key found for this signature in database
GPG key ID: 0FDE7BE0E88F5E48
2 changed files with 50 additions and 13 deletions

30
db.go
View file

@ -77,22 +77,12 @@ func (db *DB) CreateUser(user *User) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()
tx, err := db.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
var password *string var password *string
if user.Password != "" { if user.Password != "" {
password = &user.Password password = &user.Password
} }
_, err = tx.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password) _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
if err != nil { return err
return err
}
return tx.Commit()
} }
func (db *DB) ListNetworks(username string) ([]Network, error) { func (db *DB) ListNetworks(username string) ([]Network, error) {
@ -151,3 +141,19 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
return channels, nil return channels, nil
} }
func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
db.lock.Lock()
defer db.lock.Unlock()
_, err := db.db.Exec("INSERT OR REPLACE INTO Channel(network, name) VALUES (?, ?)", networkID, ch.Name)
return err
}
func (db *DB) DeleteChannel(networkID int64, name string) error {
db.lock.Lock()
defer db.lock.Unlock()
_, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
return err
}

View file

@ -114,7 +114,25 @@ func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
}) })
} }
// upstream returns the upstream connection, if any. If there are zero or if
// there are multiple upstream connections, it returns nil.
func (dc *downstreamConn) upstream() *upstreamConn {
if dc.network == nil {
return nil
}
var upstream *upstreamConn
dc.forEachUpstream(func(uc *upstreamConn) {
upstream = uc
})
return upstream
}
func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) { func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
if uc := dc.upstream(); uc != nil {
return uc, name, nil
}
// TODO: extract network name from channel name if dc.upstream == nil // TODO: extract network name from channel name if dc.upstream == nil
var channel *upstreamChannel var channel *upstreamChannel
var err error var err error
@ -461,7 +479,20 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Command: msg.Command, Command: msg.Command,
Params: []string{upstreamName}, Params: []string{upstreamName},
}) })
// TODO: add/remove channel from upstream config
switch msg.Command {
case "JOIN":
err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
Name: upstreamName,
})
if err != nil {
dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
}
case "PART":
if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
}
}
case "MODE": case "MODE":
if msg.Prefix == nil { if msg.Prefix == nil {
return fmt.Errorf("missing prefix") return fmt.Errorf("missing prefix")