Use capRegistry for downstreamConn

This commit is contained in:
Simon Ser 2022-03-14 19:15:35 +01:00
parent 347a4979da
commit 74fd506fef
4 changed files with 58 additions and 64 deletions

View file

@ -19,7 +19,7 @@ func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel
sendTopic(dc, ch)
}
if dc.caps["soju.im/read"] {
if dc.caps.IsEnabled("soju.im/read") {
channelCM := ch.conn.network.casemap(ch.Name)
r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM)
if err != nil {

View file

@ -307,8 +307,7 @@ type downstreamConn struct {
negotiatingCaps bool
capVersion int
supportedCaps map[string]string
caps map[string]bool
caps capRegistry
sasl *downstreamSASL
lastBatchRef uint64
@ -321,13 +320,12 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)}
options := connOptions{Logger: logger}
dc := &downstreamConn{
conn: *newConn(srv, ic, &options),
id: id,
nick: "*",
nickCM: "*",
supportedCaps: make(map[string]string),
caps: make(map[string]bool),
monitored: newCasemapMap(0),
conn: *newConn(srv, ic, &options),
id: id,
nick: "*",
nickCM: "*",
caps: newCapRegistry(),
monitored: newCasemapMap(0),
}
dc.monitored.SetCasemapping(casemapASCII)
dc.hostname = remoteAddr
@ -335,14 +333,14 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
dc.hostname = host
}
for k, v := range permanentDownstreamCaps {
dc.supportedCaps[k] = v
dc.caps.Available[k] = v
}
dc.supportedCaps["sasl"] = "PLAIN"
dc.caps.Available["sasl"] = "PLAIN"
// TODO: this is racy, we should only enable chathistory after
// authentication and then check that user.msgStore implements
// chatHistoryMessageStore
if srv.Config().LogPath != "" {
dc.supportedCaps["draft/chathistory"] = ""
dc.caps.Available["draft/chathistory"] = ""
}
return dc
}
@ -527,7 +525,7 @@ func (dc *downstreamConn) readMessages(ch chan<- event) error {
//
// This can only called from the user goroutine.
func (dc *downstreamConn) SendMessage(msg *irc.Message) {
if !dc.caps["message-tags"] {
if !dc.caps.IsEnabled("message-tags") {
if msg.Command == "TAGMSG" {
return
}
@ -536,32 +534,32 @@ func (dc *downstreamConn) SendMessage(msg *irc.Message) {
supported := false
switch name {
case "time":
supported = dc.caps["server-time"]
supported = dc.caps.IsEnabled("server-time")
case "account":
supported = dc.caps["account"]
supported = dc.caps.IsEnabled("account")
}
if !supported {
delete(msg.Tags, name)
}
}
}
if !dc.caps["batch"] && msg.Tags["batch"] != "" {
if !dc.caps.IsEnabled("batch") && msg.Tags["batch"] != "" {
msg = msg.Copy()
delete(msg.Tags, "batch")
}
if msg.Command == "JOIN" && !dc.caps["extended-join"] {
if msg.Command == "JOIN" && !dc.caps.IsEnabled("extended-join") {
msg.Params = msg.Params[:1]
}
if msg.Command == "SETNAME" && !dc.caps["setname"] {
if msg.Command == "SETNAME" && !dc.caps.IsEnabled("setname") {
return
}
if msg.Command == "AWAY" && !dc.caps["away-notify"] {
if msg.Command == "AWAY" && !dc.caps.IsEnabled("away-notify") {
return
}
if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] {
if msg.Command == "ACCOUNT" && !dc.caps.IsEnabled("account-notify") {
return
}
if msg.Command == "READ" && !dc.caps["soju.im/read"] {
if msg.Command == "READ" && !dc.caps.IsEnabled("soju.im/read") {
return
}
@ -573,7 +571,7 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags,
dc.lastBatchRef++
ref := fmt.Sprintf("%v", dc.lastBatchRef)
if dc.caps["batch"] {
if dc.caps.IsEnabled("batch") {
dc.SendMessage(&irc.Message{
Tags: tags,
Prefix: dc.srv.prefix(),
@ -584,7 +582,7 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags,
f(irc.TagValue(ref))
if dc.caps["batch"] {
if dc.caps.IsEnabled("batch") {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: "BATCH",
@ -597,7 +595,7 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags,
func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) {
dc.SendMessage(msg)
if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps["draft/chathistory"] {
if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps.IsEnabled("draft/chathistory") {
return
}
@ -608,7 +606,7 @@ func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) {
// sending a message. This is useful e.g. for self-messages when echo-message
// isn't enabled.
func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) {
if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps["draft/chathistory"] {
if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps.IsEnabled("draft/chathistory") {
return
}
@ -829,12 +827,12 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
// down the available capabilities when upstreams are
// known.
for k, v := range needAllDownstreamCaps {
dc.supportedCaps[k] = v
dc.caps.Available[k] = v
}
}
caps := make([]string, 0, len(dc.supportedCaps))
for k, v := range dc.supportedCaps {
caps := make([]string, 0, len(dc.caps.Available))
for k, v := range dc.caps.Available {
if dc.capVersion >= 302 && v != "" {
caps = append(caps, k+"="+v)
} else {
@ -851,7 +849,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
if dc.capVersion >= 302 {
// CAP version 302 implicitly enables cap-notify
dc.caps["cap-notify"] = true
dc.caps.SetEnabled("cap-notify", true)
}
if !dc.registered {
@ -859,10 +857,8 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
}
case "LIST":
var caps []string
for name, enabled := range dc.caps {
if enabled {
caps = append(caps, name)
}
for name := range dc.caps.Enabled {
caps = append(caps, name)
}
// TODO: multi-line replies
@ -889,12 +885,11 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
name = strings.TrimPrefix(name, "-")
}
if enable == dc.caps[name] {
if enable == dc.caps.IsEnabled(name) {
continue
}
_, ok := dc.supportedCaps[name]
if !ok {
if !dc.caps.IsAvailable(name) {
ack = false
break
}
@ -905,7 +900,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
break
}
dc.caps[name] = enable
dc.caps.SetEnabled(name, enable)
}
reply := "NAK"
@ -939,7 +934,7 @@ func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *d
}
}()
if !dc.caps["sasl"] {
if !dc.caps.IsEnabled("sasl") {
return nil, ircError{&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.ERR_SASLFAIL,
@ -1053,11 +1048,11 @@ func (dc *downstreamConn) endSASL(msg *irc.Message) {
}
func (dc *downstreamConn) setSupportedCap(name, value string) {
prevValue, hasPrev := dc.supportedCaps[name]
prevValue, hasPrev := dc.caps.Available[name]
changed := !hasPrev || prevValue != value
dc.supportedCaps[name] = value
dc.caps.Available[name] = value
if !dc.caps["cap-notify"] || !changed {
if !dc.caps.IsEnabled("cap-notify") || !changed {
return
}
@ -1074,11 +1069,10 @@ func (dc *downstreamConn) setSupportedCap(name, value string) {
}
func (dc *downstreamConn) unsetSupportedCap(name string) {
_, hasPrev := dc.supportedCaps[name]
delete(dc.supportedCaps, name)
delete(dc.caps, name)
hasPrev := dc.caps.IsAvailable(name)
dc.caps.Del(name)
if !dc.caps["cap-notify"] || !hasPrev {
if !dc.caps.IsEnabled("cap-notify") || !hasPrev {
return
}
@ -1149,7 +1143,7 @@ func (dc *downstreamConn) updateNick() {
}
func (dc *downstreamConn) updateRealname() {
if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] {
if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps.IsEnabled("setname") {
dc.SendMessage(&irc.Message{
Prefix: dc.prefix(),
Command: "SETNAME",
@ -1169,7 +1163,7 @@ func (dc *downstreamConn) updateAccount() {
return
}
if dc.account == account || !dc.caps["sasl"] {
if dc.account == account || !dc.caps.IsEnabled("sasl") {
return
}
@ -1272,7 +1266,7 @@ func (dc *downstreamConn) register(ctx context.Context) error {
dc.password = ""
if dc.user == nil {
if password == "" {
if dc.caps["sasl"] {
if dc.caps.IsEnabled("sasl") {
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"},
@ -1374,7 +1368,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
return err
}
if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream {
if dc.network == nil && !dc.caps.IsEnabled("soju.im/bouncer-networks") && dc.srv.Config().MultiUpstream {
dc.isMultiUpstream = true
}
@ -1462,7 +1456,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
})
}
if dc.caps["soju.im/bouncer-networks-notify"] {
if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) {
for _, network := range dc.user.networks {
idStr := fmt.Sprintf("%v", network.ID)
@ -1499,7 +1493,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
})
dc.forEachNetwork(func(net *network) {
if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
if dc.caps.IsEnabled("draft/chathistory") || dc.user.msgStore == nil {
return
}
@ -1549,7 +1543,7 @@ func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool {
}
func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) {
if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
if dc.caps.IsEnabled("draft/chathistory") || dc.user.msgStore == nil {
return
}
@ -2375,7 +2369,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM {
if dc.caps["echo-message"] {
if dc.caps.IsEnabled("echo-message") {
echoTags := tags.Copy()
echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
dc.SendMessage(&irc.Message{
@ -2737,7 +2731,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}}
}
eventPlayback := dc.caps["draft/event-playback"]
eventPlayback := dc.caps.IsEnabled("draft/event-playback")
var history []*irc.Message
switch subcommand {

View file

@ -1497,7 +1497,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
weAreInvited := uc.isOurNick(nick)
uc.forEachDownstream(func(dc *downstreamConn) {
if !weAreInvited && !dc.caps["invite-notify"] {
if !weAreInvited && !dc.caps.IsEnabled("invite-notify") {
return
}
dc.SendMessage(&irc.Message{
@ -2079,7 +2079,7 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr
detached := ch != nil && ch.Detached
uc.forEachDownstream(func(dc *downstreamConn) {
if !detached && (dc != origin || dc.caps["echo-message"]) {
if !detached && (dc != origin || dc.caps.IsEnabled("echo-message")) {
dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID)
} else {
dc.advanceMessageWithID(msg, msgID)

14
user.go
View file

@ -562,7 +562,7 @@ func (u *user) run() {
uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps()
if !dc.caps["soju.im/bouncer-networks"] {
if !dc.caps.IsEnabled("soju.im/bouncer-networks") {
sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
}
@ -571,7 +571,7 @@ func (u *user) run() {
dc.updateAccount()
})
u.forEachDownstream(func(dc *downstreamConn) {
if dc.caps["soju.im/bouncer-networks-notify"] {
if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: "BOUNCER",
@ -751,7 +751,7 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
}
u.forEachDownstream(func(dc *downstreamConn) {
if dc.caps["soju.im/bouncer-networks-notify"] {
if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: "BOUNCER",
@ -762,7 +762,7 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
if uc.network.lastError == nil {
uc.forEachDownstream(func(dc *downstreamConn) {
if !dc.caps["soju.im/bouncer-networks"] {
if !dc.caps.IsEnabled("soju.im/bouncer-networks") {
sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
}
})
@ -872,7 +872,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er
idStr := fmt.Sprintf("%v", network.ID)
attrs := getNetworkAttrs(network)
u.forEachDownstream(func(dc *downstreamConn) {
if dc.caps["soju.im/bouncer-networks-notify"] {
if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: "BOUNCER",
@ -953,7 +953,7 @@ func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, er
idStr := fmt.Sprintf("%v", updatedNetwork.ID)
attrs := getNetworkAttrs(updatedNetwork)
u.forEachDownstream(func(dc *downstreamConn) {
if dc.caps["soju.im/bouncer-networks-notify"] {
if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: "BOUNCER",
@ -979,7 +979,7 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
idStr := fmt.Sprintf("%v", network.ID)
u.forEachDownstream(func(dc *downstreamConn) {
if dc.caps["soju.im/bouncer-networks-notify"] {
if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: "BOUNCER",