aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2019-03-20 23:44:27 +0700
committerShulhan <ms@kilabit.info>2019-03-20 23:44:27 +0700
commit1102dae593a3d9673d23b612da6aed52e97841af (patch)
tree2a33228cde523101174cd7167ea8c0b09adcf011
parent6740da2369bf6324ddc0f7303ea2dd4b008fbfe4 (diff)
downloadpakakeh.go-1102dae593a3d9673d23b612da6aed52e97841af.tar.xz
websocket: use custom HTTP parser for client
The problem with using the standard http.ReadResponse is that the received packet may contains WebSocket frame, not just HTTP response. This cause the initial frame lost, which may required by client (for example, as a response for authentication).
-rw-r--r--lib/websocket/client.go122
-rw-r--r--lib/websocket/server_test.go2
2 files changed, 78 insertions, 46 deletions
diff --git a/lib/websocket/client.go b/lib/websocket/client.go
index 6a354ec1..7cdbaac3 100644
--- a/lib/websocket/client.go
+++ b/lib/websocket/client.go
@@ -5,7 +5,6 @@
package websocket
import (
- "bufio"
"bytes"
"crypto/tls"
"fmt"
@@ -18,6 +17,7 @@ import (
"unicode/utf8"
"github.com/shuLhan/share/lib/debug"
+ libhttp "github.com/shuLhan/share/lib/http"
)
const (
@@ -164,13 +164,25 @@ func (cl *Client) Connect() (err error) {
return fmt.Errorf("websocket: Connect: " + err.Error())
}
- err = cl.handshake()
+ var rest []byte
+
+ rest, err = cl.handshake()
if err != nil {
_ = cl.conn.Close()
cl.conn = nil
return fmt.Errorf("websocket: Connect: " + err.Error())
}
+ // At this point client successfully connected to server, but the
+ // response from server may include WebSocket frame, not just HTTP
+ // response.
+ if len(rest) > 0 {
+ isClosing := cl.handleRaw(rest)
+ if isClosing {
+ return nil
+ }
+ }
+
go cl.serve()
return nil
@@ -218,7 +230,7 @@ func (cl *Client) init() (err error) {
}
//
-// parseURI parse websocket connection URI from "endpoint" and get the remote
+// parseURI parse WebSocket connection URI from "endpoint" and get the remote
// URL (for checking up scheme) and remote address.
// By default, if no port is given, it will set to 80 for URL with any scheme
// or 443 for "wss" scheme.
@@ -254,7 +266,7 @@ func (cl *Client) parseURI() (err error) {
}
//
-// open TCP connection to websocket remote address.
+// open TCP connection to WebSocket remote address.
// If client "isTLS" field is true, the connection is opened with TLS protocol
// and the remote name MUST have a valid certificate.
//
@@ -284,9 +296,9 @@ func (cl *Client) open() (err error) {
}
//
-// handshake send the websocket opening handshake.
+// handshake send the WebSocket opening handshake.
//
-func (cl *Client) handshake() (err error) {
+func (cl *Client) handshake() (rest []byte, err error) {
var bb bytes.Buffer
path := cl.remoteURL.EscapedPath()
@@ -303,13 +315,13 @@ func (cl *Client) handshake() (err error) {
_, err = fmt.Fprintf(&bb, _handshakeReqFormat, path, cl.remoteURL.Host, key)
if err != nil {
- return err
+ return nil, err
}
if len(cl.Headers) > 0 {
err = cl.Headers.Write(&bb)
if err != nil {
- return err
+ return nil, err
}
}
@@ -320,21 +332,31 @@ func (cl *Client) handshake() (err error) {
fmt.Printf("websocket: Client.handshake:\n%s\n--\n", req)
}
- return cl.doHandshake(keyAccept, req)
+ rest, err = cl.doHandshake(keyAccept, req)
+ if err != nil {
+ return nil, err
+ }
+
+ return rest, nil
}
-func (cl *Client) doHandshake(keyAccept string, req []byte) (err error) {
+func (cl *Client) doHandshake(keyAccept string, req []byte) (rest []byte, err error) {
err = cl.send(req)
if err != nil {
- return err
+ return nil, err
}
resp, err := cl.recv()
if err != nil {
- return err
+ return nil, err
}
- return cl.handleHandshake(keyAccept, resp)
+ rest, err = cl.handleHandshake(keyAccept, resp)
+ if err != nil {
+ return nil, err
+ }
+
+ return rest, nil
}
//
@@ -520,7 +542,7 @@ func (cl *Client) handleFrame(frame *Frame) (isClosing bool) {
return isClosing
}
-func (cl *Client) handleHandshake(keyAccept string, resp []byte) (err error) {
+func (cl *Client) handleHandshake(keyAccept string, resp []byte) (rest []byte, err error) {
if debug.Value >= 3 {
max := 512
if len(resp) < 512 {
@@ -529,25 +551,23 @@ func (cl *Client) handleHandshake(keyAccept string, resp []byte) (err error) {
fmt.Printf("websocket: Client.handleHandshake:\n%s\n--\n", resp[:max])
}
- httpBuf := bufio.NewReader(bytes.NewBuffer(resp))
+ var httpRes *http.Response
- httpRes, err := http.ReadResponse(httpBuf, nil)
+ httpRes, rest, err = libhttp.ParseResponseHeader(resp)
if err != nil {
- return err
+ return nil, err
}
- httpRes.Body.Close()
-
if httpRes.StatusCode != http.StatusSwitchingProtocols {
- return fmt.Errorf(httpRes.Status)
+ return nil, fmt.Errorf(httpRes.Status)
}
gotAccept := httpRes.Header.Get(_hdrKeyWSAccept)
if keyAccept != gotAccept {
- return fmt.Errorf("invalid server accept key")
+ return nil, fmt.Errorf("invalid server accept key")
}
- return nil
+ return rest, nil
}
//
@@ -563,6 +583,30 @@ func (cl *Client) handleInvalidData() {
}
//
+// handleRaw packet from server.
+//
+func (cl *Client) handleRaw(packet []byte) (isClosing bool) {
+ frames := Unpack(packet)
+ if frames == nil {
+ log.Println("websocket: Client.handleRaw: incomplete frames received")
+ return false
+ }
+
+ for _, f := range frames.v {
+ if !f.isComplete {
+ cl.frame = f
+ continue
+ }
+ isClosing = cl.handleFrame(f)
+ if isClosing {
+ return true
+ }
+ }
+
+ return false
+}
+
+//
// SendBin send data frame as binary to server.
// If handler is nil, no response will be read from server.
//
@@ -629,14 +673,6 @@ func (cl *Client) serve() {
break
}
- if debug.Value >= 3 {
- max := len(packet)
- if max > 16 {
- max = 16
- }
- fmt.Printf("websocket: Client.serve: packet: len:%d % x\n", len(packet), packet[:max])
- }
-
if cl.frame != nil {
packet = cl.frame.unpack(packet)
if cl.frame.isComplete {
@@ -652,21 +688,9 @@ func (cl *Client) serve() {
}
}
- frames := Unpack(packet)
- if frames == nil {
- log.Println("websocket: client.serve: uncomplete frames received")
- continue
- }
-
- for _, f := range frames.v {
- if !f.isComplete {
- cl.frame = f
- continue
- }
- isClosing := cl.handleFrame(f)
- if isClosing {
- return
- }
+ isClosing := cl.handleRaw(packet)
+ if isClosing {
+ return
}
}
cl.Quit()
@@ -732,6 +756,14 @@ func (cl *Client) recv() (packet []byte, err error) {
}
}
+ if debug.Value >= 3 {
+ max := len(packet)
+ if max > 16 {
+ max = 16
+ }
+ fmt.Printf("websocket: Client.recv: packet: len:%d % x\n", len(packet), packet[:max])
+ }
+
return packet, err
}
diff --git a/lib/websocket/server_test.go b/lib/websocket/server_test.go
index 8df648f4..94d6e7f4 100644
--- a/lib/websocket/server_test.go
+++ b/lib/websocket/server_test.go
@@ -266,7 +266,7 @@ func TestServerHandshake(t *testing.T) {
fmt.Fprintf(&bb, "\r\n")
- err = cl.doHandshake("", bb.Bytes())
+ _, err = cl.doHandshake("", bb.Bytes())
if err != nil {
test.Assert(t, "error", c.expError, err.Error(), true)
}