aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2023-07-03 22:27:30 +0700
committerShulhan <ms@kilabit.info>2023-07-03 23:10:38 +0700
commite94de9a9b7a474c092aeccaf91bb43ecb35f6acb (patch)
treeea83918c96456f633ca595d08ad5e93408dd83be
parent7b01a3fc41b8611aa26c7d5d0f5c0f2bcd012ee0 (diff)
downloadpakakeh.go-e94de9a9b7a474c092aeccaf91bb43ecb35f6acb.tar.xz
lib/websocket: stop goroutines when no queue received after N duration
When the goroutine for upgrade, reader, or pinger does not receive any input from its queue after N duration, stop it. Currently, N equal to ServerOptions ReadWriteTimeout. This allow unused goroutines released back to system, minimizing resources usage.
-rw-r--r--lib/websocket/server.go232
1 files changed, 133 insertions, 99 deletions
diff --git a/lib/websocket/server.go b/lib/websocket/server.go
index 586de037..4bc55d83 100644
--- a/lib/websocket/server.go
+++ b/lib/websocket/server.go
@@ -272,7 +272,8 @@ func (serv *Server) ClientRemove(conn int) {
func (serv *Server) upgrader() {
var (
- logp = `upgrader`
+ logp = `upgrader`
+ timer = time.NewTimer(serv.Options.ReadWriteTimeout)
ctx context.Context
hs *Handshake
@@ -284,59 +285,70 @@ func (serv *Server) upgrader() {
err error
)
- for conn = range serv.chUpgrade {
- packet, err = Recv(conn, serv.Options.ReadWriteTimeout)
- if err != nil {
- log.Printf(`%s: %s`, logp, err)
- unix.Close(conn)
- continue
- }
- if len(packet) == 0 {
- unix.Close(conn)
- continue
- }
+ for {
+ select {
+ case conn = <-serv.chUpgrade:
+ packet, err = Recv(conn, serv.Options.ReadWriteTimeout)
+ if err != nil {
+ log.Printf(`%s: %s`, logp, err)
+ unix.Close(conn)
+ break
+ }
+ if len(packet) == 0 {
+ unix.Close(conn)
+ break
+ }
- hs, err = newHandshake(packet)
- if err != nil {
- serv.handleError(conn, http.StatusBadRequest, err.Error())
- continue
- }
+ hs, err = newHandshake(packet)
+ if err != nil {
+ serv.handleError(conn, http.StatusBadRequest, err.Error())
+ break
+ }
- if hs.URL.Path == serv.Options.StatusPath {
- serv.handleStatus(conn)
- continue
- }
- if hs.URL.Path != serv.Options.ConnectPath {
- serv.handleError(conn, http.StatusNotFound, "unknown path")
- continue
- }
+ if hs.URL.Path == serv.Options.StatusPath {
+ serv.handleStatus(conn)
+ break
+ }
+ if hs.URL.Path != serv.Options.ConnectPath {
+ serv.handleError(conn, http.StatusNotFound, "unknown path")
+ break
+ }
- ctx, key, err = serv.handleUpgrade(hs)
- if err != nil {
- serv.handleError(conn, http.StatusBadRequest, err.Error())
- continue
- }
+ ctx, key, err = serv.handleUpgrade(hs)
+ if err != nil {
+ serv.handleError(conn, http.StatusBadRequest, err.Error())
+ break
+ }
- wsAccept = generateHandshakeAccept(key)
+ wsAccept = generateHandshakeAccept(key)
- httpRes = _resUpgradeOK + wsAccept + "\r\n\r\n"
+ httpRes = _resUpgradeOK + wsAccept + "\r\n\r\n"
- err = Send(conn, []byte(httpRes), serv.Options.ReadWriteTimeout)
- if err != nil {
- log.Printf(`%s: %s`, logp, err)
- unix.Close(conn)
- continue
- }
+ err = Send(conn, []byte(httpRes), serv.Options.ReadWriteTimeout)
+ if err != nil {
+ log.Printf(`%s: %s`, logp, err)
+ unix.Close(conn)
+ break
+ }
- if ctx == nil {
- ctx = context.Background()
- }
+ if ctx == nil {
+ ctx = context.Background()
+ }
- err = serv.clientAdd(ctx, conn)
- if err != nil {
- log.Printf(`%s: %s`, logp, err)
- unix.Close(conn)
+ err = serv.clientAdd(ctx, conn)
+ if err != nil {
+ log.Printf(`%s: %s`, logp, err)
+ unix.Close(conn)
+ }
+
+ case <-timer.C:
+ serv.numGoUpgrade.Add(-1)
+ return
+ }
+ if !timer.Stop() {
+ <-timer.C
}
+ timer.Reset(serv.Options.ReadWriteTimeout)
}
}
@@ -747,7 +759,8 @@ func (serv *Server) pollReader() {
// reader goroutine that consume channel that are ready to be read.
func (serv *Server) reader() {
var (
- logp = `reader`
+ logp = `reader`
+ timer = time.NewTimer(serv.Options.ReadWriteTimeout)
frames *Frames
frame *Frame
@@ -757,66 +770,76 @@ func (serv *Server) reader() {
isClosing bool
)
- for conn = range serv.qreader {
- packet, err = Recv(conn, serv.Options.ReadWriteTimeout)
- if err != nil {
- log.Printf(`%s: %s`, logp, err)
- serv.ClientRemove(conn)
- continue
- }
- if len(packet) == 0 {
- log.Printf(`%s: empty packet`, logp)
- serv.ClientRemove(conn)
- continue
- }
+ for {
+ select {
+ case conn = <-serv.qreader:
+ packet, err = Recv(conn, serv.Options.ReadWriteTimeout)
+ if err != nil {
+ log.Printf(`%s: %s`, logp, err)
+ serv.ClientRemove(conn)
+ break
+ }
+ if len(packet) == 0 {
+ log.Printf(`%s: empty packet`, logp)
+ serv.ClientRemove(conn)
+ break
+ }
+
+ // Handle chopped, unfinished packet or payload.
+ frame, _ = serv.Clients.getFrame(conn)
+ if frame != nil {
+ packet = frame.unpack(packet)
+ if frame.isComplete {
+ serv.Clients.setFrame(conn, nil)
+ isClosing = serv.handleFrame(conn, frame)
+ if isClosing {
+ break
+ }
+ }
+ if len(packet) == 0 {
+ err = serv.poll.RegisterRead(conn)
+ if err != nil {
+ log.Printf(`%s: %s`, logp, err)
+ serv.ClientRemove(conn)
+ }
+ break
+ }
+ }
+
+ frames = Unpack(packet)
+ if frames == nil {
+ log.Printf(`%s: empty frames`, logp)
+ serv.ClientRemove(conn)
+ break
+ }
+
+ var isClosing bool
+ for _, frame = range frames.v {
+ if !frame.isComplete {
+ serv.Clients.setFrame(conn, frame)
+ continue
+ }
- // Handle chopped, unfinished packet or payload.
- frame, _ = serv.Clients.getFrame(conn)
- if frame != nil {
- packet = frame.unpack(packet)
- if frame.isComplete {
- serv.Clients.setFrame(conn, nil)
isClosing = serv.handleFrame(conn, frame)
if isClosing {
- continue
+ break
}
}
- if len(packet) == 0 {
+ if !isClosing {
err = serv.poll.RegisterRead(conn)
if err != nil {
log.Printf(`%s: %s`, logp, err)
serv.ClientRemove(conn)
}
- continue
- }
- }
-
- frames = Unpack(packet)
- if frames == nil {
- log.Printf(`%s: empty frames`, logp)
- serv.ClientRemove(conn)
- continue
- }
-
- var isClosing bool
- for _, frame = range frames.v {
- if !frame.isComplete {
- serv.Clients.setFrame(conn, frame)
- continue
- }
-
- isClosing = serv.handleFrame(conn, frame)
- if isClosing {
- break
}
+ case <-timer.C:
+ serv.numGoReader.Add(-1)
+ return
}
- if !isClosing {
- err = serv.poll.RegisterRead(conn)
- if err != nil {
- log.Printf(`%s: %s`, logp, err)
- serv.ClientRemove(conn)
- }
+ if !timer.Stop() {
+ <-timer.C
}
+ timer.Reset(serv.Options.ReadWriteTimeout)
}
}
@@ -883,17 +906,28 @@ func (serv *Server) pollPinger() {
func (serv *Server) pinger() {
var (
framePing = NewFramePing(false, nil)
+ timer = time.NewTimer(serv.Options.ReadWriteTimeout)
conn int
err error
)
- for conn = range serv.qpinger {
- err = Send(conn, framePing, serv.Options.ReadWriteTimeout)
- if err != nil {
- // Error on sending PING will be assumed as bad
- // connection.
- serv.ClientRemove(conn)
+ for {
+ select {
+ case conn = <-serv.qpinger:
+ err = Send(conn, framePing, serv.Options.ReadWriteTimeout)
+ if err != nil {
+ // Error on sending PING will be assumed as bad
+ // connection.
+ serv.ClientRemove(conn)
+ }
+ case <-timer.C:
+ serv.numGoPinger.Add(-1)
+ return
+ }
+ if !timer.Stop() {
+ <-timer.C
}
+ timer.Reset(serv.Options.ReadWriteTimeout)
}
}