diff options
| author | Shulhan <ms@kilabit.info> | 2023-07-03 22:27:30 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2023-07-03 23:10:38 +0700 |
| commit | e94de9a9b7a474c092aeccaf91bb43ecb35f6acb (patch) | |
| tree | ea83918c96456f633ca595d08ad5e93408dd83be /lib/websocket | |
| parent | 7b01a3fc41b8611aa26c7d5d0f5c0f2bcd012ee0 (diff) | |
| download | pakakeh.go-e94de9a9b7a474c092aeccaf91bb43ecb35f6acb.tar.xz | |
lib/websocket: stop goroutines when no queue received after N duration
When the goroutine for upgrade, reader, or pinger does not receive
any input from its queue after N duration, stop it.
Currently, N equal to ServerOptions ReadWriteTimeout.
This allow unused goroutines released back to system, minimizing
resources usage.
Diffstat (limited to 'lib/websocket')
| -rw-r--r-- | lib/websocket/server.go | 232 |
1 files changed, 133 insertions, 99 deletions
diff --git a/lib/websocket/server.go b/lib/websocket/server.go index 586de037..4bc55d83 100644 --- a/lib/websocket/server.go +++ b/lib/websocket/server.go @@ -272,7 +272,8 @@ func (serv *Server) ClientRemove(conn int) { func (serv *Server) upgrader() { var ( - logp = `upgrader` + logp = `upgrader` + timer = time.NewTimer(serv.Options.ReadWriteTimeout) ctx context.Context hs *Handshake @@ -284,59 +285,70 @@ func (serv *Server) upgrader() { err error ) - for conn = range serv.chUpgrade { - packet, err = Recv(conn, serv.Options.ReadWriteTimeout) - if err != nil { - log.Printf(`%s: %s`, logp, err) - unix.Close(conn) - continue - } - if len(packet) == 0 { - unix.Close(conn) - continue - } + for { + select { + case conn = <-serv.chUpgrade: + packet, err = Recv(conn, serv.Options.ReadWriteTimeout) + if err != nil { + log.Printf(`%s: %s`, logp, err) + unix.Close(conn) + break + } + if len(packet) == 0 { + unix.Close(conn) + break + } - hs, err = newHandshake(packet) - if err != nil { - serv.handleError(conn, http.StatusBadRequest, err.Error()) - continue - } + hs, err = newHandshake(packet) + if err != nil { + serv.handleError(conn, http.StatusBadRequest, err.Error()) + break + } - if hs.URL.Path == serv.Options.StatusPath { - serv.handleStatus(conn) - continue - } - if hs.URL.Path != serv.Options.ConnectPath { - serv.handleError(conn, http.StatusNotFound, "unknown path") - continue - } + if hs.URL.Path == serv.Options.StatusPath { + serv.handleStatus(conn) + break + } + if hs.URL.Path != serv.Options.ConnectPath { + serv.handleError(conn, http.StatusNotFound, "unknown path") + break + } - ctx, key, err = serv.handleUpgrade(hs) - if err != nil { - serv.handleError(conn, http.StatusBadRequest, err.Error()) - continue - } + ctx, key, err = serv.handleUpgrade(hs) + if err != nil { + serv.handleError(conn, http.StatusBadRequest, err.Error()) + break + } - wsAccept = generateHandshakeAccept(key) + wsAccept = generateHandshakeAccept(key) - httpRes = _resUpgradeOK + wsAccept + "\r\n\r\n" + httpRes = _resUpgradeOK + wsAccept + "\r\n\r\n" - err = Send(conn, []byte(httpRes), serv.Options.ReadWriteTimeout) - if err != nil { - log.Printf(`%s: %s`, logp, err) - unix.Close(conn) - continue - } + err = Send(conn, []byte(httpRes), serv.Options.ReadWriteTimeout) + if err != nil { + log.Printf(`%s: %s`, logp, err) + unix.Close(conn) + break + } - if ctx == nil { - ctx = context.Background() - } + if ctx == nil { + ctx = context.Background() + } - err = serv.clientAdd(ctx, conn) - if err != nil { - log.Printf(`%s: %s`, logp, err) - unix.Close(conn) + err = serv.clientAdd(ctx, conn) + if err != nil { + log.Printf(`%s: %s`, logp, err) + unix.Close(conn) + } + + case <-timer.C: + serv.numGoUpgrade.Add(-1) + return + } + if !timer.Stop() { + <-timer.C } + timer.Reset(serv.Options.ReadWriteTimeout) } } @@ -747,7 +759,8 @@ func (serv *Server) pollReader() { // reader goroutine that consume channel that are ready to be read. func (serv *Server) reader() { var ( - logp = `reader` + logp = `reader` + timer = time.NewTimer(serv.Options.ReadWriteTimeout) frames *Frames frame *Frame @@ -757,66 +770,76 @@ func (serv *Server) reader() { isClosing bool ) - for conn = range serv.qreader { - packet, err = Recv(conn, serv.Options.ReadWriteTimeout) - if err != nil { - log.Printf(`%s: %s`, logp, err) - serv.ClientRemove(conn) - continue - } - if len(packet) == 0 { - log.Printf(`%s: empty packet`, logp) - serv.ClientRemove(conn) - continue - } + for { + select { + case conn = <-serv.qreader: + packet, err = Recv(conn, serv.Options.ReadWriteTimeout) + if err != nil { + log.Printf(`%s: %s`, logp, err) + serv.ClientRemove(conn) + break + } + if len(packet) == 0 { + log.Printf(`%s: empty packet`, logp) + serv.ClientRemove(conn) + break + } + + // Handle chopped, unfinished packet or payload. + frame, _ = serv.Clients.getFrame(conn) + if frame != nil { + packet = frame.unpack(packet) + if frame.isComplete { + serv.Clients.setFrame(conn, nil) + isClosing = serv.handleFrame(conn, frame) + if isClosing { + break + } + } + if len(packet) == 0 { + err = serv.poll.RegisterRead(conn) + if err != nil { + log.Printf(`%s: %s`, logp, err) + serv.ClientRemove(conn) + } + break + } + } + + frames = Unpack(packet) + if frames == nil { + log.Printf(`%s: empty frames`, logp) + serv.ClientRemove(conn) + break + } + + var isClosing bool + for _, frame = range frames.v { + if !frame.isComplete { + serv.Clients.setFrame(conn, frame) + continue + } - // Handle chopped, unfinished packet or payload. - frame, _ = serv.Clients.getFrame(conn) - if frame != nil { - packet = frame.unpack(packet) - if frame.isComplete { - serv.Clients.setFrame(conn, nil) isClosing = serv.handleFrame(conn, frame) if isClosing { - continue + break } } - if len(packet) == 0 { + if !isClosing { err = serv.poll.RegisterRead(conn) if err != nil { log.Printf(`%s: %s`, logp, err) serv.ClientRemove(conn) } - continue - } - } - - frames = Unpack(packet) - if frames == nil { - log.Printf(`%s: empty frames`, logp) - serv.ClientRemove(conn) - continue - } - - var isClosing bool - for _, frame = range frames.v { - if !frame.isComplete { - serv.Clients.setFrame(conn, frame) - continue - } - - isClosing = serv.handleFrame(conn, frame) - if isClosing { - break } + case <-timer.C: + serv.numGoReader.Add(-1) + return } - if !isClosing { - err = serv.poll.RegisterRead(conn) - if err != nil { - log.Printf(`%s: %s`, logp, err) - serv.ClientRemove(conn) - } + if !timer.Stop() { + <-timer.C } + timer.Reset(serv.Options.ReadWriteTimeout) } } @@ -883,17 +906,28 @@ func (serv *Server) pollPinger() { func (serv *Server) pinger() { var ( framePing = NewFramePing(false, nil) + timer = time.NewTimer(serv.Options.ReadWriteTimeout) conn int err error ) - for conn = range serv.qpinger { - err = Send(conn, framePing, serv.Options.ReadWriteTimeout) - if err != nil { - // Error on sending PING will be assumed as bad - // connection. - serv.ClientRemove(conn) + for { + select { + case conn = <-serv.qpinger: + err = Send(conn, framePing, serv.Options.ReadWriteTimeout) + if err != nil { + // Error on sending PING will be assumed as bad + // connection. + serv.ClientRemove(conn) + } + case <-timer.C: + serv.numGoPinger.Add(-1) + return + } + if !timer.Stop() { + <-timer.C } + timer.Reset(serv.Options.ReadWriteTimeout) } } |
