diff --git a/db.go b/db.go index 763be7d..7a51a26 100644 --- a/db.go +++ b/db.go @@ -117,6 +117,32 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { return networks, nil } +func (db *DB) StoreNetwork(username string, network *Network) error { + db.lock.Lock() + defer db.lock.Unlock() + + var netUsername, netRealname *string + if network.Username != "" { + netUsername = &network.Username + } + if network.Realname != "" { + netRealname = &network.Realname + } + + var err error + if network.ID != 0 { + _, err = db.db.Exec("UPDATE Network SET addr = ?, nick = ?, username = ?, realname = ? WHERE id = ?", network.Addr, network.Nick, netUsername, netRealname, network.ID) + } else { + var res sql.Result + res, err = db.db.Exec("INSERT INTO Network(user, addr, nick, username, realname) VALUES (?, ?, ?, ?, ?)", username, network.Addr, network.Nick, netUsername, netRealname) + if err != nil { + return err + } + network.ID, err = res.LastInsertId() + } + return err +} + func (db *DB) ListChannels(networkID int64) ([]Channel, error) { db.lock.RLock() defer db.lock.RUnlock() diff --git a/downstream.go b/downstream.go index 8a70e33..dc0ec90 100644 --- a/downstream.go +++ b/downstream.go @@ -105,6 +105,14 @@ func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string { return name } +func (dc *downstreamConn) forEachNetwork(f func(*network)) { + if dc.network != nil { + f(dc.network) + } else { + dc.user.forEachNetwork(f) + } +} + func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) { dc.user.forEachUpstream(func(uc *upstreamConn) { if dc.network != nil && uc.network != dc.network { @@ -458,6 +466,23 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Params: []string{dc.nick, "You may not reregister"}, }} case "NICK": + var nick string + if err := parseMessageParams(msg, &nick); err != nil { + return err + } + + var err error + dc.forEachNetwork(func(n *network) { + if err != nil { + return + } + n.Nick = nick + err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network) + }) + if err != nil { + return err + } + dc.forEachUpstream(func(uc *upstreamConn) { uc.SendMessage(msg) }) diff --git a/server.go b/server.go index 20790d9..bc390d9 100644 --- a/server.go +++ b/server.go @@ -109,6 +109,14 @@ func newUser(srv *Server, record *User) *user { } } +func (u *user) forEachNetwork(f func(*network)) { + u.lock.Lock() + for _, network := range u.networks { + f(network) + } + u.lock.Unlock() +} + func (u *user) forEachUpstream(f func(uc *upstreamConn)) { u.lock.Lock() for _, network := range u.networks {