summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2023-06-26 02:21:06 +0700
committerShulhan <ms@kilabit.info>2023-06-29 21:25:29 +0700
commit8375e6cff58ed31d40354186daa7891937c2bdc3 (patch)
tree17b2fd91ed2f05e37d85b9ee1667dbe61a75a340
parentac872acf70909c2de6d01a53d73771778de2f922 (diff)
downloadpakakeh.go-8375e6cff58ed31d40354186daa7891937c2bdc3.tar.xz
lib/websocket: add option to set read/write timeout on Server
The ReadWriteTimeout define the maximum duration the server wait when receiving/sending packet from/to client before considering the connection as broken. Default read-write timeout is 30 seconds if not set. This changes affect the exported function Send and Recv by adding additional parameter timeout to both of them.
-rw-r--r--lib/websocket/examples/cmd/server/main.go7
-rw-r--r--lib/websocket/funcs.go77
-rw-r--r--lib/websocket/server.go26
-rw-r--r--lib/websocket/server_options.go12
-rw-r--r--lib/websocket/testdata/server/main.go30
-rw-r--r--lib/websocket/websocket_test.go4
6 files changed, 115 insertions, 41 deletions
diff --git a/lib/websocket/examples/cmd/server/main.go b/lib/websocket/examples/cmd/server/main.go
index d029036b..49733fb4 100644
--- a/lib/websocket/examples/cmd/server/main.go
+++ b/lib/websocket/examples/cmd/server/main.go
@@ -11,6 +11,7 @@ import (
"fmt"
"log"
"net/http"
+ "time"
"github.com/shuLhan/share/lib/websocket"
"github.com/shuLhan/share/lib/websocket/examples"
@@ -90,7 +91,7 @@ func handleClientAdd(ctx context.Context, conn int) {
if c == conn {
continue
}
- err = websocket.Send(c, packet)
+ err = websocket.Send(c, packet, 1*time.Second)
if err != nil {
log.Println(err)
}
@@ -122,7 +123,7 @@ func handleClientRemove(ctx context.Context, conn int) {
if c == conn {
continue
}
- err = websocket.Send(c, packet)
+ err = websocket.Send(c, packet, 1*time.Second)
if err != nil {
log.Println(err)
}
@@ -156,7 +157,7 @@ func handlePostMessage(ctx context.Context, req *websocket.Request) (res websock
if conn == req.Conn {
continue
}
- err = websocket.Send(conn, packet)
+ err = websocket.Send(conn, packet, 1*time.Second)
if err != nil {
log.Println(err)
}
diff --git a/lib/websocket/funcs.go b/lib/websocket/funcs.go
index 804bdbf0..99db9d06 100644
--- a/lib/websocket/funcs.go
+++ b/lib/websocket/funcs.go
@@ -8,8 +8,11 @@ import (
"crypto/sha1"
"encoding/base64"
"encoding/binary"
+ "fmt"
"math/rand"
+ "os"
"syscall"
+ "time"
"golang.org/x/sys/unix"
)
@@ -19,31 +22,49 @@ import (
// This number should be lower than MTU for better handling larger payload.
const maxBuffer = 1024
-// Recv read all content from file descriptor into slice of bytes.
-//
-// On success it will return buffer from pool. Caller must put the buffer back
-// to the pool.
-//
-// On fail it will return nil buffer and error.
-func Recv(fd int) (packet []byte, err error) {
+// Recv read packet from socket fd.
+// The timeout parameter is optional, define the timeout when reading from
+// socket.
+// If timeout is zero the Recv operation will block until a data arrived.
+// If timeout is greater than zero, the Recv operation will return
+// os.ErrDeadlineExceeded when no data received after timeout duration.
+func Recv(fd int, timeout time.Duration) (packet []byte, err error) {
var (
- buf []byte = make([]byte, maxBuffer)
+ logp = `Recv`
+ buf = make([]byte, maxBuffer)
+ timeval = unix.Timeval{}
errno syscall.Errno
n int
ok bool
)
+ err = unix.SetNonblock(fd, false)
+ if err != nil {
+ return nil, fmt.Errorf(`%s: SetNonblock: %w`, logp, err)
+ }
+
+ if timeout > 0 {
+ timeval.Sec = int64(timeout.Seconds())
+ err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeval)
+ if err != nil {
+ return nil, fmt.Errorf(`%s: SetsockoptTimeval: %w`, logp, err)
+ }
+ }
+
for {
n, err = unix.Read(fd, buf)
if err != nil {
errno, ok = err.(unix.Errno)
if ok {
- if errno == unix.EAGAIN || errno == unix.EINTR {
+ if errno == unix.EINTR {
continue
}
+ if errno == unix.EAGAIN || errno == unix.EWOULDBLOCK {
+ return nil, fmt.Errorf(`%s: %w`, logp, os.ErrDeadlineExceeded)
+ }
}
- break
+ return nil, fmt.Errorf(`%s: Read: %w`, logp, err)
}
if n > 0 {
packet = append(packet, buf[:n]...)
@@ -53,18 +74,39 @@ func Recv(fd int) (packet []byte, err error) {
}
}
- return packet, err
+ return packet, nil
}
-// Send the packet through web socket file descriptor `fd`.
-func Send(fd int, packet []byte) (err error) {
+// Send the packet through socket file descriptor fd.
+// The timeout parameter is optional, its define the maximum duration when
+// socket write should wait before considered fail.
+// If timeout is zero, Send will block until buffer is available.
+// If timeout is greater than zero, and Send has wait for this duration for
+// buffer available then it will return os.ErrDeadlineExceeded.
+func Send(fd int, packet []byte, timeout time.Duration) (err error) {
var (
+ logp = `Send`
+ timeval = unix.Timeval{}
+
errno syscall.Errno
max int
n int
ok bool
)
+ err = unix.SetNonblock(fd, false)
+ if err != nil {
+ return fmt.Errorf(`%s: SetNonblock: %w`, logp, err)
+ }
+
+ if timeout > 0 {
+ timeval.Sec = int64(timeout.Seconds())
+ err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &timeval)
+ if err != nil {
+ return fmt.Errorf(`%s: SetsockoptTimeval: %w`, logp, err)
+ }
+ }
+
for len(packet) > 0 {
if len(packet) < maxBuffer {
max = len(packet)
@@ -76,11 +118,14 @@ func Send(fd int, packet []byte) (err error) {
if err != nil {
errno, ok = err.(unix.Errno)
if ok {
- if errno == unix.EAGAIN {
+ if errno == unix.EINTR {
continue
}
+ if errno == unix.EAGAIN || errno == unix.EWOULDBLOCK {
+ return fmt.Errorf(`%s: %w`, logp, os.ErrDeadlineExceeded)
+ }
}
- return err
+ return fmt.Errorf(`%s: Write: %w`, logp, err)
}
if n > 0 {
@@ -88,7 +133,7 @@ func Send(fd int, packet []byte) (err error) {
}
}
- return err
+ return nil
}
// generateHandshakeAccept generate server accept key by concatenating key,
diff --git a/lib/websocket/server.go b/lib/websocket/server.go
index 11c7b383..2a4315f7 100644
--- a/lib/websocket/server.go
+++ b/lib/websocket/server.go
@@ -175,7 +175,7 @@ func (serv *Server) handleError(conn int, code int, msg string) {
err error
)
- err = Send(conn, []byte(rspBody))
+ err = Send(conn, []byte(rspBody), serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket: server.handleError: " + err.Error())
}
@@ -263,7 +263,7 @@ func (serv *Server) upgrader() {
)
for conn = range serv.chUpgrade {
- packet, err = Recv(conn)
+ packet, err = Recv(conn, serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket: server.upgrader: " + err.Error())
unix.Close(conn)
@@ -299,7 +299,7 @@ func (serv *Server) upgrader() {
httpRes = _resUpgradeOK + wsAccept + "\r\n\r\n"
- err = Send(conn, []byte(httpRes))
+ err = Send(conn, []byte(httpRes), serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket: server.upgrader: Send: ",
err.Error())
@@ -525,7 +525,7 @@ func (serv *Server) handleStatus(conn int) {
err error
)
- err = Send(conn, []byte(res))
+ err = Send(conn, []byte(res), serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket /health: Send: ", err.Error())
}
@@ -584,7 +584,7 @@ func (serv *Server) handleClose(conn int, req *Frame) {
err error
)
- err = Send(conn, packet)
+ err = Send(conn, packet, serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket: server.handleClose: Send: " + err.Error())
}
@@ -599,12 +599,12 @@ func (serv *Server) handleBadRequest(conn int) {
err error
)
- err = Send(conn, frameClose)
+ err = Send(conn, frameClose, serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket: server.handleBadRequest: " + err.Error())
}
- _, err = Recv(conn)
+ _, err = Recv(conn, serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket: server.handleBadRequest: " + err.Error())
}
@@ -619,12 +619,12 @@ func (serv *Server) handleInvalidData(conn int) {
err error
)
- err = Send(conn, frameClose)
+ err = Send(conn, frameClose, serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket: server.handleInvalidData: " + err.Error())
}
- _, err = Recv(conn)
+ _, err = Recv(conn, serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket: server.handleInvalidData: " + err.Error())
}
@@ -649,7 +649,7 @@ func (serv *Server) handlePing(conn int, req *Frame) {
err error
)
- err = Send(conn, res)
+ err = Send(conn, res, serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket: server.handlePing: " + err.Error())
serv.ClientRemove(conn)
@@ -689,7 +689,7 @@ func (serv *Server) reader() {
for x = 0; x < len(fds); x++ {
conn = fds[x]
- packet, err = Recv(conn)
+ packet, err = Recv(conn, serv.Options.ReadWriteTimeout)
if err != nil || len(packet) == 0 {
serv.ClientRemove(conn)
continue
@@ -755,7 +755,7 @@ func (serv *Server) pinger() {
all = serv.Clients.All()
for _, conn = range all {
- err = Send(conn, framePing)
+ err = Send(conn, framePing, serv.Options.ReadWriteTimeout)
if err != nil {
// Error on sending PING will be
// assumed as bad connection.
@@ -830,7 +830,7 @@ func (serv *Server) sendResponse(conn int, res *Response) (err error) {
packet = NewFrameText(false, packet)
- err = Send(conn, packet)
+ err = Send(conn, packet, serv.Options.ReadWriteTimeout)
if err != nil {
log.Println("websocket: server.sendResponse: " + err.Error())
}
diff --git a/lib/websocket/server_options.go b/lib/websocket/server_options.go
index e2511360..c3247e1e 100644
--- a/lib/websocket/server_options.go
+++ b/lib/websocket/server_options.go
@@ -6,12 +6,15 @@ package websocket
import (
"path"
+ "time"
)
const (
defServerAddress = ":80"
defServerConnectPath = "/"
defServerStatusPath = "/status"
+
+ defServerReadWriteTimeout = 30 * time.Second
)
// ServerOptions contain options to configure the WebSocket server.
@@ -60,6 +63,12 @@ type ServerOptions struct {
// Default to ConnectPath +"/status" if its empty.
// The StatusPath is handled by HandleStatus callback in the server.
StatusPath string
+
+ // ReadWriteTimeout define the maximum duration the server wait for
+ // receiving/sending packet from/to client before considering the
+ // connection as broken.
+ // Default to 30 seconds.
+ ReadWriteTimeout time.Duration
}
func (opts *ServerOptions) init() {
@@ -72,4 +81,7 @@ func (opts *ServerOptions) init() {
if len(opts.StatusPath) == 0 {
opts.StatusPath = path.Join(opts.ConnectPath, defServerStatusPath)
}
+ if opts.ReadWriteTimeout <= 0 {
+ opts.ReadWriteTimeout = defServerReadWriteTimeout
+ }
}
diff --git a/lib/websocket/testdata/server/main.go b/lib/websocket/testdata/server/main.go
index b0e2e46d..7330c689 100644
--- a/lib/websocket/testdata/server/main.go
+++ b/lib/websocket/testdata/server/main.go
@@ -11,6 +11,7 @@ import (
"flag"
"fmt"
"log"
+ "time"
"github.com/shuLhan/share/lib/websocket"
)
@@ -22,6 +23,8 @@ const (
// handleBin from websocket by echo-ing back the payload.
func main() {
var (
+ timeout = 30 * time.Second
+
srv *websocket.Server
err error
)
@@ -40,29 +43,42 @@ func main() {
Address: `0.0.0.0:9001`,
HandleBin: func(conn int, payload []byte) {
var (
- packet []byte = websocket.NewFrameBin(false, payload)
- err error
+ timeStart = time.Now()
+ packet []byte = websocket.NewFrameBin(false, payload)
+ err error
)
- err = websocket.Send(conn, packet)
+ err = websocket.Send(conn, packet, timeout)
if err != nil {
log.Println("handleBin: " + err.Error())
}
+
+ var elapsed = time.Now().Sub(timeStart)
+ if elapsed >= timeout {
+ log.Printf(`HandleBin: %s`, elapsed)
+ }
},
HandleText: func(conn int, payload []byte) {
var (
- packet []byte = websocket.NewFrameText(false, payload)
- err error
+ timeStart = time.Now()
+ packet []byte = websocket.NewFrameText(false, payload)
+ err error
)
- err = websocket.Send(conn, packet)
+
+ err = websocket.Send(conn, packet, timeout)
if err != nil {
log.Println("handleText: " + err.Error())
}
+ var elapsed = time.Now().Sub(timeStart)
+ if elapsed >= timeout {
+ log.Printf(`HandleText: %s`, elapsed)
+ }
+
if bytes.Equal(payload, []byte(cmdShutdown)) {
log.Println(`Shutting down server...`)
- srv.Stop()
+ os.Exit(0)
}
},
}
diff --git a/lib/websocket/websocket_test.go b/lib/websocket/websocket_test.go
index d2fcc404..6e1e6c39 100644
--- a/lib/websocket/websocket_test.go
+++ b/lib/websocket/websocket_test.go
@@ -57,7 +57,7 @@ func testHandleText(conn int, payload []byte) {
err error
)
- err = Send(conn, packet)
+ err = Send(conn, packet, 1*time.Second)
if err != nil {
log.Println("handlePayloadText: " + err.Error())
}
@@ -70,7 +70,7 @@ func testHandleBin(conn int, payload []byte) {
err error
)
- err = Send(conn, packet)
+ err = Send(conn, packet, 1*time.Second)
if err != nil {
log.Println("handlePayloadBin: " + err.Error())
}