diff options
| author | Shulhan <ms@kilabit.info> | 2023-11-25 00:04:54 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2023-11-26 21:32:52 +0700 |
| commit | 53cf77b77d36404cf720ec286bb7830b9dd6bc10 (patch) | |
| tree | b3b2af859b9c6a4fb0b620b1232fc5df68bcee61 | |
| parent | 93c7f9f18eafb68ff892a2ecb3de25d8a56d3e84 (diff) | |
| download | pakakeh.go-53cf77b77d36404cf720ec286bb7830b9dd6bc10.tar.xz | |
http/sseclient: implement Client retry
This changes everything.
On the server we split the SSEEndpoint to new type SSEConn, so each
callback use different instance of conn.
On the Client, we need to store the parsed serverUrl and the passed
header so it can be reused.
| -rw-r--r-- | lib/http/sse_conn.go | 149 | ||||
| -rw-r--r-- | lib/http/sse_endpoint.go | 163 | ||||
| -rw-r--r-- | lib/http/sse_endpoint_test.go | 3 | ||||
| -rw-r--r-- | lib/http/sseclient/sseclient.go | 149 | ||||
| -rw-r--r-- | lib/http/sseclient/sseclient_test.go | 163 |
5 files changed, 400 insertions, 227 deletions
diff --git a/lib/http/sse_conn.go b/lib/http/sse_conn.go new file mode 100644 index 00000000..ad14d337 --- /dev/null +++ b/lib/http/sse_conn.go @@ -0,0 +1,149 @@ +// Copyright 2023, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "bytes" + "fmt" + "net" + "net/http" + "strings" + "time" +) + +// SSECallback define the handler for Server-Sent Events (SSE). +// +// SSECallback type pass SSEConn that contains original HTTP request. +// This allow the server to check for header "Last-Event-ID" and/or for +// authentication. +// Remember that "the original Request.Body must not be used" according to +// [http.Hijacker] documentation. +type SSECallback func(sse *SSEConn) + +// SSEConn define the connection when the SSE request accepted by server. +type SSEConn struct { + HttpRequest *http.Request + + bufrw *bufio.ReadWriter + conn net.Conn +} + +// WriteEvent write message with event type to client. +// +// The event parameter must not be empty, otherwise it will not be sent. +// +// The msg parameter must not be empty, otherwise it will not be sent +// If msg contains new line character ('\n'), the message will be split into +// multiple "data:". +// +// The id parameter is optional. +// If its nil, it will be ignored. +// if its non-nil and empty, it will be send as empty ID. +// +// It will return an error if its failed to write to peer connection. +func (ep *SSEConn) WriteEvent(event, msg string, id *string) (err error) { + event = strings.TrimSpace(event) + if len(event) == 0 { + return nil + } + if len(msg) == 0 { + return nil + } + + var buf bytes.Buffer + + buf.WriteString(`event:`) + buf.WriteString(event) + buf.WriteByte('\n') + + ep.writeData(&buf, msg, id) + + _, err = ep.bufrw.Write(buf.Bytes()) + if err != nil { + return fmt.Errorf(`WriteMessage: %w`, err) + } + ep.bufrw.Flush() + return nil +} + +// WriteMessage write a message with optional id to client. +// +// The msg parameter must not be empty, otherwise it will not be sent +// If msg contains new line character ('\n'), the message will be split into +// multiple "data:". +// +// The id parameter is optional. +// If its nil, it will be ignored. +// if its non-nil and empty, it will be send as empty ID. +// +// It will return an error if its failed to write to peer connection. +func (ep *SSEConn) WriteMessage(msg string, id *string) (err error) { + if len(msg) == 0 { + return nil + } + + var buf bytes.Buffer + + ep.writeData(&buf, msg, id) + + _, err = ep.bufrw.Write(buf.Bytes()) + if err != nil { + return fmt.Errorf(`WriteMessage: %w`, err) + } + ep.bufrw.Flush() + return nil +} + +// WriteRaw write raw event message directly, without any parsing. +func (ep *SSEConn) WriteRaw(msg []byte) (err error) { + _, err = ep.bufrw.Write(msg) + if err != nil { + return fmt.Errorf(`WriteRaw: %w`, err) + } + ep.bufrw.Flush() + return nil +} + +// WriteRetry inform user how long they should wait, after disconnect, +// before re-connecting back to server. +// +// The duration must be in millisecond. +func (ep *SSEConn) WriteRetry(retry time.Duration) (err error) { + _, err = fmt.Fprintf(ep.bufrw, "retry:%d\n\n", retry.Milliseconds()) + if err != nil { + return fmt.Errorf(`WriteRetry: %w`, err) + } + ep.bufrw.Flush() + return nil +} + +func (ep *SSEConn) writeData(buf *bytes.Buffer, msg string, id *string) { + var ( + lines = strings.Split(msg, "\n") + line string + ) + for _, line = range lines { + buf.WriteString(`data:`) + buf.WriteString(line) + buf.WriteByte('\n') + } + if id != nil { + buf.WriteString(`id:`) + buf.WriteString(*id) + buf.WriteByte('\n') + } + buf.WriteByte('\n') +} + +// handshake write the last HTTP response to indicate the connection is +// accepted. +func (ep *SSEConn) handshake() { + ep.bufrw.WriteString("HTTP/1.1 200 OK\r\n") + ep.bufrw.WriteString("content-type: text/event-stream\r\n") + ep.bufrw.WriteString("cache-control: no-cache\r\n") + ep.bufrw.WriteString("\r\n") + ep.bufrw.Flush() +} diff --git a/lib/http/sse_endpoint.go b/lib/http/sse_endpoint.go index b453883a..b6804d8d 100644 --- a/lib/http/sse_endpoint.go +++ b/lib/http/sse_endpoint.go @@ -5,147 +5,22 @@ package http import ( - "bufio" - "bytes" "errors" - "fmt" - "net" "net/http" "net/url" - "strings" - "time" liberrors "github.com/shuLhan/share/lib/errors" ) -// SSECallback define the handler for Server-Sent Events (SSE). -// -// SSECallback type pass original HTTP request. -// This allow the server to check for header "Last-Event-ID" and/or for -// authentication. -// Remember that "the original Request.Body must not be used" according to -// [http.Hijacker] documentation. -type SSECallback func(sse *SSEEndpoint, req *http.Request) - // SSEEndpoint endpoint to create Server-Sent Events (SSE) on server. // // For creating the SSE client see subpackage [sseclient]. type SSEEndpoint struct { - bufrw *bufio.ReadWriter - conn net.Conn - - // Path where server accept the request for SSE. - Path string - // Call handler that will called when request to Path accepted. Call SSECallback -} - -// WriteEvent write message with event type to client. -// -// The event parameter must not be empty, otherwise it will not be sent. -// -// The msg parameter must not be empty, otherwise it will not be sent -// If msg contains new line character ('\n'), the message will be split into -// multiple "data:". -// -// The id parameter is optional. -// If its nil, it will be ignored. -// if its non-nil and empty, it will be send as empty ID. -// -// It will return an error if its failed to write to peer connection. -func (ep *SSEEndpoint) WriteEvent(event, msg string, id *string) (err error) { - event = strings.TrimSpace(event) - if len(event) == 0 { - return nil - } - if len(msg) == 0 { - return nil - } - - var buf bytes.Buffer - - buf.WriteString(`event:`) - buf.WriteString(event) - buf.WriteByte('\n') - - ep.writeData(&buf, msg, id) - - _, err = ep.bufrw.Write(buf.Bytes()) - if err != nil { - return fmt.Errorf(`WriteMessage: %w`, err) - } - ep.bufrw.Flush() - return nil -} - -// WriteMessage write a message with optional id to client. -// -// The msg parameter must not be empty, otherwise it will not be sent -// If msg contains new line character ('\n'), the message will be split into -// multiple "data:". -// -// The id parameter is optional. -// If its nil, it will be ignored. -// if its non-nil and empty, it will be send as empty ID. -// -// It will return an error if its failed to write to peer connection. -func (ep *SSEEndpoint) WriteMessage(msg string, id *string) (err error) { - if len(msg) == 0 { - return nil - } - - var buf bytes.Buffer - - ep.writeData(&buf, msg, id) - _, err = ep.bufrw.Write(buf.Bytes()) - if err != nil { - return fmt.Errorf(`WriteMessage: %w`, err) - } - ep.bufrw.Flush() - return nil -} - -// WriteRaw write raw event message directly, without any parsing. -func (ep *SSEEndpoint) WriteRaw(msg []byte) (err error) { - _, err = ep.bufrw.Write(msg) - if err != nil { - return fmt.Errorf(`WriteRaw: %w`, err) - } - ep.bufrw.Flush() - return nil -} - -// WriteRetry inform user how long they should wait, after disconnect, -// before re-connecting back to server. -// -// The duration must be in millisecond. -func (ep *SSEEndpoint) WriteRetry(retry time.Duration) (err error) { - _, err = fmt.Fprintf(ep.bufrw, "retry:%d\n\n", retry.Milliseconds()) - if err != nil { - return fmt.Errorf(`WriteRetry: %w`, err) - } - ep.bufrw.Flush() - return nil -} - -func (ep *SSEEndpoint) writeData(buf *bytes.Buffer, msg string, id *string) { - var ( - lines = strings.Split(msg, "\n") - line string - ) - for _, line = range lines { - buf.WriteString(`data:`) - buf.WriteString(line) - buf.WriteByte('\n') - } - if id != nil { - buf.WriteString(`id:`) - buf.WriteString(*id) - buf.WriteByte('\n') - } - buf.WriteByte('\n') + // Path where server accept the request for SSE. + Path string } func (ep *SSEEndpoint) call( @@ -180,15 +55,17 @@ func (ep *SSEEndpoint) call( return } - err = ep.hijack(res) + var sseconn *SSEConn + + sseconn, err = ep.hijack(res, req) if err != nil { http.Error(res, err.Error(), http.StatusInternalServerError) return } - ep.handshake() - ep.Call(ep, req) - ep.conn.Close() + sseconn.handshake() + ep.Call(sseconn) + sseconn.conn.Close() } func (ep *SSEEndpoint) doEvals( @@ -212,7 +89,7 @@ func (ep *SSEEndpoint) doEvals( return nil } -func (ep *SSEEndpoint) hijack(res http.ResponseWriter) (err error) { +func (ep *SSEEndpoint) hijack(res http.ResponseWriter, req *http.Request) (sseconn *SSEConn, err error) { var ( hijack http.Hijacker ok bool @@ -220,23 +97,17 @@ func (ep *SSEEndpoint) hijack(res http.ResponseWriter) (err error) { hijack, ok = res.(http.Hijacker) if !ok { - return errors.New(`http.ResponseWriter is not http.Hijacker`) + return nil, errors.New(`http.ResponseWriter is not http.Hijacker`) } - ep.conn, ep.bufrw, err = hijack.Hijack() - if err != nil { - return err + sseconn = &SSEConn{ + HttpRequest: req, } - return nil -} + sseconn.conn, sseconn.bufrw, err = hijack.Hijack() + if err != nil { + return nil, err + } -// handshake write the last HTTP response to indicate the connection is -// accepted. -func (ep *SSEEndpoint) handshake() { - ep.bufrw.WriteString("HTTP/1.1 200 OK\r\n") - ep.bufrw.WriteString("content-type: text/event-stream\r\n") - ep.bufrw.WriteString("cache-control: no-cache\r\n") - ep.bufrw.WriteString("\r\n") - ep.bufrw.Flush() + return sseconn, nil } diff --git a/lib/http/sse_endpoint_test.go b/lib/http/sse_endpoint_test.go index 2b70ffe7..106149b6 100644 --- a/lib/http/sse_endpoint_test.go +++ b/lib/http/sse_endpoint_test.go @@ -5,7 +5,6 @@ package http import ( - "net/http" "testing" "github.com/shuLhan/share/lib/test" @@ -57,7 +56,7 @@ func testSSEEndpointDuplicatePath(t *testing.T, httpd *Server) { var sse = &SSEEndpoint{ Path: `/sse`, - Call: func(ep *SSEEndpoint, req *http.Request) {}, + Call: func(sseconn *SSEConn) {}, } err = httpd.RegisterSSE(sse) diff --git a/lib/http/sseclient/sseclient.go b/lib/http/sseclient/sseclient.go index ed799114..681b9ea8 100644 --- a/lib/http/sseclient/sseclient.go +++ b/lib/http/sseclient/sseclient.go @@ -47,6 +47,9 @@ type Client struct { C <-chan Event event chan Event + serverUrl *url.URL + header http.Header + conn net.Conn closeq chan struct{} @@ -64,7 +67,12 @@ type Client struct { // This field is optional default to 10 seconds. Timeout time.Duration - retry time.Duration + // Retry define how long, in milliseconds, the client should wait + // before reconnecting back to server after disconnect. + // Zero or negative value disable it. + // + // This field is optional, default to 0, not retrying. + Retry time.Duration // Insecure allow connect to HTTPS Endpoint with invalid // certificate. @@ -76,14 +84,16 @@ func (cl *Client) Close() (err error) { // Close the connection, wait until it catched by consume goroutine. if cl.conn != nil { err = cl.conn.Close() + + var timeWait = time.NewTimer(50 * time.Millisecond) select { case cl.closeq <- struct{}{}: // Tell the consume goroutine we initiate the close. - cl.conn = nil - default: + case <-timeWait.C: // The consume goroutine may already end or end at // the same time. } + cl.conn = nil } return err } @@ -95,30 +105,38 @@ func (cl *Client) Close() (err error) { // during handshake. // The following header cannot be set: Host, User-Agent, and Accept. func (cl *Client) Connect(header http.Header) (err error) { - var ( - logp = `Connect` - serverUrl *url.URL - ) + var logp = `Connect` - serverUrl, err = cl.init() + err = cl.init(header) if err != nil { return fmt.Errorf(`%s: %w`, logp, err) } - err = cl.dial(serverUrl) + err = cl.connect() if err != nil { return fmt.Errorf(`%s: %w`, logp, err) } - var packet []byte + // Reset the ID to store the ID from server. + cl.LastEventID = `` + + go cl.consume() - packet, err = cl.handshake(serverUrl, header) + return nil +} + +func (cl *Client) connect() (err error) { + err = cl.dial() if err != nil { - return fmt.Errorf(`%s: %w`, logp, err) + return err } - // Reset the ID to store the ID from server. - cl.LastEventID = `` + var packet []byte + + packet, err = cl.handshake() + if err != nil { + return err + } select { case cl.event <- Event{Type: EventTypeOpen}: @@ -129,35 +147,41 @@ func (cl *Client) Connect(header http.Header) (err error) { // consume it. cl.parseEvent(packet) - go cl.consume() - return nil } // init validate and set default field values. -func (cl *Client) init() (serverUrl *url.URL, err error) { - serverUrl, err = url.Parse(cl.Endpoint) +func (cl *Client) init(header http.Header) (err error) { + cl.serverUrl, err = url.Parse(cl.Endpoint) if err != nil { - return nil, err + return err } var host, port string - host, port, err = net.SplitHostPort(serverUrl.Host) + host, port, err = net.SplitHostPort(cl.serverUrl.Host) if err != nil { - return nil, err + return err } if len(port) == 0 { - switch serverUrl.Scheme { + switch cl.serverUrl.Scheme { case `http`: port = `80` case `https`: port = `443` default: - return nil, fmt.Errorf(`unknown scheme %q`, serverUrl.Scheme) + return fmt.Errorf(`unknown scheme %q`, cl.serverUrl.Scheme) } } - serverUrl.Host = net.JoinHostPort(host, port) + cl.serverUrl.Host = net.JoinHostPort(host, port) + + cl.header = header + if cl.header == nil { + cl.header = http.Header{} + } + cl.header.Set(libhttp.HeaderHost, cl.serverUrl.Host) + cl.header.Set(libhttp.HeaderUserAgent, `libhttp/`+share.Version) + cl.header.Set(libhttp.HeaderAccept, libhttp.ContentTypeEventStream) if cl.Timeout <= 0 { cl.Timeout = defTimeout @@ -167,17 +191,17 @@ func (cl *Client) init() (serverUrl *url.URL, err error) { cl.C = cl.event cl.closeq = make(chan struct{}) - return serverUrl, nil + return nil } -func (cl *Client) dial(serverUrl *url.URL) (err error) { - if serverUrl.Scheme == `https` { +func (cl *Client) dial() (err error) { + if cl.serverUrl.Scheme == `https` { var tlsConfig = &tls.Config{ InsecureSkipVerify: cl.Insecure, } - cl.conn, err = tls.Dial(`tcp`, serverUrl.Host, tlsConfig) + cl.conn, err = tls.Dial(`tcp`, cl.serverUrl.Host, tlsConfig) } else { - cl.conn, err = net.Dial(`tcp`, serverUrl.Host) + cl.conn, err = net.Dial(`tcp`, cl.serverUrl.Host) } if err != nil { return err @@ -190,8 +214,8 @@ func (cl *Client) dial(serverUrl *url.URL) (err error) { // "text/event-stream". // // If the response is not empty, it contains event message, return it. -func (cl *Client) handshake(serverUrl *url.URL, header http.Header) (packet []byte, err error) { - err = cl.handshakeRequest(serverUrl, header) +func (cl *Client) handshake() (packet []byte, err error) { + err = cl.handshakeRequest() if err != nil { return nil, err } @@ -220,27 +244,21 @@ func (cl *Client) handshake(serverUrl *url.URL, header http.Header) (packet []by return packet, nil } -func (cl *Client) handshakeRequest(serverUrl *url.URL, header http.Header) (err error) { +func (cl *Client) handshakeRequest() (err error) { var buf bytes.Buffer - fmt.Fprintf(&buf, `GET %s`, serverUrl.Path) - if len(serverUrl.RawQuery) != 0 { + fmt.Fprintf(&buf, `GET %s`, cl.serverUrl.Path) + if len(cl.serverUrl.RawQuery) != 0 { buf.WriteByte('?') - buf.WriteString(serverUrl.RawQuery) + buf.WriteString(cl.serverUrl.RawQuery) } buf.WriteString(" HTTP/1.1\r\n") // Write the known values to prevent user overwrite our default // values. - if header == nil { - header = http.Header{} - } - header.Set(libhttp.HeaderHost, serverUrl.Host) - header.Set(libhttp.HeaderUserAgent, `libhttp/`+share.Version) - header.Set(libhttp.HeaderAccept, libhttp.ContentTypeEventStream) if len(cl.LastEventID) != 0 { - header.Set(libhttp.HeaderLastEventID, cl.LastEventID) + cl.header.Set(libhttp.HeaderLastEventID, cl.LastEventID) } var ( @@ -248,7 +266,7 @@ func (cl *Client) handshakeRequest(serverUrl *url.URL, header http.Header) (err hvals []string val string ) - for hkey, hvals = range header { + for hkey, hvals = range cl.header { if len(hvals) == 0 { continue } @@ -288,17 +306,44 @@ func (cl *Client) consume() { for { data, err = libnet.Read(cl.conn, 0, cl.Timeout) if err != nil { - // TODO: retry? + if cl.Retry <= 0 { + return + } + + // Check if this user Close or not. + var timeWait = time.NewTimer(50 * time.Millisecond) select { case <-cl.closeq: - // User call Close. - default: - // Peer initiated close or connection error. - // At the same time user may also call - // Close, to prevent data race, let the user - // clear it out. + // User initiated close. + if !timeWait.Stop() { + <-timeWait.C + } + return + case <-timeWait.C: + _ = cl.conn.Close() + cl.conn = nil + // Not from user, try to re-connect. + } + + var connected bool + timeWait = time.NewTimer(cl.Retry) + for !connected { + select { + case <-timeWait.C: + err = cl.connect() + if err != nil { + timeWait.Reset(cl.Retry) + continue + } + connected = true + case <-cl.closeq: + // User initiated close. + if !timeWait.Stop() { + <-timeWait.C + } + return + } } - return } cl.parseEvent(data) } @@ -397,7 +442,7 @@ func (cl *Client) parseEvent(raw []byte) { var retry int64 retry, err = strconv.ParseInt(string(fval), 10, 64) if err == nil { - cl.retry = time.Duration(retry) * time.Millisecond + cl.Retry = time.Duration(retry) } default: // Ignore the field. diff --git a/lib/http/sseclient/sseclient_test.go b/lib/http/sseclient/sseclient_test.go index cb1385f6..e7434105 100644 --- a/lib/http/sseclient/sseclient_test.go +++ b/lib/http/sseclient/sseclient_test.go @@ -7,7 +7,7 @@ package sseclient import ( "fmt" "math/rand" - "net/http" + "sync/atomic" "testing" "time" @@ -68,7 +68,7 @@ func TestClient(t *testing.T) { var expq = make(chan Event) - var servercb = func(ep *libhttp.SSEEndpoint, _ *http.Request) { + var servercb = func(sseconn *libhttp.SSEConn) { var ( c testCase err error @@ -82,13 +82,13 @@ func TestClient(t *testing.T) { // NO-OP, this is sent when client has an // error. case EventTypeMessage: - err = ep.WriteMessage(c.data, c.id()) + err = sseconn.WriteMessage(c.data, c.id()) if err != nil { t.Fatal(`WriteMessage`, err) } default: // Named type. - err = ep.WriteEvent(c.kind, c.data, c.id()) + err = sseconn.WriteEvent(c.kind, c.data, c.id()) if err != nil { t.Fatalf(`WriteEvent #%d: %s`, x, err) } @@ -98,17 +98,18 @@ func TestClient(t *testing.T) { } var ( - serverAddress string - err error + srv *libhttp.Server + err error ) - serverAddress, err = testRunSSEServer(t, servercb) + srv, err = testRunSSEServer(t, servercb) if err != nil { t.Fatal(`testRunSSEServer:`, err) } + t.Cleanup(func() { srv.Stop(1 * time.Second) }) var cl = Client{ - Endpoint: fmt.Sprintf(`http://%s/sse`, serverAddress), + Endpoint: fmt.Sprintf(`http://%s/sse`, srv.Options.Address), } err = cl.Connect(nil) @@ -211,17 +212,17 @@ func TestClient_raw(t *testing.T) { var expq = make(chan Event) - var servercb = func(ep *libhttp.SSEEndpoint, _ *http.Request) { + var servercb = func(sseconn *libhttp.SSEConn) { var ( - c testCase - ev Event - err error - x int + c testCase + ev Event + errw error + x int ) for x, c = range cases { - err = ep.WriteRaw([]byte(c.raw)) - if err != nil { - t.Fatalf(`WriteRaw #%d: %s`, x, err) + errw = sseconn.WriteRaw([]byte(c.raw)) + if errw != nil { + t.Fatalf(`WriteRaw #%d: %s`, x, errw) } for _, ev = range c.exp { expq <- ev @@ -229,15 +230,16 @@ func TestClient_raw(t *testing.T) { } } - var addr string + var srv *libhttp.Server - addr, err = testRunSSEServer(t, servercb) + srv, err = testRunSSEServer(t, servercb) if err != nil { t.Fatal(`testRunSSEServer:`, err) } + t.Cleanup(func() { srv.Stop(1 * time.Second) }) var cl = Client{ - Endpoint: fmt.Sprintf(`http://%s/sse`, addr), + Endpoint: fmt.Sprintf(`http://%s/sse`, srv.Options.Address), } err = cl.Connect(nil) @@ -274,6 +276,116 @@ func TestClient_raw(t *testing.T) { test.Assert(t, `LastEventID`, `2`, cl.LastEventID) } +func TestClientRetry(t *testing.T) { + const testKindClose = `close` + + type testCase struct { + kind string + raw []byte + exp []Event + } + + var cases = []testCase{{ + raw: []byte("retry: 100\n\n"), + exp: []Event{{ + Type: EventTypeOpen, // The first message always open. + }}, + }, { + // This is where server close its handler. + kind: testKindClose, + }, { + raw: []byte("data: after close\n\n"), + exp: []Event{{ + Type: EventTypeOpen, // The first message after reconnect. + }, { + Type: EventTypeMessage, + Data: `after close`, + }}, + }} + + var expq = make(chan Event) + + // Counter for test case between open and close in server + // callback. + var casenum atomic.Int64 + + var servercb = func(sseconn *libhttp.SSEConn) { + var ( + x = int(casenum.Load()) + c testCase + ev Event + err error + ) + for x < len(cases) { + c = cases[x] + x++ + casenum.Store(int64(x)) + + switch c.kind { + case testKindClose: + // Close the connection here, continue + // later. + return + default: + err = sseconn.WriteRaw([]byte(c.raw)) + if err != nil { + t.Fatalf(`WriteRaw #%d: %s`, x, err) + } + } + for _, ev = range c.exp { + expq <- ev + } + } + } + + var ( + srv *libhttp.Server + err error + ) + + srv, err = testRunSSEServer(t, servercb) + if err != nil { + t.Fatal(`testRunSSEServer:`, err) + } + t.Cleanup(func() { srv.Stop(1 * time.Second) }) + + var cl = Client{ + Endpoint: fmt.Sprintf(`http://%s/sse`, srv.Options.Address), + } + err = cl.Connect(nil) + if err != nil { + t.Fatal(`Connect:`, err) + } + + var ( + timeout = 1 * time.Second + ticker = time.NewTicker(timeout) + + c testCase + expEvent Event + gotEvent Event + tag string + x, y int + ) + for x, c = range cases { + for y = range c.exp { + tag = fmt.Sprintf(`Case #%d/#%d`, x, y) + + select { + case <-ticker.C: + t.Fatalf(`%s: timeout`, tag) + + case gotEvent = <-cl.C: + expEvent = <-expq + test.Assert(t, tag, expEvent, gotEvent) + } + ticker.Reset(timeout) + } + } + _ = cl.Close() + +} + // testGenerateAddress generate random port for server address. func testGenerateAddress() (addr string) { var port = rand.Int() % 60000 @@ -283,18 +395,17 @@ func testGenerateAddress() (addr string) { return fmt.Sprintf(`127.0.0.1:%d`, port) } -func testRunSSEServer(t *testing.T, cb libhttp.SSECallback) (address string, err error) { - address = testGenerateAddress() +func testRunSSEServer(t *testing.T, cb libhttp.SSECallback) (srv *libhttp.Server, err error) { + var address = testGenerateAddress() var ( serverOpts = &libhttp.ServerOptions{ Address: address, } - srv *libhttp.Server ) srv, err = libhttp.NewServer(serverOpts) if err != nil { - return ``, err + return nil, err } var sse = &libhttp.SSEEndpoint{ @@ -304,7 +415,7 @@ func testRunSSEServer(t *testing.T, cb libhttp.SSECallback) (address string, err err = srv.RegisterSSE(sse) if err != nil { - return ``, err + return nil, err } go srv.Start() @@ -315,7 +426,5 @@ func testRunSSEServer(t *testing.T, cb libhttp.SSECallback) (address string, err t.Skip(err) } - t.Cleanup(func() { srv.Stop(1 * time.Second) }) - - return address, nil + return srv, nil } |
