aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2019-03-16 19:20:50 +0700
committerShulhan <ms@kilabit.info>2019-03-17 02:15:15 +0700
commitfee7ce1ebb6c80b11c6002de69c701d524da6f06 (patch)
treeb396afb23ff6ddc5381a09d8f520293a2554c32b
parent62808cb99e31636577758999ee0163ba0973f398 (diff)
downloadpakakeh.go-fee7ce1ebb6c80b11c6002de69c701d524da6f06.tar.xz
websocket: simplify client handshake process without handler and context
User of library that want to create WebSocket client should only focus to the Endpoint and/or Headers when creating client. While at it, change handshake test to check for error instead of status code.
-rw-r--r--lib/websocket/client.go65
-rw-r--r--lib/websocket/contextkey.go3
-rw-r--r--lib/websocket/handler.go6
-rw-r--r--lib/websocket/server_test.go53
-rw-r--r--lib/websocket/websocket_test.go17
5 files changed, 48 insertions, 96 deletions
diff --git a/lib/websocket/client.go b/lib/websocket/client.go
index a6fd5e1a..12c24966 100644
--- a/lib/websocket/client.go
+++ b/lib/websocket/client.go
@@ -7,7 +7,6 @@ package websocket
import (
"bufio"
"bytes"
- "context"
"crypto/tls"
"fmt"
"log"
@@ -311,10 +310,27 @@ func (cl *Client) handshake() (err error) {
}
bb.WriteString("\r\n")
+ req := bb.Bytes()
- ctx := context.WithValue(context.Background(), ctxKeyWSAccept, keyAccept)
+ if debug.Value >= 3 {
+ fmt.Printf("websocket: Client.handshake:\n%s\n--\n", req)
+ }
- return cl.sendWithHandler(ctx, bb.Bytes(), cl.handleHandshake)
+ return cl.doHandshake(keyAccept, req)
+}
+
+func (cl *Client) doHandshake(keyAccept string, req []byte) (err error) {
+ err = cl.send(req)
+ if err != nil {
+ return err
+ }
+
+ resp, err := cl.recv()
+ if err != nil {
+ return err
+ }
+
+ return cl.handleHandshake(keyAccept, resp)
}
//
@@ -504,7 +520,7 @@ func (cl *Client) handleFrame(frame *Frame) (isClosing bool) {
return isClosing
}
-func (cl *Client) handleHandshake(ctx context.Context, resp []byte) (err error) {
+func (cl *Client) handleHandshake(keyAccept string, resp []byte) (err error) {
if debug.Value >= 3 {
max := 512
if len(resp) < 512 {
@@ -523,16 +539,12 @@ func (cl *Client) handleHandshake(ctx context.Context, resp []byte) (err error)
httpRes.Body.Close()
if httpRes.StatusCode != http.StatusSwitchingProtocols {
- fmt.Printf("websocket: Client.handleHandshake: status code: %d\n", httpRes.StatusCode)
- err = fmt.Errorf(httpRes.Status)
- return err
+ return fmt.Errorf(httpRes.Status)
}
- expAccept := ctx.Value(ctxKeyWSAccept)
gotAccept := httpRes.Header.Get(_hdrKeyWSAccept)
- if expAccept != gotAccept {
- err = fmt.Errorf("websocket: client.handleHandshake: invalid server accept key")
- return err
+ if keyAccept != gotAccept {
+ return fmt.Errorf("invalid server accept key")
}
return nil
@@ -739,34 +751,3 @@ func (cl *Client) send(packet []byte) (err error) {
return nil
}
-
-//
-// sendWithHandler send message to server, read the response, and pass it to
-// handler.
-//
-func (cl *Client) sendWithHandler(ctx context.Context, req []byte, handleRaw clientRawHandler) (err error) {
- if cl.conn == nil {
- return ErrConnClosed
- }
-
- err = cl.conn.SetWriteDeadline(time.Now().Add(defaultTimeout))
- if err != nil {
- return err
- }
-
- _, err = cl.conn.Write(req)
- if err != nil {
- return err
- }
-
- if handleRaw != nil {
- var resp []byte
- resp, err = cl.recv()
- if err != nil {
- return err
- }
- return handleRaw(ctx, resp)
- }
-
- return nil
-}
diff --git a/lib/websocket/contextkey.go b/lib/websocket/contextkey.go
index f90a254f..93e020d5 100644
--- a/lib/websocket/contextkey.go
+++ b/lib/websocket/contextkey.go
@@ -12,7 +12,4 @@ const (
CtxKeyExternalJWT ContextKey = 1 << iota
CtxKeyInternalJWT
CtxKeyUID
-
- // Internal context keys used by client.
- ctxKeyWSAccept // ctxKeyWSAccept context key for WebSocket accept key.
)
diff --git a/lib/websocket/handler.go b/lib/websocket/handler.go
index ec2356a1..d87cf4cd 100644
--- a/lib/websocket/handler.go
+++ b/lib/websocket/handler.go
@@ -9,12 +9,6 @@ import (
)
//
-// clientRawHandler define a callback type for handling raw packet from
-// send().
-//
-type clientRawHandler func(ctx context.Context, resp []byte) (err error)
-
-//
// ClientHandler define a callback type for client to handle packet from
// server (either broadcast or from response of request) in the form of frame.
//
diff --git a/lib/websocket/server_test.go b/lib/websocket/server_test.go
index 2d8a91ef..8df648f4 100644
--- a/lib/websocket/server_test.go
+++ b/lib/websocket/server_test.go
@@ -5,9 +5,7 @@
package websocket
import (
- "bufio"
"bytes"
- "context"
"fmt"
"net/http"
"net/url"
@@ -65,11 +63,10 @@ func TestServerHandshake(t *testing.T) {
}
cases := []struct {
- desc string
- req *http.Request
- query url.Values
- expKey string
- expRespCode int
+ desc string
+ req *http.Request
+ query url.Values
+ expError string
}{{
desc: "With valid request and authorization",
req: &http.Request{
@@ -86,8 +83,7 @@ func TestServerHandshake(t *testing.T) {
query: url.Values{
_qKeyTicket: []string{_testExternalJWT},
},
- expKey: _testHdrValWSAccept,
- expRespCode: http.StatusSwitchingProtocols,
+ expError: "invalid server accept key",
}, {
desc: "Without GET",
req: &http.Request{
@@ -104,7 +100,7 @@ func TestServerHandshake(t *testing.T) {
query: url.Values{
_qKeyTicket: []string{_testExternalJWT},
},
- expRespCode: http.StatusBadRequest,
+ expError: "400 invalid HTTP method",
}, {
desc: "Without HTTP header Host",
req: &http.Request{
@@ -119,7 +115,7 @@ func TestServerHandshake(t *testing.T) {
query: url.Values{
_qKeyTicket: []string{_testExternalJWT},
},
- expRespCode: http.StatusBadRequest,
+ expError: "400 bad request: header length is less than minimum",
}, {
desc: "Without HTTP header Connection",
req: &http.Request{
@@ -135,7 +131,7 @@ func TestServerHandshake(t *testing.T) {
query: url.Values{
_qKeyTicket: []string{_testExternalJWT},
},
- expRespCode: http.StatusBadRequest,
+ expError: "400 bad request: header length is less than minimum",
}, {
desc: "With invalid HTTP header Connection",
req: &http.Request{
@@ -152,7 +148,7 @@ func TestServerHandshake(t *testing.T) {
query: url.Values{
_qKeyTicket: []string{_testExternalJWT},
},
- expRespCode: http.StatusBadRequest,
+ expError: "400 invalid Connection header",
}, {
desc: "Without HTTP header Upgrade",
req: &http.Request{
@@ -168,7 +164,7 @@ func TestServerHandshake(t *testing.T) {
query: url.Values{
_qKeyTicket: []string{_testExternalJWT},
},
- expRespCode: http.StatusBadRequest,
+ expError: "400 bad request: header length is less than minimum",
}, {
desc: "Without HTTP header 'Sec-Websocket-Key'",
req: &http.Request{
@@ -184,7 +180,7 @@ func TestServerHandshake(t *testing.T) {
query: url.Values{
_qKeyTicket: []string{_testExternalJWT},
},
- expRespCode: http.StatusBadRequest,
+ expError: "400 bad request: header length is less than minimum",
}, {
desc: "Without HTTP header 'Sec-Websocket-Version'",
req: &http.Request{
@@ -200,7 +196,7 @@ func TestServerHandshake(t *testing.T) {
query: url.Values{
_qKeyTicket: []string{_testExternalJWT},
},
- expRespCode: http.StatusBadRequest,
+ expError: "400 bad request: header length is less than minimum",
}, {
desc: "With unsupported websocket version",
req: &http.Request{
@@ -217,7 +213,7 @@ func TestServerHandshake(t *testing.T) {
query: url.Values{
_qKeyTicket: []string{_testExternalJWT},
},
- expRespCode: http.StatusBadRequest,
+ expError: "400 unsupported Sec-WebSocket-Version",
}, {
desc: "Without authorization",
req: &http.Request{
@@ -231,7 +227,7 @@ func TestServerHandshake(t *testing.T) {
_hdrKeyWSVersion: []string{_hdrValWSVersion},
},
},
- expRespCode: http.StatusBadRequest,
+ expError: "400 Missing authorization",
}, {
desc: "Without invalid HTTP header 'Authorization'",
req: &http.Request{
@@ -248,7 +244,7 @@ func TestServerHandshake(t *testing.T) {
query: url.Values{
"Basic": []string{_testExternalJWT},
},
- expRespCode: http.StatusBadRequest,
+ expError: "400 Missing authorization",
}}
var bb bytes.Buffer
@@ -270,24 +266,9 @@ func TestServerHandshake(t *testing.T) {
fmt.Fprintf(&bb, "\r\n")
- c := c
- handleHandshake := func(ctx context.Context, resp []byte) (err error) {
- httpBuf := bufio.NewReader(bytes.NewBuffer(resp))
-
- httpRes, err := http.ReadResponse(httpBuf, nil)
- if err != nil {
- t.Fatal(err)
- return
- }
-
- test.Assert(t, "expRespCode", c.expRespCode, httpRes.StatusCode, true)
-
- return
- }
-
- err = cl.sendWithHandler(context.Background(), bb.Bytes(), handleHandshake)
+ err = cl.doHandshake("", bb.Bytes())
if err != nil {
- t.Fatal(err)
+ test.Assert(t, "error", c.expError, err.Error(), true)
}
}
}
diff --git a/lib/websocket/websocket_test.go b/lib/websocket/websocket_test.go
index 86815518..85d816c2 100644
--- a/lib/websocket/websocket_test.go
+++ b/lib/websocket/websocket_test.go
@@ -17,15 +17,14 @@ import (
var (
_testExternalJWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE1MzA0NjU2MDYsImhhc2giOiJiYmJiYmJiYi1iYmJiLWJiYmItYmJiYi1iYmJiYmJiYmJiYmIiLCJpYXQiOjE1MzAyMDY0MDYsIm5hZiI6MTUzMjc5ODQwNn0.15quj_gkeo9cWkLN98_2rXjtjihQym16Kn_9BQjYC14" //nolint: lll, gochecknoglobals
- _testEndpointAuth string //nolint: gochecknoglobals
- _testInternalJWT = _testExternalJWT //nolint: gochecknoglobals
- _testUID = uint64(100) //nolint: gochecknoglobals
- _testPort = 9001 //nolint: gochecknoglobals
- _testServer *Server //nolint: gochecknoglobals
- _testWSAddr string //nolint: gochecknoglobals
- _testHdrValWSAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" //nolint: gochecknoglobals
- _testHdrValWSKey = "dGhlIHNhbXBsZSBub25jZQ==" //nolint: gochecknoglobals
- _testMaskKey = []byte{'7', 'ú', '!', '='} //nolint: gochecknoglobals
+ _testEndpointAuth string //nolint: gochecknoglobals
+ _testInternalJWT = _testExternalJWT //nolint: gochecknoglobals
+ _testUID = uint64(100) //nolint: gochecknoglobals
+ _testPort = 9001 //nolint: gochecknoglobals
+ _testServer *Server //nolint: gochecknoglobals
+ _testWSAddr string //nolint: gochecknoglobals
+ _testHdrValWSKey = "dGhlIHNhbXBsZSBub25jZQ==" //nolint: gochecknoglobals
+ _testMaskKey = []byte{'7', 'ú', '!', '='} //nolint: gochecknoglobals
)
var (