diff --git a/bridge.go b/bridge.go index 09c66bb..1fa4aa1 100644 --- a/bridge.go +++ b/bridge.go @@ -9,23 +9,25 @@ func forwardChannel(dc *downstreamConn, ch *upstreamChannel) { panic("Tried to forward a partial channel") } + downstreamName := dc.marshalChannel(ch.conn, ch.Name) + dc.SendMessage(&irc.Message{ Prefix: dc.prefix(), Command: "JOIN", - Params: []string{ch.Name}, + Params: []string{downstreamName}, }) if ch.Topic != "" { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_TOPIC, - Params: []string{dc.nick, ch.Name, ch.Topic}, + Params: []string{dc.nick, downstreamName, ch.Topic}, }) } else { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_NOTOPIC, - Params: []string{dc.nick, ch.Name, "No topic is set"}, + Params: []string{dc.nick, downstreamName, "No topic is set"}, }) } @@ -33,21 +35,21 @@ func forwardChannel(dc *downstreamConn, ch *upstreamChannel) { // TODO: send multiple members in each message for nick, membership := range ch.Members { - s := nick + s := dc.marshalNick(ch.conn, nick) if membership != 0 { - s = string(membership) + nick + s = string(membership) + s } dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_NAMREPLY, - Params: []string{dc.nick, string(ch.Status), ch.Name, s}, + Params: []string{dc.nick, string(ch.Status), downstreamName, s}, }) } dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFNAMES, - Params: []string{dc.nick, ch.Name, "End of /NAMES list"}, + Params: []string{dc.nick, downstreamName, "End of /NAMES list"}, }) } diff --git a/downstream.go b/downstream.go index e505f0d..a58b381 100644 --- a/downstream.go +++ b/downstream.go @@ -39,14 +39,19 @@ func (err ircError) Error() string { return err.Message.String() } +type consumption struct { + consumer *RingConsumer + upstreamConn *upstreamConn +} + type downstreamConn struct { - net net.Conn - irc *irc.Conn - srv *Server - logger Logger - messages chan *irc.Message - consumers chan *RingConsumer - closed chan struct{} + net net.Conn + irc *irc.Conn + srv *Server + logger Logger + messages chan *irc.Message + consumptions chan consumption + closed chan struct{} registered bool user *user @@ -57,13 +62,13 @@ type downstreamConn struct { func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn { dc := &downstreamConn{ - net: netConn, - irc: irc.NewConn(netConn), - srv: srv, - logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, - messages: make(chan *irc.Message, 64), - consumers: make(chan *RingConsumer), - closed: make(chan struct{}), + net: netConn, + irc: irc.NewConn(netConn), + srv: srv, + logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, + messages: make(chan *irc.Message, 64), + consumptions: make(chan consumption), + closed: make(chan struct{}), } go func() { @@ -88,6 +93,33 @@ func (dc *downstreamConn) prefix() *irc.Prefix { } } +func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string { + return name +} + +func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) { + // TODO: extract network name from channel name + ch, err := dc.user.getChannel(name) + if err != nil { + return nil, "", err + } + return ch.conn, ch.Name, nil +} + +func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string { + if nick == uc.nick { + return dc.nick + } + return nick +} + +func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix { + if prefix.Name == uc.nick { + return dc.prefix() + } + return prefix +} + func (dc *downstreamConn) isClosed() bool { select { case <-dc.closed: @@ -138,12 +170,21 @@ func (dc *downstreamConn) writeMessages() error { dc.logger.Printf("sent: %v", msg) } err = dc.irc.WriteMessage(msg) - case consumer := <-dc.consumers: + case consumption := <-dc.consumptions: + consumer, uc := consumption.consumer, consumption.upstreamConn for { msg := consumer.Peek() if msg == nil { break } + msg = msg.Copy() + switch msg.Command { + case "PRIVMSG": + // TODO: detect whether it's a user or a channel + msg.Params[0] = dc.marshalChannel(uc, msg.Params[0]) + default: + panic("expected to consume a PRIVMSG message") + } if dc.srv.Debug { dc.logger.Printf("sent: %v", msg) } @@ -303,7 +344,7 @@ func (dc *downstreamConn) register() error { var closed bool select { case <-ch: - dc.consumers <- consumer + dc.consumptions <- consumption{consumer, uc} case <-dc.closed: closed = true } @@ -338,35 +379,30 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { dc.user.forEachUpstream(func(uc *upstreamConn) { uc.SendMessage(msg) }) - case "JOIN": + case "JOIN", "PART": var name string if err := parseMessageParams(msg, &name); err != nil { return err } - if ch, _ := dc.user.getChannel(name); ch != nil { - break // already joined - } - - // TODO: extract network name from channel name - return ircError{&irc.Message{ - Command: irc.ERR_NOSUCHCHANNEL, - Params: []string{name, "Channel name ambiguous"}, - }} - case "PART": - var name string - if err := parseMessageParams(msg, &name); err != nil { - return err - } - - ch, err := dc.user.getChannel(name) + uc, upstreamName, err := dc.unmarshalChannel(name) if err != nil { - return err + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{name, err.Error()}, + }} } - ch.conn.SendMessage(msg) - // TODO: remove channel from upstream config + uc.SendMessage(&irc.Message{ + Command: msg.Command, + Params: []string{upstreamName}, + }) + // TODO: add/remove channel from upstream config case "MODE": + if msg.Prefix == nil { + return fmt.Errorf("missing prefix") + } + var name string if err := parseMessageParams(msg, &name); err != nil { return err @@ -378,18 +414,30 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { } if msg.Prefix.Name != name { - ch, err := dc.user.getChannel(name) + uc, upstreamName, err := dc.unmarshalChannel(name) if err != nil { return err } if modeStr != "" { - ch.conn.SendMessage(msg) + uc.SendMessage(&irc.Message{ + Prefix: uc.prefix(), + Command: "MODE", + Params: []string{upstreamName, modeStr}, + }) } else { + ch, ok := uc.channels[upstreamName] + if !ok { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{name, "No such channel"}, + }} + } + dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_CHANNELMODEIS, - Params: []string{ch.Name, string(ch.modes)}, + Params: []string{name, string(ch.modes)}, }) } } else { @@ -402,7 +450,11 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { if modeStr != "" { dc.user.forEachUpstream(func(uc *upstreamConn) { - uc.SendMessage(msg) + uc.SendMessage(&irc.Message{ + Prefix: uc.prefix(), + Command: "MODE", + Params: []string{uc.nick, modeStr}, + }) }) } else { dc.SendMessage(&irc.Message{ @@ -419,15 +471,15 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { } for _, name := range strings.Split(targetsStr, ",") { - ch, err := dc.user.getChannel(name) + uc, upstreamName, err := dc.unmarshalChannel(name) if err != nil { return err } - ch.conn.SendMessage(&irc.Message{ - Prefix: msg.Prefix, + uc.SendMessage(&irc.Message{ + Prefix: uc.prefix(), Command: "PRIVMSG", - Params: []string{name, text}, + Params: []string{upstreamName, text}, }) } default: diff --git a/upstream.go b/upstream.go index 236da88..c0ece12 100644 --- a/upstream.go +++ b/upstream.go @@ -91,6 +91,14 @@ func connectToUpstream(u *user, upstream *Upstream) (*upstreamConn, error) { return uc, nil } +func (uc *upstreamConn) prefix() *irc.Prefix { + return &irc.Prefix{ + Name: uc.nick, + User: uc.upstream.Username, + // TODO: fill the host? + } +} + func (uc *upstreamConn) Close() error { if uc.closed { return fmt.Errorf("upstream connection already closed") @@ -117,6 +125,10 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { }) return nil case "MODE": + if msg.Prefix == nil { + return fmt.Errorf("missing prefix") + } + var name, modeStr string if err := parseMessageParams(msg, &name, &modeStr); err != nil { return err @@ -135,11 +147,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { if err := ch.modes.Apply(modeStr); err != nil { return err } - } - uc.user.forEachDownstream(func(dc *downstreamConn) { - dc.SendMessage(msg) - }) + uc.user.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc, msg.Prefix), + Command: "MODE", + Params: []string{dc.marshalChannel(uc, name), modeStr}, + }) + }) + } case "NOTICE": uc.logger.Print(msg) case irc.RPL_WELCOME: @@ -176,11 +192,11 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { ch.Members[newNick] = membership } } - - uc.user.forEachDownstream(func(dc *downstreamConn) { - dc.SendMessage(msg) - }) case "JOIN": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + var channels string if err := parseMessageParams(msg, &channels); err != nil { return err @@ -201,12 +217,20 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } ch.Members[msg.Prefix.Name] = 0 } + + uc.user.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc, msg.Prefix), + Command: "JOIN", + Params: []string{dc.marshalChannel(uc, ch)}, + }) + }) + } + case "PART": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") } - uc.user.forEachDownstream(func(dc *downstreamConn) { - dc.SendMessage(msg) - }) - case "PART": var channels string if err := parseMessageParams(msg, &channels); err != nil { return err @@ -223,11 +247,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } delete(ch.Members, msg.Prefix.Name) } - } - uc.user.forEachDownstream(func(dc *downstreamConn) { - dc.SendMessage(msg) - }) + uc.user.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc, msg.Prefix), + Command: "PART", + Params: []string{dc.marshalChannel(uc, ch)}, + }) + }) + } case irc.RPL_TOPIC, irc.RPL_NOTOPIC: var name, topic string if err := parseMessageParams(msg, nil, &name, &topic); err != nil { @@ -310,6 +338,9 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { forwardChannel(dc, ch) }) case "PRIVMSG": + if err := parseMessageParams(msg, nil, nil); err != nil { + return err + } uc.ring.Produce(msg) case irc.RPL_YOURHOST, irc.RPL_CREATED: // Ignore @@ -331,7 +362,7 @@ func (uc *upstreamConn) register() { uc.nick = uc.upstream.Nick uc.SendMessage(&irc.Message{ Command: "NICK", - Params: []string{uc.upstream.Nick}, + Params: []string{uc.nick}, }) uc.SendMessage(&irc.Message{ Command: "USER",