diff options
| author | Shulhan <ms@kilabit.info> | 2023-06-26 18:34:54 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2023-07-01 17:23:26 +0700 |
| commit | 58b60e8cc1a71a026e77c549c1089aa25096d18e (patch) | |
| tree | 7fb59742c2816992042f8a0d93560306a4b6f14f | |
| parent | e3d1ce9c956dbb8c951bb8a7d52bcad482350940 (diff) | |
| download | pakakeh.go-58b60e8cc1a71a026e77c549c1089aa25096d18e.tar.xz | |
lib/websocket: handle concurrent upgrade using goroutine
The maxGoroutineUpgrader define maximum goroutines running at the same
time to handle client upgrade.
The new goroutine only dispatched when others are full, so it will
run incrementally not all at once.
Default to defServerMaxGoroutineUpgrader (128) if its not set.
| -rw-r--r-- | lib/websocket/server.go | 63 | ||||
| -rw-r--r-- | lib/websocket/server_options.go | 13 | ||||
| -rw-r--r-- | lib/websocket/server_test.go | 48 |
3 files changed, 111 insertions, 13 deletions
diff --git a/lib/websocket/server.go b/lib/websocket/server.go index 2a4315f7..2c7e756a 100644 --- a/lib/websocket/server.go +++ b/lib/websocket/server.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "strconv" + "sync/atomic" "time" "unicode/utf8" @@ -63,6 +64,8 @@ type Server struct { sock int + numGoUpgrade atomic.Int32 + allowRsv1 bool allowRsv2 bool allowRsv3 bool @@ -75,10 +78,11 @@ func NewServer(opts *ServerOptions) (serv *Server) { } serv = &Server{ - Options: opts, - Clients: newClientManager(), - routes: newRootRoute(), - running: make(chan struct{}, 1), + Options: opts, + Clients: newClientManager(), + routes: newRootRoute(), + chUpgrade: make(chan int), + running: make(chan struct{}, 1), } opts.init() @@ -252,6 +256,8 @@ func (serv *Server) ClientRemove(conn int) { func (serv *Server) upgrader() { var ( + logp = `upgrader` + ctx context.Context hs *Handshake httpRes string @@ -265,7 +271,7 @@ func (serv *Server) upgrader() { for conn = range serv.chUpgrade { packet, err = Recv(conn, serv.Options.ReadWriteTimeout) if err != nil { - log.Println("websocket: server.upgrader: " + err.Error()) + log.Printf(`%s: %s`, logp, err) unix.Close(conn) continue } @@ -301,8 +307,7 @@ func (serv *Server) upgrader() { err = Send(conn, []byte(httpRes), serv.Options.ReadWriteTimeout) if err != nil { - log.Println("websocket: server.upgrader: Send: ", - err.Error()) + log.Printf(`%s: %s`, logp, err) unix.Close(conn) continue } @@ -313,8 +318,7 @@ func (serv *Server) upgrader() { err = serv.clientAdd(ctx, conn) if err != nil { - log.Println("websocket: server.upgrader: clientAdd: ", - err.Error()) + log.Printf(`%s: %s`, logp, err) unix.Close(conn) } } @@ -775,8 +779,8 @@ func (serv *Server) Start() (err error) { return } - serv.chUpgrade = make(chan int, _maxQueue) go serv.upgrader() + serv.numGoUpgrade.Add(1) serv.poll, err = libnet.NewPoll() if err != nil { @@ -786,7 +790,10 @@ func (serv *Server) Start() (err error) { go serv.reader() go serv.pinger() - var conn int + var ( + conn int + numUpgrader int32 + ) for { conn, _, err = unix.Accept(serv.sock) if err != nil { @@ -798,8 +805,40 @@ func (serv *Server) Start() (err error) { return } - serv.chUpgrade <- conn + select { + case serv.chUpgrade <- conn: + default: + numUpgrader = serv.numGoUpgrade.Load() + if numUpgrader < serv.Options.maxGoroutineUpgrader { + go serv.upgrader() + serv.numGoUpgrade.Add(1) + serv.chUpgrade <- conn + } else { + go serv.delayUpgrade(conn) + } + } + } +} + +// delayUpgrade the maximum goroutine for upgrader has reached, we wait for +// 300 milliseconds and try to push to upgrade queue again until total wait is +// greater than ReadWriteTimeout. +// If its still full, close the connection. +func (serv *Server) delayUpgrade(conn int) { + var ( + delay = 300 * time.Millisecond + total time.Duration + ) + for total < serv.Options.ReadWriteTimeout { + time.Sleep(delay) + select { + case serv.chUpgrade <- conn: + return + default: + total += delay + } } + unix.Close(conn) } // Stop the server. diff --git a/lib/websocket/server_options.go b/lib/websocket/server_options.go index c3247e1e..a82cecdb 100644 --- a/lib/websocket/server_options.go +++ b/lib/websocket/server_options.go @@ -14,7 +14,8 @@ const ( defServerConnectPath = "/" defServerStatusPath = "/status" - defServerReadWriteTimeout = 30 * time.Second + defServerReadWriteTimeout = 30 * time.Second + defServerMaxGoroutineUpgrader int32 = 128 ) // ServerOptions contain options to configure the WebSocket server. @@ -69,6 +70,13 @@ type ServerOptions struct { // connection as broken. // Default to 30 seconds. ReadWriteTimeout time.Duration + + // maxGoroutineUpgrader define maximum goroutines running at the same + // time to handle client upgrade. + // The new goroutine only dispatched when others are full, so it will + // run incrementally not all at once. + // Default to defServerMaxGoroutineUpgrader if its not set. + maxGoroutineUpgrader int32 } func (opts *ServerOptions) init() { @@ -84,4 +92,7 @@ func (opts *ServerOptions) init() { if opts.ReadWriteTimeout <= 0 { opts.ReadWriteTimeout = defServerReadWriteTimeout } + if opts.maxGoroutineUpgrader <= 0 { + opts.maxGoroutineUpgrader = defServerMaxGoroutineUpgrader + } } diff --git a/lib/websocket/server_test.go b/lib/websocket/server_test.go index 0bc1334f..941cd508 100644 --- a/lib/websocket/server_test.go +++ b/lib/websocket/server_test.go @@ -7,6 +7,7 @@ package websocket import ( "bytes" "fmt" + "net" "net/http" "net/url" "testing" @@ -283,3 +284,50 @@ func TestServer_Health(t *testing.T) { test.Assert(t, "/status response code", http.StatusOK, res.StatusCode) } + +// TestServer_upgrader test to make sure that server upgrade does not block +// other requests. +func TestServer_upgrader_nonblocking(t *testing.T) { + var ( + err error + ) + + // Open new connection that does not send anything, that will trigger + // the server Accept and continue to Recv. + _, err = net.Dial(`tcp`, _testAddr) + if err != nil { + t.Fatal(err) + } + + // Create new client that send text. + // The client should able to receive response without waiting the + // above connection for timeout. + var ( + qtext = make(chan []byte, 1) + cl = &Client{ + Endpoint: _testEndpointAuth, + HandleText: func(cl *Client, frame *Frame) (err error) { + qtext <- frame.Payload() + return nil + }, + } + ) + + err = cl.Connect() + if err != nil { + t.Fatal(err) + } + + var ( + msg = []byte(`hello world`) + got []byte + ) + + err = cl.SendText(msg) + if err != nil { + t.Fatal(err) + } + + got = <-qtext + test.Assert(t, `SendText`, msg, got) +} |
