Add network.channels, remove DB.GetChannel

Store the list of configured channels in the network data structure.
This removes the need for a database lookup and will be useful for
detached channels.
This commit is contained in:
Simon Ser 2020-04-11 17:00:40 +02:00
parent dbd6cd689e
commit 276ce12e7c
No known key found for this signature in database
GPG key ID: 0FDE7BE0E88F5E48
3 changed files with 27 additions and 34 deletions

19
db.go
View file

@ -48,8 +48,6 @@ type Channel struct {
Key string Key string
} }
var ErrNoSuchChannel = fmt.Errorf("soju: no such channel")
const schema = ` const schema = `
CREATE TABLE User ( CREATE TABLE User (
username VARCHAR(255) PRIMARY KEY, username VARCHAR(255) PRIMARY KEY,
@ -371,23 +369,6 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
return channels, nil return channels, nil
} }
func (db *DB) GetChannel(networkID int64, name string) (*Channel, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ch := &Channel{Name: name}
var key *string
row := db.db.QueryRow("SELECT id, key FROM Channel WHERE network = ? AND name = ?", networkID, name)
if err := row.Scan(&ch.ID, &key); err == sql.ErrNoRows {
return nil, ErrNoSuchChannel
} else if err != nil {
return nil, err
}
ch.Key = fromStringPtr(key)
return ch, nil
}
func (db *DB) StoreChannel(networkID int64, ch *Channel) error { func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
db.lock.Lock() db.lock.Lock()
defer db.lock.Unlock() defer db.lock.Unlock()

View file

@ -421,13 +421,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.registered = true uc.registered = true
uc.logger.Printf("connection registered") uc.logger.Printf("connection registered")
channels, err := uc.srv.db.ListChannels(uc.network.ID) for _, ch := range uc.network.channels {
if err != nil {
uc.logger.Printf("failed to list channels from database: %v", err)
break
}
for _, ch := range channels {
params := []string{ch.Name} params := []string{ch.Name}
if ch.Key != "" { if ch.Key != "" {
params = append(params, ch.Key) params = append(params, ch.Key)

34
user.go
View file

@ -56,16 +56,23 @@ type network struct {
stopped chan struct{} stopped chan struct{}
conn *upstreamConn conn *upstreamConn
channels map[string]*Channel
history map[string]*networkHistory // indexed by entity history map[string]*networkHistory // indexed by entity
offlineClients map[string]struct{} // indexed by client name offlineClients map[string]struct{} // indexed by client name
lastError error lastError error
} }
func newNetwork(user *user, record *Network) *network { func newNetwork(user *user, record *Network, channels []Channel) *network {
m := make(map[string]*Channel, len(channels))
for _, ch := range channels {
m[ch.Name] = &ch
}
return &network{ return &network{
Network: *record, Network: *record,
user: user, user: user,
stopped: make(chan struct{}), stopped: make(chan struct{}),
channels: m,
history: make(map[string]*networkHistory), history: make(map[string]*networkHistory),
offlineClients: make(map[string]struct{}), offlineClients: make(map[string]struct{}),
} }
@ -140,16 +147,22 @@ func (net *network) Stop() {
} }
func (net *network) createUpdateChannel(ch *Channel) error { func (net *network) createUpdateChannel(ch *Channel) error {
if dbCh, err := net.user.srv.db.GetChannel(net.ID, ch.Name); err == nil { if current, ok := net.channels[ch.Name]; ok {
ch.ID = dbCh.ID ch.ID = current.ID // update channel if it already exists
} else if err != ErrNoSuchChannel { }
if err := net.user.srv.db.StoreChannel(net.ID, ch); err != nil {
return err return err
} }
return net.user.srv.db.StoreChannel(net.ID, ch) net.channels[ch.Name] = ch
return nil
} }
func (net *network) deleteChannel(name string) error { func (net *network) deleteChannel(name string) error {
return net.user.srv.db.DeleteChannel(net.ID, name) if err := net.user.srv.db.DeleteChannel(net.ID, name); err != nil {
return err
}
delete(net.channels, name)
return nil
} }
type user struct { type user struct {
@ -221,7 +234,12 @@ func (u *user) run() {
} }
for _, record := range networks { for _, record := range networks {
network := newNetwork(u, &record) channels, err := u.srv.db.ListChannels(record.ID)
if err != nil {
u.srv.Logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
}
network := newNetwork(u, &record, channels)
u.networks = append(u.networks, network) u.networks = append(u.networks, network)
go network.run() go network.run()
@ -353,7 +371,7 @@ func (u *user) createNetwork(net *Network) (*network, error) {
panic("tried creating an already-existing network") panic("tried creating an already-existing network")
} }
network := newNetwork(u, net) network := newNetwork(u, net, nil)
err := u.srv.db.StoreNetwork(u.Username, &network.Network) err := u.srv.db.StoreNetwork(u.Username, &network.Network)
if err != nil { if err != nil {
return nil, err return nil, err