aboutsummaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/websocket/server.go152
1 files changed, 132 insertions, 20 deletions
diff --git a/lib/websocket/server.go b/lib/websocket/server.go
index 6698fc64..2443c792 100644
--- a/lib/websocket/server.go
+++ b/lib/websocket/server.go
@@ -39,8 +39,11 @@ type Server struct {
port int
sock int
chUpgrade chan int
- epollRead int
- routes *rootRoute
+
+ epollEvents [128]unix.EpollEvent
+ epollRead int
+
+ routes *rootRoute
// HandleAuth callback that will be called when receiving
// client handshake.
@@ -250,6 +253,20 @@ func (serv *Server) ClientRemove(conn int) {
}
}
+//
+// epollRegisterRead register the connection for read in epoll.
+//
+func (serv *Server) epollRegisterRead(idx, conn int) {
+ // See https://idea.popcount.org/2017-02-20-epoll-is-fundamentally-broken-12/
+ serv.epollEvents[idx].Events = unix.EPOLLIN | unix.EPOLLONESHOT
+
+ err := unix.EpollCtl(serv.epollRead, unix.EPOLL_CTL_MOD, conn, &serv.epollEvents[idx])
+ if err != nil {
+ log.Println("websocket: server.reader: unix.EpollCtl: " + err.Error())
+ serv.ClientRemove(conn)
+ }
+}
+
func (serv *Server) upgrader() {
for conn := range serv.chUpgrade {
packet, err := Recv(conn)
@@ -280,6 +297,10 @@ func (serv *Server) upgrader() {
continue
}
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
err = serv.clientAdd(ctx, conn)
if err != nil {
log.Println("websocket: server.upgrader: clientAdd: " + err.Error())
@@ -289,6 +310,81 @@ func (serv *Server) upgrader() {
}
//
+// handleChopped handle possible chopped payload.
+//
+// It will return true if continuous frame exist and its length is greater
+// than payload.
+//
+// It will return false if no continuous frame exist.
+//
+func (serv *Server) handleChopped(x, conn int, packet []byte) bool {
+ frame, _ := serv.Clients.Frame(conn)
+
+ if frame == nil {
+ return false
+ }
+ if frame.len == uint64(len(frame.payload)) {
+ // Connection contains continuous frame, but its already
+ // filled.
+ return false
+ }
+
+ start := len(frame.payload) % 4
+ for y := 0; y < len(packet); y++ {
+ packet[y] ^= frame.maskKey[start%4]
+ start++
+ }
+
+ frame.payload = append(frame.payload, packet...)
+ if uint64(len(frame.payload)) < frame.len {
+ // We still got unfinished payload.
+ serv.Clients.SetFrame(conn, frame)
+ serv.epollRegisterRead(x, conn)
+ return true
+ }
+ if frame.fin == 0 {
+ serv.Clients.SetFrame(conn, frame)
+ serv.epollRegisterRead(x, conn)
+ return true
+ }
+
+ serv.Clients.SetFrame(conn, nil)
+
+ // Handle full frame.
+ var isClosing bool
+
+ switch frame.opcode {
+ case opcodeText:
+ serv.HandleText(conn, frame.payload)
+ case opcodeBin:
+ serv.HandleBin(conn, frame.payload)
+ case opcodeDataRsv3, opcodeDataRsv4, opcodeDataRsv5, opcodeDataRsv6, opcodeDataRsv7:
+ serv.handleBadRequest(conn)
+ isClosing = true
+ case opcodeClose:
+ serv.handleClose(conn, frame)
+ isClosing = true
+ case opcodePing:
+ serv.handlePing(conn, frame)
+ case opcodePong:
+ if serv.handlePong != nil {
+ serv.handlePong(conn, frame)
+ }
+ case opcodeControlRsvB, opcodeControlRsvC, opcodeControlRsvD, opcodeControlRsvE, opcodeControlRsvF:
+ if serv.HandleRsvControl != nil {
+ serv.HandleRsvControl(conn, frame)
+ } else {
+ serv.handleClose(conn, frame)
+ isClosing = true
+ }
+ }
+ if !isClosing {
+ serv.epollRegisterRead(x, conn)
+ }
+ return true
+}
+
+//
// handleFragment will handle continuation frame (fragmentation).
//
// (RFC 6455 Section 5.4 Page 34)
@@ -317,6 +413,11 @@ func (serv *Server) upgrader() {
func (serv *Server) handleFragment(conn int, req *Frame) {
frame, ok := serv.Clients.Frame(conn)
+ if debug.Value >= 3 {
+ log.Printf("websocket: Server.handleFragment: frame: {fin:%d opcode:%d len:%d, payload.len:%d}\n",
+ req.fin, req.opcode, req.len, len(req.payload))
+ }
+
if frame == nil {
frame = req
} else {
@@ -330,14 +431,21 @@ func (serv *Server) handleFragment(conn int, req *Frame) {
serv.Clients.SetFrame(conn, frame)
return
}
+
+ // Frame with fin set with chopped payload.
+ if uint64(len(frame.payload)) < frame.len {
+ serv.Clients.SetFrame(conn, frame)
+ return
+ }
+
if ok {
serv.Clients.SetFrame(conn, nil)
}
if frame.opcode == opcodeText {
- go serv.HandleText(conn, frame.payload)
+ serv.HandleText(conn, frame.payload)
} else {
- go serv.HandleBin(conn, frame.payload)
+ serv.HandleBin(conn, frame.payload)
}
}
@@ -412,14 +520,20 @@ func (serv *Server) handleBin(conn int, payload []byte) {
// handleClose request from client.
//
func (serv *Server) handleClose(conn int, req *Frame) {
- req.opcode = opcodeClose
- req.masked = 0
+ if req.closeCode == 0 {
+ req.closeCode = StatusNormal
+ }
- res := req.Pack(false)
+ packet := NewFrameClose(false, req.closeCode, req.payload)
- err := Send(conn, res)
+ if debug.Value >= 3 {
+ log.Printf("websocket: Server.handleClose: req: %+v\n", req)
+ log.Printf("websocket: Server.handleClose: packet: % x\n", packet)
+ }
+
+ err := Send(conn, packet)
if err != nil {
- log.Println("websocket: server.handleClose: " + err.Error())
+ log.Println("websocket: server.handleClose: Send: " + err.Error())
}
serv.ClientRemove(conn)
@@ -486,19 +600,18 @@ func (serv *Server) handlePing(conn int, req *Frame) {
//
func (serv *Server) reader() {
var (
- events [128]unix.EpollEvent
isClosing bool
)
for {
- nevents, err := unix.EpollWait(serv.epollRead, events[:], -1)
+ nevents, err := unix.EpollWait(serv.epollRead, serv.epollEvents[:], -1)
if err != nil {
log.Println("websocket: server.reader: unix.EpollWait: " + err.Error())
break
}
for x := 0; x < nevents; x++ {
- conn := int(events[x].Fd)
+ conn := int(serv.epollEvents[x].Fd)
packet, err := Recv(conn)
if err != nil || len(packet) == 0 {
@@ -511,6 +624,12 @@ func (serv *Server) reader() {
len(packet), packet)
}
+ // Handle chopped, unfinished payload.
+ isChopped := serv.handleChopped(x, conn, packet)
+ if isChopped {
+ continue
+ }
+
frames := Unpack(packet)
if frames == nil {
serv.ClientRemove(conn)
@@ -575,14 +694,7 @@ func (serv *Server) reader() {
}
if !isClosing {
- // See https://idea.popcount.org/2017-02-20-epoll-is-fundamentally-broken-12/
- events[x].Events = unix.EPOLLIN | unix.EPOLLONESHOT
-
- err = unix.EpollCtl(serv.epollRead, unix.EPOLL_CTL_MOD, conn, &events[x])
- if err != nil {
- log.Println("websocket: server.reader: unix.EpollCtl: " + err.Error())
- continue
- }
+ serv.epollRegisterRead(x, conn)
}
}
}