From 07cd1f2f5dda092f2721a61d83c1e4836daa5178 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 1 Mar 2023 13:15:38 +0100 Subject: [PATCH] Use generics for casemapMap --- downstream.go | 14 ++--- irc.go | 168 +++++++------------------------------------------- service.go | 2 +- upstream.go | 30 ++++----- user.go | 20 +++--- 5 files changed, 56 insertions(+), 178 deletions(-) diff --git a/downstream.go b/downstream.go index c3e9c6f..c463344 100644 --- a/downstream.go +++ b/downstream.go @@ -348,7 +348,7 @@ type downstreamConn struct { lastBatchRef uint64 - monitored casemapMap + monitored casemapMap[struct{}] } func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { @@ -362,7 +362,7 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { nickCM: "*", username: "~u", caps: xirc.NewCapRegistry(), - monitored: newCasemapMap(), + monitored: newCasemapMap[struct{}](), registration: new(downstreamRegistration), } dc.monitored.SetCasemapping(casemapASCII) @@ -1553,7 +1553,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error { } dc.forEachUpstream(func(uc *upstreamConn) { - uc.channels.ForEach(func(ch *upstreamChannel) { + uc.channels.ForEach(func(_ string, ch *upstreamChannel) { if !ch.complete { return } @@ -1928,7 +1928,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Name: name, Key: key, } - uc.network.channels.Set(ch) + uc.network.channels.Set(ch.Name, ch) } if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { dc.logger.Printf("failed to create or update channel %q: %v", name, err) @@ -1960,7 +1960,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Name: name, Detached: true, } - uc.network.channels.Set(ch) + uc.network.channels.Set(ch.Name, ch) } if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { dc.logger.Printf("failed to create or update channel %q: %v", name, err) @@ -2621,7 +2621,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. continue } - dc.monitored.set(target, nil) + dc.monitored.Set(target, struct{}{}) if uc.network.casemap(target) == serviceNickCM { // BouncerServ is never tired @@ -2651,7 +2651,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } uc.updateMonitor() case "C": // clear - dc.monitored = newCasemapMap() + dc.monitored = newCasemapMap[struct{}]() dc.monitored.SetCasemapping(uc.network.casemap) uc.updateMonitor() case "L": // list diff --git a/irc.go b/irc.go index 391a809..25a7dce 100644 --- a/irc.go +++ b/irc.go @@ -9,7 +9,6 @@ import ( "gopkg.in/irc.v4" - "git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/xirc" ) @@ -287,45 +286,46 @@ func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) { return casemap, true } -type casemapMap struct { - m map[string]casemapEntry +type casemapMap[V interface{}] struct { + m map[string]casemapEntry[V] casemap casemapping } -type casemapEntry struct { +type casemapEntry[V interface{}] struct { originalKey string - value interface{} + value V } -func newCasemapMap() casemapMap { - return casemapMap{ - m: make(map[string]casemapEntry), +func newCasemapMap[V interface{}]() casemapMap[V] { + return casemapMap[V]{ + m: make(map[string]casemapEntry[V]), casemap: casemapNone, } } -func (cm *casemapMap) Has(name string) bool { +func (cm *casemapMap[V]) Has(name string) bool { _, ok := cm.m[cm.casemap(name)] return ok } -func (cm *casemapMap) Len() int { +func (cm *casemapMap[V]) Len() int { return len(cm.m) } -func (cm *casemapMap) get(name string) interface{} { +func (cm *casemapMap[V]) Get(name string) V { entry, ok := cm.m[cm.casemap(name)] if !ok { - return nil + var v V + return v } return entry.value } -func (cm *casemapMap) set(name string, value interface{}) { +func (cm *casemapMap[V]) Set(name string, value V) { nameCM := cm.casemap(name) entry, ok := cm.m[nameCM] if !ok { - cm.m[nameCM] = casemapEntry{ + cm.m[nameCM] = casemapEntry[V]{ originalKey: name, value: value, } @@ -335,147 +335,25 @@ func (cm *casemapMap) set(name string, value interface{}) { cm.m[nameCM] = entry } -func (cm *casemapMap) Del(name string) { +func (cm *casemapMap[V]) Del(name string) { delete(cm.m, cm.casemap(name)) } -func (cm *casemapMap) SetCasemapping(newCasemap casemapping) { +func (cm *casemapMap[V]) ForEach(f func(string, V)) { + for _, entry := range cm.m { + f(entry.originalKey, entry.value) + } +} + +func (cm *casemapMap[V]) SetCasemapping(newCasemap casemapping) { cm.casemap = newCasemap - m := make(map[string]casemapEntry, len(cm.m)) + m := make(map[string]casemapEntry[V], len(cm.m)) for _, entry := range cm.m { m[cm.casemap(entry.originalKey)] = entry } cm.m = m } -type upstreamChannelCasemapMap struct{ casemapMap } - -func (cm *upstreamChannelCasemapMap) Get(name string) *upstreamChannel { - if v := cm.get(name); v == nil { - return nil - } else { - return v.(*upstreamChannel) - } -} - -func (cm *upstreamChannelCasemapMap) Set(uch *upstreamChannel) { - cm.set(uch.Name, uch) -} - -func (cm *upstreamChannelCasemapMap) ForEach(f func(*upstreamChannel)) { - for _, entry := range cm.m { - f(entry.value.(*upstreamChannel)) - } -} - -type upstreamUserCasemapMap struct{ casemapMap } - -func (cm *upstreamUserCasemapMap) Get(name string) *upstreamUser { - if v := cm.get(name); v == nil { - return nil - } else { - return v.(*upstreamUser) - } -} - -func (cm *upstreamUserCasemapMap) Set(u *upstreamUser) { - cm.set(u.Nickname, u) -} - -type channelCasemapMap struct{ casemapMap } - -func (cm *channelCasemapMap) Get(name string) *database.Channel { - if v := cm.get(name); v == nil { - return nil - } else { - return v.(*database.Channel) - } -} - -func (cm *channelCasemapMap) Set(ch *database.Channel) { - cm.set(ch.Name, ch) -} - -func (cm *channelCasemapMap) ForEach(f func(*database.Channel)) { - for _, entry := range cm.m { - f(entry.value.(*database.Channel)) - } -} - -type membershipsCasemapMap struct{ casemapMap } - -func (cm *membershipsCasemapMap) Get(name string) *xirc.MembershipSet { - if v := cm.get(name); v == nil { - return nil - } else { - return v.(*xirc.MembershipSet) - } -} - -func (cm *membershipsCasemapMap) Set(name string, ms *xirc.MembershipSet) { - cm.set(name, ms) -} - -func (cm *membershipsCasemapMap) ForEach(f func(string, *xirc.MembershipSet)) { - for _, entry := range cm.m { - f(entry.originalKey, entry.value.(*xirc.MembershipSet)) - } -} - -type deliveredCasemapMap struct{ casemapMap } - -func (cm *deliveredCasemapMap) Get(name string) deliveredClientMap { - if v := cm.get(name); v == nil { - return nil - } else { - return v.(deliveredClientMap) - } -} - -func (cm *deliveredCasemapMap) Set(name string, m deliveredClientMap) { - cm.set(name, m) -} - -func (cm *deliveredCasemapMap) ForEach(f func(string, deliveredClientMap)) { - for _, entry := range cm.m { - f(entry.originalKey, entry.value.(deliveredClientMap)) - } -} - -type monitorCasemapMap struct{ casemapMap } - -func (cm *monitorCasemapMap) Get(name string) (online bool) { - if v := cm.get(name); v == nil { - return false - } else { - return v.(bool) - } -} - -func (cm *monitorCasemapMap) Set(name string, online bool) { - cm.set(name, online) -} - -func (cm *monitorCasemapMap) ForEach(f func(name string, online bool)) { - for _, entry := range cm.m { - f(entry.originalKey, entry.value.(bool)) - } -} - -type pushTargetCasemapMap struct{ casemapMap } - -func (cm *pushTargetCasemapMap) Get(name string) (last time.Time) { - if v := cm.get(name); v == nil { - return time.Time{} - } else { - return v.(time.Time) - } -} - -func (cm *pushTargetCasemapMap) Set(name string, last time.Time) { - cm.set(name, last) -} - func isWordBoundary(r rune) bool { switch r { case '-', '_', '|': // inspired from weechat.look.highlight_regex diff --git a/service.go b/service.go index 92cfb73..66f74fe 100644 --- a/service.go +++ b/service.go @@ -1224,7 +1224,7 @@ func handleServiceChannelStatus(ctx *serviceContext, params []string) error { sendNetwork := func(net *network) { var channels []*database.Channel - net.channels.ForEach(func(ch *database.Channel) { + net.channels.ForEach(func(_ string, ch *database.Channel) { channels = append(channels, ch) }) diff --git a/upstream.go b/upstream.go index 0586d80..8b3d5b3 100644 --- a/upstream.go +++ b/upstream.go @@ -89,7 +89,7 @@ type upstreamChannel struct { Status xirc.ChannelStatus modes channelModes creationTime string - Members membershipsCasemapMap + Members casemapMap[*xirc.MembershipSet] complete bool detachTimer *time.Timer } @@ -208,14 +208,14 @@ type upstreamConn struct { realname string hostname string modes userModes - channels upstreamChannelCasemapMap - users upstreamUserCasemapMap + channels casemapMap[*upstreamChannel] + users casemapMap[*upstreamUser] caps xirc.CapRegistry batches map[string]upstreamBatch away bool account string nextLabelID uint64 - monitored monitorCasemapMap + monitored casemapMap[bool] saslClient sasl.Client saslStarted bool @@ -366,8 +366,8 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options), network: network, user: network.user, - channels: upstreamChannelCasemapMap{newCasemapMap()}, - users: upstreamUserCasemapMap{newCasemapMap()}, + channels: newCasemapMap[*upstreamChannel](), + users: newCasemapMap[*upstreamUser](), caps: xirc.NewCapRegistry(), batches: make(map[string]upstreamBatch), serverPrefix: &irc.Prefix{Name: "*"}, @@ -376,7 +376,7 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er availableMemberships: stdMemberships, isupport: make(map[string]*string), pendingCmds: make(map[string][]pendingUpstreamCommand), - monitored: monitorCasemapMap{newCasemapMap()}, + monitored: newCasemapMap[bool](), hasDesiredNick: true, } return uc, nil @@ -898,7 +898,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if uc.network.channels.Len() > 0 { var channels, keys []string - uc.network.channels.ForEach(func(ch *database.Channel) { + uc.network.channels.ForEach(func(_ string, ch *database.Channel) { channels = append(channels, ch.Name) keys = append(keys, ch.Key) }) @@ -1067,7 +1067,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } } - uc.channels.ForEach(func(ch *upstreamChannel) { + uc.channels.ForEach(func(_ string, ch *upstreamChannel) { memberships := ch.Members.Get(msg.Prefix.Name) if memberships != nil { ch.Members.Del(msg.Prefix.Name) @@ -1173,9 +1173,9 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err for _, ch := range strings.Split(channels, ",") { if uc.isOurNick(msg.Prefix.Name) { uc.logger.Printf("joined channel %q", ch) - members := membershipsCasemapMap{newCasemapMap()} + members := newCasemapMap[*xirc.MembershipSet]() members.casemap = uc.network.casemap - uc.channels.Set(&upstreamChannel{ + uc.channels.Set(ch, &upstreamChannel{ Name: ch, conn: uc, Members: members, @@ -1264,7 +1264,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.logger.Printf("quit") } - uc.channels.ForEach(func(ch *upstreamChannel) { + uc.channels.ForEach(func(_ string, ch *upstreamChannel) { if ch.Members.Has(msg.Prefix.Name) { ch.Members.Del(msg.Prefix.Name) uc.appendLog(ch.Name, msg) @@ -2440,12 +2440,12 @@ func (uc *upstreamConn) cacheUserInfo(nick string, info *upstreamUser) { } else { info.Nickname = nick } - uc.users.Set(info) + uc.users.Set(info.Nickname, info) } else { uu.updateFrom(info) if info.Nickname != "" && nick != info.Nickname { uc.users.Del(nick) - uc.users.Set(uu) + uc.users.Set(uu.Nickname, uu) } } } @@ -2459,7 +2459,7 @@ func (uc *upstreamConn) shouldCacheUserInfo(nick string) bool { return true } found := false - uc.channels.ForEach(func(ch *upstreamChannel) { + uc.channels.ForEach(func(_ string, ch *upstreamChannel) { found = found || ch.Members.Has(nick) }) return found diff --git a/user.go b/user.go index 0e434aa..d038cb6 100644 --- a/user.go +++ b/user.go @@ -93,11 +93,11 @@ type eventUserRun struct { type deliveredClientMap map[string]string // client name -> msg ID type deliveredStore struct { - m deliveredCasemapMap + m casemapMap[deliveredClientMap] } func newDeliveredStore() deliveredStore { - return deliveredStore{deliveredCasemapMap{newCasemapMap()}} + return deliveredStore{newCasemapMap[deliveredClientMap]()} } func (ds deliveredStore) HasTarget(target string) bool { @@ -147,9 +147,9 @@ type network struct { stopped chan struct{} conn *upstreamConn - channels channelCasemapMap + channels casemapMap[*database.Channel] delivered deliveredStore - pushTargets pushTargetCasemapMap + pushTargets casemapMap[time.Time] lastError error casemap casemapping } @@ -157,10 +157,10 @@ type network struct { func newNetwork(user *user, record *database.Network, channels []database.Channel) *network { logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())} - m := channelCasemapMap{newCasemapMap()} + m := newCasemapMap[*database.Channel]() for _, ch := range channels { ch := ch - m.Set(&ch) + m.Set(ch.Name, &ch) } return &network{ @@ -170,7 +170,7 @@ func newNetwork(user *user, record *database.Network, channels []database.Channe stopped: make(chan struct{}), channels: m, delivered: newDeliveredStore(), - pushTargets: pushTargetCasemapMap{newCasemapMap()}, + pushTargets: newCasemapMap[time.Time](), casemap: casemapRFC1459, } } @@ -394,7 +394,7 @@ func (net *network) updateCasemapping(newCasemap casemapping) { net.pushTargets.SetCasemapping(newCasemap) if uc := net.conn; uc != nil { uc.channels.SetCasemapping(newCasemap) - uc.channels.ForEach(func(uch *upstreamChannel) { + uc.channels.ForEach(func(_ string, uch *upstreamChannel) { uch.Members.SetCasemapping(newCasemap) }) uc.users.SetCasemapping(newCasemap) @@ -856,7 +856,7 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { uc.stopRegainNickTimer() uc.abortPendingCommands() - uc.channels.ForEach(func(uch *upstreamChannel) { + uc.channels.ForEach(func(_ string, uch *upstreamChannel) { uch.updateAutoDetach(0) }) @@ -1036,7 +1036,7 @@ func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*ne // Most network changes require us to re-connect to the upstream server channels := make([]database.Channel, 0, network.channels.Len()) - network.channels.ForEach(func(ch *database.Channel) { + network.channels.ForEach(func(_ string, ch *database.Channel) { channels = append(channels, *ch) })