From c35fae8aee8bf26eadb1cfbf42bfc02cb47a686e Mon Sep 17 00:00:00 2001 From: Shulhan Date: Mon, 11 Mar 2019 05:03:44 +0700 Subject: websocket: handle receiving chopped frame on server Due to nature of TCP, frame with larger payload size may be get chopped in the middle by operating system during transmission. To handle this we check if connection have continuous frame with payload length is less that the length defined in the first frame. --- lib/websocket/server.go | 152 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 132 insertions(+), 20 deletions(-) (limited to 'lib/websocket') diff --git a/lib/websocket/server.go b/lib/websocket/server.go index 6698fc64..2443c792 100644 --- a/lib/websocket/server.go +++ b/lib/websocket/server.go @@ -39,8 +39,11 @@ type Server struct { port int sock int chUpgrade chan int - epollRead int - routes *rootRoute + + epollEvents [128]unix.EpollEvent + epollRead int + + routes *rootRoute // HandleAuth callback that will be called when receiving // client handshake. @@ -250,6 +253,20 @@ func (serv *Server) ClientRemove(conn int) { } } +// +// epollRegisterRead register the connection for read in epoll. +// +func (serv *Server) epollRegisterRead(idx, conn int) { + // See https://idea.popcount.org/2017-02-20-epoll-is-fundamentally-broken-12/ + serv.epollEvents[idx].Events = unix.EPOLLIN | unix.EPOLLONESHOT + + err := unix.EpollCtl(serv.epollRead, unix.EPOLL_CTL_MOD, conn, &serv.epollEvents[idx]) + if err != nil { + log.Println("websocket: server.reader: unix.EpollCtl: " + err.Error()) + serv.ClientRemove(conn) + } +} + func (serv *Server) upgrader() { for conn := range serv.chUpgrade { packet, err := Recv(conn) @@ -280,6 +297,10 @@ func (serv *Server) upgrader() { continue } + if ctx == nil { + ctx = context.Background() + } + err = serv.clientAdd(ctx, conn) if err != nil { log.Println("websocket: server.upgrader: clientAdd: " + err.Error()) @@ -288,6 +309,81 @@ func (serv *Server) upgrader() { } } +// +// handleChopped handle possible chopped payload. +// +// It will return true if continuous frame exist and its length is greater +// than payload. +// +// It will return false if no continuous frame exist. +// +func (serv *Server) handleChopped(x, conn int, packet []byte) bool { + frame, _ := serv.Clients.Frame(conn) + + if frame == nil { + return false + } + if frame.len == uint64(len(frame.payload)) { + // Connection contains continuous frame, but its already + // filled. + return false + } + + start := len(frame.payload) % 4 + for y := 0; y < len(packet); y++ { + packet[y] ^= frame.maskKey[start%4] + start++ + } + + frame.payload = append(frame.payload, packet...) + if uint64(len(frame.payload)) < frame.len { + // We still got unfinished payload. + serv.Clients.SetFrame(conn, frame) + serv.epollRegisterRead(x, conn) + return true + } + if frame.fin == 0 { + serv.Clients.SetFrame(conn, frame) + serv.epollRegisterRead(x, conn) + return true + } + + serv.Clients.SetFrame(conn, nil) + + // Handle full frame. + var isClosing bool + + switch frame.opcode { + case opcodeText: + serv.HandleText(conn, frame.payload) + case opcodeBin: + serv.HandleBin(conn, frame.payload) + case opcodeDataRsv3, opcodeDataRsv4, opcodeDataRsv5, opcodeDataRsv6, opcodeDataRsv7: + serv.handleBadRequest(conn) + isClosing = true + case opcodeClose: + serv.handleClose(conn, frame) + isClosing = true + case opcodePing: + serv.handlePing(conn, frame) + case opcodePong: + if serv.handlePong != nil { + serv.handlePong(conn, frame) + } + case opcodeControlRsvB, opcodeControlRsvC, opcodeControlRsvD, opcodeControlRsvE, opcodeControlRsvF: + if serv.HandleRsvControl != nil { + serv.HandleRsvControl(conn, frame) + } else { + serv.handleClose(conn, frame) + isClosing = true + } + } + if !isClosing { + serv.epollRegisterRead(x, conn) + } + return true +} + // // handleFragment will handle continuation frame (fragmentation). // @@ -317,6 +413,11 @@ func (serv *Server) upgrader() { func (serv *Server) handleFragment(conn int, req *Frame) { frame, ok := serv.Clients.Frame(conn) + if debug.Value >= 3 { + log.Printf("websocket: Server.handleFragment: frame: {fin:%d opcode:%d len:%d, payload.len:%d}\n", + req.fin, req.opcode, req.len, len(req.payload)) + } + if frame == nil { frame = req } else { @@ -330,14 +431,21 @@ func (serv *Server) handleFragment(conn int, req *Frame) { serv.Clients.SetFrame(conn, frame) return } + + // Frame with fin set with chopped payload. + if uint64(len(frame.payload)) < frame.len { + serv.Clients.SetFrame(conn, frame) + return + } + if ok { serv.Clients.SetFrame(conn, nil) } if frame.opcode == opcodeText { - go serv.HandleText(conn, frame.payload) + serv.HandleText(conn, frame.payload) } else { - go serv.HandleBin(conn, frame.payload) + serv.HandleBin(conn, frame.payload) } } @@ -412,14 +520,20 @@ func (serv *Server) handleBin(conn int, payload []byte) { // handleClose request from client. // func (serv *Server) handleClose(conn int, req *Frame) { - req.opcode = opcodeClose - req.masked = 0 + if req.closeCode == 0 { + req.closeCode = StatusNormal + } - res := req.Pack(false) + packet := NewFrameClose(false, req.closeCode, req.payload) - err := Send(conn, res) + if debug.Value >= 3 { + log.Printf("websocket: Server.handleClose: req: %+v\n", req) + log.Printf("websocket: Server.handleClose: packet: % x\n", packet) + } + + err := Send(conn, packet) if err != nil { - log.Println("websocket: server.handleClose: " + err.Error()) + log.Println("websocket: server.handleClose: Send: " + err.Error()) } serv.ClientRemove(conn) @@ -486,19 +600,18 @@ func (serv *Server) handlePing(conn int, req *Frame) { // func (serv *Server) reader() { var ( - events [128]unix.EpollEvent isClosing bool ) for { - nevents, err := unix.EpollWait(serv.epollRead, events[:], -1) + nevents, err := unix.EpollWait(serv.epollRead, serv.epollEvents[:], -1) if err != nil { log.Println("websocket: server.reader: unix.EpollWait: " + err.Error()) break } for x := 0; x < nevents; x++ { - conn := int(events[x].Fd) + conn := int(serv.epollEvents[x].Fd) packet, err := Recv(conn) if err != nil || len(packet) == 0 { @@ -511,6 +624,12 @@ func (serv *Server) reader() { len(packet), packet) } + // Handle chopped, unfinished payload. + isChopped := serv.handleChopped(x, conn, packet) + if isChopped { + continue + } + frames := Unpack(packet) if frames == nil { serv.ClientRemove(conn) @@ -575,14 +694,7 @@ func (serv *Server) reader() { } if !isClosing { - // See https://idea.popcount.org/2017-02-20-epoll-is-fundamentally-broken-12/ - events[x].Events = unix.EPOLLIN | unix.EPOLLONESHOT - - err = unix.EpollCtl(serv.epollRead, unix.EPOLL_CTL_MOD, conn, &events[x]) - if err != nil { - log.Println("websocket: server.reader: unix.EpollCtl: " + err.Error()) - continue - } + serv.epollRegisterRead(x, conn) } } } -- cgit v1.3-5-g9baa