From 1102dae593a3d9673d23b612da6aed52e97841af Mon Sep 17 00:00:00 2001 From: Shulhan Date: Wed, 20 Mar 2019 23:44:27 +0700 Subject: websocket: use custom HTTP parser for client The problem with using the standard http.ReadResponse is that the received packet may contains WebSocket frame, not just HTTP response. This cause the initial frame lost, which may required by client (for example, as a response for authentication). --- lib/websocket/client.go | 122 +++++++++++++++++++++++++++---------------- lib/websocket/server_test.go | 2 +- 2 files changed, 78 insertions(+), 46 deletions(-) diff --git a/lib/websocket/client.go b/lib/websocket/client.go index 6a354ec1..7cdbaac3 100644 --- a/lib/websocket/client.go +++ b/lib/websocket/client.go @@ -5,7 +5,6 @@ package websocket import ( - "bufio" "bytes" "crypto/tls" "fmt" @@ -18,6 +17,7 @@ import ( "unicode/utf8" "github.com/shuLhan/share/lib/debug" + libhttp "github.com/shuLhan/share/lib/http" ) const ( @@ -164,13 +164,25 @@ func (cl *Client) Connect() (err error) { return fmt.Errorf("websocket: Connect: " + err.Error()) } - err = cl.handshake() + var rest []byte + + rest, err = cl.handshake() if err != nil { _ = cl.conn.Close() cl.conn = nil return fmt.Errorf("websocket: Connect: " + err.Error()) } + // At this point client successfully connected to server, but the + // response from server may include WebSocket frame, not just HTTP + // response. + if len(rest) > 0 { + isClosing := cl.handleRaw(rest) + if isClosing { + return nil + } + } + go cl.serve() return nil @@ -218,7 +230,7 @@ func (cl *Client) init() (err error) { } // -// parseURI parse websocket connection URI from "endpoint" and get the remote +// parseURI parse WebSocket connection URI from "endpoint" and get the remote // URL (for checking up scheme) and remote address. // By default, if no port is given, it will set to 80 for URL with any scheme // or 443 for "wss" scheme. @@ -254,7 +266,7 @@ func (cl *Client) parseURI() (err error) { } // -// open TCP connection to websocket remote address. +// open TCP connection to WebSocket remote address. // If client "isTLS" field is true, the connection is opened with TLS protocol // and the remote name MUST have a valid certificate. // @@ -284,9 +296,9 @@ func (cl *Client) open() (err error) { } // -// handshake send the websocket opening handshake. +// handshake send the WebSocket opening handshake. // -func (cl *Client) handshake() (err error) { +func (cl *Client) handshake() (rest []byte, err error) { var bb bytes.Buffer path := cl.remoteURL.EscapedPath() @@ -303,13 +315,13 @@ func (cl *Client) handshake() (err error) { _, err = fmt.Fprintf(&bb, _handshakeReqFormat, path, cl.remoteURL.Host, key) if err != nil { - return err + return nil, err } if len(cl.Headers) > 0 { err = cl.Headers.Write(&bb) if err != nil { - return err + return nil, err } } @@ -320,21 +332,31 @@ func (cl *Client) handshake() (err error) { fmt.Printf("websocket: Client.handshake:\n%s\n--\n", req) } - return cl.doHandshake(keyAccept, req) + rest, err = cl.doHandshake(keyAccept, req) + if err != nil { + return nil, err + } + + return rest, nil } -func (cl *Client) doHandshake(keyAccept string, req []byte) (err error) { +func (cl *Client) doHandshake(keyAccept string, req []byte) (rest []byte, err error) { err = cl.send(req) if err != nil { - return err + return nil, err } resp, err := cl.recv() if err != nil { - return err + return nil, err } - return cl.handleHandshake(keyAccept, resp) + rest, err = cl.handleHandshake(keyAccept, resp) + if err != nil { + return nil, err + } + + return rest, nil } // @@ -520,7 +542,7 @@ func (cl *Client) handleFrame(frame *Frame) (isClosing bool) { return isClosing } -func (cl *Client) handleHandshake(keyAccept string, resp []byte) (err error) { +func (cl *Client) handleHandshake(keyAccept string, resp []byte) (rest []byte, err error) { if debug.Value >= 3 { max := 512 if len(resp) < 512 { @@ -529,25 +551,23 @@ func (cl *Client) handleHandshake(keyAccept string, resp []byte) (err error) { fmt.Printf("websocket: Client.handleHandshake:\n%s\n--\n", resp[:max]) } - httpBuf := bufio.NewReader(bytes.NewBuffer(resp)) + var httpRes *http.Response - httpRes, err := http.ReadResponse(httpBuf, nil) + httpRes, rest, err = libhttp.ParseResponseHeader(resp) if err != nil { - return err + return nil, err } - httpRes.Body.Close() - if httpRes.StatusCode != http.StatusSwitchingProtocols { - return fmt.Errorf(httpRes.Status) + return nil, fmt.Errorf(httpRes.Status) } gotAccept := httpRes.Header.Get(_hdrKeyWSAccept) if keyAccept != gotAccept { - return fmt.Errorf("invalid server accept key") + return nil, fmt.Errorf("invalid server accept key") } - return nil + return rest, nil } // @@ -562,6 +582,30 @@ func (cl *Client) handleInvalidData() { } } +// +// handleRaw packet from server. +// +func (cl *Client) handleRaw(packet []byte) (isClosing bool) { + frames := Unpack(packet) + if frames == nil { + log.Println("websocket: Client.handleRaw: incomplete frames received") + return false + } + + for _, f := range frames.v { + if !f.isComplete { + cl.frame = f + continue + } + isClosing = cl.handleFrame(f) + if isClosing { + return true + } + } + + return false +} + // // SendBin send data frame as binary to server. // If handler is nil, no response will be read from server. @@ -629,14 +673,6 @@ func (cl *Client) serve() { break } - if debug.Value >= 3 { - max := len(packet) - if max > 16 { - max = 16 - } - fmt.Printf("websocket: Client.serve: packet: len:%d % x\n", len(packet), packet[:max]) - } - if cl.frame != nil { packet = cl.frame.unpack(packet) if cl.frame.isComplete { @@ -652,21 +688,9 @@ func (cl *Client) serve() { } } - frames := Unpack(packet) - if frames == nil { - log.Println("websocket: client.serve: uncomplete frames received") - continue - } - - for _, f := range frames.v { - if !f.isComplete { - cl.frame = f - continue - } - isClosing := cl.handleFrame(f) - if isClosing { - return - } + isClosing := cl.handleRaw(packet) + if isClosing { + return } } cl.Quit() @@ -732,6 +756,14 @@ func (cl *Client) recv() (packet []byte, err error) { } } + if debug.Value >= 3 { + max := len(packet) + if max > 16 { + max = 16 + } + fmt.Printf("websocket: Client.recv: packet: len:%d % x\n", len(packet), packet[:max]) + } + return packet, err } diff --git a/lib/websocket/server_test.go b/lib/websocket/server_test.go index 8df648f4..94d6e7f4 100644 --- a/lib/websocket/server_test.go +++ b/lib/websocket/server_test.go @@ -266,7 +266,7 @@ func TestServerHandshake(t *testing.T) { fmt.Fprintf(&bb, "\r\n") - err = cl.doHandshake("", bb.Bytes()) + _, err = cl.doHandshake("", bb.Bytes()) if err != nil { test.Assert(t, "error", c.expError, err.Error(), true) } -- cgit v1.3