Make casemapMap more type-safe

In addition to a type-safe getter, also define type-safe setters
and iterators.

References: https://lists.sr.ht/~emersion/soju-dev/patches/32777
This commit is contained in:
Simon Ser 2022-06-06 09:58:39 +02:00
parent c8f9728ff6
commit 657e25b25c
5 changed files with 183 additions and 137 deletions

View file

@ -1592,14 +1592,13 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
}
dc.forEachUpstream(func(uc *upstreamConn) {
for _, entry := range uc.channels.innerMap {
ch := entry.value.(*upstreamChannel)
uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
if !ch.complete {
continue
return
}
record := uc.network.channels.Value(ch.Name)
record := uc.network.channels.Get(ch.Name)
if record != nil && record.Detached {
continue
return
}
dc.SendMessage(&irc.Message{
@ -1609,7 +1608,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
})
forwardChannel(ctx, dc, ch)
}
})
})
dc.forEachNetwork(func(net *network) {
@ -1667,7 +1666,7 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t
return
}
ch := net.channels.Value(target)
ch := net.channels.Get(target)
ctx, cancel := context.WithTimeout(ctx, backlogTimeout)
defer cancel()
@ -1938,7 +1937,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
})
}
ch := uc.network.channels.Value(upstreamName)
ch := uc.network.channels.Get(upstreamName)
if ch != nil {
// Don't clear the channel key if there's one set
// TODO: add a way to unset the channel key
@ -1951,7 +1950,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Name: upstreamName,
Key: key,
}
uc.network.channels.SetValue(upstreamName, ch)
uc.network.channels.Set(upstreamName, ch)
}
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
@ -1975,7 +1974,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
if strings.EqualFold(reason, "detach") {
ch := uc.network.channels.Value(upstreamName)
ch := uc.network.channels.Get(upstreamName)
if ch != nil {
uc.network.detach(ch)
} else {
@ -1983,7 +1982,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Name: name,
Detached: true,
}
uc.network.channels.SetValue(upstreamName, ch)
uc.network.channels.Set(upstreamName, ch)
}
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
@ -2119,7 +2118,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Params: params,
})
} else {
ch := uc.channels.Value(upstreamName)
ch := uc.channels.Get(upstreamName)
if ch == nil {
return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
@ -2168,7 +2167,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Params: []string{upstreamName, topic},
})
} else { // getting topic
ch := uc.channels.Value(upstreamName)
ch := uc.channels.Get(upstreamName)
if ch == nil {
return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
@ -2223,7 +2222,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return err
}
ch := uc.channels.Value(upstreamName)
ch := uc.channels.Get(upstreamName)
if ch != nil {
sendNames(dc, ch)
} else {
@ -2677,7 +2676,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
for _, target := range strings.Split(targets, ",") {
if subcommand == "+" {
// Hard limit, just to avoid having downstreams fill our map
if len(dc.monitored.innerMap) >= 1000 {
if dc.monitored.Len() >= 1000 {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.ERR_MONLISTFULL,
@ -2686,7 +2685,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
continue
}
dc.monitored.SetValue(target, nil)
dc.monitored.set(target, nil)
if uc.network.casemap(target) == serviceNickCM {
// BouncerServ is never tired
@ -2700,7 +2699,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if uc.monitored.Has(target) {
cmd := irc.RPL_MONOFFLINE
if online := uc.monitored.Value(target); online {
if online := uc.monitored.Get(target); online {
cmd = irc.RPL_MONONLINE
}
@ -2711,7 +2710,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
})
}
} else {
dc.monitored.Delete(target)
dc.monitored.Del(target)
}
}
uc.updateMonitor()
@ -2721,7 +2720,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
uc.updateMonitor()
case "L": // list
// TODO: be less lazy and pack the list
for _, entry := range dc.monitored.innerMap {
for _, entry := range dc.monitored.m {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_MONLIST,
@ -2735,11 +2734,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
})
case "S": // status
// TODO: be less lazy and pack the lists
for _, entry := range dc.monitored.innerMap {
for _, entry := range dc.monitored.m {
target := entry.originalKey
cmd := irc.RPL_MONOFFLINE
if online := uc.monitored.Value(target); online {
if online := uc.monitored.Get(target); online {
cmd = irc.RPL_MONONLINE
}
@ -2872,7 +2871,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {
for _, target := range targets {
if ch := network.channels.Value(target.Name); ch != nil && ch.Detached {
if ch := network.channels.Get(target.Name); ch != nil && ch.Detached {
continue
}
@ -3329,12 +3328,10 @@ func sendNames(dc *downstreamConn, ch *upstreamChannel) {
downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
var members []string
for _, entry := range ch.Members.innerMap {
nick := entry.originalKey
memberships := entry.value.(*xirc.MembershipSet)
ch.Members.ForEach(func(nick string, memberships *xirc.MembershipSet) {
s := formatMemberPrefix(*memberships, dc) + dc.marshalEntity(ch.conn.network, nick)
members = append(members, s)
}
})
msgs := xirc.GenerateNamesReply(dc.srv.prefix(), dc.nick, downstreamName, ch.Status, members)
for _, msg := range msgs {

134
irc.go
View file

@ -111,7 +111,7 @@ outer:
return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
}
member := arguments[nextArgument]
m := ch.Members.Value(member)
m := ch.Members.Get(member)
if m != nil {
if plusMinus == '+' {
m.Add(ch.conn.availableMemberships, membership)
@ -304,8 +304,8 @@ func partialCasemap(higher casemapping, name string) string {
}
type casemapMap struct {
innerMap map[string]casemapEntry
casemap casemapping
m map[string]casemapEntry
casemap casemapping
}
type casemapEntry struct {
@ -315,95 +315,153 @@ type casemapEntry struct {
func newCasemapMap() casemapMap {
return casemapMap{
innerMap: make(map[string]casemapEntry),
casemap: casemapNone,
m: make(map[string]casemapEntry),
casemap: casemapNone,
}
}
func (cm *casemapMap) Has(name string) bool {
_, ok := cm.innerMap[cm.casemap(name)]
_, ok := cm.m[cm.casemap(name)]
return ok
}
func (cm *casemapMap) Len() int {
return len(cm.innerMap)
return len(cm.m)
}
func (cm *casemapMap) SetValue(name string, value interface{}) {
nameCM := cm.casemap(name)
entry, ok := cm.innerMap[nameCM]
func (cm *casemapMap) get(name string) interface{} {
entry, ok := cm.m[cm.casemap(name)]
if !ok {
cm.innerMap[nameCM] = casemapEntry{
return nil
}
return entry.value
}
func (cm *casemapMap) set(name string, value interface{}) {
nameCM := cm.casemap(name)
entry, ok := cm.m[nameCM]
if !ok {
cm.m[nameCM] = casemapEntry{
originalKey: name,
value: value,
}
return
}
entry.value = value
cm.innerMap[nameCM] = entry
cm.m[nameCM] = entry
}
func (cm *casemapMap) Delete(name string) {
delete(cm.innerMap, cm.casemap(name))
func (cm *casemapMap) Del(name string) {
delete(cm.m, cm.casemap(name))
}
func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
cm.casemap = newCasemap
newInnerMap := make(map[string]casemapEntry, len(cm.innerMap))
for _, entry := range cm.innerMap {
newInnerMap[cm.casemap(entry.originalKey)] = entry
m := make(map[string]casemapEntry, len(cm.m))
for _, entry := range cm.m {
m[cm.casemap(entry.originalKey)] = entry
}
cm.innerMap = newInnerMap
cm.m = m
}
type upstreamChannelCasemapMap struct{ casemapMap }
func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
func (cm *upstreamChannelCasemapMap) Get(name string) *upstreamChannel {
if v := cm.get(name); v == nil {
return nil
} else {
return v.(*upstreamChannel)
}
}
func (cm *upstreamChannelCasemapMap) Set(name string, uch *upstreamChannel) {
cm.set(name, uch)
}
func (cm *upstreamChannelCasemapMap) ForEach(f func(string, *upstreamChannel)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(*upstreamChannel))
}
return entry.value.(*upstreamChannel)
}
type channelCasemapMap struct{ casemapMap }
func (cm *channelCasemapMap) Value(name string) *database.Channel {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
func (cm *channelCasemapMap) Get(name string) *database.Channel {
if v := cm.get(name); v == nil {
return nil
} else {
return v.(*database.Channel)
}
}
func (cm *channelCasemapMap) Set(name string, ch *database.Channel) {
cm.set(name, ch)
}
func (cm *channelCasemapMap) ForEach(f func(string, *database.Channel)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(*database.Channel))
}
return entry.value.(*database.Channel)
}
type membershipsCasemapMap struct{ casemapMap }
func (cm *membershipsCasemapMap) Value(name string) *xirc.MembershipSet {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
func (cm *membershipsCasemapMap) Get(name string) *xirc.MembershipSet {
if v := cm.get(name); v == nil {
return nil
} else {
return v.(*xirc.MembershipSet)
}
}
func (cm *membershipsCasemapMap) Set(name string, ms *xirc.MembershipSet) {
cm.set(name, ms)
}
func (cm *membershipsCasemapMap) ForEach(f func(string, *xirc.MembershipSet)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(*xirc.MembershipSet))
}
return entry.value.(*xirc.MembershipSet)
}
type deliveredCasemapMap struct{ casemapMap }
func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
func (cm *deliveredCasemapMap) Get(name string) deliveredClientMap {
if v := cm.get(name); v == nil {
return nil
} else {
return v.(deliveredClientMap)
}
}
func (cm *deliveredCasemapMap) Set(name string, m deliveredClientMap) {
cm.set(name, m)
}
func (cm *deliveredCasemapMap) ForEach(f func(string, deliveredClientMap)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(deliveredClientMap))
}
return entry.value.(deliveredClientMap)
}
type monitorCasemapMap struct{ casemapMap }
func (cm *monitorCasemapMap) Value(name string) (online bool) {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
func (cm *monitorCasemapMap) Get(name string) (online bool) {
if v := cm.get(name); v == nil {
return false
} else {
return v.(bool)
}
}
func (cm *monitorCasemapMap) Set(name string, online bool) {
cm.set(name, online)
}
func (cm *monitorCasemapMap) ForEach(f func(name string, online bool)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(bool))
}
return entry.value.(bool)
}
func isWordBoundary(r rune) bool {

View file

@ -974,9 +974,9 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
sendNetwork := func(net *network) {
var channels []*database.Channel
for _, entry := range net.channels.innerMap {
channels = append(channels, entry.value.(*database.Channel))
}
net.channels.ForEach(func(_ string, ch *database.Channel) {
channels = append(channels, ch)
})
sort.Slice(channels, func(i, j int) bool {
return strings.ReplaceAll(channels[i].Name, "#", "") <
@ -986,7 +986,7 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
for _, ch := range channels {
var uch *upstreamChannel
if net.conn != nil {
uch = net.conn.channels.Value(ch.Name)
uch = net.conn.channels.Get(ch.Name)
}
name := ch.Name
@ -1109,7 +1109,7 @@ func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params
return fmt.Errorf("unknown channel %q", name)
}
ch := uc.network.channels.Value(upstreamName)
ch := uc.network.channels.Get(upstreamName)
if ch == nil {
return fmt.Errorf("unknown channel %q", name)
}

View file

@ -292,7 +292,7 @@ func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn {
}
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
ch := uc.channels.Value(name)
ch := uc.channels.Get(name)
if ch == nil {
return nil, fmt.Errorf("unknown channel %q", name)
}
@ -513,7 +513,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
self := uc.isOurNick(msg.Prefix.Name)
ch := uc.network.channels.Value(target)
ch := uc.network.channels.Get(target)
if ch != nil && msg.Command != "TAGMSG" && !self {
if ch.Detached {
uc.handleDetachedMessage(ctx, ch, msg)
@ -757,11 +757,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if uc.network.channels.Len() > 0 {
var channels, keys []string
for _, entry := range uc.network.channels.innerMap {
ch := entry.value.(*database.Channel)
uc.network.channels.ForEach(func(_ string, ch *database.Channel) {
channels = append(channels, ch.Name)
keys = append(keys, ch.Key)
}
})
for _, msg := range xirc.GenerateJoin(channels, keys) {
uc.SendMessage(ctx, msg)
@ -918,15 +917,14 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.nickCM = uc.network.casemap(uc.nick)
}
for _, entry := range uc.channels.innerMap {
ch := entry.value.(*upstreamChannel)
memberships := ch.Members.Value(msg.Prefix.Name)
uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
memberships := ch.Members.Get(msg.Prefix.Name)
if memberships != nil {
ch.Members.Delete(msg.Prefix.Name)
ch.Members.SetValue(newNick, memberships)
ch.Members.Del(msg.Prefix.Name)
ch.Members.Set(newNick, memberships)
uc.appendLog(ch.Name, msg)
}
}
})
if !me {
uc.forEachDownstream(func(dc *downstreamConn) {
@ -995,7 +993,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.logger.Printf("joined channel %q", ch)
members := membershipsCasemapMap{newCasemapMap()}
members.casemap = uc.network.casemap
uc.channels.SetValue(ch, &upstreamChannel{
uc.channels.Set(ch, &upstreamChannel{
Name: ch,
conn: uc,
Members: members,
@ -1011,7 +1009,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if err != nil {
return err
}
ch.Members.SetValue(msg.Prefix.Name, &xirc.MembershipSet{})
ch.Members.Set(msg.Prefix.Name, &xirc.MembershipSet{})
}
chMsg := msg.Copy()
@ -1027,9 +1025,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
for _, ch := range strings.Split(channels, ",") {
if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("parted channel %q", ch)
uch := uc.channels.Value(ch)
if uch != nil {
uc.channels.Delete(ch)
if uch := uc.channels.Get(ch); uch != nil {
uc.channels.Del(ch)
uch.updateAutoDetach(0)
}
} else {
@ -1037,7 +1034,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if err != nil {
return err
}
ch.Members.Delete(msg.Prefix.Name)
ch.Members.Del(msg.Prefix.Name)
}
chMsg := msg.Copy()
@ -1052,13 +1049,13 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if uc.isOurNick(user) {
uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
uc.channels.Delete(channel)
uc.channels.Del(channel)
} else {
ch, err := uc.getChannel(channel)
if err != nil {
return err
}
ch.Members.Delete(user)
ch.Members.Del(user)
}
uc.produce(channel, msg, 0)
@ -1067,14 +1064,12 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.logger.Printf("quit")
}
for _, entry := range uc.channels.innerMap {
ch := entry.value.(*upstreamChannel)
uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
if ch.Members.Has(msg.Prefix.Name) {
ch.Members.Delete(msg.Prefix.Name)
ch.Members.Del(msg.Prefix.Name)
uc.appendLog(ch.Name, msg)
}
}
})
if msg.Prefix.Name != uc.nick {
uc.forEachDownstream(func(dc *downstreamConn) {
@ -1147,7 +1142,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.appendLog(ch.Name, msg)
c := uc.network.channels.Value(name)
c := uc.network.channels.Get(name)
if c == nil || !c.Detached {
uc.forEachDownstream(func(dc *downstreamConn) {
params := make([]string, len(msg.Params))
@ -1211,7 +1206,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return err
}
c := uc.network.channels.Value(channel)
c := uc.network.channels.Get(channel)
if firstMode && (c == nil || !c.Detached) {
modeStr, modeParams := ch.modes.Format()
@ -1240,7 +1235,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
firstCreationTime := ch.creationTime == ""
ch.creationTime = creationTime
c := uc.network.channels.Value(channel)
c := uc.network.channels.Get(channel)
if firstCreationTime && (c == nil || !c.Detached) {
uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{
@ -1269,7 +1264,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}
ch.TopicTime = time.Unix(sec, 0)
c := uc.network.channels.Value(channel)
c := uc.network.channels.Get(channel)
if firstTopicWhoTime && (c == nil || !c.Detached) {
uc.forEachDownstream(func(dc *downstreamConn) {
topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho)
@ -1322,7 +1317,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return err
}
ch := uc.channels.Value(name)
ch := uc.channels.Get(name)
if ch == nil {
// NAMES on a channel we have not joined, forward to downstream
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
@ -1351,7 +1346,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
for _, s := range splitSpace(members) {
memberships, nick := uc.parseMembershipPrefix(s)
ch.Members.SetValue(nick, memberships)
ch.Members.Set(nick, &memberships)
}
case irc.RPL_ENDOFNAMES:
var name string
@ -1359,7 +1354,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return err
}
ch := uc.channels.Value(name)
ch := uc.channels.Get(name)
if ch == nil {
// NAMES on a channel we have not joined, forward to downstream
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
@ -1379,7 +1374,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}
ch.complete = true
c := uc.network.channels.Value(name)
c := uc.network.channels.Get(name)
if c == nil || !c.Detached {
uc.forEachDownstream(func(dc *downstreamConn) {
forwardChannel(ctx, dc, ch)
@ -1542,7 +1537,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
online := msg.Command == irc.RPL_MONONLINE
for _, target := range targets {
prefix := irc.ParsePrefix(target)
uc.monitored.SetValue(prefix.Name, online)
uc.monitored.Set(prefix.Name, online)
}
// Check if the nick we want is now free
@ -2112,7 +2107,7 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, originID uint64
}
// Don't forward messages if it's a detached channel
ch := uc.network.channels.Value(target)
ch := uc.network.channels.Get(target)
detached := ch != nil && ch.Detached
uc.forEachDownstream(func(dc *downstreamConn) {
@ -2148,11 +2143,11 @@ func (uc *upstreamConn) updateAway() {
}
func (uc *upstreamConn) updateChannelAutoDetach(name string) {
uch := uc.channels.Value(name)
uch := uc.channels.Get(name)
if uch == nil {
return
}
ch := uc.network.channels.Value(name)
ch := uc.network.channels.Get(name)
if ch == nil || ch.Detached {
return
}
@ -2170,7 +2165,7 @@ func (uc *upstreamConn) updateMonitor() {
var addList []string
seen := make(map[string]struct{})
uc.forEachDownstream(func(dc *downstreamConn) {
for _, entry := range dc.monitored.innerMap {
for _, entry := range dc.monitored.m {
targetCM := uc.network.casemap(entry.originalKey)
if targetCM == serviceNickCM {
continue
@ -2195,13 +2190,13 @@ func (uc *upstreamConn) updateMonitor() {
removeAll := true
var removeList []string
for targetCM, entry := range uc.monitored.innerMap {
if _, ok := seen[targetCM]; ok {
uc.monitored.ForEach(func(nick string, online bool) {
if _, ok := seen[uc.network.casemap(nick)]; ok {
removeAll = false
} else {
removeList = append(removeList, entry.originalKey)
removeList = append(removeList, nick)
}
}
})
// TODO: better handle the case where len(uc.monitored) + len(addList)
// exceeds the limit, probably by immediately sending ERR_MONLISTFULL?
@ -2221,6 +2216,6 @@ func (uc *upstreamConn) updateMonitor() {
}
for _, target := range removeList {
uc.monitored.Delete(target)
uc.monitored.Del(target)
}
}

48
user.go
View file

@ -85,11 +85,11 @@ func newDeliveredStore() deliveredStore {
}
func (ds deliveredStore) HasTarget(target string) bool {
return ds.m.Value(target) != nil
return ds.m.Get(target) != nil
}
func (ds deliveredStore) LoadID(target, clientName string) string {
clients := ds.m.Value(target)
clients := ds.m.Get(target)
if clients == nil {
return ""
}
@ -97,28 +97,27 @@ func (ds deliveredStore) LoadID(target, clientName string) string {
}
func (ds deliveredStore) StoreID(target, clientName, msgID string) {
clients := ds.m.Value(target)
clients := ds.m.Get(target)
if clients == nil {
clients = make(deliveredClientMap)
ds.m.SetValue(target, clients)
ds.m.Set(target, clients)
}
clients[clientName] = msgID
}
func (ds deliveredStore) ForEachTarget(f func(target string)) {
for _, entry := range ds.m.innerMap {
f(entry.originalKey)
}
ds.m.ForEach(func(name string, _ deliveredClientMap) {
f(name)
})
}
func (ds deliveredStore) ForEachClient(f func(clientName string)) {
clients := make(map[string]struct{})
for _, entry := range ds.m.innerMap {
delivered := entry.value.(deliveredClientMap)
ds.m.ForEach(func(name string, delivered deliveredClientMap) {
for clientName := range delivered {
clients[clientName] = struct{}{}
}
}
})
for clientName := range clients {
f(clientName)
@ -144,7 +143,7 @@ func newNetwork(user *user, record *database.Network, channels []database.Channe
m := channelCasemapMap{newCasemapMap()}
for _, ch := range channels {
ch := ch
m.SetValue(ch.Name, &ch)
m.Set(ch.Name, &ch)
}
return &network{
@ -300,7 +299,7 @@ func (net *network) detach(ch *database.Channel) {
}
if net.conn != nil {
uch := net.conn.channels.Value(ch.Name)
uch := net.conn.channels.Get(ch.Name)
if uch != nil {
uch.updateAutoDetach(0)
}
@ -328,7 +327,7 @@ func (net *network) attach(ctx context.Context, ch *database.Channel) {
var uch *upstreamChannel
if net.conn != nil {
uch = net.conn.channels.Value(ch.Name)
uch = net.conn.channels.Get(ch.Name)
net.conn.updateChannelAutoDetach(ch.Name)
}
@ -351,12 +350,12 @@ func (net *network) attach(ctx context.Context, ch *database.Channel) {
}
func (net *network) deleteChannel(ctx context.Context, name string) error {
ch := net.channels.Value(name)
ch := net.channels.Get(name)
if ch == nil {
return fmt.Errorf("unknown channel %q", name)
}
if net.conn != nil {
uch := net.conn.channels.Value(ch.Name)
uch := net.conn.channels.Get(ch.Name)
if uch != nil {
uch.updateAutoDetach(0)
}
@ -365,7 +364,7 @@ func (net *network) deleteChannel(ctx context.Context, name string) error {
if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
return err
}
net.channels.Delete(name)
net.channels.Del(name)
return nil
}
@ -375,10 +374,9 @@ func (net *network) updateCasemapping(newCasemap casemapping) {
net.delivered.m.SetCasemapping(newCasemap)
if uc := net.conn; uc != nil {
uc.channels.SetCasemapping(newCasemap)
for _, entry := range uc.channels.innerMap {
uch := entry.value.(*upstreamChannel)
uc.channels.ForEach(func(_ string, uch *upstreamChannel) {
uch.Members.SetCasemapping(newCasemap)
}
})
uc.monitored.SetCasemapping(newCasemap)
}
net.forEachDownstream(func(dc *downstreamConn) {
@ -623,7 +621,7 @@ func (u *user) run() {
}
case eventChannelDetach:
uc, name := e.uc, e.name
c := uc.network.channels.Value(name)
c := uc.network.channels.Get(name)
if c == nil || c.Detached {
continue
}
@ -746,10 +744,9 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
uc.abortPendingCommands()
for _, entry := range uc.channels.innerMap {
uch := entry.value.(*upstreamChannel)
uc.channels.ForEach(func(_ string, uch *upstreamChannel) {
uch.updateAutoDetach(0)
}
})
uc.forEachDownstream(func(dc *downstreamConn) {
dc.updateSupportedCaps()
@ -924,10 +921,9 @@ func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*ne
// Most network changes require us to re-connect to the upstream server
channels := make([]database.Channel, 0, network.channels.Len())
for _, entry := range network.channels.innerMap {
ch := entry.value.(*database.Channel)
network.channels.ForEach(func(_ string, ch *database.Channel) {
channels = append(channels, *ch)
}
})
updatedNetwork := newNetwork(u, record, channels)