diff options
| author | Shulhan <ms@kilabit.info> | 2019-03-20 23:44:27 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2019-03-20 23:44:27 +0700 |
| commit | 1102dae593a3d9673d23b612da6aed52e97841af (patch) | |
| tree | 2a33228cde523101174cd7167ea8c0b09adcf011 | |
| parent | 6740da2369bf6324ddc0f7303ea2dd4b008fbfe4 (diff) | |
| download | pakakeh.go-1102dae593a3d9673d23b612da6aed52e97841af.tar.xz | |
websocket: use custom HTTP parser for client
The problem with using the standard http.ReadResponse is that the
received packet may contains WebSocket frame, not just HTTP response.
This cause the initial frame lost, which may required by client (for
example, as a response for authentication).
| -rw-r--r-- | lib/websocket/client.go | 122 | ||||
| -rw-r--r-- | lib/websocket/server_test.go | 2 |
2 files changed, 78 insertions, 46 deletions
diff --git a/lib/websocket/client.go b/lib/websocket/client.go index 6a354ec1..7cdbaac3 100644 --- a/lib/websocket/client.go +++ b/lib/websocket/client.go @@ -5,7 +5,6 @@ package websocket import ( - "bufio" "bytes" "crypto/tls" "fmt" @@ -18,6 +17,7 @@ import ( "unicode/utf8" "github.com/shuLhan/share/lib/debug" + libhttp "github.com/shuLhan/share/lib/http" ) const ( @@ -164,13 +164,25 @@ func (cl *Client) Connect() (err error) { return fmt.Errorf("websocket: Connect: " + err.Error()) } - err = cl.handshake() + var rest []byte + + rest, err = cl.handshake() if err != nil { _ = cl.conn.Close() cl.conn = nil return fmt.Errorf("websocket: Connect: " + err.Error()) } + // At this point client successfully connected to server, but the + // response from server may include WebSocket frame, not just HTTP + // response. + if len(rest) > 0 { + isClosing := cl.handleRaw(rest) + if isClosing { + return nil + } + } + go cl.serve() return nil @@ -218,7 +230,7 @@ func (cl *Client) init() (err error) { } // -// parseURI parse websocket connection URI from "endpoint" and get the remote +// parseURI parse WebSocket connection URI from "endpoint" and get the remote // URL (for checking up scheme) and remote address. // By default, if no port is given, it will set to 80 for URL with any scheme // or 443 for "wss" scheme. @@ -254,7 +266,7 @@ func (cl *Client) parseURI() (err error) { } // -// open TCP connection to websocket remote address. +// open TCP connection to WebSocket remote address. // If client "isTLS" field is true, the connection is opened with TLS protocol // and the remote name MUST have a valid certificate. // @@ -284,9 +296,9 @@ func (cl *Client) open() (err error) { } // -// handshake send the websocket opening handshake. +// handshake send the WebSocket opening handshake. // -func (cl *Client) handshake() (err error) { +func (cl *Client) handshake() (rest []byte, err error) { var bb bytes.Buffer path := cl.remoteURL.EscapedPath() @@ -303,13 +315,13 @@ func (cl *Client) handshake() (err error) { _, err = fmt.Fprintf(&bb, _handshakeReqFormat, path, cl.remoteURL.Host, key) if err != nil { - return err + return nil, err } if len(cl.Headers) > 0 { err = cl.Headers.Write(&bb) if err != nil { - return err + return nil, err } } @@ -320,21 +332,31 @@ func (cl *Client) handshake() (err error) { fmt.Printf("websocket: Client.handshake:\n%s\n--\n", req) } - return cl.doHandshake(keyAccept, req) + rest, err = cl.doHandshake(keyAccept, req) + if err != nil { + return nil, err + } + + return rest, nil } -func (cl *Client) doHandshake(keyAccept string, req []byte) (err error) { +func (cl *Client) doHandshake(keyAccept string, req []byte) (rest []byte, err error) { err = cl.send(req) if err != nil { - return err + return nil, err } resp, err := cl.recv() if err != nil { - return err + return nil, err } - return cl.handleHandshake(keyAccept, resp) + rest, err = cl.handleHandshake(keyAccept, resp) + if err != nil { + return nil, err + } + + return rest, nil } // @@ -520,7 +542,7 @@ func (cl *Client) handleFrame(frame *Frame) (isClosing bool) { return isClosing } -func (cl *Client) handleHandshake(keyAccept string, resp []byte) (err error) { +func (cl *Client) handleHandshake(keyAccept string, resp []byte) (rest []byte, err error) { if debug.Value >= 3 { max := 512 if len(resp) < 512 { @@ -529,25 +551,23 @@ func (cl *Client) handleHandshake(keyAccept string, resp []byte) (err error) { fmt.Printf("websocket: Client.handleHandshake:\n%s\n--\n", resp[:max]) } - httpBuf := bufio.NewReader(bytes.NewBuffer(resp)) + var httpRes *http.Response - httpRes, err := http.ReadResponse(httpBuf, nil) + httpRes, rest, err = libhttp.ParseResponseHeader(resp) if err != nil { - return err + return nil, err } - httpRes.Body.Close() - if httpRes.StatusCode != http.StatusSwitchingProtocols { - return fmt.Errorf(httpRes.Status) + return nil, fmt.Errorf(httpRes.Status) } gotAccept := httpRes.Header.Get(_hdrKeyWSAccept) if keyAccept != gotAccept { - return fmt.Errorf("invalid server accept key") + return nil, fmt.Errorf("invalid server accept key") } - return nil + return rest, nil } // @@ -563,6 +583,30 @@ func (cl *Client) handleInvalidData() { } // +// handleRaw packet from server. +// +func (cl *Client) handleRaw(packet []byte) (isClosing bool) { + frames := Unpack(packet) + if frames == nil { + log.Println("websocket: Client.handleRaw: incomplete frames received") + return false + } + + for _, f := range frames.v { + if !f.isComplete { + cl.frame = f + continue + } + isClosing = cl.handleFrame(f) + if isClosing { + return true + } + } + + return false +} + +// // SendBin send data frame as binary to server. // If handler is nil, no response will be read from server. // @@ -629,14 +673,6 @@ func (cl *Client) serve() { break } - if debug.Value >= 3 { - max := len(packet) - if max > 16 { - max = 16 - } - fmt.Printf("websocket: Client.serve: packet: len:%d % x\n", len(packet), packet[:max]) - } - if cl.frame != nil { packet = cl.frame.unpack(packet) if cl.frame.isComplete { @@ -652,21 +688,9 @@ func (cl *Client) serve() { } } - frames := Unpack(packet) - if frames == nil { - log.Println("websocket: client.serve: uncomplete frames received") - continue - } - - for _, f := range frames.v { - if !f.isComplete { - cl.frame = f - continue - } - isClosing := cl.handleFrame(f) - if isClosing { - return - } + isClosing := cl.handleRaw(packet) + if isClosing { + return } } cl.Quit() @@ -732,6 +756,14 @@ func (cl *Client) recv() (packet []byte, err error) { } } + if debug.Value >= 3 { + max := len(packet) + if max > 16 { + max = 16 + } + fmt.Printf("websocket: Client.recv: packet: len:%d % x\n", len(packet), packet[:max]) + } + return packet, err } diff --git a/lib/websocket/server_test.go b/lib/websocket/server_test.go index 8df648f4..94d6e7f4 100644 --- a/lib/websocket/server_test.go +++ b/lib/websocket/server_test.go @@ -266,7 +266,7 @@ func TestServerHandshake(t *testing.T) { fmt.Fprintf(&bb, "\r\n") - err = cl.doHandshake("", bb.Bytes()) + _, err = cl.doHandshake("", bb.Bytes()) if err != nil { test.Assert(t, "error", c.expError, err.Error(), true) } |
