diff options
| author | Shulhan <ms@kilabit.info> | 2023-06-26 02:21:06 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2023-06-29 21:25:29 +0700 |
| commit | 8375e6cff58ed31d40354186daa7891937c2bdc3 (patch) | |
| tree | 17b2fd91ed2f05e37d85b9ee1667dbe61a75a340 | |
| parent | ac872acf70909c2de6d01a53d73771778de2f922 (diff) | |
| download | pakakeh.go-8375e6cff58ed31d40354186daa7891937c2bdc3.tar.xz | |
lib/websocket: add option to set read/write timeout on Server
The ReadWriteTimeout define the maximum duration the server wait when
receiving/sending packet from/to client before considering the
connection as broken.
Default read-write timeout is 30 seconds if not set.
This changes affect the exported function Send and Recv by adding
additional parameter timeout to both of them.
| -rw-r--r-- | lib/websocket/examples/cmd/server/main.go | 7 | ||||
| -rw-r--r-- | lib/websocket/funcs.go | 77 | ||||
| -rw-r--r-- | lib/websocket/server.go | 26 | ||||
| -rw-r--r-- | lib/websocket/server_options.go | 12 | ||||
| -rw-r--r-- | lib/websocket/testdata/server/main.go | 30 | ||||
| -rw-r--r-- | lib/websocket/websocket_test.go | 4 |
6 files changed, 115 insertions, 41 deletions
diff --git a/lib/websocket/examples/cmd/server/main.go b/lib/websocket/examples/cmd/server/main.go index d029036b..49733fb4 100644 --- a/lib/websocket/examples/cmd/server/main.go +++ b/lib/websocket/examples/cmd/server/main.go @@ -11,6 +11,7 @@ import ( "fmt" "log" "net/http" + "time" "github.com/shuLhan/share/lib/websocket" "github.com/shuLhan/share/lib/websocket/examples" @@ -90,7 +91,7 @@ func handleClientAdd(ctx context.Context, conn int) { if c == conn { continue } - err = websocket.Send(c, packet) + err = websocket.Send(c, packet, 1*time.Second) if err != nil { log.Println(err) } @@ -122,7 +123,7 @@ func handleClientRemove(ctx context.Context, conn int) { if c == conn { continue } - err = websocket.Send(c, packet) + err = websocket.Send(c, packet, 1*time.Second) if err != nil { log.Println(err) } @@ -156,7 +157,7 @@ func handlePostMessage(ctx context.Context, req *websocket.Request) (res websock if conn == req.Conn { continue } - err = websocket.Send(conn, packet) + err = websocket.Send(conn, packet, 1*time.Second) if err != nil { log.Println(err) } diff --git a/lib/websocket/funcs.go b/lib/websocket/funcs.go index 804bdbf0..99db9d06 100644 --- a/lib/websocket/funcs.go +++ b/lib/websocket/funcs.go @@ -8,8 +8,11 @@ import ( "crypto/sha1" "encoding/base64" "encoding/binary" + "fmt" "math/rand" + "os" "syscall" + "time" "golang.org/x/sys/unix" ) @@ -19,31 +22,49 @@ import ( // This number should be lower than MTU for better handling larger payload. const maxBuffer = 1024 -// Recv read all content from file descriptor into slice of bytes. -// -// On success it will return buffer from pool. Caller must put the buffer back -// to the pool. -// -// On fail it will return nil buffer and error. -func Recv(fd int) (packet []byte, err error) { +// Recv read packet from socket fd. +// The timeout parameter is optional, define the timeout when reading from +// socket. +// If timeout is zero the Recv operation will block until a data arrived. +// If timeout is greater than zero, the Recv operation will return +// os.ErrDeadlineExceeded when no data received after timeout duration. +func Recv(fd int, timeout time.Duration) (packet []byte, err error) { var ( - buf []byte = make([]byte, maxBuffer) + logp = `Recv` + buf = make([]byte, maxBuffer) + timeval = unix.Timeval{} errno syscall.Errno n int ok bool ) + err = unix.SetNonblock(fd, false) + if err != nil { + return nil, fmt.Errorf(`%s: SetNonblock: %w`, logp, err) + } + + if timeout > 0 { + timeval.Sec = int64(timeout.Seconds()) + err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeval) + if err != nil { + return nil, fmt.Errorf(`%s: SetsockoptTimeval: %w`, logp, err) + } + } + for { n, err = unix.Read(fd, buf) if err != nil { errno, ok = err.(unix.Errno) if ok { - if errno == unix.EAGAIN || errno == unix.EINTR { + if errno == unix.EINTR { continue } + if errno == unix.EAGAIN || errno == unix.EWOULDBLOCK { + return nil, fmt.Errorf(`%s: %w`, logp, os.ErrDeadlineExceeded) + } } - break + return nil, fmt.Errorf(`%s: Read: %w`, logp, err) } if n > 0 { packet = append(packet, buf[:n]...) @@ -53,18 +74,39 @@ func Recv(fd int) (packet []byte, err error) { } } - return packet, err + return packet, nil } -// Send the packet through web socket file descriptor `fd`. -func Send(fd int, packet []byte) (err error) { +// Send the packet through socket file descriptor fd. +// The timeout parameter is optional, its define the maximum duration when +// socket write should wait before considered fail. +// If timeout is zero, Send will block until buffer is available. +// If timeout is greater than zero, and Send has wait for this duration for +// buffer available then it will return os.ErrDeadlineExceeded. +func Send(fd int, packet []byte, timeout time.Duration) (err error) { var ( + logp = `Send` + timeval = unix.Timeval{} + errno syscall.Errno max int n int ok bool ) + err = unix.SetNonblock(fd, false) + if err != nil { + return fmt.Errorf(`%s: SetNonblock: %w`, logp, err) + } + + if timeout > 0 { + timeval.Sec = int64(timeout.Seconds()) + err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &timeval) + if err != nil { + return fmt.Errorf(`%s: SetsockoptTimeval: %w`, logp, err) + } + } + for len(packet) > 0 { if len(packet) < maxBuffer { max = len(packet) @@ -76,11 +118,14 @@ func Send(fd int, packet []byte) (err error) { if err != nil { errno, ok = err.(unix.Errno) if ok { - if errno == unix.EAGAIN { + if errno == unix.EINTR { continue } + if errno == unix.EAGAIN || errno == unix.EWOULDBLOCK { + return fmt.Errorf(`%s: %w`, logp, os.ErrDeadlineExceeded) + } } - return err + return fmt.Errorf(`%s: Write: %w`, logp, err) } if n > 0 { @@ -88,7 +133,7 @@ func Send(fd int, packet []byte) (err error) { } } - return err + return nil } // generateHandshakeAccept generate server accept key by concatenating key, diff --git a/lib/websocket/server.go b/lib/websocket/server.go index 11c7b383..2a4315f7 100644 --- a/lib/websocket/server.go +++ b/lib/websocket/server.go @@ -175,7 +175,7 @@ func (serv *Server) handleError(conn int, code int, msg string) { err error ) - err = Send(conn, []byte(rspBody)) + err = Send(conn, []byte(rspBody), serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket: server.handleError: " + err.Error()) } @@ -263,7 +263,7 @@ func (serv *Server) upgrader() { ) for conn = range serv.chUpgrade { - packet, err = Recv(conn) + packet, err = Recv(conn, serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket: server.upgrader: " + err.Error()) unix.Close(conn) @@ -299,7 +299,7 @@ func (serv *Server) upgrader() { httpRes = _resUpgradeOK + wsAccept + "\r\n\r\n" - err = Send(conn, []byte(httpRes)) + err = Send(conn, []byte(httpRes), serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket: server.upgrader: Send: ", err.Error()) @@ -525,7 +525,7 @@ func (serv *Server) handleStatus(conn int) { err error ) - err = Send(conn, []byte(res)) + err = Send(conn, []byte(res), serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket /health: Send: ", err.Error()) } @@ -584,7 +584,7 @@ func (serv *Server) handleClose(conn int, req *Frame) { err error ) - err = Send(conn, packet) + err = Send(conn, packet, serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket: server.handleClose: Send: " + err.Error()) } @@ -599,12 +599,12 @@ func (serv *Server) handleBadRequest(conn int) { err error ) - err = Send(conn, frameClose) + err = Send(conn, frameClose, serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket: server.handleBadRequest: " + err.Error()) } - _, err = Recv(conn) + _, err = Recv(conn, serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket: server.handleBadRequest: " + err.Error()) } @@ -619,12 +619,12 @@ func (serv *Server) handleInvalidData(conn int) { err error ) - err = Send(conn, frameClose) + err = Send(conn, frameClose, serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket: server.handleInvalidData: " + err.Error()) } - _, err = Recv(conn) + _, err = Recv(conn, serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket: server.handleInvalidData: " + err.Error()) } @@ -649,7 +649,7 @@ func (serv *Server) handlePing(conn int, req *Frame) { err error ) - err = Send(conn, res) + err = Send(conn, res, serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket: server.handlePing: " + err.Error()) serv.ClientRemove(conn) @@ -689,7 +689,7 @@ func (serv *Server) reader() { for x = 0; x < len(fds); x++ { conn = fds[x] - packet, err = Recv(conn) + packet, err = Recv(conn, serv.Options.ReadWriteTimeout) if err != nil || len(packet) == 0 { serv.ClientRemove(conn) continue @@ -755,7 +755,7 @@ func (serv *Server) pinger() { all = serv.Clients.All() for _, conn = range all { - err = Send(conn, framePing) + err = Send(conn, framePing, serv.Options.ReadWriteTimeout) if err != nil { // Error on sending PING will be // assumed as bad connection. @@ -830,7 +830,7 @@ func (serv *Server) sendResponse(conn int, res *Response) (err error) { packet = NewFrameText(false, packet) - err = Send(conn, packet) + err = Send(conn, packet, serv.Options.ReadWriteTimeout) if err != nil { log.Println("websocket: server.sendResponse: " + err.Error()) } diff --git a/lib/websocket/server_options.go b/lib/websocket/server_options.go index e2511360..c3247e1e 100644 --- a/lib/websocket/server_options.go +++ b/lib/websocket/server_options.go @@ -6,12 +6,15 @@ package websocket import ( "path" + "time" ) const ( defServerAddress = ":80" defServerConnectPath = "/" defServerStatusPath = "/status" + + defServerReadWriteTimeout = 30 * time.Second ) // ServerOptions contain options to configure the WebSocket server. @@ -60,6 +63,12 @@ type ServerOptions struct { // Default to ConnectPath +"/status" if its empty. // The StatusPath is handled by HandleStatus callback in the server. StatusPath string + + // ReadWriteTimeout define the maximum duration the server wait for + // receiving/sending packet from/to client before considering the + // connection as broken. + // Default to 30 seconds. + ReadWriteTimeout time.Duration } func (opts *ServerOptions) init() { @@ -72,4 +81,7 @@ func (opts *ServerOptions) init() { if len(opts.StatusPath) == 0 { opts.StatusPath = path.Join(opts.ConnectPath, defServerStatusPath) } + if opts.ReadWriteTimeout <= 0 { + opts.ReadWriteTimeout = defServerReadWriteTimeout + } } diff --git a/lib/websocket/testdata/server/main.go b/lib/websocket/testdata/server/main.go index b0e2e46d..7330c689 100644 --- a/lib/websocket/testdata/server/main.go +++ b/lib/websocket/testdata/server/main.go @@ -11,6 +11,7 @@ import ( "flag" "fmt" "log" + "time" "github.com/shuLhan/share/lib/websocket" ) @@ -22,6 +23,8 @@ const ( // handleBin from websocket by echo-ing back the payload. func main() { var ( + timeout = 30 * time.Second + srv *websocket.Server err error ) @@ -40,29 +43,42 @@ func main() { Address: `0.0.0.0:9001`, HandleBin: func(conn int, payload []byte) { var ( - packet []byte = websocket.NewFrameBin(false, payload) - err error + timeStart = time.Now() + packet []byte = websocket.NewFrameBin(false, payload) + err error ) - err = websocket.Send(conn, packet) + err = websocket.Send(conn, packet, timeout) if err != nil { log.Println("handleBin: " + err.Error()) } + + var elapsed = time.Now().Sub(timeStart) + if elapsed >= timeout { + log.Printf(`HandleBin: %s`, elapsed) + } }, HandleText: func(conn int, payload []byte) { var ( - packet []byte = websocket.NewFrameText(false, payload) - err error + timeStart = time.Now() + packet []byte = websocket.NewFrameText(false, payload) + err error ) - err = websocket.Send(conn, packet) + + err = websocket.Send(conn, packet, timeout) if err != nil { log.Println("handleText: " + err.Error()) } + var elapsed = time.Now().Sub(timeStart) + if elapsed >= timeout { + log.Printf(`HandleText: %s`, elapsed) + } + if bytes.Equal(payload, []byte(cmdShutdown)) { log.Println(`Shutting down server...`) - srv.Stop() + os.Exit(0) } }, } diff --git a/lib/websocket/websocket_test.go b/lib/websocket/websocket_test.go index d2fcc404..6e1e6c39 100644 --- a/lib/websocket/websocket_test.go +++ b/lib/websocket/websocket_test.go @@ -57,7 +57,7 @@ func testHandleText(conn int, payload []byte) { err error ) - err = Send(conn, packet) + err = Send(conn, packet, 1*time.Second) if err != nil { log.Println("handlePayloadText: " + err.Error()) } @@ -70,7 +70,7 @@ func testHandleBin(conn int, payload []byte) { err error ) - err = Send(conn, packet) + err = Send(conn, packet, 1*time.Second) if err != nil { log.Println("handlePayloadBin: " + err.Error()) } |
