diff options
| author | Shulhan <ms@kilabit.info> | 2019-03-16 19:20:50 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2019-03-17 02:15:15 +0700 |
| commit | fee7ce1ebb6c80b11c6002de69c701d524da6f06 (patch) | |
| tree | b396afb23ff6ddc5381a09d8f520293a2554c32b | |
| parent | 62808cb99e31636577758999ee0163ba0973f398 (diff) | |
| download | pakakeh.go-fee7ce1ebb6c80b11c6002de69c701d524da6f06.tar.xz | |
websocket: simplify client handshake process without handler and context
User of library that want to create WebSocket client should only focus
to the Endpoint and/or Headers when creating client.
While at it, change handshake test to check for error instead of status
code.
| -rw-r--r-- | lib/websocket/client.go | 65 | ||||
| -rw-r--r-- | lib/websocket/contextkey.go | 3 | ||||
| -rw-r--r-- | lib/websocket/handler.go | 6 | ||||
| -rw-r--r-- | lib/websocket/server_test.go | 53 | ||||
| -rw-r--r-- | lib/websocket/websocket_test.go | 17 |
5 files changed, 48 insertions, 96 deletions
diff --git a/lib/websocket/client.go b/lib/websocket/client.go index a6fd5e1a..12c24966 100644 --- a/lib/websocket/client.go +++ b/lib/websocket/client.go @@ -7,7 +7,6 @@ package websocket import ( "bufio" "bytes" - "context" "crypto/tls" "fmt" "log" @@ -311,10 +310,27 @@ func (cl *Client) handshake() (err error) { } bb.WriteString("\r\n") + req := bb.Bytes() - ctx := context.WithValue(context.Background(), ctxKeyWSAccept, keyAccept) + if debug.Value >= 3 { + fmt.Printf("websocket: Client.handshake:\n%s\n--\n", req) + } - return cl.sendWithHandler(ctx, bb.Bytes(), cl.handleHandshake) + return cl.doHandshake(keyAccept, req) +} + +func (cl *Client) doHandshake(keyAccept string, req []byte) (err error) { + err = cl.send(req) + if err != nil { + return err + } + + resp, err := cl.recv() + if err != nil { + return err + } + + return cl.handleHandshake(keyAccept, resp) } // @@ -504,7 +520,7 @@ func (cl *Client) handleFrame(frame *Frame) (isClosing bool) { return isClosing } -func (cl *Client) handleHandshake(ctx context.Context, resp []byte) (err error) { +func (cl *Client) handleHandshake(keyAccept string, resp []byte) (err error) { if debug.Value >= 3 { max := 512 if len(resp) < 512 { @@ -523,16 +539,12 @@ func (cl *Client) handleHandshake(ctx context.Context, resp []byte) (err error) httpRes.Body.Close() if httpRes.StatusCode != http.StatusSwitchingProtocols { - fmt.Printf("websocket: Client.handleHandshake: status code: %d\n", httpRes.StatusCode) - err = fmt.Errorf(httpRes.Status) - return err + return fmt.Errorf(httpRes.Status) } - expAccept := ctx.Value(ctxKeyWSAccept) gotAccept := httpRes.Header.Get(_hdrKeyWSAccept) - if expAccept != gotAccept { - err = fmt.Errorf("websocket: client.handleHandshake: invalid server accept key") - return err + if keyAccept != gotAccept { + return fmt.Errorf("invalid server accept key") } return nil @@ -739,34 +751,3 @@ func (cl *Client) send(packet []byte) (err error) { return nil } - -// -// sendWithHandler send message to server, read the response, and pass it to -// handler. -// -func (cl *Client) sendWithHandler(ctx context.Context, req []byte, handleRaw clientRawHandler) (err error) { - if cl.conn == nil { - return ErrConnClosed - } - - err = cl.conn.SetWriteDeadline(time.Now().Add(defaultTimeout)) - if err != nil { - return err - } - - _, err = cl.conn.Write(req) - if err != nil { - return err - } - - if handleRaw != nil { - var resp []byte - resp, err = cl.recv() - if err != nil { - return err - } - return handleRaw(ctx, resp) - } - - return nil -} diff --git a/lib/websocket/contextkey.go b/lib/websocket/contextkey.go index f90a254f..93e020d5 100644 --- a/lib/websocket/contextkey.go +++ b/lib/websocket/contextkey.go @@ -12,7 +12,4 @@ const ( CtxKeyExternalJWT ContextKey = 1 << iota CtxKeyInternalJWT CtxKeyUID - - // Internal context keys used by client. - ctxKeyWSAccept // ctxKeyWSAccept context key for WebSocket accept key. ) diff --git a/lib/websocket/handler.go b/lib/websocket/handler.go index ec2356a1..d87cf4cd 100644 --- a/lib/websocket/handler.go +++ b/lib/websocket/handler.go @@ -9,12 +9,6 @@ import ( ) // -// clientRawHandler define a callback type for handling raw packet from -// send(). -// -type clientRawHandler func(ctx context.Context, resp []byte) (err error) - -// // ClientHandler define a callback type for client to handle packet from // server (either broadcast or from response of request) in the form of frame. // diff --git a/lib/websocket/server_test.go b/lib/websocket/server_test.go index 2d8a91ef..8df648f4 100644 --- a/lib/websocket/server_test.go +++ b/lib/websocket/server_test.go @@ -5,9 +5,7 @@ package websocket import ( - "bufio" "bytes" - "context" "fmt" "net/http" "net/url" @@ -65,11 +63,10 @@ func TestServerHandshake(t *testing.T) { } cases := []struct { - desc string - req *http.Request - query url.Values - expKey string - expRespCode int + desc string + req *http.Request + query url.Values + expError string }{{ desc: "With valid request and authorization", req: &http.Request{ @@ -86,8 +83,7 @@ func TestServerHandshake(t *testing.T) { query: url.Values{ _qKeyTicket: []string{_testExternalJWT}, }, - expKey: _testHdrValWSAccept, - expRespCode: http.StatusSwitchingProtocols, + expError: "invalid server accept key", }, { desc: "Without GET", req: &http.Request{ @@ -104,7 +100,7 @@ func TestServerHandshake(t *testing.T) { query: url.Values{ _qKeyTicket: []string{_testExternalJWT}, }, - expRespCode: http.StatusBadRequest, + expError: "400 invalid HTTP method", }, { desc: "Without HTTP header Host", req: &http.Request{ @@ -119,7 +115,7 @@ func TestServerHandshake(t *testing.T) { query: url.Values{ _qKeyTicket: []string{_testExternalJWT}, }, - expRespCode: http.StatusBadRequest, + expError: "400 bad request: header length is less than minimum", }, { desc: "Without HTTP header Connection", req: &http.Request{ @@ -135,7 +131,7 @@ func TestServerHandshake(t *testing.T) { query: url.Values{ _qKeyTicket: []string{_testExternalJWT}, }, - expRespCode: http.StatusBadRequest, + expError: "400 bad request: header length is less than minimum", }, { desc: "With invalid HTTP header Connection", req: &http.Request{ @@ -152,7 +148,7 @@ func TestServerHandshake(t *testing.T) { query: url.Values{ _qKeyTicket: []string{_testExternalJWT}, }, - expRespCode: http.StatusBadRequest, + expError: "400 invalid Connection header", }, { desc: "Without HTTP header Upgrade", req: &http.Request{ @@ -168,7 +164,7 @@ func TestServerHandshake(t *testing.T) { query: url.Values{ _qKeyTicket: []string{_testExternalJWT}, }, - expRespCode: http.StatusBadRequest, + expError: "400 bad request: header length is less than minimum", }, { desc: "Without HTTP header 'Sec-Websocket-Key'", req: &http.Request{ @@ -184,7 +180,7 @@ func TestServerHandshake(t *testing.T) { query: url.Values{ _qKeyTicket: []string{_testExternalJWT}, }, - expRespCode: http.StatusBadRequest, + expError: "400 bad request: header length is less than minimum", }, { desc: "Without HTTP header 'Sec-Websocket-Version'", req: &http.Request{ @@ -200,7 +196,7 @@ func TestServerHandshake(t *testing.T) { query: url.Values{ _qKeyTicket: []string{_testExternalJWT}, }, - expRespCode: http.StatusBadRequest, + expError: "400 bad request: header length is less than minimum", }, { desc: "With unsupported websocket version", req: &http.Request{ @@ -217,7 +213,7 @@ func TestServerHandshake(t *testing.T) { query: url.Values{ _qKeyTicket: []string{_testExternalJWT}, }, - expRespCode: http.StatusBadRequest, + expError: "400 unsupported Sec-WebSocket-Version", }, { desc: "Without authorization", req: &http.Request{ @@ -231,7 +227,7 @@ func TestServerHandshake(t *testing.T) { _hdrKeyWSVersion: []string{_hdrValWSVersion}, }, }, - expRespCode: http.StatusBadRequest, + expError: "400 Missing authorization", }, { desc: "Without invalid HTTP header 'Authorization'", req: &http.Request{ @@ -248,7 +244,7 @@ func TestServerHandshake(t *testing.T) { query: url.Values{ "Basic": []string{_testExternalJWT}, }, - expRespCode: http.StatusBadRequest, + expError: "400 Missing authorization", }} var bb bytes.Buffer @@ -270,24 +266,9 @@ func TestServerHandshake(t *testing.T) { fmt.Fprintf(&bb, "\r\n") - c := c - handleHandshake := func(ctx context.Context, resp []byte) (err error) { - httpBuf := bufio.NewReader(bytes.NewBuffer(resp)) - - httpRes, err := http.ReadResponse(httpBuf, nil) - if err != nil { - t.Fatal(err) - return - } - - test.Assert(t, "expRespCode", c.expRespCode, httpRes.StatusCode, true) - - return - } - - err = cl.sendWithHandler(context.Background(), bb.Bytes(), handleHandshake) + err = cl.doHandshake("", bb.Bytes()) if err != nil { - t.Fatal(err) + test.Assert(t, "error", c.expError, err.Error(), true) } } } diff --git a/lib/websocket/websocket_test.go b/lib/websocket/websocket_test.go index 86815518..85d816c2 100644 --- a/lib/websocket/websocket_test.go +++ b/lib/websocket/websocket_test.go @@ -17,15 +17,14 @@ import ( var ( _testExternalJWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE1MzA0NjU2MDYsImhhc2giOiJiYmJiYmJiYi1iYmJiLWJiYmItYmJiYi1iYmJiYmJiYmJiYmIiLCJpYXQiOjE1MzAyMDY0MDYsIm5hZiI6MTUzMjc5ODQwNn0.15quj_gkeo9cWkLN98_2rXjtjihQym16Kn_9BQjYC14" //nolint: lll, gochecknoglobals - _testEndpointAuth string //nolint: gochecknoglobals - _testInternalJWT = _testExternalJWT //nolint: gochecknoglobals - _testUID = uint64(100) //nolint: gochecknoglobals - _testPort = 9001 //nolint: gochecknoglobals - _testServer *Server //nolint: gochecknoglobals - _testWSAddr string //nolint: gochecknoglobals - _testHdrValWSAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" //nolint: gochecknoglobals - _testHdrValWSKey = "dGhlIHNhbXBsZSBub25jZQ==" //nolint: gochecknoglobals - _testMaskKey = []byte{'7', 'ú', '!', '='} //nolint: gochecknoglobals + _testEndpointAuth string //nolint: gochecknoglobals + _testInternalJWT = _testExternalJWT //nolint: gochecknoglobals + _testUID = uint64(100) //nolint: gochecknoglobals + _testPort = 9001 //nolint: gochecknoglobals + _testServer *Server //nolint: gochecknoglobals + _testWSAddr string //nolint: gochecknoglobals + _testHdrValWSKey = "dGhlIHNhbXBsZSBub25jZQ==" //nolint: gochecknoglobals + _testMaskKey = []byte{'7', 'ú', '!', '='} //nolint: gochecknoglobals ) var ( |
