diff options
| author | Shulhan <ms@kilabit.info> | 2019-03-11 23:30:22 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2019-03-11 23:30:22 +0700 |
| commit | 929eef95cd27c798d8c128ae0eea83ec34220a3f (patch) | |
| tree | b9db116953d711b5ef24af2d8a7065c7d35dc0e4 | |
| parent | 1bedaa4d003eb7595bc4bd9edd74423a1f004e25 (diff) | |
| download | pakakeh.go-929eef95cd27c798d8c128ae0eea83ec34220a3f.tar.xz | |
websocket: handle chopped frame with interjected control frame
Another possibility is fragmented frames with one of the frame is
chopped and in the middle of it is a control FRAME. For example,
C> Frame{fin:false opcode:TEXT len:1024 payload:512}
C> payload:512 FRAME{fin:TRUE opcode:PING len:0}
C> Frame{fin:true opcode:0 len:1024 payload:1024}
| -rw-r--r-- | lib/websocket/clientmanager.go | 97 | ||||
| -rw-r--r-- | lib/websocket/frame.go | 155 | ||||
| -rw-r--r-- | lib/websocket/frames.go | 10 | ||||
| -rw-r--r-- | lib/websocket/server.go | 105 | ||||
| -rw-r--r-- | lib/websocket/websocket_test.go | 2 |
5 files changed, 314 insertions, 55 deletions
diff --git a/lib/websocket/clientmanager.go b/lib/websocket/clientmanager.go index bfd6a5f7..7e3f9961 100644 --- a/lib/websocket/clientmanager.go +++ b/lib/websocket/clientmanager.go @@ -35,9 +35,13 @@ type ClientManager struct { // HandleAuth on Server. ctx map[int]context.Context - // frame contains a one-to-one mapping between a socket connection - // and a continuous frame. + // frame contains a one-to-one mapping between a socket and a frame. + // Its usually used to handle chopped frame. frame map[int]*Frame + + // frames contains a one-to-one mapping between a socket and + // continuous frame. + frames map[int]*Frames } // @@ -45,13 +49,32 @@ type ClientManager struct { // func newClientManager() *ClientManager { return &ClientManager{ - conns: make(map[uint64][]int), - ctx: make(map[int]context.Context), - frame: make(map[int]*Frame), + conns: make(map[uint64][]int), + ctx: make(map[int]context.Context), + frame: make(map[int]*Frame), + frames: make(map[int]*Frames), } } // +// AddFrame add a frame as part of continuous frame on a client connection. +// +func (cls *ClientManager) AddFrame(conn int, frame *Frame) { + cls.Lock() + frames, ok := cls.frames[conn] + if !ok { + frames = new(Frames) + } + frames.Append(frame) + if !ok { + cls.frames[conn] = frames + } + delete(cls.frame, conn) + cls.Unlock() + return +} + +// // All return a copy of all client connections. // func (cls *ClientManager) All() (conns []int) { @@ -65,6 +88,39 @@ func (cls *ClientManager) All() (conns []int) { } // +// finFrames merge all continuous frames into single frame and clear the +// stored frame and frames on behalf of connection. +// +func (cls *ClientManager) finFrames(conn int, fin *Frame) (f *Frame) { + cls.Lock() + frames, ok := cls.frames[conn] + if !ok { + cls.Unlock() + return fin + } + + f = frames.v[0] + for x := 1; x < len(frames.v); x++ { + if frames.v[x].opcode == opcodeClose { + break + } + + // Ignore control PING or PONG frame. + if frames.v[x].opcode == opcodePing || frames.v[x].opcode == opcodePong { + continue + } + + f.payload = append(f.payload, frames.v[x].payload...) + } + f.payload = append(f.payload, fin.payload...) + delete(cls.frames, conn) + delete(cls.frame, conn) + + cls.Unlock() + return +} + +// // Conns return list of connections by user ID. // // Each user may have more than one connection (e.g. from Android, iOS, or @@ -89,7 +145,7 @@ func (cls *ClientManager) Context(conn int) (ctx context.Context) { } // -// Frame return continuous frame on a client connection. +// Frame return an active frame on a client connection. // func (cls *ClientManager) Frame(conn int) (frame *Frame, ok bool) { cls.Lock() @@ -99,8 +155,18 @@ func (cls *ClientManager) Frame(conn int) (frame *Frame, ok bool) { } // -// SetFrame set the continuous frame on client connection. If frame is nil, -// it will delete the stored frame in connection. +// Frames return continuous frames on behalf of connection. +// +func (cls *ClientManager) Frames(conn int) (frames *Frames, ok bool) { + cls.Lock() + frames, ok = cls.frames[conn] + cls.Unlock() + return +} + +// +// SetFrame set the active, chopped frame on client connection. If frame is +// nil, it will delete the stored frame in connection. // func (cls *ClientManager) SetFrame(conn int, frame *Frame) { cls.Lock() @@ -113,6 +179,20 @@ func (cls *ClientManager) SetFrame(conn int, frame *Frame) { } // +// SetFrames set continuous frames on client connection. If frames is nil it +// will clear the stored frames. +// +func (cls *ClientManager) SetFrames(conn int, frames *Frames) { + cls.Lock() + if frames == nil { + delete(cls.frames, conn) + } else { + cls.frames[conn] = frames + } + cls.Unlock() +} + +// // add new socket connection to user ID with its context. // func (cls *ClientManager) add(ctx context.Context, conn int) { @@ -153,6 +233,7 @@ func (cls *ClientManager) remove(conn int) { cls.Lock() delete(cls.frame, conn) + delete(cls.frames, conn) cls.all, _ = ints.Remove(cls.all, conn) ctx, ok := cls.ctx[conn] diff --git a/lib/websocket/frame.go b/lib/websocket/frame.go index f044681d..f3c06621 100644 --- a/lib/websocket/frame.go +++ b/lib/websocket/frame.go @@ -39,6 +39,7 @@ type Frame struct { // closeCode represent the status of control frame close request. closeCode CloseCode + codes []byte // // len represent Payload length: 7 bits, 7+16 bits, or 7+64 bits @@ -68,7 +69,7 @@ type Frame struct { // is set to 0. See Section 5.3 for further information on client- // to-server masking. // - maskKey [4]byte + maskKey []byte // // Payload data: (x+y) bytes @@ -93,6 +94,12 @@ type Frame struct { // data". // payload []byte + + // + // chopped contains the unfinished frame, excluding mask keys and + // payload. + // + chopped []byte } // @@ -220,32 +227,33 @@ func frameUnpack(in []byte) (f *Frame, rest []byte) { f.opcode = opcode(in[x] & 0x0F) x++ if x >= len(in) { - return f, in[x:] + f.chopped = append(f.chopped, in...) + return f, nil } f.masked = in[x] & frameIsMasked f.len = uint64(in[x] & 0x7F) x++ if x >= len(in) { - // Unmasked frame, with 0 payload. - if f.masked == 0 { - return f, in[x:] + if f.len > 0 { + f.chopped = append(f.chopped, in...) } - // Missing mask. - return nil, nil + return f, nil } switch f.len { case frameLargePayload: if x+8 >= len(in) { - return nil, nil + f.chopped = append(f.chopped, in...) + return f, nil } f.len = binary.BigEndian.Uint64(in[x : x+8]) x += 8 case frameMediumPayload: if x+2 >= len(in) { - return nil, nil + f.chopped = append(f.chopped, in...) + return f, nil } f.len = uint64(binary.BigEndian.Uint16(in[x : x+2])) @@ -254,28 +262,32 @@ func frameUnpack(in []byte) (f *Frame, rest []byte) { if f.masked == frameIsMasked { if x >= len(in) { - return nil, nil + f.chopped = append(f.chopped, in...) + return f, nil } - f.maskKey[0] = in[x] + f.maskKey = append(f.maskKey, in[x]) x++ if x >= len(in) { - return nil, nil + f.chopped = append(f.chopped, in...) + return f, nil } - f.maskKey[1] = in[x] + f.maskKey = append(f.maskKey, in[x]) x++ if x >= len(in) { - return nil, nil + f.chopped = append(f.chopped, in...) + return f, nil } - f.maskKey[2] = in[x] + f.maskKey = append(f.maskKey, in[x]) x++ if x >= len(in) { - return nil, nil + f.chopped = append(f.chopped, in...) + return f, nil } - f.maskKey[3] = in[x] + f.maskKey = append(f.maskKey, in[x]) x++ } @@ -298,10 +310,13 @@ func frameUnpack(in []byte) (f *Frame, rest []byte) { if f.opcode == opcodeClose { switch len(f.payload) { case 0: + f.codes = []byte{0, 0} f.closeCode = StatusNormal case 1: + f.codes = []byte{f.payload[0], 0} f.closeCode = StatusBadRequest default: + f.codes = []byte{f.payload[0], f.payload[1]} f.closeCode = CloseCode(binary.BigEndian.Uint16(f.payload[:2])) f.payload = f.payload[2:] } @@ -379,7 +394,8 @@ func (f *Frame) Pack(randomMask bool) (out []byte) { if _rng == nil { _rng = rand.New(rand.NewSource(time.Now().UnixNano())) } - binary.LittleEndian.PutUint32(f.maskKey[0:], _rng.Uint32()) + f.maskKey = make([]byte, 4) + binary.LittleEndian.PutUint32(f.maskKey, _rng.Uint32()) } out[x] = f.maskKey[0] @@ -408,3 +424,106 @@ func (f *Frame) Pack(randomMask bool) (out []byte) { func (f *Frame) Payload() []byte { return f.payload } + +// +// continueUnpack unpack frame header (fin, opcode, masked, length, and mask +// keys) based on chopped length. +// +func (f *Frame) continueUnpack(packet []byte) []byte { + var isHaveLen bool + + for len(packet) > 0 && !isHaveLen { + switch len(f.chopped) { + case 0: + f.fin = packet[0] & frameIsFinished + f.rsv1 = packet[0] & 0x40 + f.rsv2 = packet[0] & 0x20 + f.rsv3 = packet[0] & 0x10 + f.opcode = opcode(packet[0] & 0x0F) + f.chopped = append(f.chopped, packet[0]) + packet = packet[1:] + case 1: + f.masked = packet[0] & frameIsMasked + f.len = uint64(packet[0] & 0x7F) + f.chopped = append(f.chopped, packet[0]) + packet = packet[1:] + default: + // We got the masked and len, lets check and get the + // extended length. + switch f.len { + case frameLargePayload: + if len(f.chopped) < 10 { + exp := 10 - len(f.chopped) + if len(packet) < exp { + f.chopped = append(f.chopped, packet...) + return nil + } + // chopped: 81 FF 0 0 0 1 0 0 = 10 - 8) = 2 + // exp: 0 0 + f.chopped = append(f.chopped, packet[:exp]...) + f.len = binary.BigEndian.Uint64(f.chopped[2:10]) + packet = packet[exp:] + } + case frameMediumPayload: + if len(f.chopped) < 4 { + exp := 4 - len(f.chopped) + if len(packet) < exp { + f.chopped = append(f.chopped, packet...) + return nil + } + f.chopped = append(f.chopped, packet[:exp]...) + f.len = uint64(binary.BigEndian.Uint16(f.chopped[2:4])) + packet = packet[exp:] + } + } + isHaveLen = true + } + } + if len(packet) == 0 { + return nil + } + if f.masked == frameIsMasked && len(f.maskKey) != 4 { + exp := 4 - len(f.maskKey) + if len(packet) < exp { + f.maskKey = append(f.maskKey, packet...) + return nil + } + + f.maskKey = append(f.maskKey, packet[:exp]...) + + packet = packet[exp:] + } + if f.opcode == opcodeClose && len(f.codes) != 2 { + exp := 2 - len(f.codes) + if len(packet) < exp { + f.codes = append(f.codes, packet...) + return nil + } + f.codes = append(f.codes, packet[:exp]...) + f.closeCode = CloseCode(binary.BigEndian.Uint16(f.codes)) + packet = packet[exp:] + } + if f.len > 0 && cap(f.payload) == 0 { + f.payload = make([]byte, 0, f.len) + + if len(packet) > 0 { + paclen := len(packet) + if uint64(paclen) > f.len { + paclen = int(f.len) + } + + f.payload = append(f.payload, packet[:paclen]...) + + if f.masked == frameIsMasked { + for x := 0; x < paclen; x++ { + f.payload[x] ^= f.maskKey[x%4] + } + } + + packet = packet[paclen:] + } + } + f.chopped = nil + + return packet +} diff --git a/lib/websocket/frames.go b/lib/websocket/frames.go index 66274ca3..1fa40df2 100644 --- a/lib/websocket/frames.go +++ b/lib/websocket/frames.go @@ -95,6 +95,16 @@ func (frames *Frames) Len() int { } // +// Opcode return the operation code of the first frame. +// +func (frames *Frames) Opcode() opcode { + if len(frames.v) == 0 { + return opcodeCont + } + return frames.v[0].opcode +} + +// // Payload return the concatenation of continuous data frame's payload. // // The first frame must be a data frame, either text or binary, otherwise it diff --git a/lib/websocket/server.go b/lib/websocket/server.go index cadbad05..ea8108da 100644 --- a/lib/websocket/server.go +++ b/lib/websocket/server.go @@ -311,7 +311,7 @@ func (serv *Server) upgrader() { } // -// handleChopped handle possible chopped payload. +// handleChopped handle possible chopped packet or payload. // // There are three possible cases that will returned from this function. // First, there is no continuous frame. Packet is new frame, return it as @@ -332,10 +332,22 @@ func (serv *Server) upgrader() { // func (serv *Server) handleChopped(x, conn int, packet []byte) (rest []byte, isClosing bool) { frame, _ := serv.Clients.Frame(conn) + frames, _ := serv.Clients.Frames(conn) if frame == nil { return packet, false } + + // Check if frame contains chopped packet. + if len(frame.chopped) > 0 { + packet = frame.continueUnpack(packet) + if len(packet) == 0 { + serv.Clients.SetFrame(conn, frame) + serv.epollRegisterRead(x, conn) + return nil, false + } + } + if frame.len == uint64(len(frame.payload)) { // Connection contains continuous frame, but its already // filled. @@ -356,18 +368,62 @@ func (serv *Server) handleChopped(x, conn int, packet []byte) (rest []byte, isCl frame.payload = append(frame.payload, packet...) if uint64(len(frame.payload)) < frame.len { - // We still got unfinished payload. + // We got frame with unfinished payload. serv.Clients.SetFrame(conn, frame) serv.epollRegisterRead(x, conn) return rest, false } if frame.fin == 0 { - serv.Clients.SetFrame(conn, frame) + if frames == nil { + if frame.opcode == 0 { + serv.handleBadRequest(conn) + isClosing = true + return + } + } + // We got frame with finished payload but with unfinished + // frames. + serv.Clients.AddFrame(conn, frame) serv.epollRegisterRead(x, conn) return rest, false } - serv.Clients.SetFrame(conn, nil) + // FIN:0x80 + + if frame.opcode == 0 { + if frames == nil { + // Got frame with opcode CONT in non fragmented + // frames. + serv.handleBadRequest(conn) + isClosing = true + return + } + + frame = serv.Clients.finFrames(conn, frame) + } else { + serv.Clients.SetFrame(conn, nil) + } + + if frame.masked != frameIsMasked { + serv.handleBadRequest(conn) + isClosing = true + return + } + if frame.rsv1 > 0 && !serv.allowRsv1 { + serv.handleBadRequest(conn) + isClosing = true + return + } + if frame.rsv2 > 0 && !serv.allowRsv2 { + serv.handleBadRequest(conn) + isClosing = true + return + } + if frame.rsv3 > 0 && !serv.allowRsv3 { + serv.handleBadRequest(conn) + isClosing = true + return + } switch frame.opcode { case opcodeText: @@ -432,53 +488,41 @@ func (serv *Server) handleChopped(x, conn int, packet []byte) (rest []byte, isCl // (RFC 6455 Section 5.4 Page 34) // func (serv *Server) handleFragment(conn int, req *Frame) (isInvalid bool) { - frame, ok := serv.Clients.Frame(conn) + frames, ok := serv.Clients.Frames(conn) if debug.Value >= 3 { - log.Printf("websocket: Server.handleFragment: frame: {fin:%d opcode:%d len:%d, payload.len:%d}\n", - req.fin, req.opcode, req.len, len(req.payload)) + log.Printf("websocket: Server.handleFragment: frame: %+v\n", req) } - if frame == nil { + if frames == nil { // If a connection does not have continuous frame, then // current frame opcode must not be 0. if req.opcode == opcodeCont { return true } - - frame = req + frames = new(Frames) } else { // If a connection have continuous frame, the next frame // opcode must be 0. if req.opcode != opcodeCont { return true } - - frame.payload = append(frame.payload, req.payload...) - if req.len > 0 { - frame.len += req.len - } } if req.fin == 0 { - serv.Clients.SetFrame(conn, frame) + frames.Append(req) + serv.Clients.SetFrame(conn, req) + serv.Clients.SetFrames(conn, frames) return false } - // Frame with fin set with chopped payload. - if debug.Value >= 3 { - log.Printf("handleFragment: payload.len:%d frame.len:%d\n", - len(frame.payload), frame.len) - } - - if uint64(len(frame.payload)) < frame.len { - serv.Clients.SetFrame(conn, frame) + if !ok && uint64(len(req.payload)) < req.len { + // Finished frame with unfinished payload. + serv.Clients.SetFrame(conn, req) return false } - if ok { - serv.Clients.SetFrame(conn, nil) - } + frame := serv.Clients.finFrames(conn, req) if frame.opcode == opcodeText { if !utf8.Valid(frame.payload) { @@ -724,7 +768,7 @@ func (serv *Server) reader() { len(packet), packet) } - // Handle chopped, unfinished payload. + // Handle chopped, unfinished packet or payload. packet, isClosing = serv.handleChopped(x, conn, packet) if isClosing || len(packet) == 0 { continue @@ -744,6 +788,11 @@ func (serv *Server) reader() { isClosing = false for _, frame := range frames.v { + if len(frame.chopped) > 0 { + serv.Clients.SetFrame(conn, frame) + continue + } + if frame.masked != frameIsMasked { serv.handleBadRequest(conn) isClosing = true diff --git a/lib/websocket/websocket_test.go b/lib/websocket/websocket_test.go index f6a66c70..86815518 100644 --- a/lib/websocket/websocket_test.go +++ b/lib/websocket/websocket_test.go @@ -25,7 +25,7 @@ var ( _testWSAddr string //nolint: gochecknoglobals _testHdrValWSAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" //nolint: gochecknoglobals _testHdrValWSKey = "dGhlIHNhbXBsZSBub25jZQ==" //nolint: gochecknoglobals - _testMaskKey = [4]byte{'7', 'ú', '!', '='} //nolint: gochecknoglobals + _testMaskKey = []byte{'7', 'ú', '!', '='} //nolint: gochecknoglobals ) var ( |
