summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2022-10-08 19:24:17 +0700
committerShulhan <ms@kilabit.info>2022-10-10 00:48:56 +0700
commit1edecf8f5a09a175cb30bb9939edd8f4497de9a1 (patch)
tree8741be0b10494b5da5b80ea7de12b9aba438f59b
parent76df3428683054771930983681721318abfd80d8 (diff)
downloadpakakeh.go-1edecf8f5a09a175cb30bb9939edd8f4497de9a1.tar.xz
lib/websocket: fix possible data race on Client
The Client have method send that check if the underlying connection (conn) has been closed or not. Since the conn can be closed anytime, for example server send to the control CLOSE frame: recv -> handleFrame -> handleClose -> Quit we need to guard the conn with Mutex before calling send to prevent data race.
-rw-r--r--lib/websocket/client.go74
-rw-r--r--lib/websocket/client_test.go541
-rw-r--r--lib/websocket/websocket_test.go2
3 files changed, 289 insertions, 328 deletions
diff --git a/lib/websocket/client.go b/lib/websocket/client.go
index b19c5843..9e9b707c 100644
--- a/lib/websocket/client.go
+++ b/lib/websocket/client.go
@@ -178,35 +178,38 @@ func (cl *Client) Close() (err error) {
}
var (
+ logp = `websocket: Close`
packet []byte = NewFrameClose(true, StatusNormal, nil)
- timer *time.Timer
+
+ timer *time.Timer
+ wait bool
)
cl.gracefulClose = make(chan bool, 1)
err = cl.send(packet)
if err != nil {
- return fmt.Errorf("websocket: Close: %w", err)
+ return fmt.Errorf(`%s: %w`, logp, err)
}
// Wait for server to response with CLOSE.
timer = time.NewTimer(defaultTimeout)
-loop:
- for {
+ wait = true
+ for wait {
select {
case <-timer.C:
// We did not receive server CLOSE frame in timely
// manner.
- break loop
+ wait = false
case <-cl.gracefulClose:
timer.Stop()
- break loop
+ wait = false
}
}
err = cl.conn.Close()
if err != nil {
- err = fmt.Errorf("websocket: Close: %w", err)
+ err = fmt.Errorf(`%s: %w`, logp, err)
}
cl.conn = nil
@@ -215,6 +218,9 @@ loop:
// Connect to endpoint.
func (cl *Client) Connect() (err error) {
+ cl.Lock()
+ defer cl.Unlock()
+
err = cl.init()
if err != nil {
return fmt.Errorf("websocket: Connect: " + err.Error())
@@ -353,8 +359,7 @@ func (cl *Client) open() (err error) {
}
if cl.TLSConfig != nil {
- cl.conn, err = tls.DialWithDialer(dialer, "tcp",
- cl.remoteAddr, cl.TLSConfig)
+ cl.conn, err = tls.DialWithDialer(dialer, "tcp", cl.remoteAddr, cl.TLSConfig)
} else {
cl.conn, err = dialer.Dial("tcp", cl.remoteAddr)
}
@@ -493,7 +498,9 @@ func clientOnClose(cl *Client, frame *Frame) (err error) {
fmt.Printf("websocket: clientOnClose: payload: %s\n", frame.payload)
}
+ cl.Lock()
err = cl.send(packet)
+ cl.Unlock()
if err != nil {
log.Println("websocket: clientOnClose: send: " + err.Error())
}
@@ -686,35 +693,50 @@ func (cl *Client) handleRaw(packet []byte) (isClosing bool) {
// SendBin send data frame as binary to server.
// If handler is nil, no response will be read from server.
-func (cl *Client) SendBin(payload []byte) error {
+func (cl *Client) SendBin(payload []byte) (err error) {
+ cl.Lock()
var packet []byte = NewFrameBin(true, payload)
- return cl.send(packet)
+ err = cl.send(packet)
+ cl.Unlock()
+ return err
}
// sendClose send the control CLOSE frame to server with optional payload.
func (cl *Client) sendClose(status CloseCode, payload []byte) (err error) {
+ cl.Lock()
var packet []byte = NewFrameClose(true, status, payload)
- return cl.send(packet)
+ err = cl.send(packet)
+ cl.Unlock()
+ return err
}
// SendPing send control PING frame to server, expecting PONG as response.
-func (cl *Client) SendPing(payload []byte) error {
+func (cl *Client) SendPing(payload []byte) (err error) {
+ cl.Lock()
var packet []byte = NewFramePing(true, payload)
- return cl.send(packet)
+ err = cl.send(packet)
+ cl.Unlock()
+ return err
}
// SendPong send the control frame PONG to server, by using payload from PING
// frame.
-func (cl *Client) SendPong(payload []byte) error {
+func (cl *Client) SendPong(payload []byte) (err error) {
+ cl.Lock()
var packet []byte = NewFramePong(true, payload)
- return cl.send(packet)
+ err = cl.send(packet)
+ cl.Unlock()
+ return err
}
// SendText send data frame as text to server.
// If handler is nil, no response will be read from server.
func (cl *Client) SendText(payload []byte) (err error) {
+ cl.Lock()
var packet []byte = NewFrameText(true, payload)
- return cl.send(packet)
+ err = cl.send(packet)
+ cl.Unlock()
+ return err
}
// serve read one data frame at a time from server and propagated to handler.
@@ -731,19 +753,20 @@ func (cl *Client) serve() {
isClosing bool
)
- for {
+ for !isClosing {
packet, err = cl.recv()
if err != nil {
log.Println("websocket: Client.serve: " + err.Error())
- break
+ isClosing = true
+ continue
}
if len(packet) == 0 {
// Empty packet may indicated that server has closed
// the connection abnormally.
log.Println("websocket: Client.serve: empty packet received, closing")
- break
+ isClosing = true
+ continue
}
-
if cl.frame != nil {
packet = cl.frame.unpack(packet)
if cl.frame.isComplete {
@@ -751,18 +774,14 @@ func (cl *Client) serve() {
cl.frame = nil
isClosing = cl.handleFrame(frame)
if isClosing {
- return
+ continue
}
}
if len(packet) == 0 {
continue
}
}
-
isClosing = cl.handleRaw(packet)
- if isClosing {
- return
- }
}
cl.Quit()
}
@@ -876,7 +895,8 @@ func (cl *Client) send(packet []byte) (err error) {
// pinger send the PING control frame every 10 seconds.
func (cl *Client) pinger() {
var (
- t *time.Ticker = time.NewTicker(cl.PingInterval)
+ t *time.Ticker = time.NewTicker(cl.PingInterval)
+
err error
)
diff --git a/lib/websocket/client_test.go b/lib/websocket/client_test.go
index 36a512f9..b851b260 100644
--- a/lib/websocket/client_test.go
+++ b/lib/websocket/client_test.go
@@ -7,8 +7,6 @@ package websocket
import (
"crypto/tls"
"net/http"
- "strings"
- "sync"
"testing"
"github.com/shuLhan/share/lib/test"
@@ -127,35 +125,19 @@ func TestClient_parseURI(t *testing.T) {
func TestClientPing(t *testing.T) {
type testCase struct {
- exp *Frame
- expClose *Frame
- desc string
- req []byte
- reconnect bool
+ exp *Frame
+ desc string
+ req []byte
}
if _testServer == nil {
runTestServer()
}
- var (
- testClient = &Client{
- Endpoint: _testEndpointAuth,
- }
-
- wg sync.WaitGroup
- err error
- )
-
- err = testClient.Connect()
- if err != nil {
- t.Fatal("TestClientPing: " + err.Error())
- }
-
var cases = []testCase{{
- desc: "Without payload, unmasked",
+ desc: `Without payload, unmasked`,
req: NewFramePing(false, nil),
- expClose: &Frame{
+ exp: &Frame{
fin: frameIsFinished,
opcode: OpcodeClose,
closeCode: StatusBadRequest,
@@ -164,10 +146,9 @@ func TestClientPing(t *testing.T) {
isComplete: true,
},
}, {
- desc: "With payload, unmasked",
- reconnect: true,
- req: NewFramePing(false, []byte("Hello")),
- expClose: &Frame{
+ desc: `With payload, unmasked`,
+ req: NewFramePing(false, []byte(`Hello`)),
+ exp: &Frame{
fin: frameIsFinished,
opcode: OpcodeClose,
closeCode: StatusBadRequest,
@@ -176,102 +157,79 @@ func TestClientPing(t *testing.T) {
isComplete: true,
},
}, {
- desc: "With payload, masked",
- reconnect: true,
- req: NewFramePing(true, []byte("Hello")),
+ desc: `With payload, masked`,
+ req: NewFramePing(true, []byte(`Hello`)),
exp: &Frame{
fin: frameIsFinished,
opcode: OpcodePong,
len: 5,
- payload: []byte("Hello"),
+ payload: []byte(`Hello`),
isComplete: true,
},
}}
var (
- c testCase
+ gotFrame = make(chan *Frame)
+
+ cl *Client
+ got *Frame
+ c testCase
+ err error
)
for _, c = range cases {
- var c testCase = c
t.Log(c.desc)
- if c.reconnect {
- err = testClient.Connect()
- if err != nil {
- t.Fatal(err)
- }
- }
-
- testClient.handleClose = func(cl *Client, got *Frame) error {
- var exp *Frame = c.expClose
-
- test.Assert(t, "close", exp, got)
-
- if len(got.payload) >= 2 {
- got.payload = got.payload[2:]
- }
-
- cl.sendClose(got.closeCode, got.payload)
- cl.Quit()
- wg.Done()
- return nil
+ cl = &Client{
+ Endpoint: _testEndpointAuth,
+ handleClose: func(cl *Client, f *Frame) error {
+ cl.sendClose(f.closeCode, nil)
+ cl.Quit()
+ gotFrame <- f
+ return nil
+ },
+ handlePong: func(cl *Client, f *Frame) (err error) {
+ gotFrame <- f
+ return nil
+ },
}
- testClient.handlePong = func(cl *Client, got *Frame) (err error) {
- var exp *Frame = c.exp
-
- test.Assert(t, "handlePong", exp, got)
-
- wg.Done()
- return nil
+ err = cl.Connect()
+ if err != nil {
+ t.Fatal(err)
}
- wg.Add(1)
- testClient.Lock()
- err = testClient.send(c.req)
- testClient.Unlock()
+ cl.Lock()
+ err = cl.send(c.req)
+ cl.Unlock()
if err != nil {
t.Fatal(err)
}
- wg.Wait()
- }
+ got = <-gotFrame
+ test.Assert(t, `response`, c.exp, got)
- testClient.Quit()
+ if got.opcode != OpcodeClose {
+ cl.Close()
+ }
+ }
}
-func TestClientText(t *testing.T) {
+func TestClient_send_FrameText(t *testing.T) {
type testCase struct {
- exp *Frame
- expClose *Frame
- desc string
- req []byte
- reconnect bool
+ exp *Frame
+ desc string
+ req []byte
}
if _testServer == nil {
runTestServer()
}
- var (
- testClient = &Client{
- Endpoint: _testEndpointAuth,
- }
-
- wg sync.WaitGroup
- err error
- )
-
- err = testClient.Connect()
- if err != nil {
- t.Fatal("TestClientText: " + err.Error())
- }
-
var cases = []testCase{{
- desc: "Small payload, unmasked",
- req: NewFrameText(false, []byte("Hello")),
- expClose: &Frame{
+ desc: `Small payload, unmasked`,
+ req: NewFrameText(false, []byte(`Hello`)),
+ exp: &Frame{
fin: frameIsFinished,
opcode: OpcodeClose,
closeCode: StatusBadRequest,
@@ -280,20 +238,19 @@ func TestClientText(t *testing.T) {
isComplete: true,
},
}, {
- desc: "Small payload, masked",
- reconnect: true,
- req: NewFrameText(true, []byte("Hello")),
+ desc: `Small payload, masked`,
+ req: NewFrameText(true, []byte(`Hello`)),
exp: &Frame{
fin: frameIsFinished,
opcode: OpcodeText,
len: 5,
- payload: []byte("Hello"),
+ payload: []byte(`Hello`),
isComplete: true,
},
}, {
- desc: "Medium payload 256, unmasked",
+ desc: `Medium payload 256, unmasked`,
req: NewFrameText(false, _dummyPayload256),
- expClose: &Frame{
+ exp: &Frame{
fin: frameIsFinished,
opcode: OpcodeClose,
closeCode: StatusBadRequest,
@@ -302,9 +259,8 @@ func TestClientText(t *testing.T) {
isComplete: true,
},
}, {
- desc: "Medium payload 256, masked",
- reconnect: true,
- req: NewFrameText(true, _dummyPayload256),
+ desc: `Medium payload 256, masked`,
+ req: NewFrameText(true, _dummyPayload256),
exp: &Frame{
fin: frameIsFinished,
opcode: OpcodeText,
@@ -313,9 +269,9 @@ func TestClientText(t *testing.T) {
isComplete: true,
},
}, {
- desc: "Large payload 65536, unmasked",
+ desc: `Large payload 65536, unmasked`,
req: NewFrameText(false, _dummyPayload65536),
- expClose: &Frame{
+ exp: &Frame{
fin: frameIsFinished,
opcode: OpcodeClose,
closeCode: StatusBadRequest,
@@ -324,9 +280,8 @@ func TestClientText(t *testing.T) {
isComplete: true,
},
}, {
- desc: "Large payload 65536, masked",
- reconnect: true,
- req: NewFrameText(true, _dummyPayload65536),
+ desc: `Large payload 65536, masked`,
+ req: NewFrameText(true, _dummyPayload65536),
exp: &Frame{
fin: frameIsFinished,
opcode: OpcodeText,
@@ -337,62 +292,63 @@ func TestClientText(t *testing.T) {
}}
var (
- c testCase
+ gotFrame = make(chan *Frame)
+
+ cl *Client
+ got *Frame
+ c testCase
+ err error
)
for _, c = range cases {
- var c testCase = c
t.Log(c.desc)
- if c.reconnect {
- err = testClient.Connect()
- if err != nil {
- t.Fatal(err)
- }
- }
-
- testClient.handleClose = func(cl *Client, got *Frame) error {
- var exp *Frame = c.expClose
- test.Assert(t, "close", exp, got)
- cl.sendClose(got.closeCode, got.payload)
- cl.Quit()
- wg.Done()
- return nil
+ cl = &Client{
+ Endpoint: _testEndpointAuth,
+ handleClose: func(cl *Client, f *Frame) error {
+ cl.sendClose(f.closeCode, nil)
+ cl.Quit()
+ gotFrame <- f
+ return nil
+ },
+ HandleText: func(cl *Client, f *Frame) error {
+ gotFrame <- f
+ return nil
+ },
}
- testClient.HandleText = func(cl *Client, got *Frame) error {
- var exp *Frame = c.exp
- test.Assert(t, "text", exp, got)
- wg.Done()
- return nil
+ err = cl.Connect()
+ if err != nil {
+ t.Fatal(err.Error())
}
- wg.Add(1)
- err = testClient.send(c.req)
+ cl.Lock()
+ err = cl.send(c.req)
if err != nil {
t.Fatal(err)
}
+ cl.Unlock()
+
+ got = <-gotFrame
+ test.Assert(t, `response`, c.exp, got)
- wg.Wait()
+ if got.opcode != OpcodeClose {
+ cl.Close()
+ }
}
}
func TestClientFragmentation(t *testing.T) {
type testCase struct {
- exp *Frame
- expClose *Frame
- desc string
- frames []Frame
+ exp *Frame
+ desc string
+ frames []Frame
}
if _testServer == nil {
runTestServer()
}
- var (
- wg sync.WaitGroup
- )
-
var cases = []testCase{{
desc: "Two text frames, unmasked",
frames: []Frame{{
@@ -404,7 +360,7 @@ func TestClientFragmentation(t *testing.T) {
opcode: OpcodeCont,
payload: []byte{'l', 'o'},
}},
- expClose: &Frame{
+ exp: &Frame{
fin: frameIsFinished,
opcode: OpcodeClose,
closeCode: StatusBadRequest,
@@ -427,7 +383,7 @@ func TestClientFragmentation(t *testing.T) {
opcode: OpcodeCont,
payload: []byte("Shulhan"),
}},
- expClose: &Frame{
+ exp: &Frame{
fin: frameIsFinished,
opcode: OpcodeClose,
closeCode: StatusBadRequest,
@@ -463,76 +419,90 @@ func TestClientFragmentation(t *testing.T) {
}}
var (
- c testCase
- testClient *Client
- err error
- req []byte
- x int
- brokenPipe bool
+ gotFrame = make(chan *Frame)
+
+ cl *Client
+ frame Frame
+ got *Frame
+ c testCase
+ err error
+ x int
+ req []byte
)
for _, c = range cases {
- testClient = &Client{
+ t.Log(c.desc)
+
+ cl = &Client{
Endpoint: _testEndpointAuth,
+ handleClose: func(cl *Client, f *Frame) error {
+ cl.sendClose(f.closeCode, nil)
+ cl.Quit()
+ gotFrame <- f
+ return nil
+ },
+ HandleText: func(cl *Client, f *Frame) error {
+ gotFrame <- f
+ return nil
+ },
}
- err = testClient.Connect()
+ err = cl.Connect()
if err != nil {
t.Fatal(err)
}
- testClient.handleClose = func(desc string, exp *Frame) ClientHandler {
- return func(cl *Client, got *Frame) (err error) {
- test.Assert(t, desc+": close", exp, got)
- cl.sendClose(got.closeCode, got.payload)
- cl.Quit()
- wg.Done()
- return nil
- }
- }(c.desc, c.expClose)
-
- testClient.HandleText = func(desc string, exp *Frame) ClientHandler {
- return func(cl *Client, got *Frame) error {
- test.Assert(t, desc+": text", exp, got)
- wg.Done()
- return nil
- }
- }(c.desc, c.exp)
-
- wg.Add(1)
- for x = 0; x < len(c.frames); x++ {
- req = c.frames[x].pack()
+ for x, frame = range c.frames {
+ req = frame.pack()
- testClient.Lock()
- err = testClient.send(req)
- testClient.Unlock()
+ cl.Lock()
+ err = cl.send(req)
+ cl.Unlock()
if err != nil {
- // If the client send unmasked frame, the
+ // If the client send unmasked frame,
// server may close the connection before we
// can test send the second frame.
- brokenPipe = strings.Contains(err.Error(), "write: broken pipe")
- if !brokenPipe {
- t.Fatalf("expecting broken pipe, got %s", err)
- }
- break
+ t.Logf(`send frame %d: %s`, x, err)
}
}
- wg.Wait()
+
+ got = <-gotFrame
+ test.Assert(t, `response`, c.exp, got)
+
+ if got.opcode != OpcodeClose {
+ cl.Close()
+ }
}
}
+// TestClientFragmentation2 We are sending two requests, first request split
+// into 3 frames in between second request (PING):
+//
+// F1->F2->PING->F3
func TestClientFragmentation2(t *testing.T) {
if _testServer == nil {
runTestServer()
}
var (
+ gotFrame = make(chan *Frame)
testClient = &Client{
Endpoint: _testEndpointAuth,
+ handlePong: func(cl *Client, f *Frame) error {
+ gotFrame <- f
+ return nil
+ },
+ HandleText: func(cl *Client, f *Frame) error {
+ gotFrame <- f
+ return nil
+ },
}
- wg sync.WaitGroup
+ exp *Frame
+ got *Frame
err error
+ x int
+ req []byte
)
err = testClient.Connect()
@@ -562,39 +532,6 @@ func TestClientFragmentation2(t *testing.T) {
payload: []byte("Shulhan"),
}}
- testClient.handlePong = func(cl *Client, got *Frame) error {
- var exp = &Frame{
- fin: frameIsFinished,
- opcode: OpcodePong,
- len: 4,
- payload: []byte("PING"),
- isComplete: true,
- }
- test.Assert(t, "handlePong", exp, got)
- wg.Done()
- return nil
- }
-
- testClient.HandleText = func(cl *Client, got *Frame) error {
- var exp = &Frame{
- fin: frameIsFinished,
- opcode: OpcodeText,
- len: 14,
- payload: []byte("Hello, Shulhan"),
- isComplete: true,
- }
- test.Assert(t, "handlePong", exp, got)
- wg.Done()
- return nil
- }
-
- wg.Add(2)
-
- var (
- x int
- req []byte
- )
-
for x = 0; x < len(frames); x++ {
req = frames[x].pack()
@@ -606,36 +543,39 @@ func TestClientFragmentation2(t *testing.T) {
}
}
- wg.Wait()
+ // The first response should be PONG.
+ exp = &Frame{
+ fin: frameIsFinished,
+ opcode: OpcodePong,
+ len: 4,
+ payload: []byte(`PING`),
+ isComplete: true,
+ }
+ got = <-gotFrame
+ test.Assert(t, `response PONG`, exp, got)
+
+ exp = &Frame{
+ fin: frameIsFinished,
+ opcode: OpcodeText,
+ len: 14,
+ payload: []byte(`Hello, Shulhan`),
+ isComplete: true,
+ }
+ got = <-gotFrame
+ test.Assert(t, `response TEXT`, exp, got)
}
func TestClientSendBin(t *testing.T) {
type testCase struct {
- exp *Frame
- desc string
- payload []byte
- reconnect bool
+ exp *Frame
+ desc string
+ payload []byte
}
if _testServer == nil {
runTestServer()
}
- var (
- testClient = &Client{
- Endpoint: _testEndpointAuth,
- }
-
- c testCase
- wg sync.WaitGroup
- err error
- )
-
- err = testClient.Connect()
- if err != nil {
- t.Fatal("TestSendBin: Connect: " + err.Error())
- }
-
var cases = []testCase{{
desc: "Single bin frame",
payload: []byte("Hello"),
@@ -648,32 +588,40 @@ func TestClientSendBin(t *testing.T) {
},
}}
- for _, c = range cases {
- var cc testCase = c
+ var (
+ gotFrame = make(chan *Frame)
- t.Log(cc.desc)
+ cl *Client
+ got *Frame
+ c testCase
+ err error
+ )
- if cc.reconnect {
- err = testClient.Connect()
- if err != nil {
- t.Fatal(err)
- }
+ for _, c = range cases {
+ t.Log(c.desc)
+
+ cl = &Client{
+ Endpoint: _testEndpointAuth,
+ HandleBin: func(cl *Client, f *Frame) error {
+ gotFrame <- f
+ return nil
+ },
}
- testClient.HandleBin = func(cl *Client, got *Frame) error {
- var exp *Frame = cc.exp
- test.Assert(t, "HandleBin", exp, got)
- wg.Done()
- return nil
+ err = cl.Connect()
+ if err != nil {
+ t.Fatal(err)
}
- wg.Add(1)
- err = testClient.SendBin(cc.payload)
+ err = cl.SendBin(c.payload)
if err != nil {
- t.Fatal("TestSendBin: " + err.Error())
+ t.Fatal(err.Error())
}
- wg.Wait()
+ got = <-gotFrame
+ test.Assert(t, `response`, c.exp, got)
+
+ cl.Close()
}
}
@@ -688,20 +636,6 @@ func TestClientSendPing(t *testing.T) {
runTestServer()
}
- var (
- testClient = &Client{
- Endpoint: _testEndpointAuth,
- }
-
- wg sync.WaitGroup
- err error
- )
-
- err = testClient.Connect()
- if err != nil {
- t.Fatal("TestSendBin: Connect: " + err.Error())
- }
-
var cases = []testCase{{
desc: "Without payload",
exp: &Frame{
@@ -723,27 +657,35 @@ func TestClientSendPing(t *testing.T) {
}}
var (
- c testCase
+ gotFrame = make(chan *Frame)
+ testClient = &Client{
+ Endpoint: _testEndpointAuth,
+ handlePong: func(cl *Client, f *Frame) error {
+ gotFrame <- f
+ return nil
+ },
+ }
+
+ got *Frame
+ err error
+ c testCase
)
- for _, c = range cases {
- var cc testCase = c
- t.Log(cc.desc)
+ err = testClient.Connect()
+ if err != nil {
+ t.Fatal(err.Error())
+ }
- testClient.handlePong = func(cl *Client, got *Frame) error {
- var exp *Frame = cc.exp
- test.Assert(t, "handlePong", exp, got)
- wg.Done()
- return nil
- }
+ for _, c = range cases {
+ t.Log(c.desc)
- wg.Add(1)
- err = testClient.SendPing(cc.payload)
+ err = testClient.SendPing(c.payload)
if err != nil {
- t.Fatal("TestSendPing: " + err.Error())
+ t.Fatal(err.Error())
}
- wg.Wait()
+ got = <-gotFrame
+ test.Assert(t, `response`, c.exp, got)
}
}
@@ -753,43 +695,42 @@ func TestClient_sendClose(t *testing.T) {
}
var (
- testClient = &Client{
+ gotFrame = make(chan *Frame)
+ cl = &Client{
Endpoint: _testEndpointAuth,
+ handleClose: func(cl *Client, f *Frame) error {
+ cl.sendClose(f.closeCode, nil)
+ cl.Quit()
+ gotFrame <- f
+ return nil
+ },
}
- wg sync.WaitGroup
+ got *Frame
err error
)
- err = testClient.Connect()
+ err = cl.Connect()
if err != nil {
t.Fatal("TestClient_sendClose: Connect: " + err.Error())
}
- testClient.handleClose = func(cl *Client, got *Frame) error {
- var exp = &Frame{
- fin: frameIsFinished,
- opcode: OpcodeClose,
- closeCode: StatusNormal,
- len: 8,
- payload: []byte{0x03, 0xE8, 'n', 'o', 'r', 'm', 'a', 'l'},
- isComplete: true,
- }
- test.Assert(t, "handleClose", exp, got)
- cl.Quit()
- wg.Done()
- return nil
- }
-
- wg.Add(1)
- err = testClient.sendClose(StatusNormal, []byte("normal"))
+ err = cl.sendClose(StatusNormal, []byte("normal"))
if err != nil {
t.Fatal("TestClient_sendClose: " + err.Error())
}
- wg.Wait()
-
- err = testClient.SendPing(nil)
+ got = <-gotFrame
+ var exp = &Frame{
+ fin: frameIsFinished,
+ opcode: OpcodeClose,
+ closeCode: StatusNormal,
+ len: 8,
+ payload: []byte{0x03, 0xE8, 'n', 'o', 'r', 'm', 'a', 'l'},
+ isComplete: true,
+ }
+ test.Assert(t, `sendClose response`, exp, got)
- test.Assert(t, "error", ErrConnClosed, err)
+ err = cl.SendPing(nil)
+ test.Assert(t, `SendPing should error`, ErrConnClosed, err)
}
diff --git a/lib/websocket/websocket_test.go b/lib/websocket/websocket_test.go
index 3fcc4f73..d2fcc404 100644
--- a/lib/websocket/websocket_test.go
+++ b/lib/websocket/websocket_test.go
@@ -121,7 +121,7 @@ func runTestServer() {
}
}()
- time.Sleep(1 * time.Second)
+ time.Sleep(500 * time.Millisecond)
}
func TestMain(m *testing.M) {