Per-user dispatcher goroutine

This allows message handlers to read upstream/downstream connection
information without causing any race condition.

References: https://todo.sr.ht/~emersion/soju/1
This commit is contained in:
Simon Ser 2020-03-16 12:44:59 +01:00
parent cdab0dc825
commit 3919ee2036
No known key found for this signature in database
GPG key ID: 0FDE7BE0E88F5E48
4 changed files with 71 additions and 21 deletions

View file

@ -191,7 +191,7 @@ func (dc *downstreamConn) isClosed() bool {
} }
} }
func (dc *downstreamConn) readMessages() error { func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
dc.logger.Printf("new connection") dc.logger.Printf("new connection")
for { for {
@ -206,17 +206,7 @@ func (dc *downstreamConn) readMessages() error {
dc.logger.Printf("received: %v", msg) dc.logger.Printf("received: %v", msg)
} }
err = dc.handleMessage(msg) ch <- downstreamIncomingMessage{msg, dc}
if ircErr, ok := err.(ircError); ok {
ircErr.Message.Prefix = dc.srv.prefix()
dc.SendMessage(ircErr.Message)
} else if err != nil {
return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err)
}
if dc.isClosed() {
return nil
}
} }
return nil return nil
@ -484,6 +474,27 @@ func (dc *downstreamConn) register() error {
return nil return nil
} }
func (dc *downstreamConn) runUntilRegistered() error {
for !dc.registered {
msg, err := dc.irc.ReadMessage()
if err == io.EOF {
break
} else if err != nil {
return fmt.Errorf("failed to read IRC command: %v", err)
}
err = dc.handleMessage(msg)
if ircErr, ok := err.(ircError); ok {
ircErr.Message.Prefix = dc.srv.prefix()
dc.SendMessage(ircErr.Message)
} else if err != nil {
return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
}
}
return nil
}
func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
switch msg.Command { switch msg.Command {
case "USER": case "USER":

View file

@ -114,8 +114,12 @@ func (s *Server) Serve(ln net.Listener) error {
s.downstreamConns = append(s.downstreamConns, dc) s.downstreamConns = append(s.downstreamConns, dc)
s.lock.Unlock() s.lock.Unlock()
if err := dc.readMessages(); err != nil { if err := dc.runUntilRegistered(); err != nil {
dc.logger.Printf("failed to handle messages: %v", err) dc.logger.Print(err)
} else {
if err := dc.readMessages(dc.user.downstreamIncoming); err != nil {
dc.logger.Print(err)
}
} }
dc.Close() dc.Close()

View file

@ -659,7 +659,7 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
return nil return nil
} }
func (uc *upstreamConn) readMessages() error { func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
for { for {
msg, err := uc.irc.ReadMessage() msg, err := uc.irc.ReadMessage()
if err == io.EOF { if err == io.EOF {
@ -672,9 +672,7 @@ func (uc *upstreamConn) readMessages() error {
uc.logger.Printf("received: %v", msg) uc.logger.Printf("received: %v", msg)
} }
if err := uc.handleMessage(msg); err != nil { ch <- upstreamIncomingMessage{msg, uc}
uc.logger.Printf("failed to handle message %q: %v", msg, err)
}
} }
return nil return nil

39
user.go
View file

@ -3,8 +3,20 @@ package soju
import ( import (
"sync" "sync"
"time" "time"
"gopkg.in/irc.v3"
) )
type upstreamIncomingMessage struct {
msg *irc.Message
uc *upstreamConn
}
type downstreamIncomingMessage struct {
msg *irc.Message
dc *downstreamConn
}
type network struct { type network struct {
Network Network
user *user user *user
@ -40,7 +52,7 @@ func (net *network) run() {
net.conn = uc net.conn = uc
net.user.lock.Unlock() net.user.lock.Unlock()
if err := uc.readMessages(); err != nil { if err := uc.readMessages(net.user.upstreamIncoming); err != nil {
uc.logger.Printf("failed to handle messages: %v", err) uc.logger.Printf("failed to handle messages: %v", err)
} }
uc.Close() uc.Close()
@ -55,6 +67,9 @@ type user struct {
User User
srv *Server srv *Server
upstreamIncoming chan upstreamIncomingMessage
downstreamIncoming chan downstreamIncomingMessage
lock sync.Mutex lock sync.Mutex
networks []*network networks []*network
downstreamConns []*downstreamConn downstreamConns []*downstreamConn
@ -64,6 +79,8 @@ func newUser(srv *Server, record *User) *user {
return &user{ return &user{
User: *record, User: *record,
srv: srv, srv: srv,
upstreamIncoming: make(chan upstreamIncomingMessage, 64),
downstreamIncoming: make(chan downstreamIncomingMessage, 64),
} }
} }
@ -119,6 +136,26 @@ func (u *user) run() {
go network.run() go network.run()
} }
u.lock.Unlock() u.lock.Unlock()
for {
select {
case upstreamMsg := <-u.upstreamIncoming:
msg, uc := upstreamMsg.msg, upstreamMsg.uc
if err := uc.handleMessage(msg); err != nil {
uc.logger.Printf("failed to handle message %q: %v", msg, err)
}
case downstreamMsg := <-u.downstreamIncoming:
msg, dc := downstreamMsg.msg, downstreamMsg.dc
err := dc.handleMessage(msg)
if ircErr, ok := err.(ircError); ok {
ircErr.Message.Prefix = dc.srv.prefix()
dc.SendMessage(ircErr.Message)
} else if err != nil {
dc.logger.Printf("failed to handle message %q: %v", msg, err)
dc.Close()
}
}
}
} }
func (u *user) createNetwork(addr, nick string) (*network, error) { func (u *user) createNetwork(addr, nick string) (*network, error) {