Use generics for casemapMap

This commit is contained in:
Simon Ser 2023-03-01 13:15:38 +01:00
parent 3da6c23ad4
commit 07cd1f2f5d
5 changed files with 56 additions and 178 deletions

View file

@ -348,7 +348,7 @@ type downstreamConn struct {
lastBatchRef uint64 lastBatchRef uint64
monitored casemapMap monitored casemapMap[struct{}]
} }
func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
@ -362,7 +362,7 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
nickCM: "*", nickCM: "*",
username: "~u", username: "~u",
caps: xirc.NewCapRegistry(), caps: xirc.NewCapRegistry(),
monitored: newCasemapMap(), monitored: newCasemapMap[struct{}](),
registration: new(downstreamRegistration), registration: new(downstreamRegistration),
} }
dc.monitored.SetCasemapping(casemapASCII) dc.monitored.SetCasemapping(casemapASCII)
@ -1553,7 +1553,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
} }
dc.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
uc.channels.ForEach(func(ch *upstreamChannel) { uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
if !ch.complete { if !ch.complete {
return return
} }
@ -1928,7 +1928,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Name: name, Name: name,
Key: key, Key: key,
} }
uc.network.channels.Set(ch) uc.network.channels.Set(ch.Name, ch)
} }
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", name, err) dc.logger.Printf("failed to create or update channel %q: %v", name, err)
@ -1960,7 +1960,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Name: name, Name: name,
Detached: true, Detached: true,
} }
uc.network.channels.Set(ch) uc.network.channels.Set(ch.Name, ch)
} }
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", name, err) dc.logger.Printf("failed to create or update channel %q: %v", name, err)
@ -2621,7 +2621,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
continue continue
} }
dc.monitored.set(target, nil) dc.monitored.Set(target, struct{}{})
if uc.network.casemap(target) == serviceNickCM { if uc.network.casemap(target) == serviceNickCM {
// BouncerServ is never tired // BouncerServ is never tired
@ -2651,7 +2651,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
uc.updateMonitor() uc.updateMonitor()
case "C": // clear case "C": // clear
dc.monitored = newCasemapMap() dc.monitored = newCasemapMap[struct{}]()
dc.monitored.SetCasemapping(uc.network.casemap) dc.monitored.SetCasemapping(uc.network.casemap)
uc.updateMonitor() uc.updateMonitor()
case "L": // list case "L": // list

168
irc.go
View file

@ -9,7 +9,6 @@ import (
"gopkg.in/irc.v4" "gopkg.in/irc.v4"
"git.sr.ht/~emersion/soju/database"
"git.sr.ht/~emersion/soju/xirc" "git.sr.ht/~emersion/soju/xirc"
) )
@ -287,45 +286,46 @@ func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) {
return casemap, true return casemap, true
} }
type casemapMap struct { type casemapMap[V interface{}] struct {
m map[string]casemapEntry m map[string]casemapEntry[V]
casemap casemapping casemap casemapping
} }
type casemapEntry struct { type casemapEntry[V interface{}] struct {
originalKey string originalKey string
value interface{} value V
} }
func newCasemapMap() casemapMap { func newCasemapMap[V interface{}]() casemapMap[V] {
return casemapMap{ return casemapMap[V]{
m: make(map[string]casemapEntry), m: make(map[string]casemapEntry[V]),
casemap: casemapNone, casemap: casemapNone,
} }
} }
func (cm *casemapMap) Has(name string) bool { func (cm *casemapMap[V]) Has(name string) bool {
_, ok := cm.m[cm.casemap(name)] _, ok := cm.m[cm.casemap(name)]
return ok return ok
} }
func (cm *casemapMap) Len() int { func (cm *casemapMap[V]) Len() int {
return len(cm.m) return len(cm.m)
} }
func (cm *casemapMap) get(name string) interface{} { func (cm *casemapMap[V]) Get(name string) V {
entry, ok := cm.m[cm.casemap(name)] entry, ok := cm.m[cm.casemap(name)]
if !ok { if !ok {
return nil var v V
return v
} }
return entry.value return entry.value
} }
func (cm *casemapMap) set(name string, value interface{}) { func (cm *casemapMap[V]) Set(name string, value V) {
nameCM := cm.casemap(name) nameCM := cm.casemap(name)
entry, ok := cm.m[nameCM] entry, ok := cm.m[nameCM]
if !ok { if !ok {
cm.m[nameCM] = casemapEntry{ cm.m[nameCM] = casemapEntry[V]{
originalKey: name, originalKey: name,
value: value, value: value,
} }
@ -335,147 +335,25 @@ func (cm *casemapMap) set(name string, value interface{}) {
cm.m[nameCM] = entry cm.m[nameCM] = entry
} }
func (cm *casemapMap) Del(name string) { func (cm *casemapMap[V]) Del(name string) {
delete(cm.m, cm.casemap(name)) delete(cm.m, cm.casemap(name))
} }
func (cm *casemapMap) SetCasemapping(newCasemap casemapping) { func (cm *casemapMap[V]) ForEach(f func(string, V)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value)
}
}
func (cm *casemapMap[V]) SetCasemapping(newCasemap casemapping) {
cm.casemap = newCasemap cm.casemap = newCasemap
m := make(map[string]casemapEntry, len(cm.m)) m := make(map[string]casemapEntry[V], len(cm.m))
for _, entry := range cm.m { for _, entry := range cm.m {
m[cm.casemap(entry.originalKey)] = entry m[cm.casemap(entry.originalKey)] = entry
} }
cm.m = m cm.m = m
} }
type upstreamChannelCasemapMap struct{ casemapMap }
func (cm *upstreamChannelCasemapMap) Get(name string) *upstreamChannel {
if v := cm.get(name); v == nil {
return nil
} else {
return v.(*upstreamChannel)
}
}
func (cm *upstreamChannelCasemapMap) Set(uch *upstreamChannel) {
cm.set(uch.Name, uch)
}
func (cm *upstreamChannelCasemapMap) ForEach(f func(*upstreamChannel)) {
for _, entry := range cm.m {
f(entry.value.(*upstreamChannel))
}
}
type upstreamUserCasemapMap struct{ casemapMap }
func (cm *upstreamUserCasemapMap) Get(name string) *upstreamUser {
if v := cm.get(name); v == nil {
return nil
} else {
return v.(*upstreamUser)
}
}
func (cm *upstreamUserCasemapMap) Set(u *upstreamUser) {
cm.set(u.Nickname, u)
}
type channelCasemapMap struct{ casemapMap }
func (cm *channelCasemapMap) Get(name string) *database.Channel {
if v := cm.get(name); v == nil {
return nil
} else {
return v.(*database.Channel)
}
}
func (cm *channelCasemapMap) Set(ch *database.Channel) {
cm.set(ch.Name, ch)
}
func (cm *channelCasemapMap) ForEach(f func(*database.Channel)) {
for _, entry := range cm.m {
f(entry.value.(*database.Channel))
}
}
type membershipsCasemapMap struct{ casemapMap }
func (cm *membershipsCasemapMap) Get(name string) *xirc.MembershipSet {
if v := cm.get(name); v == nil {
return nil
} else {
return v.(*xirc.MembershipSet)
}
}
func (cm *membershipsCasemapMap) Set(name string, ms *xirc.MembershipSet) {
cm.set(name, ms)
}
func (cm *membershipsCasemapMap) ForEach(f func(string, *xirc.MembershipSet)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(*xirc.MembershipSet))
}
}
type deliveredCasemapMap struct{ casemapMap }
func (cm *deliveredCasemapMap) Get(name string) deliveredClientMap {
if v := cm.get(name); v == nil {
return nil
} else {
return v.(deliveredClientMap)
}
}
func (cm *deliveredCasemapMap) Set(name string, m deliveredClientMap) {
cm.set(name, m)
}
func (cm *deliveredCasemapMap) ForEach(f func(string, deliveredClientMap)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(deliveredClientMap))
}
}
type monitorCasemapMap struct{ casemapMap }
func (cm *monitorCasemapMap) Get(name string) (online bool) {
if v := cm.get(name); v == nil {
return false
} else {
return v.(bool)
}
}
func (cm *monitorCasemapMap) Set(name string, online bool) {
cm.set(name, online)
}
func (cm *monitorCasemapMap) ForEach(f func(name string, online bool)) {
for _, entry := range cm.m {
f(entry.originalKey, entry.value.(bool))
}
}
type pushTargetCasemapMap struct{ casemapMap }
func (cm *pushTargetCasemapMap) Get(name string) (last time.Time) {
if v := cm.get(name); v == nil {
return time.Time{}
} else {
return v.(time.Time)
}
}
func (cm *pushTargetCasemapMap) Set(name string, last time.Time) {
cm.set(name, last)
}
func isWordBoundary(r rune) bool { func isWordBoundary(r rune) bool {
switch r { switch r {
case '-', '_', '|': // inspired from weechat.look.highlight_regex case '-', '_', '|': // inspired from weechat.look.highlight_regex

View file

@ -1224,7 +1224,7 @@ func handleServiceChannelStatus(ctx *serviceContext, params []string) error {
sendNetwork := func(net *network) { sendNetwork := func(net *network) {
var channels []*database.Channel var channels []*database.Channel
net.channels.ForEach(func(ch *database.Channel) { net.channels.ForEach(func(_ string, ch *database.Channel) {
channels = append(channels, ch) channels = append(channels, ch)
}) })

View file

@ -89,7 +89,7 @@ type upstreamChannel struct {
Status xirc.ChannelStatus Status xirc.ChannelStatus
modes channelModes modes channelModes
creationTime string creationTime string
Members membershipsCasemapMap Members casemapMap[*xirc.MembershipSet]
complete bool complete bool
detachTimer *time.Timer detachTimer *time.Timer
} }
@ -208,14 +208,14 @@ type upstreamConn struct {
realname string realname string
hostname string hostname string
modes userModes modes userModes
channels upstreamChannelCasemapMap channels casemapMap[*upstreamChannel]
users upstreamUserCasemapMap users casemapMap[*upstreamUser]
caps xirc.CapRegistry caps xirc.CapRegistry
batches map[string]upstreamBatch batches map[string]upstreamBatch
away bool away bool
account string account string
nextLabelID uint64 nextLabelID uint64
monitored monitorCasemapMap monitored casemapMap[bool]
saslClient sasl.Client saslClient sasl.Client
saslStarted bool saslStarted bool
@ -366,8 +366,8 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options), conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options),
network: network, network: network,
user: network.user, user: network.user,
channels: upstreamChannelCasemapMap{newCasemapMap()}, channels: newCasemapMap[*upstreamChannel](),
users: upstreamUserCasemapMap{newCasemapMap()}, users: newCasemapMap[*upstreamUser](),
caps: xirc.NewCapRegistry(), caps: xirc.NewCapRegistry(),
batches: make(map[string]upstreamBatch), batches: make(map[string]upstreamBatch),
serverPrefix: &irc.Prefix{Name: "*"}, serverPrefix: &irc.Prefix{Name: "*"},
@ -376,7 +376,7 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
availableMemberships: stdMemberships, availableMemberships: stdMemberships,
isupport: make(map[string]*string), isupport: make(map[string]*string),
pendingCmds: make(map[string][]pendingUpstreamCommand), pendingCmds: make(map[string][]pendingUpstreamCommand),
monitored: monitorCasemapMap{newCasemapMap()}, monitored: newCasemapMap[bool](),
hasDesiredNick: true, hasDesiredNick: true,
} }
return uc, nil return uc, nil
@ -898,7 +898,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if uc.network.channels.Len() > 0 { if uc.network.channels.Len() > 0 {
var channels, keys []string var channels, keys []string
uc.network.channels.ForEach(func(ch *database.Channel) { uc.network.channels.ForEach(func(_ string, ch *database.Channel) {
channels = append(channels, ch.Name) channels = append(channels, ch.Name)
keys = append(keys, ch.Key) keys = append(keys, ch.Key)
}) })
@ -1067,7 +1067,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
} }
uc.channels.ForEach(func(ch *upstreamChannel) { uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
memberships := ch.Members.Get(msg.Prefix.Name) memberships := ch.Members.Get(msg.Prefix.Name)
if memberships != nil { if memberships != nil {
ch.Members.Del(msg.Prefix.Name) ch.Members.Del(msg.Prefix.Name)
@ -1173,9 +1173,9 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
for _, ch := range strings.Split(channels, ",") { for _, ch := range strings.Split(channels, ",") {
if uc.isOurNick(msg.Prefix.Name) { if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("joined channel %q", ch) uc.logger.Printf("joined channel %q", ch)
members := membershipsCasemapMap{newCasemapMap()} members := newCasemapMap[*xirc.MembershipSet]()
members.casemap = uc.network.casemap members.casemap = uc.network.casemap
uc.channels.Set(&upstreamChannel{ uc.channels.Set(ch, &upstreamChannel{
Name: ch, Name: ch,
conn: uc, conn: uc,
Members: members, Members: members,
@ -1264,7 +1264,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.logger.Printf("quit") uc.logger.Printf("quit")
} }
uc.channels.ForEach(func(ch *upstreamChannel) { uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
if ch.Members.Has(msg.Prefix.Name) { if ch.Members.Has(msg.Prefix.Name) {
ch.Members.Del(msg.Prefix.Name) ch.Members.Del(msg.Prefix.Name)
uc.appendLog(ch.Name, msg) uc.appendLog(ch.Name, msg)
@ -2440,12 +2440,12 @@ func (uc *upstreamConn) cacheUserInfo(nick string, info *upstreamUser) {
} else { } else {
info.Nickname = nick info.Nickname = nick
} }
uc.users.Set(info) uc.users.Set(info.Nickname, info)
} else { } else {
uu.updateFrom(info) uu.updateFrom(info)
if info.Nickname != "" && nick != info.Nickname { if info.Nickname != "" && nick != info.Nickname {
uc.users.Del(nick) uc.users.Del(nick)
uc.users.Set(uu) uc.users.Set(uu.Nickname, uu)
} }
} }
} }
@ -2459,7 +2459,7 @@ func (uc *upstreamConn) shouldCacheUserInfo(nick string) bool {
return true return true
} }
found := false found := false
uc.channels.ForEach(func(ch *upstreamChannel) { uc.channels.ForEach(func(_ string, ch *upstreamChannel) {
found = found || ch.Members.Has(nick) found = found || ch.Members.Has(nick)
}) })
return found return found

20
user.go
View file

@ -93,11 +93,11 @@ type eventUserRun struct {
type deliveredClientMap map[string]string // client name -> msg ID type deliveredClientMap map[string]string // client name -> msg ID
type deliveredStore struct { type deliveredStore struct {
m deliveredCasemapMap m casemapMap[deliveredClientMap]
} }
func newDeliveredStore() deliveredStore { func newDeliveredStore() deliveredStore {
return deliveredStore{deliveredCasemapMap{newCasemapMap()}} return deliveredStore{newCasemapMap[deliveredClientMap]()}
} }
func (ds deliveredStore) HasTarget(target string) bool { func (ds deliveredStore) HasTarget(target string) bool {
@ -147,9 +147,9 @@ type network struct {
stopped chan struct{} stopped chan struct{}
conn *upstreamConn conn *upstreamConn
channels channelCasemapMap channels casemapMap[*database.Channel]
delivered deliveredStore delivered deliveredStore
pushTargets pushTargetCasemapMap pushTargets casemapMap[time.Time]
lastError error lastError error
casemap casemapping casemap casemapping
} }
@ -157,10 +157,10 @@ type network struct {
func newNetwork(user *user, record *database.Network, channels []database.Channel) *network { func newNetwork(user *user, record *database.Network, channels []database.Channel) *network {
logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())} logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
m := channelCasemapMap{newCasemapMap()} m := newCasemapMap[*database.Channel]()
for _, ch := range channels { for _, ch := range channels {
ch := ch ch := ch
m.Set(&ch) m.Set(ch.Name, &ch)
} }
return &network{ return &network{
@ -170,7 +170,7 @@ func newNetwork(user *user, record *database.Network, channels []database.Channe
stopped: make(chan struct{}), stopped: make(chan struct{}),
channels: m, channels: m,
delivered: newDeliveredStore(), delivered: newDeliveredStore(),
pushTargets: pushTargetCasemapMap{newCasemapMap()}, pushTargets: newCasemapMap[time.Time](),
casemap: casemapRFC1459, casemap: casemapRFC1459,
} }
} }
@ -394,7 +394,7 @@ func (net *network) updateCasemapping(newCasemap casemapping) {
net.pushTargets.SetCasemapping(newCasemap) net.pushTargets.SetCasemapping(newCasemap)
if uc := net.conn; uc != nil { if uc := net.conn; uc != nil {
uc.channels.SetCasemapping(newCasemap) uc.channels.SetCasemapping(newCasemap)
uc.channels.ForEach(func(uch *upstreamChannel) { uc.channels.ForEach(func(_ string, uch *upstreamChannel) {
uch.Members.SetCasemapping(newCasemap) uch.Members.SetCasemapping(newCasemap)
}) })
uc.users.SetCasemapping(newCasemap) uc.users.SetCasemapping(newCasemap)
@ -856,7 +856,7 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
uc.stopRegainNickTimer() uc.stopRegainNickTimer()
uc.abortPendingCommands() uc.abortPendingCommands()
uc.channels.ForEach(func(uch *upstreamChannel) { uc.channels.ForEach(func(_ string, uch *upstreamChannel) {
uch.updateAutoDetach(0) uch.updateAutoDetach(0)
}) })
@ -1036,7 +1036,7 @@ func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*ne
// Most network changes require us to re-connect to the upstream server // Most network changes require us to re-connect to the upstream server
channels := make([]database.Channel, 0, network.channels.Len()) channels := make([]database.Channel, 0, network.channels.Len())
network.channels.ForEach(func(ch *database.Channel) { network.channels.ForEach(func(_ string, ch *database.Channel) {
channels = append(channels, *ch) channels = append(channels, *ch)
}) })