diff options
Diffstat (limited to 'lib/websocket/server.go')
| -rw-r--r-- | lib/websocket/server.go | 152 |
1 files changed, 132 insertions, 20 deletions
diff --git a/lib/websocket/server.go b/lib/websocket/server.go index 6698fc64..2443c792 100644 --- a/lib/websocket/server.go +++ b/lib/websocket/server.go @@ -39,8 +39,11 @@ type Server struct { port int sock int chUpgrade chan int - epollRead int - routes *rootRoute + + epollEvents [128]unix.EpollEvent + epollRead int + + routes *rootRoute // HandleAuth callback that will be called when receiving // client handshake. @@ -250,6 +253,20 @@ func (serv *Server) ClientRemove(conn int) { } } +// +// epollRegisterRead register the connection for read in epoll. +// +func (serv *Server) epollRegisterRead(idx, conn int) { + // See https://idea.popcount.org/2017-02-20-epoll-is-fundamentally-broken-12/ + serv.epollEvents[idx].Events = unix.EPOLLIN | unix.EPOLLONESHOT + + err := unix.EpollCtl(serv.epollRead, unix.EPOLL_CTL_MOD, conn, &serv.epollEvents[idx]) + if err != nil { + log.Println("websocket: server.reader: unix.EpollCtl: " + err.Error()) + serv.ClientRemove(conn) + } +} + func (serv *Server) upgrader() { for conn := range serv.chUpgrade { packet, err := Recv(conn) @@ -280,6 +297,10 @@ func (serv *Server) upgrader() { continue } + if ctx == nil { + ctx = context.Background() + } + err = serv.clientAdd(ctx, conn) if err != nil { log.Println("websocket: server.upgrader: clientAdd: " + err.Error()) @@ -289,6 +310,81 @@ func (serv *Server) upgrader() { } // +// handleChopped handle possible chopped payload. +// +// It will return true if continuous frame exist and its length is greater +// than payload. +// +// It will return false if no continuous frame exist. +// +func (serv *Server) handleChopped(x, conn int, packet []byte) bool { + frame, _ := serv.Clients.Frame(conn) + + if frame == nil { + return false + } + if frame.len == uint64(len(frame.payload)) { + // Connection contains continuous frame, but its already + // filled. + return false + } + + start := len(frame.payload) % 4 + for y := 0; y < len(packet); y++ { + packet[y] ^= frame.maskKey[start%4] + start++ + } + + frame.payload = append(frame.payload, packet...) + if uint64(len(frame.payload)) < frame.len { + // We still got unfinished payload. + serv.Clients.SetFrame(conn, frame) + serv.epollRegisterRead(x, conn) + return true + } + if frame.fin == 0 { + serv.Clients.SetFrame(conn, frame) + serv.epollRegisterRead(x, conn) + return true + } + + serv.Clients.SetFrame(conn, nil) + + // Handle full frame. + var isClosing bool + + switch frame.opcode { + case opcodeText: + serv.HandleText(conn, frame.payload) + case opcodeBin: + serv.HandleBin(conn, frame.payload) + case opcodeDataRsv3, opcodeDataRsv4, opcodeDataRsv5, opcodeDataRsv6, opcodeDataRsv7: + serv.handleBadRequest(conn) + isClosing = true + case opcodeClose: + serv.handleClose(conn, frame) + isClosing = true + case opcodePing: + serv.handlePing(conn, frame) + case opcodePong: + if serv.handlePong != nil { + serv.handlePong(conn, frame) + } + case opcodeControlRsvB, opcodeControlRsvC, opcodeControlRsvD, opcodeControlRsvE, opcodeControlRsvF: + if serv.HandleRsvControl != nil { + serv.HandleRsvControl(conn, frame) + } else { + serv.handleClose(conn, frame) + isClosing = true + } + } + if !isClosing { + serv.epollRegisterRead(x, conn) + } + return true +} + +// // handleFragment will handle continuation frame (fragmentation). // // (RFC 6455 Section 5.4 Page 34) @@ -317,6 +413,11 @@ func (serv *Server) upgrader() { func (serv *Server) handleFragment(conn int, req *Frame) { frame, ok := serv.Clients.Frame(conn) + if debug.Value >= 3 { + log.Printf("websocket: Server.handleFragment: frame: {fin:%d opcode:%d len:%d, payload.len:%d}\n", + req.fin, req.opcode, req.len, len(req.payload)) + } + if frame == nil { frame = req } else { @@ -330,14 +431,21 @@ func (serv *Server) handleFragment(conn int, req *Frame) { serv.Clients.SetFrame(conn, frame) return } + + // Frame with fin set with chopped payload. + if uint64(len(frame.payload)) < frame.len { + serv.Clients.SetFrame(conn, frame) + return + } + if ok { serv.Clients.SetFrame(conn, nil) } if frame.opcode == opcodeText { - go serv.HandleText(conn, frame.payload) + serv.HandleText(conn, frame.payload) } else { - go serv.HandleBin(conn, frame.payload) + serv.HandleBin(conn, frame.payload) } } @@ -412,14 +520,20 @@ func (serv *Server) handleBin(conn int, payload []byte) { // handleClose request from client. // func (serv *Server) handleClose(conn int, req *Frame) { - req.opcode = opcodeClose - req.masked = 0 + if req.closeCode == 0 { + req.closeCode = StatusNormal + } - res := req.Pack(false) + packet := NewFrameClose(false, req.closeCode, req.payload) - err := Send(conn, res) + if debug.Value >= 3 { + log.Printf("websocket: Server.handleClose: req: %+v\n", req) + log.Printf("websocket: Server.handleClose: packet: % x\n", packet) + } + + err := Send(conn, packet) if err != nil { - log.Println("websocket: server.handleClose: " + err.Error()) + log.Println("websocket: server.handleClose: Send: " + err.Error()) } serv.ClientRemove(conn) @@ -486,19 +600,18 @@ func (serv *Server) handlePing(conn int, req *Frame) { // func (serv *Server) reader() { var ( - events [128]unix.EpollEvent isClosing bool ) for { - nevents, err := unix.EpollWait(serv.epollRead, events[:], -1) + nevents, err := unix.EpollWait(serv.epollRead, serv.epollEvents[:], -1) if err != nil { log.Println("websocket: server.reader: unix.EpollWait: " + err.Error()) break } for x := 0; x < nevents; x++ { - conn := int(events[x].Fd) + conn := int(serv.epollEvents[x].Fd) packet, err := Recv(conn) if err != nil || len(packet) == 0 { @@ -511,6 +624,12 @@ func (serv *Server) reader() { len(packet), packet) } + // Handle chopped, unfinished payload. + isChopped := serv.handleChopped(x, conn, packet) + if isChopped { + continue + } + frames := Unpack(packet) if frames == nil { serv.ClientRemove(conn) @@ -575,14 +694,7 @@ func (serv *Server) reader() { } if !isClosing { - // See https://idea.popcount.org/2017-02-20-epoll-is-fundamentally-broken-12/ - events[x].Events = unix.EPOLLIN | unix.EPOLLONESHOT - - err = unix.EpollCtl(serv.epollRead, unix.EPOLL_CTL_MOD, conn, &events[x]) - if err != nil { - log.Println("websocket: server.reader: unix.EpollCtl: " + err.Error()) - continue - } + serv.epollRegisterRead(x, conn) } } } |
