diff --git a/downstream.go b/downstream.go index e0d4f40..228502e 100644 --- a/downstream.go +++ b/downstream.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "strings" "gopkg.in/irc.v3" ) @@ -167,8 +168,9 @@ func (c *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { } func (c *downstreamConn) register() error { - u := c.srv.getUser(c.username) + u := c.srv.getUser(strings.TrimPrefix(c.username, "~")) if u == nil { + c.logger.Printf("failed authentication: unknown username %q", c.username) c.messages <- &irc.Message{ Prefix: c.srv.prefix(), Command: irc.ERR_PASSWDMISMATCH, @@ -206,20 +208,14 @@ func (c *downstreamConn) register() error { Params: []string{c.nick, "No MOTD"}, } - u.lock.Lock() - for _, uc := range u.upstreamConns { + u.forEachUpstream(func(uc *upstreamConn) { // TODO: fix races accessing upstream connection data - if !uc.registered { - continue - } for _, ch := range uc.channels { if ch.complete { forwardChannel(c, ch) } } - } - u.lock.Unlock() - + }) return nil } diff --git a/server.go b/server.go index 3076eb8..ec30a38 100644 --- a/server.go +++ b/server.go @@ -39,6 +39,17 @@ type user struct { upstreamConns []*upstreamConn } +func (u *user) forEachUpstream(f func(uc *upstreamConn)) { + u.lock.Lock() + for _, uc := range u.upstreamConns { + if !uc.registered { + continue + } + f(uc) + } + u.lock.Unlock() +} + type Upstream struct { Addr string Nick string