From 802e82c272312564a9a8a8d61c04c8a2407d1f37 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 8 Nov 2021 19:40:30 +0100 Subject: [PATCH] Add context support to service References: https://todo.sr.ht/~emersion/soju/141 --- downstream.go | 2 +- service.go | 60 +++++++++++++++++++++++++-------------------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/downstream.go b/downstream.go index 46e2015..2d80d9d 100644 --- a/downstream.go +++ b/downstream.go @@ -2161,7 +2161,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Params: []string{name, text}, }) } - handleServicePRIVMSG(dc, text) + handleServicePRIVMSG(ctx, dc, text) continue } diff --git a/service.go b/service.go index 93d2dba..06ed9ec 100644 --- a/service.go +++ b/service.go @@ -39,7 +39,7 @@ type serviceCommandSet map[string]*serviceCommand type serviceCommand struct { usage string desc string - handle func(dc *downstreamConn, params []string) error + handle func(ctx context.Context, dc *downstreamConn, params []string) error children serviceCommandSet admin bool } @@ -113,7 +113,7 @@ func splitWords(s string) ([]string, error) { return words, nil } -func handleServicePRIVMSG(dc *downstreamConn, text string) { +func handleServicePRIVMSG(ctx context.Context, dc *downstreamConn, text string) { words, err := splitWords(text) if err != nil { sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err)) @@ -144,7 +144,7 @@ func handleServicePRIVMSG(dc *downstreamConn, text string) { return } - if err := cmd.handle(dc, params); err != nil { + if err := cmd.handle(ctx, dc, params); err != nil { sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err)) } } @@ -322,7 +322,7 @@ func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, admin } } -func handleServiceHelp(dc *downstreamConn, params []string) error { +func handleServiceHelp(ctx context.Context, dc *downstreamConn, params []string) error { if len(params) > 0 { cmd, rest, err := serviceCommands.Get(params) if err != nil { @@ -473,7 +473,7 @@ func (fs *networkFlagSet) update(network *Network) error { return nil } -func handleServiceNetworkCreate(dc *downstreamConn, params []string) error { +func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params []string) error { fs := newNetworkFlagSet() if err := fs.Parse(params); err != nil { return err @@ -490,7 +490,7 @@ func handleServiceNetworkCreate(dc *downstreamConn, params []string) error { return err } - network, err := dc.user.createNetwork(context.TODO(), record) + network, err := dc.user.createNetwork(ctx, record) if err != nil { return fmt.Errorf("could not create network: %v", err) } @@ -499,7 +499,7 @@ func handleServiceNetworkCreate(dc *downstreamConn, params []string) error { return nil } -func handleServiceNetworkStatus(dc *downstreamConn, params []string) error { +func handleServiceNetworkStatus(ctx context.Context, dc *downstreamConn, params []string) error { n := 0 dc.user.forEachNetwork(func(net *network) { var statuses []string @@ -545,7 +545,7 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error { return nil } -func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error { +func handleServiceNetworkUpdate(ctx context.Context, dc *downstreamConn, params []string) error { if len(params) < 1 { return fmt.Errorf("expected at least one argument") } @@ -565,7 +565,7 @@ func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error { return err } - network, err := dc.user.updateNetwork(context.TODO(), &record) + network, err := dc.user.updateNetwork(ctx, &record) if err != nil { return fmt.Errorf("could not update network: %v", err) } @@ -574,7 +574,7 @@ func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error { return nil } -func handleServiceNetworkDelete(dc *downstreamConn, params []string) error { +func handleServiceNetworkDelete(ctx context.Context, dc *downstreamConn, params []string) error { if len(params) != 1 { return fmt.Errorf("expected exactly one argument") } @@ -584,7 +584,7 @@ func handleServiceNetworkDelete(dc *downstreamConn, params []string) error { return fmt.Errorf("unknown network %q", params[0]) } - if err := dc.user.deleteNetwork(context.TODO(), net.ID); err != nil { + if err := dc.user.deleteNetwork(ctx, net.ID); err != nil { return err } @@ -592,7 +592,7 @@ func handleServiceNetworkDelete(dc *downstreamConn, params []string) error { return nil } -func handleServiceNetworkQuote(dc *downstreamConn, params []string) error { +func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params []string) error { if len(params) != 2 { return fmt.Errorf("expected exactly two arguments") } @@ -626,7 +626,7 @@ func sendCertfpFingerprints(dc *downstreamConn, cert []byte) { sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:])) } -func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error { +func handleServiceCertFPGenerate(ctx context.Context, dc *downstreamConn, params []string) error { fs := newFlagSet() keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)") bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA") @@ -657,7 +657,7 @@ func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error { net.SASL.External.PrivKeyBlob = privKey net.SASL.Mechanism = "EXTERNAL" - if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil { + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { return err } @@ -666,7 +666,7 @@ func handleServiceCertFPGenerate(dc *downstreamConn, params []string) error { return nil } -func handleServiceCertFPFingerprints(dc *downstreamConn, params []string) error { +func handleServiceCertFPFingerprints(ctx context.Context, dc *downstreamConn, params []string) error { if len(params) != 1 { return fmt.Errorf("expected exactly one argument") } @@ -684,7 +684,7 @@ func handleServiceCertFPFingerprints(dc *downstreamConn, params []string) error return nil } -func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error { +func handleServiceSASLSetPlain(ctx context.Context, dc *downstreamConn, params []string) error { if len(params) != 3 { return fmt.Errorf("expected exactly 3 arguments") } @@ -698,7 +698,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error { net.SASL.Plain.Password = params[2] net.SASL.Mechanism = "PLAIN" - if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil { + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { return err } @@ -706,7 +706,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error { return nil } -func handleServiceSASLReset(dc *downstreamConn, params []string) error { +func handleServiceSASLReset(ctx context.Context, dc *downstreamConn, params []string) error { if len(params) != 1 { return fmt.Errorf("expected exactly one argument") } @@ -722,7 +722,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error { net.SASL.External.PrivKeyBlob = nil net.SASL.Mechanism = "" - if err := dc.srv.db.StoreNetwork(context.TODO(), dc.user.ID, &net.Network); err != nil { + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { return err } @@ -730,7 +730,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error { return nil } -func handleUserCreate(dc *downstreamConn, params []string) error { +func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) error { fs := newFlagSet() username := fs.String("username", "", "") password := fs.String("password", "", "") @@ -773,7 +773,7 @@ func popArg(params []string) (string, []string) { return "", params } -func handleUserUpdate(dc *downstreamConn, params []string) error { +func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error { var password, realname *string var admin *bool fs := newFlagSet() @@ -837,7 +837,7 @@ func handleUserUpdate(dc *downstreamConn, params []string) error { return fmt.Errorf("cannot update -admin of own user") } - if err := dc.user.updateUser(context.TODO(), &record); err != nil { + if err := dc.user.updateUser(ctx, &record); err != nil { return err } @@ -847,7 +847,7 @@ func handleUserUpdate(dc *downstreamConn, params []string) error { return nil } -func handleUserDelete(dc *downstreamConn, params []string) error { +func handleUserDelete(ctx context.Context, dc *downstreamConn, params []string) error { if len(params) != 1 { return fmt.Errorf("expected exactly one argument") } @@ -860,7 +860,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error { u.stop() - if err := dc.srv.db.DeleteUser(context.TODO(), u.ID); err != nil { + if err := dc.srv.db.DeleteUser(ctx, u.ID); err != nil { return fmt.Errorf("failed to delete user: %v", err) } @@ -868,7 +868,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error { return nil } -func handleServiceChannelStatus(dc *downstreamConn, params []string) error { +func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params []string) error { var defaultNetworkName string if dc.network != nil { defaultNetworkName = dc.network.GetName() @@ -988,7 +988,7 @@ func (fs *channelFlagSet) update(channel *Channel) error { return nil } -func handleServiceChannelUpdate(dc *downstreamConn, params []string) error { +func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params []string) error { if len(params) < 1 { return fmt.Errorf("expected at least one argument") } @@ -1015,7 +1015,7 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error { uc.updateChannelAutoDetach(upstreamName) - if err := dc.srv.db.StoreChannel(context.TODO(), uc.network.ID, ch); err != nil { + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { return fmt.Errorf("failed to update channel: %v", err) } @@ -1023,8 +1023,8 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error { return nil } -func handleServiceServerStatus(dc *downstreamConn, params []string) error { - dbStats, err := dc.user.srv.db.Stats(context.TODO()) +func handleServiceServerStatus(ctx context.Context, dc *downstreamConn, params []string) error { + dbStats, err := dc.user.srv.db.Stats(ctx) if err != nil { return err } @@ -1033,7 +1033,7 @@ func handleServiceServerStatus(dc *downstreamConn, params []string) error { return nil } -func handleServiceServerNotice(dc *downstreamConn, params []string) error { +func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params []string) error { if len(params) != 1 { return fmt.Errorf("expected exactly one argument") }