diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/websocket/client.go | 29 | ||||
| -rw-r--r-- | lib/websocket/client_test.go | 51 |
2 files changed, 68 insertions, 12 deletions
diff --git a/lib/websocket/client.go b/lib/websocket/client.go index f19dcefe..455a4b78 100644 --- a/lib/websocket/client.go +++ b/lib/websocket/client.go @@ -28,6 +28,11 @@ const ( "Connection: Upgrade\r\n" + "Sec-Websocket-Key: %s\r\n" + "Sec-Websocket-Version: 13\r\n" + + schemeWSS = "wss" + schemeHTTPS = "https" + defTLSPort = "443" + defPort = "80" ) var ( @@ -295,10 +300,10 @@ func (cl *Client) init() (err error) { } // -// parseURI parse WebSocket connection URI from "endpoint" and get the remote +// parseURI parse WebSocket connection URI from Endpoint and get the remote // URL (for checking up scheme) and remote address. -// By default, if no port is given, it will set to 80 for URL with any scheme -// or 443 for "wss" scheme. +// By default, if no port is given, it will be set to 80 or 443 for "wss" or +// "https" scheme. // // On success it will set the remote address that can be used on open(). // On fail it will return an error. @@ -310,24 +315,24 @@ func (cl *Client) parseURI() (err error) { return err } + serverAddress := cl.remoteURL.Hostname() serverPort := cl.remoteURL.Port() - if len(serverPort) != 0 { - cl.remoteAddr = cl.remoteURL.Host - return nil - } - switch cl.remoteURL.Scheme { - case "wss": - serverPort = "443" + case schemeWSS, schemeHTTPS: + if len(serverPort) == 0 { + serverPort = defTLSPort + } if cl.TLSConfig == nil { cl.TLSConfig = &tls.Config{} } default: - serverPort = "80" + if len(serverPort) == 0 { + serverPort = defPort + } } - cl.remoteAddr = cl.remoteURL.Hostname() + ":" + serverPort + cl.remoteAddr = serverAddress + ":" + serverPort return nil } diff --git a/lib/websocket/client_test.go b/lib/websocket/client_test.go index 9aae9cfc..466e0932 100644 --- a/lib/websocket/client_test.go +++ b/lib/websocket/client_test.go @@ -5,6 +5,7 @@ package websocket import ( + "crypto/tls" "net/http" "sync" "testing" @@ -60,6 +61,56 @@ func TestConnect(t *testing.T) { } } +func TestClient_parseURI(t *testing.T) { + cl := &Client{} + + cases := []struct { + endpoint string + expRemoteAddress string + expTLSConfig *tls.Config + expError string + }{{ + endpoint: "ws://127.0.0.1:8080", + expRemoteAddress: "127.0.0.1:8080", + }, { + endpoint: "wss://127.0.0.1", + expRemoteAddress: "127.0.0.1:443", + expTLSConfig: new(tls.Config), + }, { + endpoint: "wss://127.0.0.1:8000", + expRemoteAddress: "127.0.0.1:8000", + expTLSConfig: new(tls.Config), + }, { + endpoint: "http://127.0.0.1", + expRemoteAddress: "127.0.0.1:80", + }, { + endpoint: "https://127.0.0.1", + expRemoteAddress: "127.0.0.1:443", + expTLSConfig: new(tls.Config), + }, { + endpoint: "https://127.0.0.1:8443", + expRemoteAddress: "127.0.0.1:8443", + expTLSConfig: new(tls.Config), + }} + + for _, c := range cases { + t.Log("parseURI", c.endpoint) + + cl.remoteAddr = "" + cl.TLSConfig = nil + cl.Endpoint = c.endpoint + + err := cl.parseURI() + if err != nil { + test.Assert(t, "error", c.expError, err.Error(), true) + continue + } + + test.Assert(t, "remote address", c.expRemoteAddress, cl.remoteAddr, true) + test.Assert(t, "TLS config", c.expTLSConfig, cl.TLSConfig, true) + } +} + func TestClientPing(t *testing.T) { if _testServer == nil { runTestServer() |
