aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2019-03-11 23:30:22 +0700
committerShulhan <ms@kilabit.info>2019-03-11 23:30:22 +0700
commit929eef95cd27c798d8c128ae0eea83ec34220a3f (patch)
treeb9db116953d711b5ef24af2d8a7065c7d35dc0e4
parent1bedaa4d003eb7595bc4bd9edd74423a1f004e25 (diff)
downloadpakakeh.go-929eef95cd27c798d8c128ae0eea83ec34220a3f.tar.xz
websocket: handle chopped frame with interjected control frame
Another possibility is fragmented frames with one of the frame is chopped and in the middle of it is a control FRAME. For example, C> Frame{fin:false opcode:TEXT len:1024 payload:512} C> payload:512 FRAME{fin:TRUE opcode:PING len:0} C> Frame{fin:true opcode:0 len:1024 payload:1024}
-rw-r--r--lib/websocket/clientmanager.go97
-rw-r--r--lib/websocket/frame.go155
-rw-r--r--lib/websocket/frames.go10
-rw-r--r--lib/websocket/server.go105
-rw-r--r--lib/websocket/websocket_test.go2
5 files changed, 314 insertions, 55 deletions
diff --git a/lib/websocket/clientmanager.go b/lib/websocket/clientmanager.go
index bfd6a5f7..7e3f9961 100644
--- a/lib/websocket/clientmanager.go
+++ b/lib/websocket/clientmanager.go
@@ -35,9 +35,13 @@ type ClientManager struct {
// HandleAuth on Server.
ctx map[int]context.Context
- // frame contains a one-to-one mapping between a socket connection
- // and a continuous frame.
+ // frame contains a one-to-one mapping between a socket and a frame.
+ // Its usually used to handle chopped frame.
frame map[int]*Frame
+
+ // frames contains a one-to-one mapping between a socket and
+ // continuous frame.
+ frames map[int]*Frames
}
//
@@ -45,13 +49,32 @@ type ClientManager struct {
//
func newClientManager() *ClientManager {
return &ClientManager{
- conns: make(map[uint64][]int),
- ctx: make(map[int]context.Context),
- frame: make(map[int]*Frame),
+ conns: make(map[uint64][]int),
+ ctx: make(map[int]context.Context),
+ frame: make(map[int]*Frame),
+ frames: make(map[int]*Frames),
}
}
//
+// AddFrame add a frame as part of continuous frame on a client connection.
+//
+func (cls *ClientManager) AddFrame(conn int, frame *Frame) {
+ cls.Lock()
+ frames, ok := cls.frames[conn]
+ if !ok {
+ frames = new(Frames)
+ }
+ frames.Append(frame)
+ if !ok {
+ cls.frames[conn] = frames
+ }
+ delete(cls.frame, conn)
+ cls.Unlock()
+ return
+}
+
+//
// All return a copy of all client connections.
//
func (cls *ClientManager) All() (conns []int) {
@@ -65,6 +88,39 @@ func (cls *ClientManager) All() (conns []int) {
}
//
+// finFrames merge all continuous frames into single frame and clear the
+// stored frame and frames on behalf of connection.
+//
+func (cls *ClientManager) finFrames(conn int, fin *Frame) (f *Frame) {
+ cls.Lock()
+ frames, ok := cls.frames[conn]
+ if !ok {
+ cls.Unlock()
+ return fin
+ }
+
+ f = frames.v[0]
+ for x := 1; x < len(frames.v); x++ {
+ if frames.v[x].opcode == opcodeClose {
+ break
+ }
+
+ // Ignore control PING or PONG frame.
+ if frames.v[x].opcode == opcodePing || frames.v[x].opcode == opcodePong {
+ continue
+ }
+
+ f.payload = append(f.payload, frames.v[x].payload...)
+ }
+ f.payload = append(f.payload, fin.payload...)
+ delete(cls.frames, conn)
+ delete(cls.frame, conn)
+
+ cls.Unlock()
+ return
+}
+
+//
// Conns return list of connections by user ID.
//
// Each user may have more than one connection (e.g. from Android, iOS, or
@@ -89,7 +145,7 @@ func (cls *ClientManager) Context(conn int) (ctx context.Context) {
}
//
-// Frame return continuous frame on a client connection.
+// Frame return an active frame on a client connection.
//
func (cls *ClientManager) Frame(conn int) (frame *Frame, ok bool) {
cls.Lock()
@@ -99,8 +155,18 @@ func (cls *ClientManager) Frame(conn int) (frame *Frame, ok bool) {
}
//
-// SetFrame set the continuous frame on client connection. If frame is nil,
-// it will delete the stored frame in connection.
+// Frames return continuous frames on behalf of connection.
+//
+func (cls *ClientManager) Frames(conn int) (frames *Frames, ok bool) {
+ cls.Lock()
+ frames, ok = cls.frames[conn]
+ cls.Unlock()
+ return
+}
+
+//
+// SetFrame set the active, chopped frame on client connection. If frame is
+// nil, it will delete the stored frame in connection.
//
func (cls *ClientManager) SetFrame(conn int, frame *Frame) {
cls.Lock()
@@ -113,6 +179,20 @@ func (cls *ClientManager) SetFrame(conn int, frame *Frame) {
}
//
+// SetFrames set continuous frames on client connection. If frames is nil it
+// will clear the stored frames.
+//
+func (cls *ClientManager) SetFrames(conn int, frames *Frames) {
+ cls.Lock()
+ if frames == nil {
+ delete(cls.frames, conn)
+ } else {
+ cls.frames[conn] = frames
+ }
+ cls.Unlock()
+}
+
+//
// add new socket connection to user ID with its context.
//
func (cls *ClientManager) add(ctx context.Context, conn int) {
@@ -153,6 +233,7 @@ func (cls *ClientManager) remove(conn int) {
cls.Lock()
delete(cls.frame, conn)
+ delete(cls.frames, conn)
cls.all, _ = ints.Remove(cls.all, conn)
ctx, ok := cls.ctx[conn]
diff --git a/lib/websocket/frame.go b/lib/websocket/frame.go
index f044681d..f3c06621 100644
--- a/lib/websocket/frame.go
+++ b/lib/websocket/frame.go
@@ -39,6 +39,7 @@ type Frame struct {
// closeCode represent the status of control frame close request.
closeCode CloseCode
+ codes []byte
//
// len represent Payload length: 7 bits, 7+16 bits, or 7+64 bits
@@ -68,7 +69,7 @@ type Frame struct {
// is set to 0. See Section 5.3 for further information on client-
// to-server masking.
//
- maskKey [4]byte
+ maskKey []byte
//
// Payload data: (x+y) bytes
@@ -93,6 +94,12 @@ type Frame struct {
// data".
//
payload []byte
+
+ //
+ // chopped contains the unfinished frame, excluding mask keys and
+ // payload.
+ //
+ chopped []byte
}
//
@@ -220,32 +227,33 @@ func frameUnpack(in []byte) (f *Frame, rest []byte) {
f.opcode = opcode(in[x] & 0x0F)
x++
if x >= len(in) {
- return f, in[x:]
+ f.chopped = append(f.chopped, in...)
+ return f, nil
}
f.masked = in[x] & frameIsMasked
f.len = uint64(in[x] & 0x7F)
x++
if x >= len(in) {
- // Unmasked frame, with 0 payload.
- if f.masked == 0 {
- return f, in[x:]
+ if f.len > 0 {
+ f.chopped = append(f.chopped, in...)
}
- // Missing mask.
- return nil, nil
+ return f, nil
}
switch f.len {
case frameLargePayload:
if x+8 >= len(in) {
- return nil, nil
+ f.chopped = append(f.chopped, in...)
+ return f, nil
}
f.len = binary.BigEndian.Uint64(in[x : x+8])
x += 8
case frameMediumPayload:
if x+2 >= len(in) {
- return nil, nil
+ f.chopped = append(f.chopped, in...)
+ return f, nil
}
f.len = uint64(binary.BigEndian.Uint16(in[x : x+2]))
@@ -254,28 +262,32 @@ func frameUnpack(in []byte) (f *Frame, rest []byte) {
if f.masked == frameIsMasked {
if x >= len(in) {
- return nil, nil
+ f.chopped = append(f.chopped, in...)
+ return f, nil
}
- f.maskKey[0] = in[x]
+ f.maskKey = append(f.maskKey, in[x])
x++
if x >= len(in) {
- return nil, nil
+ f.chopped = append(f.chopped, in...)
+ return f, nil
}
- f.maskKey[1] = in[x]
+ f.maskKey = append(f.maskKey, in[x])
x++
if x >= len(in) {
- return nil, nil
+ f.chopped = append(f.chopped, in...)
+ return f, nil
}
- f.maskKey[2] = in[x]
+ f.maskKey = append(f.maskKey, in[x])
x++
if x >= len(in) {
- return nil, nil
+ f.chopped = append(f.chopped, in...)
+ return f, nil
}
- f.maskKey[3] = in[x]
+ f.maskKey = append(f.maskKey, in[x])
x++
}
@@ -298,10 +310,13 @@ func frameUnpack(in []byte) (f *Frame, rest []byte) {
if f.opcode == opcodeClose {
switch len(f.payload) {
case 0:
+ f.codes = []byte{0, 0}
f.closeCode = StatusNormal
case 1:
+ f.codes = []byte{f.payload[0], 0}
f.closeCode = StatusBadRequest
default:
+ f.codes = []byte{f.payload[0], f.payload[1]}
f.closeCode = CloseCode(binary.BigEndian.Uint16(f.payload[:2]))
f.payload = f.payload[2:]
}
@@ -379,7 +394,8 @@ func (f *Frame) Pack(randomMask bool) (out []byte) {
if _rng == nil {
_rng = rand.New(rand.NewSource(time.Now().UnixNano()))
}
- binary.LittleEndian.PutUint32(f.maskKey[0:], _rng.Uint32())
+ f.maskKey = make([]byte, 4)
+ binary.LittleEndian.PutUint32(f.maskKey, _rng.Uint32())
}
out[x] = f.maskKey[0]
@@ -408,3 +424,106 @@ func (f *Frame) Pack(randomMask bool) (out []byte) {
func (f *Frame) Payload() []byte {
return f.payload
}
+
+//
+// continueUnpack unpack frame header (fin, opcode, masked, length, and mask
+// keys) based on chopped length.
+//
+func (f *Frame) continueUnpack(packet []byte) []byte {
+ var isHaveLen bool
+
+ for len(packet) > 0 && !isHaveLen {
+ switch len(f.chopped) {
+ case 0:
+ f.fin = packet[0] & frameIsFinished
+ f.rsv1 = packet[0] & 0x40
+ f.rsv2 = packet[0] & 0x20
+ f.rsv3 = packet[0] & 0x10
+ f.opcode = opcode(packet[0] & 0x0F)
+ f.chopped = append(f.chopped, packet[0])
+ packet = packet[1:]
+ case 1:
+ f.masked = packet[0] & frameIsMasked
+ f.len = uint64(packet[0] & 0x7F)
+ f.chopped = append(f.chopped, packet[0])
+ packet = packet[1:]
+ default:
+ // We got the masked and len, lets check and get the
+ // extended length.
+ switch f.len {
+ case frameLargePayload:
+ if len(f.chopped) < 10 {
+ exp := 10 - len(f.chopped)
+ if len(packet) < exp {
+ f.chopped = append(f.chopped, packet...)
+ return nil
+ }
+ // chopped: 81 FF 0 0 0 1 0 0 = 10 - 8) = 2
+ // exp: 0 0
+ f.chopped = append(f.chopped, packet[:exp]...)
+ f.len = binary.BigEndian.Uint64(f.chopped[2:10])
+ packet = packet[exp:]
+ }
+ case frameMediumPayload:
+ if len(f.chopped) < 4 {
+ exp := 4 - len(f.chopped)
+ if len(packet) < exp {
+ f.chopped = append(f.chopped, packet...)
+ return nil
+ }
+ f.chopped = append(f.chopped, packet[:exp]...)
+ f.len = uint64(binary.BigEndian.Uint16(f.chopped[2:4]))
+ packet = packet[exp:]
+ }
+ }
+ isHaveLen = true
+ }
+ }
+ if len(packet) == 0 {
+ return nil
+ }
+ if f.masked == frameIsMasked && len(f.maskKey) != 4 {
+ exp := 4 - len(f.maskKey)
+ if len(packet) < exp {
+ f.maskKey = append(f.maskKey, packet...)
+ return nil
+ }
+
+ f.maskKey = append(f.maskKey, packet[:exp]...)
+
+ packet = packet[exp:]
+ }
+ if f.opcode == opcodeClose && len(f.codes) != 2 {
+ exp := 2 - len(f.codes)
+ if len(packet) < exp {
+ f.codes = append(f.codes, packet...)
+ return nil
+ }
+ f.codes = append(f.codes, packet[:exp]...)
+ f.closeCode = CloseCode(binary.BigEndian.Uint16(f.codes))
+ packet = packet[exp:]
+ }
+ if f.len > 0 && cap(f.payload) == 0 {
+ f.payload = make([]byte, 0, f.len)
+
+ if len(packet) > 0 {
+ paclen := len(packet)
+ if uint64(paclen) > f.len {
+ paclen = int(f.len)
+ }
+
+ f.payload = append(f.payload, packet[:paclen]...)
+
+ if f.masked == frameIsMasked {
+ for x := 0; x < paclen; x++ {
+ f.payload[x] ^= f.maskKey[x%4]
+ }
+ }
+
+ packet = packet[paclen:]
+ }
+ }
+ f.chopped = nil
+
+ return packet
+}
diff --git a/lib/websocket/frames.go b/lib/websocket/frames.go
index 66274ca3..1fa40df2 100644
--- a/lib/websocket/frames.go
+++ b/lib/websocket/frames.go
@@ -95,6 +95,16 @@ func (frames *Frames) Len() int {
}
//
+// Opcode return the operation code of the first frame.
+//
+func (frames *Frames) Opcode() opcode {
+ if len(frames.v) == 0 {
+ return opcodeCont
+ }
+ return frames.v[0].opcode
+}
+
+//
// Payload return the concatenation of continuous data frame's payload.
//
// The first frame must be a data frame, either text or binary, otherwise it
diff --git a/lib/websocket/server.go b/lib/websocket/server.go
index cadbad05..ea8108da 100644
--- a/lib/websocket/server.go
+++ b/lib/websocket/server.go
@@ -311,7 +311,7 @@ func (serv *Server) upgrader() {
}
//
-// handleChopped handle possible chopped payload.
+// handleChopped handle possible chopped packet or payload.
//
// There are three possible cases that will returned from this function.
// First, there is no continuous frame. Packet is new frame, return it as
@@ -332,10 +332,22 @@ func (serv *Server) upgrader() {
//
func (serv *Server) handleChopped(x, conn int, packet []byte) (rest []byte, isClosing bool) {
frame, _ := serv.Clients.Frame(conn)
+ frames, _ := serv.Clients.Frames(conn)
if frame == nil {
return packet, false
}
+
+ // Check if frame contains chopped packet.
+ if len(frame.chopped) > 0 {
+ packet = frame.continueUnpack(packet)
+ if len(packet) == 0 {
+ serv.Clients.SetFrame(conn, frame)
+ serv.epollRegisterRead(x, conn)
+ return nil, false
+ }
+ }
+
if frame.len == uint64(len(frame.payload)) {
// Connection contains continuous frame, but its already
// filled.
@@ -356,18 +368,62 @@ func (serv *Server) handleChopped(x, conn int, packet []byte) (rest []byte, isCl
frame.payload = append(frame.payload, packet...)
if uint64(len(frame.payload)) < frame.len {
- // We still got unfinished payload.
+ // We got frame with unfinished payload.
serv.Clients.SetFrame(conn, frame)
serv.epollRegisterRead(x, conn)
return rest, false
}
if frame.fin == 0 {
- serv.Clients.SetFrame(conn, frame)
+ if frames == nil {
+ if frame.opcode == 0 {
+ serv.handleBadRequest(conn)
+ isClosing = true
+ return
+ }
+ }
+ // We got frame with finished payload but with unfinished
+ // frames.
+ serv.Clients.AddFrame(conn, frame)
serv.epollRegisterRead(x, conn)
return rest, false
}
- serv.Clients.SetFrame(conn, nil)
+ // FIN:0x80
+
+ if frame.opcode == 0 {
+ if frames == nil {
+ // Got frame with opcode CONT in non fragmented
+ // frames.
+ serv.handleBadRequest(conn)
+ isClosing = true
+ return
+ }
+
+ frame = serv.Clients.finFrames(conn, frame)
+ } else {
+ serv.Clients.SetFrame(conn, nil)
+ }
+
+ if frame.masked != frameIsMasked {
+ serv.handleBadRequest(conn)
+ isClosing = true
+ return
+ }
+ if frame.rsv1 > 0 && !serv.allowRsv1 {
+ serv.handleBadRequest(conn)
+ isClosing = true
+ return
+ }
+ if frame.rsv2 > 0 && !serv.allowRsv2 {
+ serv.handleBadRequest(conn)
+ isClosing = true
+ return
+ }
+ if frame.rsv3 > 0 && !serv.allowRsv3 {
+ serv.handleBadRequest(conn)
+ isClosing = true
+ return
+ }
switch frame.opcode {
case opcodeText:
@@ -432,53 +488,41 @@ func (serv *Server) handleChopped(x, conn int, packet []byte) (rest []byte, isCl
// (RFC 6455 Section 5.4 Page 34)
//
func (serv *Server) handleFragment(conn int, req *Frame) (isInvalid bool) {
- frame, ok := serv.Clients.Frame(conn)
+ frames, ok := serv.Clients.Frames(conn)
if debug.Value >= 3 {
- log.Printf("websocket: Server.handleFragment: frame: {fin:%d opcode:%d len:%d, payload.len:%d}\n",
- req.fin, req.opcode, req.len, len(req.payload))
+ log.Printf("websocket: Server.handleFragment: frame: %+v\n", req)
}
- if frame == nil {
+ if frames == nil {
// If a connection does not have continuous frame, then
// current frame opcode must not be 0.
if req.opcode == opcodeCont {
return true
}
-
- frame = req
+ frames = new(Frames)
} else {
// If a connection have continuous frame, the next frame
// opcode must be 0.
if req.opcode != opcodeCont {
return true
}
-
- frame.payload = append(frame.payload, req.payload...)
- if req.len > 0 {
- frame.len += req.len
- }
}
if req.fin == 0 {
- serv.Clients.SetFrame(conn, frame)
+ frames.Append(req)
+ serv.Clients.SetFrame(conn, req)
+ serv.Clients.SetFrames(conn, frames)
return false
}
- // Frame with fin set with chopped payload.
- if debug.Value >= 3 {
- log.Printf("handleFragment: payload.len:%d frame.len:%d\n",
- len(frame.payload), frame.len)
- }
-
- if uint64(len(frame.payload)) < frame.len {
- serv.Clients.SetFrame(conn, frame)
+ if !ok && uint64(len(req.payload)) < req.len {
+ // Finished frame with unfinished payload.
+ serv.Clients.SetFrame(conn, req)
return false
}
- if ok {
- serv.Clients.SetFrame(conn, nil)
- }
+ frame := serv.Clients.finFrames(conn, req)
if frame.opcode == opcodeText {
if !utf8.Valid(frame.payload) {
@@ -724,7 +768,7 @@ func (serv *Server) reader() {
len(packet), packet)
}
- // Handle chopped, unfinished payload.
+ // Handle chopped, unfinished packet or payload.
packet, isClosing = serv.handleChopped(x, conn, packet)
if isClosing || len(packet) == 0 {
continue
@@ -744,6 +788,11 @@ func (serv *Server) reader() {
isClosing = false
for _, frame := range frames.v {
+ if len(frame.chopped) > 0 {
+ serv.Clients.SetFrame(conn, frame)
+ continue
+ }
+
if frame.masked != frameIsMasked {
serv.handleBadRequest(conn)
isClosing = true
diff --git a/lib/websocket/websocket_test.go b/lib/websocket/websocket_test.go
index f6a66c70..86815518 100644
--- a/lib/websocket/websocket_test.go
+++ b/lib/websocket/websocket_test.go
@@ -25,7 +25,7 @@ var (
_testWSAddr string //nolint: gochecknoglobals
_testHdrValWSAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" //nolint: gochecknoglobals
_testHdrValWSKey = "dGhlIHNhbXBsZSBub25jZQ==" //nolint: gochecknoglobals
- _testMaskKey = [4]byte{'7', 'ú', '!', '='} //nolint: gochecknoglobals
+ _testMaskKey = []byte{'7', 'ú', '!', '='} //nolint: gochecknoglobals
)
var (