summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2023-06-26 18:34:54 +0700
committerShulhan <ms@kilabit.info>2023-07-01 17:23:26 +0700
commit58b60e8cc1a71a026e77c549c1089aa25096d18e (patch)
tree7fb59742c2816992042f8a0d93560306a4b6f14f
parente3d1ce9c956dbb8c951bb8a7d52bcad482350940 (diff)
downloadpakakeh.go-58b60e8cc1a71a026e77c549c1089aa25096d18e.tar.xz
lib/websocket: handle concurrent upgrade using goroutine
The maxGoroutineUpgrader define maximum goroutines running at the same time to handle client upgrade. The new goroutine only dispatched when others are full, so it will run incrementally not all at once. Default to defServerMaxGoroutineUpgrader (128) if its not set.
-rw-r--r--lib/websocket/server.go63
-rw-r--r--lib/websocket/server_options.go13
-rw-r--r--lib/websocket/server_test.go48
3 files changed, 111 insertions, 13 deletions
diff --git a/lib/websocket/server.go b/lib/websocket/server.go
index 2a4315f7..2c7e756a 100644
--- a/lib/websocket/server.go
+++ b/lib/websocket/server.go
@@ -13,6 +13,7 @@ import (
"net"
"net/http"
"strconv"
+ "sync/atomic"
"time"
"unicode/utf8"
@@ -63,6 +64,8 @@ type Server struct {
sock int
+ numGoUpgrade atomic.Int32
+
allowRsv1 bool
allowRsv2 bool
allowRsv3 bool
@@ -75,10 +78,11 @@ func NewServer(opts *ServerOptions) (serv *Server) {
}
serv = &Server{
- Options: opts,
- Clients: newClientManager(),
- routes: newRootRoute(),
- running: make(chan struct{}, 1),
+ Options: opts,
+ Clients: newClientManager(),
+ routes: newRootRoute(),
+ chUpgrade: make(chan int),
+ running: make(chan struct{}, 1),
}
opts.init()
@@ -252,6 +256,8 @@ func (serv *Server) ClientRemove(conn int) {
func (serv *Server) upgrader() {
var (
+ logp = `upgrader`
+
ctx context.Context
hs *Handshake
httpRes string
@@ -265,7 +271,7 @@ func (serv *Server) upgrader() {
for conn = range serv.chUpgrade {
packet, err = Recv(conn, serv.Options.ReadWriteTimeout)
if err != nil {
- log.Println("websocket: server.upgrader: " + err.Error())
+ log.Printf(`%s: %s`, logp, err)
unix.Close(conn)
continue
}
@@ -301,8 +307,7 @@ func (serv *Server) upgrader() {
err = Send(conn, []byte(httpRes), serv.Options.ReadWriteTimeout)
if err != nil {
- log.Println("websocket: server.upgrader: Send: ",
- err.Error())
+ log.Printf(`%s: %s`, logp, err)
unix.Close(conn)
continue
}
@@ -313,8 +318,7 @@ func (serv *Server) upgrader() {
err = serv.clientAdd(ctx, conn)
if err != nil {
- log.Println("websocket: server.upgrader: clientAdd: ",
- err.Error())
+ log.Printf(`%s: %s`, logp, err)
unix.Close(conn)
}
}
@@ -775,8 +779,8 @@ func (serv *Server) Start() (err error) {
return
}
- serv.chUpgrade = make(chan int, _maxQueue)
go serv.upgrader()
+ serv.numGoUpgrade.Add(1)
serv.poll, err = libnet.NewPoll()
if err != nil {
@@ -786,7 +790,10 @@ func (serv *Server) Start() (err error) {
go serv.reader()
go serv.pinger()
- var conn int
+ var (
+ conn int
+ numUpgrader int32
+ )
for {
conn, _, err = unix.Accept(serv.sock)
if err != nil {
@@ -798,8 +805,40 @@ func (serv *Server) Start() (err error) {
return
}
- serv.chUpgrade <- conn
+ select {
+ case serv.chUpgrade <- conn:
+ default:
+ numUpgrader = serv.numGoUpgrade.Load()
+ if numUpgrader < serv.Options.maxGoroutineUpgrader {
+ go serv.upgrader()
+ serv.numGoUpgrade.Add(1)
+ serv.chUpgrade <- conn
+ } else {
+ go serv.delayUpgrade(conn)
+ }
+ }
+ }
+}
+
+// delayUpgrade the maximum goroutine for upgrader has reached, we wait for
+// 300 milliseconds and try to push to upgrade queue again until total wait is
+// greater than ReadWriteTimeout.
+// If its still full, close the connection.
+func (serv *Server) delayUpgrade(conn int) {
+ var (
+ delay = 300 * time.Millisecond
+ total time.Duration
+ )
+ for total < serv.Options.ReadWriteTimeout {
+ time.Sleep(delay)
+ select {
+ case serv.chUpgrade <- conn:
+ return
+ default:
+ total += delay
+ }
}
+ unix.Close(conn)
}
// Stop the server.
diff --git a/lib/websocket/server_options.go b/lib/websocket/server_options.go
index c3247e1e..a82cecdb 100644
--- a/lib/websocket/server_options.go
+++ b/lib/websocket/server_options.go
@@ -14,7 +14,8 @@ const (
defServerConnectPath = "/"
defServerStatusPath = "/status"
- defServerReadWriteTimeout = 30 * time.Second
+ defServerReadWriteTimeout = 30 * time.Second
+ defServerMaxGoroutineUpgrader int32 = 128
)
// ServerOptions contain options to configure the WebSocket server.
@@ -69,6 +70,13 @@ type ServerOptions struct {
// connection as broken.
// Default to 30 seconds.
ReadWriteTimeout time.Duration
+
+ // maxGoroutineUpgrader define maximum goroutines running at the same
+ // time to handle client upgrade.
+ // The new goroutine only dispatched when others are full, so it will
+ // run incrementally not all at once.
+ // Default to defServerMaxGoroutineUpgrader if its not set.
+ maxGoroutineUpgrader int32
}
func (opts *ServerOptions) init() {
@@ -84,4 +92,7 @@ func (opts *ServerOptions) init() {
if opts.ReadWriteTimeout <= 0 {
opts.ReadWriteTimeout = defServerReadWriteTimeout
}
+ if opts.maxGoroutineUpgrader <= 0 {
+ opts.maxGoroutineUpgrader = defServerMaxGoroutineUpgrader
+ }
}
diff --git a/lib/websocket/server_test.go b/lib/websocket/server_test.go
index 0bc1334f..941cd508 100644
--- a/lib/websocket/server_test.go
+++ b/lib/websocket/server_test.go
@@ -7,6 +7,7 @@ package websocket
import (
"bytes"
"fmt"
+ "net"
"net/http"
"net/url"
"testing"
@@ -283,3 +284,50 @@ func TestServer_Health(t *testing.T) {
test.Assert(t, "/status response code", http.StatusOK, res.StatusCode)
}
+
+// TestServer_upgrader test to make sure that server upgrade does not block
+// other requests.
+func TestServer_upgrader_nonblocking(t *testing.T) {
+ var (
+ err error
+ )
+
+ // Open new connection that does not send anything, that will trigger
+ // the server Accept and continue to Recv.
+ _, err = net.Dial(`tcp`, _testAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create new client that send text.
+ // The client should able to receive response without waiting the
+ // above connection for timeout.
+ var (
+ qtext = make(chan []byte, 1)
+ cl = &Client{
+ Endpoint: _testEndpointAuth,
+ HandleText: func(cl *Client, frame *Frame) (err error) {
+ qtext <- frame.Payload()
+ return nil
+ },
+ }
+ )
+
+ err = cl.Connect()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var (
+ msg = []byte(`hello world`)
+ got []byte
+ )
+
+ err = cl.SendText(msg)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ got = <-qtext
+ test.Assert(t, `SendText`, msg, got)
+}