Add context support to service

References: https://todo.sr.ht/~emersion/soju/141
This commit is contained in:
Simon Ser 2021-11-08 19:40:30 +01:00
parent c21202160c
commit 802e82c272
2 changed files with 31 additions and 31 deletions

View file

@ -2161,7 +2161,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: []string{name, text}, Params: []string{name, text},
}) })
} }
handleServicePRIVMSG(dc, text) handleServicePRIVMSG(ctx, dc, text)
continue continue
} }

View file

@ -39,7 +39,7 @@ type serviceCommandSet map[string]*serviceCommand
type serviceCommand struct { type serviceCommand struct {
usage string usage string
desc string desc string
handle func(dc *downstreamConn, params []string) error handle func(ctx context.Context, dc *downstreamConn, params []string) error
children serviceCommandSet children serviceCommandSet
admin bool admin bool
} }
@ -113,7 +113,7 @@ func splitWords(s string) ([]string, error) {
return words, nil return words, nil
} }
func handleServicePRIVMSG(dc *downstreamConn, text string) { func handleServicePRIVMSG(ctx context.Context, dc *downstreamConn, text string) {
words, err := splitWords(text) words, err := splitWords(text)
if err != nil { if err != nil {
sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err)) sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err))
@ -144,7 +144,7 @@ func handleServicePRIVMSG(dc *downstreamConn, text string) {
return return
} }
if err := cmd.handle(dc, params); err != nil { if err := cmd.handle(ctx, dc, params); err != nil {
sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err)) sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err))
} }
} }
@ -322,7 +322,7 @@ func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, admin
} }
} }
func handleServiceHelp(dc *downstreamConn, params []string) error { func handleServiceHelp(ctx context.Context, dc *downstreamConn, params []string) error {
if len(params) > 0 { if len(params) > 0 {
cmd, rest, err := serviceCommands.Get(params) cmd, rest, err := serviceCommands.Get(params)
if err != nil { if err != nil {
@ -473,7 +473,7 @@ func (fs *networkFlagSet) update(network *Network) error {
return nil return nil
} }
func handleServiceNetworkCreate(dc *downstreamConn, params []string) error { func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params []string) error {
fs := newNetworkFlagSet() fs := newNetworkFlagSet()
if err := fs.Parse(params); err != nil { if err := fs.Parse(params); err != nil {
return err return err
@ -490,7 +490,7 @@ func handleServiceNetworkCreate(dc *downstreamConn, params []string) error {
return err return err
} }
network, err := dc.user.createNetwork(context.TODO(), record) network, err := dc.user.createNetwork(ctx, record)
if err != nil { if err != nil {
return fmt.Errorf("could not create network: %v", err) return fmt.Errorf("could not create network: %v", err)
} }
@ -499,7 +499,7 @@ func handleServiceNetworkCreate(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleServiceNetworkStatus(dc *downstreamConn, params []string) error { func handleServiceNetworkStatus(ctx context.Context, dc *downstreamConn, params []string) error {
n := 0 n := 0
dc.user.forEachNetwork(func(net *network) { dc.user.forEachNetwork(func(net *network) {
var statuses []string var statuses []string
@ -545,7 +545,7 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error { func handleServiceNetworkUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
if len(params) < 1 { if len(params) < 1 {
return fmt.Errorf("expected at least one argument") return fmt.Errorf("expected at least one argument")
} }
@ -565,7 +565,7 @@ func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error {
return err return err
} }
network, err := dc.user.updateNetwork(context.TODO(), &record) network, err := dc.user.updateNetwork(ctx, &record)
if err != nil { if err != nil {
return fmt.Errorf("could not update network: %v", err) return fmt.Errorf("could not update network: %v", err)
} }
@ -574,7 +574,7 @@ func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleServiceNetworkDelete(dc *downstreamConn, params []string) error { func handleServiceNetworkDelete(ctx context.Context, dc *downstreamConn, params []string) error {
if len(params) != 1 { if len(params) != 1 {
return fmt.Errorf("expected exactly one argument") return fmt.Errorf("expected exactly one argument")
} }
@ -584,7 +584,7 @@ func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
return fmt.Errorf("unknown network %q", params[0]) return fmt.Errorf("unknown network %q", params[0])
} }
if err := dc.user.deleteNetwork(context.TODO(), net.ID); err != nil { if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
return err return err
} }
@ -592,7 +592,7 @@ func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleServiceNetworkQuote(dc *downstreamConn, params []string) error { func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params []string) error {
if len(params) != 2 { if len(params) != 2 {
return fmt.Errorf("expected exactly two arguments") return fmt.Errorf("expected exactly two arguments")
} }
@ -626,7 +626,7 @@ func sendCertfpFingerprints(dc *downstreamConn, cert []byte) {
sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:])) sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:]))
} }
func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error { func handleServiceCertFPGenerate(ctx context.Context, dc *downstreamConn, params []string) error {
fs := newFlagSet() fs := newFlagSet()
keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)") keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)")
bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA") bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA")
@ -657,7 +657,7 @@ func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error {
net.SASL.External.PrivKeyBlob = privKey net.SASL.External.PrivKeyBlob = privKey
net.SASL.Mechanism = "EXTERNAL" net.SASL.Mechanism = "EXTERNAL"
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil { if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
return err return err
} }
@ -666,7 +666,7 @@ func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleServiceCertFPFingerprints(dc *downstreamConn, params []string) error { func handleServiceCertFPFingerprints(ctx context.Context, dc *downstreamConn, params []string) error {
if len(params) != 1 { if len(params) != 1 {
return fmt.Errorf("expected exactly one argument") return fmt.Errorf("expected exactly one argument")
} }
@ -684,7 +684,7 @@ func handleServiceCertFPFingerprints(dc *downstreamConn, params []string) error
return nil return nil
} }
func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error { func handleServiceSASLSetPlain(ctx context.Context, dc *downstreamConn, params []string) error {
if len(params) != 3 { if len(params) != 3 {
return fmt.Errorf("expected exactly 3 arguments") return fmt.Errorf("expected exactly 3 arguments")
} }
@ -698,7 +698,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error {
net.SASL.Plain.Password = params[2] net.SASL.Plain.Password = params[2]
net.SASL.Mechanism = "PLAIN" net.SASL.Mechanism = "PLAIN"
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil { if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
return err return err
} }
@ -706,7 +706,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleServiceSASLReset(dc *downstreamConn, params []string) error { func handleServiceSASLReset(ctx context.Context, dc *downstreamConn, params []string) error {
if len(params) != 1 { if len(params) != 1 {
return fmt.Errorf("expected exactly one argument") return fmt.Errorf("expected exactly one argument")
} }
@ -722,7 +722,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error {
net.SASL.External.PrivKeyBlob = nil net.SASL.External.PrivKeyBlob = nil
net.SASL.Mechanism = "" net.SASL.Mechanism = ""
if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil { if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
return err return err
} }
@ -730,7 +730,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleUserCreate(dc *downstreamConn, params []string) error { func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) error {
fs := newFlagSet() fs := newFlagSet()
username := fs.String("username", "", "") username := fs.String("username", "", "")
password := fs.String("password", "", "") password := fs.String("password", "", "")
@ -773,7 +773,7 @@ func popArg(params []string) (string, []string) {
return "", params return "", params
} }
func handleUserUpdate(dc *downstreamConn, params []string) error { func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
var password, realname *string var password, realname *string
var admin *bool var admin *bool
fs := newFlagSet() fs := newFlagSet()
@ -837,7 +837,7 @@ func handleUserUpdate(dc *downstreamConn, params []string) error {
return fmt.Errorf("cannot update -admin of own user") return fmt.Errorf("cannot update -admin of own user")
} }
if err := dc.user.updateUser(context.TODO(), &record); err != nil { if err := dc.user.updateUser(ctx, &record); err != nil {
return err return err
} }
@ -847,7 +847,7 @@ func handleUserUpdate(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleUserDelete(dc *downstreamConn, params []string) error { func handleUserDelete(ctx context.Context, dc *downstreamConn, params []string) error {
if len(params) != 1 { if len(params) != 1 {
return fmt.Errorf("expected exactly one argument") return fmt.Errorf("expected exactly one argument")
} }
@ -860,7 +860,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error {
u.stop() u.stop()
if err := dc.srv.db.DeleteUser(context.TODO(), u.ID); err != nil { if err := dc.srv.db.DeleteUser(ctx, u.ID); err != nil {
return fmt.Errorf("failed to delete user: %v", err) return fmt.Errorf("failed to delete user: %v", err)
} }
@ -868,7 +868,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleServiceChannelStatus(dc *downstreamConn, params []string) error { func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params []string) error {
var defaultNetworkName string var defaultNetworkName string
if dc.network != nil { if dc.network != nil {
defaultNetworkName = dc.network.GetName() defaultNetworkName = dc.network.GetName()
@ -988,7 +988,7 @@ func (fs *channelFlagSet) update(channel *Channel) error {
return nil return nil
} }
func handleServiceChannelUpdate(dc *downstreamConn, params []string) error { func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
if len(params) < 1 { if len(params) < 1 {
return fmt.Errorf("expected at least one argument") return fmt.Errorf("expected at least one argument")
} }
@ -1015,7 +1015,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
uc.updateChannelAutoDetach(upstreamName) uc.updateChannelAutoDetach(upstreamName)
if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
return fmt.Errorf("failed to update channel: %v", err) return fmt.Errorf("failed to update channel: %v", err)
} }
@ -1023,8 +1023,8 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleServiceServerStatus(dc *downstreamConn, params []string) error { func handleServiceServerStatus(ctx context.Context, dc *downstreamConn, params []string) error {
dbStats, err := dc.user.srv.db.Stats(context.TODO()) dbStats, err := dc.user.srv.db.Stats(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -1033,7 +1033,7 @@ func handleServiceServerStatus(dc *downstreamConn, params []string) error {
return nil return nil
} }
func handleServiceServerNotice(dc *downstreamConn, params []string) error { func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params []string) error {
if len(params) != 1 { if len(params) != 1 {
return fmt.Errorf("expected exactly one argument") return fmt.Errorf("expected exactly one argument")
} }