Add network.channels, remove DB.GetChannel

Store the list of configured channels in the network data structure.
This removes the need for a database lookup and will be useful for
detached channels.
This commit is contained in:
Simon Ser 2020-04-11 17:00:40 +02:00
parent dbd6cd689e
commit 276ce12e7c
No known key found for this signature in database
GPG key ID: 0FDE7BE0E88F5E48
3 changed files with 27 additions and 34 deletions

19
db.go
View file

@ -48,8 +48,6 @@ type Channel struct {
Key string
}
var ErrNoSuchChannel = fmt.Errorf("soju: no such channel")
const schema = `
CREATE TABLE User (
username VARCHAR(255) PRIMARY KEY,
@ -371,23 +369,6 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
return channels, nil
}
func (db *DB) GetChannel(networkID int64, name string) (*Channel, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ch := &Channel{Name: name}
var key *string
row := db.db.QueryRow("SELECT id, key FROM Channel WHERE network = ? AND name = ?", networkID, name)
if err := row.Scan(&ch.ID, &key); err == sql.ErrNoRows {
return nil, ErrNoSuchChannel
} else if err != nil {
return nil, err
}
ch.Key = fromStringPtr(key)
return ch, nil
}
func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
db.lock.Lock()
defer db.lock.Unlock()

View file

@ -421,13 +421,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.registered = true
uc.logger.Printf("connection registered")
channels, err := uc.srv.db.ListChannels(uc.network.ID)
if err != nil {
uc.logger.Printf("failed to list channels from database: %v", err)
break
}
for _, ch := range channels {
for _, ch := range uc.network.channels {
params := []string{ch.Name}
if ch.Key != "" {
params = append(params, ch.Key)

34
user.go
View file

@ -56,16 +56,23 @@ type network struct {
stopped chan struct{}
conn *upstreamConn
channels map[string]*Channel
history map[string]*networkHistory // indexed by entity
offlineClients map[string]struct{} // indexed by client name
lastError error
}
func newNetwork(user *user, record *Network) *network {
func newNetwork(user *user, record *Network, channels []Channel) *network {
m := make(map[string]*Channel, len(channels))
for _, ch := range channels {
m[ch.Name] = &ch
}
return &network{
Network: *record,
user: user,
stopped: make(chan struct{}),
channels: m,
history: make(map[string]*networkHistory),
offlineClients: make(map[string]struct{}),
}
@ -140,16 +147,22 @@ func (net *network) Stop() {
}
func (net *network) createUpdateChannel(ch *Channel) error {
if dbCh, err := net.user.srv.db.GetChannel(net.ID, ch.Name); err == nil {
ch.ID = dbCh.ID
} else if err != ErrNoSuchChannel {
if current, ok := net.channels[ch.Name]; ok {
ch.ID = current.ID // update channel if it already exists
}
if err := net.user.srv.db.StoreChannel(net.ID, ch); err != nil {
return err
}
return net.user.srv.db.StoreChannel(net.ID, ch)
net.channels[ch.Name] = ch
return nil
}
func (net *network) deleteChannel(name string) error {
return net.user.srv.db.DeleteChannel(net.ID, name)
if err := net.user.srv.db.DeleteChannel(net.ID, name); err != nil {
return err
}
delete(net.channels, name)
return nil
}
type user struct {
@ -221,7 +234,12 @@ func (u *user) run() {
}
for _, record := range networks {
network := newNetwork(u, &record)
channels, err := u.srv.db.ListChannels(record.ID)
if err != nil {
u.srv.Logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
}
network := newNetwork(u, &record, channels)
u.networks = append(u.networks, network)
go network.run()
@ -353,7 +371,7 @@ func (u *user) createNetwork(net *Network) (*network, error) {
panic("tried creating an already-existing network")
}
network := newNetwork(u, net)
network := newNetwork(u, net, nil)
err := u.srv.db.StoreNetwork(u.Username, &network.Network)
if err != nil {
return nil, err