diff options
| author | Shulhan <ms@kilabit.info> | 2023-11-22 03:23:21 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2023-11-26 21:32:46 +0700 |
| commit | c170365f09a38b0d9b09b74716dacd7a0114cd7d (patch) | |
| tree | 5fe386cfb764a53560a50e814ee670a9ab88dfa8 /lib/http | |
| parent | c64476c10a8a6d331278fdc96f896559d3939f17 (diff) | |
| download | pakakeh.go-c170365f09a38b0d9b09b74716dacd7a0114cd7d.tar.xz | |
lib/http: implement Server-Sent Events (SSE)
For server SSE, we add new type SSEEndpoint, for registering endpoint
that can handle SSE.
For client SSE, we add it in new sub package "sseclient".
Implements: https://todo.sr.ht/~shulhan/share/1
Implements: https://todo.sr.ht/~shulhan/share/2
Signed-off-by: Shulhan <ms@kilabit.info>
Diffstat (limited to 'lib/http')
| -rw-r--r-- | lib/http/http.go | 7 | ||||
| -rw-r--r-- | lib/http/route.go | 29 | ||||
| -rw-r--r-- | lib/http/server.go | 47 | ||||
| -rw-r--r-- | lib/http/sse_endpoint.go | 228 | ||||
| -rw-r--r-- | lib/http/sse_endpoint_test.go | 66 | ||||
| -rw-r--r-- | lib/http/sseclient/event.go | 31 | ||||
| -rw-r--r-- | lib/http/sseclient/sseclient.go | 400 | ||||
| -rw-r--r-- | lib/http/sseclient/sseclient_test.go | 158 |
8 files changed, 959 insertions, 7 deletions
diff --git a/lib/http/http.go b/lib/http/http.go index e6a3254a..8136ffa9 100644 --- a/lib/http/http.go +++ b/lib/http/http.go @@ -16,6 +16,8 @@ // multipart/form-data, application/json with POST or PUT methods in // Client. // - Add support for [HTTP Range] in Server and Client +// - Add support for [Server-Sent Events] (SSE) in Server. +// For client see the sub package [sseclient]. // // # Problems // @@ -201,6 +203,7 @@ // the last one is static path to "y". // // [HTTP Range]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests +// [Server-Sent Events]: https://html.spec.whatwg.org/multipage/server-sent-events.html package http import ( @@ -225,6 +228,7 @@ const ( ContentEncodingDeflate = "deflate" // Using zlib. ContentTypeBinary = "application/octet-stream" + ContentTypeEventStream = `text/event-stream` ContentTypeForm = "application/x-www-form-urlencoded" ContentTypeMultipartByteRanges = `multipart/byteranges` ContentTypeMultipartForm = "multipart/form-data" @@ -241,11 +245,13 @@ const ( HeaderACMaxAge = "Access-Control-Max-Age" HeaderACRequestHeaders = "Access-Control-Request-Headers" HeaderACRequestMethod = "Access-Control-Request-Method" + HeaderAccept = `Accept` HeaderAcceptEncoding = "Accept-Encoding" HeaderAcceptRanges = `Accept-Ranges` HeaderAllow = "Allow" HeaderAuthKeyBearer = "Bearer" HeaderAuthorization = "Authorization" + HeaderCacheControl = `Cache-Control` HeaderContentEncoding = "Content-Encoding" HeaderContentLength = "Content-Length" HeaderContentRange = `Content-Range` @@ -255,6 +261,7 @@ const ( HeaderETag = "Etag" HeaderHost = "Host" HeaderIfNoneMatch = "If-None-Match" + HeaderLastEventID = `Last-Event-ID` HeaderLocation = "Location" HeaderOrigin = "Origin" HeaderRange = `Range` diff --git a/lib/http/route.go b/lib/http/route.go index feaa80cf..2a06e458 100644 --- a/lib/http/route.go +++ b/lib/http/route.go @@ -5,15 +5,26 @@ package http import ( + "path" "strings" ) +// List of kind for route. +const ( + routeKindHttp int = iota // Normal routing. + routeKindSSE // Routing for Server-Sent Events (SSE). +) + // route represent the route to endpoint. type route struct { - endpoint *Endpoint // endpoint of route. - path string // path contains Endpoint's path that has been cleaned up. - nodes []*node // nodes contains sub-path. - nkey int // nkey contains the number of keys in nodes. + endpoint *Endpoint // endpoint of route. + endpointSSE *SSEEndpoint // Endpoint for SSE. + + path string // path contains Endpoint's path that has been cleaned up. + nodes []*node // nodes contains sub-path. + nkey int // nkey contains the number of keys in nodes. + + kind int } // newRoute parse the Endpoint's path, store the key(s) in path if available @@ -69,6 +80,16 @@ func newRoute(ep *Endpoint) (rute *route, err error) { return rute, nil } +// newRouteSSE create and initialize new route for SSE. +func newRouteSSE(ep *SSEEndpoint) (rute *route) { + rute = &route{ + endpointSSE: ep, + path: path.Clean(ep.Path), + kind: routeKindSSE, + } + return rute +} + // isKeyExist will return true if the key already exist in nodes; otherwise it // will return false. func (rute *route) isKeyExist(key string) bool { diff --git a/lib/http/server.go b/lib/http/server.go index 96a2a902..f0f6f409 100644 --- a/lib/http/server.go +++ b/lib/http/server.go @@ -127,6 +127,34 @@ func (srv *Server) RegisterEndpoint(ep *Endpoint) (err error) { return err } +// RegisterSSE register Server-Sent Events endpoint. +// It will return an error if the Call field is not set or +// [ErrEndpointAmbiguous], if the same path is already registered. +func (srv *Server) RegisterSSE(ep *SSEEndpoint) (err error) { + var logp = `RegisterSSE` + + if ep.Call == nil { + return fmt.Errorf(`%s: Call field not set`, logp) + } + + // Check if the same GET path already registered. + var ( + rute *route + exist bool + ) + for _, rute = range srv.routeGets { + _, exist = rute.parse(ep.Path) + if exist { + return fmt.Errorf(`%s: %w`, logp, ErrEndpointAmbiguous) + } + } + + rute = newRouteSSE(ep) + srv.routeGets = append(srv.routeGets, rute) + + return nil +} + // registerDelete register HTTP method DELETE with specific endpoint to handle // it. func (srv *Server) registerDelete(ep *Endpoint) (err error) { @@ -561,12 +589,25 @@ func (srv *Server) HandleFS(res http.ResponseWriter, req *http.Request) { // handleGet handle the GET request by searching the registered route and // calling the endpoint. func (srv *Server) handleGet(res http.ResponseWriter, req *http.Request) { - for _, rute := range srv.routeGets { - vals, ok := rute.parse(req.URL.Path) - if ok { + var ( + rute *route + vals map[string]string + ok bool + ) + for _, rute = range srv.routeGets { + vals, ok = rute.parse(req.URL.Path) + if !ok { + continue + } + if rute.kind == routeKindHttp { rute.endpoint.call(res, req, srv.evals, vals) return } + if rute.kind == routeKindSSE { + rute.endpointSSE.call(res, req, srv.evals, vals) + return + } + // Unknown kind will be handled by HandleFS. } srv.HandleFS(res, req) diff --git a/lib/http/sse_endpoint.go b/lib/http/sse_endpoint.go new file mode 100644 index 00000000..a896c60c --- /dev/null +++ b/lib/http/sse_endpoint.go @@ -0,0 +1,228 @@ +// Copyright 2023, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "time" + + liberrors "github.com/shuLhan/share/lib/errors" +) + +// SSECallback define the handler for Server-Sent Events (SSE). +// +// SSECallback type pass original HTTP request. +// This allow the server to check for header "Last-Event-ID" and/or for +// authentication. +// Remember that "the original Request.Body must not be used" according to +// [http.Hijacker] documentation. +type SSECallback func(sse *SSEEndpoint, req *http.Request) + +// SSEEndpoint endpoint to create Server-Sent Events (SSE) on server. +// +// For creating the SSE client see subpackage [sseclient]. +type SSEEndpoint struct { + bufrw *bufio.ReadWriter + conn net.Conn + + // Path where server accept the request for SSE. + Path string + + // Call handler that will called when request to Path accepted. + Call SSECallback +} + +// WriteEvent write message with event type to client. +// +// The event parameter must not be empty, otherwise it will not be sent. +// +// The msg parameter must not be empty, otherwise it will not be sent +// If msg contains new line character ('\n'), the message will be split into +// multiple "data:". +// +// The id parameter is optional, can be empty. +// +// It will return an error if its failed to write to peer connection. +func (ep *SSEEndpoint) WriteEvent(event, msg, id string) (err error) { + event = strings.TrimSpace(event) + if len(event) == 0 { + return nil + } + if len(msg) == 0 { + return nil + } + + var buf bytes.Buffer + + buf.WriteString(`event:`) + buf.WriteString(event) + buf.WriteByte('\n') + + ep.writeData(&buf, msg, id) + + _, err = ep.bufrw.Write(buf.Bytes()) + if err != nil { + return fmt.Errorf(`WriteMessage: %w`, err) + } + ep.bufrw.Flush() + return nil +} + +// WriteMessage write a message with optional id to client. +// +// The msg parameter must not be empty, otherwise it will not be sent +// If msg contains new line character ('\n'), the message will be split into +// multiple "data:". +// +// The id parameter is optional, can be empty. +// +// It will return an error if its failed to write to peer connection. +func (ep *SSEEndpoint) WriteMessage(msg, id string) (err error) { + if len(msg) == 0 { + return nil + } + + var buf bytes.Buffer + + ep.writeData(&buf, msg, id) + + _, err = ep.bufrw.Write(buf.Bytes()) + if err != nil { + return fmt.Errorf(`WriteMessage: %w`, err) + } + ep.bufrw.Flush() + return nil +} + +// WriteRetry inform user how long they should wait, after disconnect, +// before re-connecting back to server. +// +// The duration must be in millisecond. +func (ep *SSEEndpoint) WriteRetry(retry time.Duration) (err error) { + _, err = fmt.Fprintf(ep.bufrw, "retry:%d\n\n", retry.Milliseconds()) + if err != nil { + return fmt.Errorf(`WriteRetry: %w`, err) + } + ep.bufrw.Flush() + return nil +} + +func (ep *SSEEndpoint) writeData(buf *bytes.Buffer, msg, id string) { + var ( + lines = strings.Split(msg, "\n") + line string + ) + for _, line = range lines { + buf.WriteString(`data:`) + buf.WriteString(line) + buf.WriteByte('\n') + } + if len(id) != 0 { + buf.WriteString(`id:`) + buf.WriteString(id) + buf.WriteByte('\n') + } + buf.WriteByte('\n') +} + +func (ep *SSEEndpoint) call( + res http.ResponseWriter, + req *http.Request, + evaluators []Evaluator, + vals map[string]string, +) { + var err error + + err = req.ParseForm() + if err != nil { + http.Error(res, err.Error(), http.StatusBadRequest) + return + } + + // Fill the form with path binding. + if len(vals) > 0 { + if req.Form == nil { + req.Form = make(url.Values, len(vals)) + } + var k, v string + for k, v = range vals { + if len(k) > 0 && len(v) > 0 { + req.Form.Set(k, v) + } + } + } + + err = ep.doEvals(res, req, evaluators) + if err != nil { + return + } + + err = ep.hijack(res) + if err != nil { + http.Error(res, err.Error(), http.StatusInternalServerError) + return + } + + ep.handshake() + ep.Call(ep, req) + ep.conn.Close() +} + +func (ep *SSEEndpoint) doEvals( + res http.ResponseWriter, + req *http.Request, + evaluators []Evaluator, +) (err error) { + var eval Evaluator + + for _, eval = range evaluators { + err = eval(req, nil) + if err != nil { + var errInternal = &liberrors.E{} + if !errors.As(err, &errInternal) { + errInternal.Code = http.StatusUnprocessableEntity + } + http.Error(res, err.Error(), errInternal.Code) + return err + } + } + return nil +} + +func (ep *SSEEndpoint) hijack(res http.ResponseWriter) (err error) { + var ( + hijack http.Hijacker + ok bool + ) + + hijack, ok = res.(http.Hijacker) + if !ok { + return errors.New(`http.ResponseWriter is not http.Hijacker`) + } + + ep.conn, ep.bufrw, err = hijack.Hijack() + if err != nil { + return err + } + + return nil +} + +// handshake write the last HTTP response to indicate the connection is +// accepted. +func (ep *SSEEndpoint) handshake() { + ep.bufrw.WriteString("HTTP/1.1 200 OK\r\n") + ep.bufrw.WriteString("content-type: text/event-stream\r\n") + ep.bufrw.WriteString("cache-control: no-cache\r\n") + ep.bufrw.WriteString("\r\n") + ep.bufrw.Flush() +} diff --git a/lib/http/sse_endpoint_test.go b/lib/http/sse_endpoint_test.go new file mode 100644 index 00000000..2b70ffe7 --- /dev/null +++ b/lib/http/sse_endpoint_test.go @@ -0,0 +1,66 @@ +// Copyright 2023, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/http" + "testing" + + "github.com/shuLhan/share/lib/test" +) + +func TestSSEEndpoint(t *testing.T) { + var opts = &ServerOptions{ + Address: `127.0.0.1:24168`, + } + + var ( + httpd *Server + err error + ) + httpd, err = NewServer(opts) + if err != nil { + t.Fatal(err) + } + + t.Run(`EmptyCall`, func(tt *testing.T) { + testSSEEndpointEmptyCall(tt, httpd) + }) + + t.Run(`DuplicatePath`, func(tt *testing.T) { + testSSEEndpointDuplicatePath(tt, httpd) + }) +} + +func testSSEEndpointEmptyCall(t *testing.T, httpd *Server) { + var sse = &SSEEndpoint{ + Path: `/sse`, + } + + var err = httpd.RegisterSSE(sse) + + test.Assert(t, `error`, `RegisterSSE: Call field not set`, err.Error()) +} + +func testSSEEndpointDuplicatePath(t *testing.T, httpd *Server) { + var ep = &Endpoint{ + Path: `/sse`, + Call: func(req *EndpointRequest) ([]byte, error) { return nil, nil }, + } + + var err = httpd.RegisterEndpoint(ep) + if err != nil { + t.Fatal(err) + } + + var sse = &SSEEndpoint{ + Path: `/sse`, + Call: func(ep *SSEEndpoint, req *http.Request) {}, + } + + err = httpd.RegisterSSE(sse) + + test.Assert(t, `error`, `RegisterSSE: ambigous endpoint`, err.Error()) +} diff --git a/lib/http/sseclient/event.go b/lib/http/sseclient/event.go new file mode 100644 index 00000000..d191cf32 --- /dev/null +++ b/lib/http/sseclient/event.go @@ -0,0 +1,31 @@ +// Copyright 2023, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sseclient + +// List of system event type. +const ( + // EventTypeOpen is set when connection succesfully established. + // The passed [Event.Data] is empty. + EventTypeOpen = `open` + + // EventTypeMessage is set when client received message from server, + // possibly with new ID. + EventTypeMessage = `message` + + EventTypeError = `error` +) + +// Event contains SSE message from server or client status. +type Event struct { + Type string + Data string + ID string +} + +func (ev *Event) reset(id string) { + ev.Type = EventTypeMessage + ev.Data = `` + ev.ID = id +} diff --git a/lib/http/sseclient/sseclient.go b/lib/http/sseclient/sseclient.go new file mode 100644 index 00000000..9cfeddd2 --- /dev/null +++ b/lib/http/sseclient/sseclient.go @@ -0,0 +1,400 @@ +// Copyright 2023, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sseclient implement HTTP client for Server-Sent Events (SSE). +// +// References, +// - [whatwg.org Server-sent events] +// +// [whatwg.org Server-sent events]: https://html.spec.whatwg.org/multipage/server-sent-events.html +package sseclient + +import ( + "bytes" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/shuLhan/share" + libhttp "github.com/shuLhan/share/lib/http" + libnet "github.com/shuLhan/share/lib/net" +) + +const defTimeout = 10 * time.Second + +// defEventBuffer define maximum event buffered in channel. +const defEventBuffer = 1024 + +// Client for SSE. +// Once the Client filled, user need only to call Connect to start receiving +// message from channel C. +type Client struct { + C <-chan Event + event chan Event + + conn net.Conn + closeq chan struct{} + + // Endpoint define the HTTP server URL to connect. + Endpoint string + + // LastEventID define the last event ID to be sent during handshake. + // Once the handshake success, this field will be reset and may set + // with next ID from server. + // This field is optional. + LastEventID string + + // Timeout define the read and write timeout when reading and + // writing from/to server. + // This field is optional default to 10 seconds. + Timeout time.Duration + + retry time.Duration + + // Insecure allow connect to HTTPS Endpoint with invalid + // certificate. + Insecure bool +} + +// Close the connection and release all resources. +func (cl *Client) Close() (err error) { + // Close the connection, wait until it catched by consume goroutine. + if cl.conn != nil { + err = cl.conn.Close() + select { + case cl.closeq <- struct{}{}: + // Tell the consume goroutine we initiate the close. + cl.conn = nil + default: + // The consume goroutine may already end or end at + // the same time. + } + } + return err +} + +// Connect to server and start consume the message and propagate to each +// registered handlers. +// +// The header parameter define custom, optional HTTP header to be sent +// during handshake. +// The following header cannot be set: Host, User-Agent, and Accept. +func (cl *Client) Connect(header http.Header) (err error) { + var ( + logp = `Connect` + serverUrl *url.URL + ) + + serverUrl, err = cl.init() + if err != nil { + return fmt.Errorf(`%s: %w`, logp, err) + } + + err = cl.dial(serverUrl) + if err != nil { + return fmt.Errorf(`%s: %w`, logp, err) + } + + var packet []byte + + packet, err = cl.handshake(serverUrl, header) + if err != nil { + return fmt.Errorf(`%s: %w`, logp, err) + } + + // Reset the ID to store the ID from server. + cl.LastEventID = `` + + select { + case cl.event <- Event{Type: EventTypeOpen}: + default: + } + + if len(packet) != 0 { + // The HTTP response may contains events in the body, + // consume it. + cl.parseEvent(packet) + } + + go cl.consume() + + return nil +} + +// init validate and set default field values. +func (cl *Client) init() (serverUrl *url.URL, err error) { + serverUrl, err = url.Parse(cl.Endpoint) + if err != nil { + return nil, err + } + + var host, port string + + host, port, err = net.SplitHostPort(serverUrl.Host) + if err != nil { + return nil, err + } + if len(port) == 0 { + switch serverUrl.Scheme { + case `http`: + port = `80` + case `https`: + port = `443` + default: + return nil, fmt.Errorf(`unknown scheme %q`, serverUrl.Scheme) + } + } + serverUrl.Host = net.JoinHostPort(host, port) + + if cl.Timeout <= 0 { + cl.Timeout = defTimeout + } + + cl.event = make(chan Event, defEventBuffer) + cl.C = cl.event + cl.closeq = make(chan struct{}) + + return serverUrl, nil +} + +func (cl *Client) dial(serverUrl *url.URL) (err error) { + if serverUrl.Scheme == `https` { + var tlsConfig = &tls.Config{ + InsecureSkipVerify: cl.Insecure, + } + cl.conn, err = tls.Dial(`tcp`, serverUrl.Host, tlsConfig) + } else { + cl.conn, err = net.Dial(`tcp`, serverUrl.Host) + } + if err != nil { + return err + } + return nil +} + +// handshake send the HTTP request and check for the response. +// The response must be HTTP status code 200 with Content-Type +// "text/event-stream". +// +// If the response is not empty, it contains event message, return it. +func (cl *Client) handshake(serverUrl *url.URL, header http.Header) (packet []byte, err error) { + err = cl.handshakeRequest(serverUrl, header) + if err != nil { + return nil, err + } + + packet, err = libnet.Read(cl.conn, 0, cl.Timeout) + if err != nil { + return nil, err + } + + var httpRes *http.Response + + httpRes, packet, err = libhttp.ParseResponseHeader(packet) + if err != nil { + return nil, err + } + + if httpRes.StatusCode != http.StatusOK { + return nil, fmt.Errorf(`handshake failed with response status %q`, httpRes.Status) + } + + var contentType = httpRes.Header.Get(libhttp.HeaderContentType) + if contentType != libhttp.ContentTypeEventStream { + return nil, fmt.Errorf(`handshake failed with unknown Content-Type %q`, contentType) + } + + return packet, nil +} + +func (cl *Client) handshakeRequest(serverUrl *url.URL, header http.Header) (err error) { + var buf bytes.Buffer + + fmt.Fprintf(&buf, `GET %s`, serverUrl.Path) + if len(serverUrl.RawQuery) != 0 { + buf.WriteByte('?') + buf.WriteString(serverUrl.RawQuery) + } + buf.WriteString(" HTTP/1.1\r\n") + + // Write the known values to prevent user overwrite our default + // values. + + if header == nil { + header = http.Header{} + } + header.Set(libhttp.HeaderHost, serverUrl.Host) + header.Set(libhttp.HeaderUserAgent, `libhttp/`+share.Version) + header.Set(libhttp.HeaderAccept, libhttp.ContentTypeEventStream) + if len(cl.LastEventID) != 0 { + header.Set(libhttp.HeaderLastEventID, cl.LastEventID) + } + + var ( + hkey string + hvals []string + val string + ) + for hkey, hvals = range header { + if len(hvals) == 0 { + continue + } + if len(hvals) == 1 { + val = hvals[0] + } else { + val = strings.Join(hvals, `,`) + } + fmt.Fprintf(&buf, "%s: %s\r\n", hkey, val) + } + buf.WriteString("\r\n") + + var deadline = time.Now().Add(cl.Timeout) + + cl.conn.SetWriteDeadline(deadline) + + var ( + buflen = buf.Len() + n int + ) + + n, err = cl.conn.Write(buf.Bytes()) + if err != nil { + return err + } + if n != buflen { + return fmt.Errorf(`handshake write error, %d out of %d`, n, buflen) + } + return nil +} + +func (cl *Client) consume() { + var ( + data []byte + err error + ) + for { + data, err = libnet.Read(cl.conn, 0, cl.Timeout) + if err != nil { + // TODO: retry? + select { + case <-cl.closeq: + // User call Close. + default: + // Peer initiated close or connection error. + // At the same time user may also call + // Close, to prevent data race, let the user + // clear it out. + } + return + } + cl.parseEvent(data) + } +} + +// parseEvent parse the raw event and publish it when ready. +func (cl *Client) parseEvent(raw []byte) { + if len(raw) == 0 { + return + } + + // Normalize the line ending into "\n" only. + var lineEnd = []byte{'\n'} + raw = bytes.ReplaceAll(raw, []byte{'\r', '\n'}, lineEnd) + raw = bytes.ReplaceAll(raw, []byte{'\r'}, lineEnd) + + var ( + fieldSep = []byte{':'} + lines = bytes.Split(raw, lineEnd) + + ev Event + data bytes.Buffer + line []byte + fname []byte + fval []byte + err error + + // counter count each passing "data:" event. + // When receiving empty line, the counter will reset to 0. + counter int + ) + + ev.reset(cl.LastEventID) + + for _, line = range lines { + if len(line) == 0 { + // An empty line trigger dispatching the message. + if counter == 0 { + // Skip continuous empty line. + continue + } + + ev.Data = data.String() + if len(ev.Data) != 0 { + select { + case cl.event <- ev: + default: + } + data.Reset() + if ev.ID != cl.LastEventID { + // Only set LastEventID if event + // message is complete. + cl.LastEventID = ev.ID + } + } + ev.reset(cl.LastEventID) + counter = 0 + continue + } + + if line[0] == ':' { + continue + } + + // ABNF syntax for field: + // + // 1*name-char [ colon [ space ] *any-char ] end-of-line + // + // - There is no space in field name. + // So, field line like "event :E" will be ignored. + // - There is only one space allowed after colon. + + fname, fval, _ = bytes.Cut(line, fieldSep) + + if len(fval) != 0 { + if fval[0] == ' ' { + fval = fval[1:] + } + } + + switch string(fname) { + case `event`: + fval = bytes.TrimSpace(fval) + if len(fval) != 0 { + ev.Type = string(fval) + } + case `data`: + if counter > 0 { + data.WriteByte('\n') + } + data.Write(fval) + counter++ + case `id`: + ev.ID = string(fval) + case `retry`: + var retry int64 + retry, err = strconv.ParseInt(string(fval), 10, 64) + if err == nil { + cl.retry = time.Duration(retry) * time.Millisecond + } + default: + // Ignore the field. + } + } + // Ignore incomplete event that does end with empty line. +} diff --git a/lib/http/sseclient/sseclient_test.go b/lib/http/sseclient/sseclient_test.go new file mode 100644 index 00000000..45333d90 --- /dev/null +++ b/lib/http/sseclient/sseclient_test.go @@ -0,0 +1,158 @@ +// Copyright 2023, Shulhan <ms@kilabit.info>. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sseclient + +import ( + "fmt" + "math/rand" + "net/http" + "testing" + "time" + + libhttp "github.com/shuLhan/share/lib/http" + libnet "github.com/shuLhan/share/lib/net" + "github.com/shuLhan/share/lib/test" +) + +func TestClient(t *testing.T) { + var expEvents = []Event{{ + Type: EventTypeOpen, + }, { + Type: EventTypeMessage, + Data: `Hello, world`, + }, { + Type: EventTypeMessage, + Data: "Hello\nmulti\nline\nworld", + }, { + Type: `join`, + Data: `John join the event`, + }, { + Type: `join`, + Data: `Jane join the event`, + ID: `1`, + }} + + var expq = make(chan Event) + + var servercb = func(ep *libhttp.SSEEndpoint, _ *http.Request) { + var ( + ev Event + ewrite error + x int + ) + for x, ev = range expEvents { + switch ev.Type { + case EventTypeOpen: + // NO-OP, this is sent during connect. + case EventTypeError: + // NO-OP, this is sent when client has an + // error. + case EventTypeMessage: + ewrite = ep.WriteMessage(ev.Data, ev.ID) + if ewrite != nil { + t.Fatal(`WriteMessage`, ewrite) + } + default: + // Named type. + ewrite = ep.WriteEvent(ev.Type, ev.Data, ev.ID) + if ewrite != nil { + t.Fatalf(`WriteEvent #%d: %s`, x, ewrite) + } + } + expq <- ev + } + } + + var ( + serverAddress string + err error + ) + + serverAddress, err = testRunSSEServer(t, servercb) + if err != nil { + t.Fatal(`testRunSSEServer:`, err) + } + + var cl = Client{ + Endpoint: fmt.Sprintf(`http://%s/sse`, serverAddress), + } + + err = cl.Connect(nil) + if err != nil { + t.Fatal(`Connect:`, err) + } + + var ( + timeout = 3 * time.Second + ticker = time.NewTicker(timeout) + + expEvent Event + gotEvent Event + tag string + x int + ) + for x, expEvent = range expEvents { + tag = fmt.Sprintf(`expEvent #%d`, x) + select { + case <-ticker.C: + t.Fatalf(`%s: timeout`, tag) + + case gotEvent = <-cl.C: + expEvent = <-expq + test.Assert(t, tag, expEvent, gotEvent) + } + ticker.Reset(timeout) + } + + _ = cl.Close() + + test.Assert(t, `LastEventID`, cl.LastEventID, `1`) +} + +// testGenerateAddress generate random port for server address. +func testGenerateAddress() (addr string) { + var port = rand.Int() % 60000 + if port < 1024 { + port += 1024 + } + return fmt.Sprintf(`127.0.0.1:%d`, port) +} + +func testRunSSEServer(t *testing.T, cb libhttp.SSECallback) (address string, err error) { + address = testGenerateAddress() + + var ( + serverOpts = &libhttp.ServerOptions{ + Address: address, + } + srv *libhttp.Server + ) + srv, err = libhttp.NewServer(serverOpts) + if err != nil { + return ``, err + } + + var sse = &libhttp.SSEEndpoint{ + Path: `/sse`, + Call: cb, + } + + err = srv.RegisterSSE(sse) + if err != nil { + return ``, err + } + + go srv.Start() + + err = libnet.WaitAlive(`tcp`, address, 1*time.Second) + if err != nil { + // Server may not run due to address already in use. + t.Skip(err) + } + + t.Cleanup(func() { srv.Stop(1 * time.Second) }) + + return address, nil +} |
