Implement rate limiting for upstream messages

Allow up to 10 outgoing messages in a burst, then throttle to 1 message
each 2 seconds.

Closes: https://todo.sr.ht/~emersion/soju/87
This commit is contained in:
Simon Ser 2020-08-19 19:28:29 +02:00
parent 9f26422592
commit bdb132ad98
No known key found for this signature in database
GPG key ID: 0FDE7BE0E88F5E48
5 changed files with 72 additions and 7 deletions

60
conn.go
View file

@ -106,6 +106,52 @@ func (wa websocketAddr) String() string {
return string(wa)
}
type rateLimiter struct {
C <-chan struct{}
ticker *time.Ticker
stopped chan struct{}
}
func newRateLimiter(delay time.Duration, burst int) *rateLimiter {
ch := make(chan struct{}, burst)
for i := 0; i < burst; i++ {
ch <- struct{}{}
}
ticker := time.NewTicker(delay)
stopped := make(chan struct{})
go func() {
for {
select {
case <-ticker.C:
select {
case ch <- struct{}{}:
// This space is intentionally left blank
case <-stopped:
return
}
case <-stopped:
return
}
}
}()
return &rateLimiter{
C: ch,
ticker: ticker,
stopped: stopped,
}
}
func (rl *rateLimiter) Stop() {
rl.ticker.Stop()
close(rl.stopped)
}
type connOptions struct {
Logger Logger
RateLimitDelay time.Duration
RateLimitBurst int
}
type conn struct {
conn ircConn
srv *Server
@ -116,17 +162,27 @@ type conn struct {
closed bool
}
func newConn(srv *Server, ic ircConn, logger Logger) *conn {
func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
outgoing := make(chan *irc.Message, 64)
c := &conn{
conn: ic,
srv: srv,
outgoing: outgoing,
logger: logger,
logger: options.Logger,
}
go func() {
var rl *rateLimiter
if options.RateLimitDelay > 0 && options.RateLimitBurst > 0 {
rl = newRateLimiter(options.RateLimitDelay, options.RateLimitBurst)
defer rl.Stop()
}
for msg := range outgoing {
if rl != nil {
<-rl.C
}
if c.srv.Debug {
c.logger.Printf("sent: %v", msg)
}

View file

@ -102,8 +102,9 @@ type downstreamConn struct {
func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
remoteAddr := ic.RemoteAddr().String()
logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)}
options := connOptions{Logger: logger}
dc := &downstreamConn{
conn: *newConn(srv, ic, logger),
conn: *newConn(srv, ic, &options),
id: id,
supportedCaps: make(map[string]string),
caps: make(map[string]bool),

View file

@ -16,9 +16,11 @@ import (
)
// TODO: make configurable
var retryConnectMinDelay = time.Minute
var retryConnectDelay = time.Minute
var connectTimeout = 15 * time.Second
var writeTimeout = 10 * time.Second
var upstreamMessageDelay = 2 * time.Second
var upstreamMessageBurst = 10
type Logger interface {
Print(v ...interface{})

View file

@ -157,8 +157,14 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme)
}
options := connOptions{
Logger: logger,
RateLimitDelay: upstreamMessageDelay,
RateLimitBurst: upstreamMessageBurst,
}
uc := &upstreamConn{
conn: *newConn(network.user.srv, newNetIRCConn(netConn), logger),
conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options),
network: network,
user: network.user,
channels: make(map[string]*upstreamChannel),

View file

@ -120,8 +120,8 @@ func (net *network) run() {
return
}
if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
delay := retryConnectMinDelay - dur
if dur := time.Now().Sub(lastTry); dur < retryConnectDelay {
delay := retryConnectDelay - dur
net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
time.Sleep(delay)
}