Add network update command

The user.updateNetwork function is a bit involved because we need to
make sure that the upstream connection is closed before re-connecting
(would otherwise cause "Nick already used" errors) and that the
downstream connections' state is kept in sync.

References: https://todo.sr.ht/~emersion/soju/17
This commit is contained in:
Simon Ser 2020-06-02 11:39:53 +02:00
parent bee2001e29
commit c709ebfc91
No known key found for this signature in database
GPG key ID: 0FDE7BE0E88F5E48
2 changed files with 253 additions and 90 deletions

View file

@ -118,7 +118,7 @@ func init() {
"network": { "network": {
children: serviceCommandSet{ children: serviceCommandSet{
"create": { "create": {
usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [[-connect-command command] ...]", usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
desc: "add a new network", desc: "add a new network",
handle: handleServiceCreateNetwork, handle: handleServiceCreateNetwork,
}, },
@ -126,6 +126,11 @@ func init() {
desc: "show a list of saved networks and their current status", desc: "show a list of saved networks and their current status",
handle: handleServiceNetworkStatus, handle: handleServiceNetworkStatus,
}, },
"update": {
usage: "[-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
desc: "update a network",
handle: handleServiceNetworkUpdate,
},
"delete": { "delete": {
usage: "<name>", usage: "<name>",
desc: "delete a network", desc: "delete a network",
@ -338,36 +343,57 @@ func newFlagSet() *flag.FlagSet {
return fs return fs
} }
type stringSliceVar []string type stringSliceFlag []string
func (v *stringSliceVar) String() string { func (v *stringSliceFlag) String() string {
return fmt.Sprint([]string(*v)) return fmt.Sprint([]string(*v))
} }
func (v *stringSliceVar) Set(s string) error { func (v *stringSliceFlag) Set(s string) error {
*v = append(*v, s) *v = append(*v, s)
return nil return nil
} }
func handleServiceCreateNetwork(dc *downstreamConn, params []string) error { // stringPtrFlag is a flag value populating a string pointer. This allows to
fs := newFlagSet() // disambiguate between a flag that hasn't been set and a flag that has been
addr := fs.String("addr", "", "") // set to an empty string.
name := fs.String("name", "", "") type stringPtrFlag struct {
username := fs.String("username", "", "") ptr **string
pass := fs.String("pass", "", "")
realname := fs.String("realname", "", "")
nick := fs.String("nick", "", "")
var connectCommands stringSliceVar
fs.Var(&connectCommands, "connect-command", "")
if err := fs.Parse(params); err != nil {
return err
}
if *addr == "" {
return fmt.Errorf("flag -addr is required")
} }
if addrParts := strings.SplitN(*addr, "://", 2); len(addrParts) == 2 { func (f stringPtrFlag) String() string {
if *f.ptr == nil {
return ""
}
return **f.ptr
}
func (f stringPtrFlag) Set(s string) error {
*f.ptr = &s
return nil
}
type networkFlagSet struct {
*flag.FlagSet
Addr, Name, Nick, Username, Pass, Realname *string
ConnectCommands []string
}
func newNetworkFlagSet() *networkFlagSet {
fs := &networkFlagSet{FlagSet: newFlagSet()}
fs.Var(stringPtrFlag{&fs.Addr}, "addr", "")
fs.Var(stringPtrFlag{&fs.Name}, "name", "")
fs.Var(stringPtrFlag{&fs.Nick}, "nick", "")
fs.Var(stringPtrFlag{&fs.Username}, "username", "")
fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
return fs
}
func (fs *networkFlagSet) update(network *Network) error {
if fs.Addr != nil {
if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
scheme := addrParts[0] scheme := addrParts[0]
switch scheme { switch scheme {
case "ircs", "irc+insecure": case "ircs", "irc+insecure":
@ -375,28 +401,57 @@ func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme) return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme)
} }
} }
network.Addr = *fs.Addr
for _, command := range connectCommands { }
if fs.Name != nil {
network.Name = *fs.Name
}
if fs.Nick != nil {
network.Nick = *fs.Nick
}
if fs.Username != nil {
network.Username = *fs.Username
}
if fs.Pass != nil {
network.Pass = *fs.Pass
}
if fs.Realname != nil {
network.Realname = *fs.Realname
}
if fs.ConnectCommands != nil {
if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
network.ConnectCommands = nil
} else {
for _, command := range fs.ConnectCommands {
_, err := irc.ParseMessage(command) _, err := irc.ParseMessage(command)
if err != nil { if err != nil {
return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err) return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
} }
} }
network.ConnectCommands = fs.ConnectCommands
if *nick == "" { }
*nick = dc.nick }
return nil
} }
var err error func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
network, err := dc.user.createNetwork(&Network{ fs := newNetworkFlagSet()
Addr: *addr, if err := fs.Parse(params); err != nil {
Name: *name, return err
Username: *username, }
Pass: *pass, if fs.Addr == nil {
Realname: *realname, return fmt.Errorf("flag -addr is required")
Nick: *nick, }
ConnectCommands: connectCommands,
}) record := &Network{
Addr: *fs.Addr,
Nick: dc.nick,
}
if err := fs.update(record); err != nil {
return err
}
network, err := dc.user.createNetwork(record)
if err != nil { if err != nil {
return fmt.Errorf("could not create network: %v", err) return fmt.Errorf("could not create network: %v", err)
} }
@ -441,6 +496,35 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error {
if len(params) < 1 {
return fmt.Errorf("expected exactly one argument")
}
fs := newNetworkFlagSet()
if err := fs.Parse(params[1:]); err != nil {
return err
}
net := dc.user.getNetwork(params[0])
if net == nil {
return fmt.Errorf("unknown network %q", params[0])
}
record := net.Network // copy network record because we'll mutate it
if err := fs.update(&record); err != nil {
return err
}
network, err := dc.user.updateNetwork(&record)
if err != nil {
return fmt.Errorf("could not update network: %v", err)
}
sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName()))
return nil
}
func handleServiceNetworkDelete(dc *downstreamConn, params []string) error { func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
if len(params) != 1 { if len(params) != 1 {
return fmt.Errorf("expected exactly one argument") return fmt.Errorf("expected exactly one argument")

155
user.go
View file

@ -272,6 +272,15 @@ func (u *user) getNetwork(name string) *network {
return nil return nil
} }
func (u *user) getNetworkByID(id int64) *network {
for _, net := range u.networks {
if net.ID == id {
return net
}
}
return nil
}
func (u *user) run() { func (u *user) run() {
networks, err := u.srv.db.ListNetworks(u.Username) networks, err := u.srv.db.ListNetworks(u.Username)
if err != nil { if err != nil {
@ -309,31 +318,18 @@ func (u *user) run() {
}) })
uc.network.lastError = nil uc.network.lastError = nil
case eventUpstreamDisconnected: case eventUpstreamDisconnected:
uc := e.uc u.handleUpstreamDisconnected(e.uc)
uc.network.conn = nil
for _, ml := range uc.messageLoggers {
if err := ml.Close(); err != nil {
uc.logger.Printf("failed to close message logger: %v", err)
}
}
uc.endPendingLISTs(true)
uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps()
})
if uc.network.lastError == nil {
uc.forEachDownstream(func(dc *downstreamConn) {
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
})
}
case eventUpstreamConnectionError: case eventUpstreamConnectionError:
net := e.net net := e.net
if net.lastError == nil || net.lastError.Error() != e.err.Error() { stopped := false
select {
case <-net.stopped:
stopped = true
default:
}
if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
net.forEachDownstream(func(dc *downstreamConn) { net.forEachDownstream(func(dc *downstreamConn) {
sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err)) sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
}) })
@ -425,47 +421,130 @@ func (u *user) run() {
} }
} }
func (u *user) createNetwork(net *Network) (*network, error) { func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
if net.ID != 0 { uc.network.conn = nil
for _, ml := range uc.messageLoggers {
if err := ml.Close(); err != nil {
uc.logger.Printf("failed to close message logger: %v", err)
}
}
uc.endPendingLISTs(true)
uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps()
})
if uc.network.lastError == nil {
uc.forEachDownstream(func(dc *downstreamConn) {
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
})
}
}
func (u *user) addNetwork(network *network) {
u.networks = append(u.networks, network)
go network.run()
}
func (u *user) removeNetwork(network *network) {
network.stop()
u.forEachDownstream(func(dc *downstreamConn) {
if dc.network != nil && dc.network == network {
dc.Close()
}
})
for i, net := range u.networks {
if net == network {
u.networks = append(u.networks[:i], u.networks[i+1:]...)
return
}
}
panic("tried to remove a non-existing network")
}
func (u *user) createNetwork(record *Network) (*network, error) {
if record.ID != 0 {
panic("tried creating an already-existing network") panic("tried creating an already-existing network")
} }
network := newNetwork(u, net, nil) network := newNetwork(u, record, nil)
err := u.srv.db.StoreNetwork(u.Username, &network.Network) err := u.srv.db.StoreNetwork(u.Username, &network.Network)
if err != nil { if err != nil {
return nil, err return nil, err
} }
u.networks = append(u.networks, network) u.addNetwork(network)
go network.run()
return network, nil return network, nil
} }
func (u *user) deleteNetwork(id int64) error { func (u *user) updateNetwork(record *Network) (*network, error) {
for i, net := range u.networks { if record.ID == 0 {
if net.ID != id { panic("tried updating a new network")
continue
} }
if err := u.srv.db.DeleteNetwork(net.ID); err != nil { network := u.getNetworkByID(record.ID)
return err if network == nil {
panic("tried updating a non-existing network")
} }
if err := u.srv.db.StoreNetwork(u.Username, record); err != nil {
return nil, err
}
// Most network changes require us to re-connect to the upstream server
channels := make([]Channel, 0, len(network.channels))
for _, ch := range network.channels {
channels = append(channels, *ch)
}
updatedNetwork := newNetwork(u, record, channels)
// If we're currently connected, disconnect and perform the necessary
// bookkeeping
if network.conn != nil {
network.stop()
// Note: this will set network.conn to nil
u.handleUpstreamDisconnected(network.conn)
}
// Patch downstream connections to use our fresh updated network
u.forEachDownstream(func(dc *downstreamConn) { u.forEachDownstream(func(dc *downstreamConn) {
if dc.network != nil && dc.network == net { if dc.network != nil && dc.network == network {
dc.Close() dc.network = updatedNetwork
} }
}) })
net.stop() // We need to remove the network after patching downstream connections,
u.networks = append(u.networks[:i], u.networks[i+1:]...) // otherwise they'll get closed
return nil u.removeNetwork(network)
// This will re-connect to the upstream server
u.addNetwork(updatedNetwork)
return updatedNetwork, nil
} }
func (u *user) deleteNetwork(id int64) error {
network := u.getNetworkByID(id)
if network == nil {
panic("tried deleting a non-existing network") panic("tried deleting a non-existing network")
} }
if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
return err
}
u.removeNetwork(network)
return nil
}
func (u *user) updatePassword(hashed string) error { func (u *user) updatePassword(hashed string) error {
u.User.Password = hashed u.User.Password = hashed
return u.srv.db.UpdatePassword(&u.User) return u.srv.db.UpdatePassword(&u.User)