diff options
| author | Shulhan <ms@kilabit.info> | 2021-06-21 18:10:05 +0700 |
|---|---|---|
| committer | Shulhan <ms@kilabit.info> | 2021-06-21 18:11:07 +0700 |
| commit | d996cc8a4e629a1044b7d9372ee976e3ee3206f7 (patch) | |
| tree | 95b71bb273e4820f448bdc06dcc9740095182b87 /lib/websocket | |
| parent | 6993647b706d488886f9b0665e73d2fc2c251cb2 (diff) | |
| download | pakakeh.go-d996cc8a4e629a1044b7d9372ee976e3ee3206f7.tar.xz | |
websocket: store all the handshake headers to Handshake.Header
Previously only non-required headers are stored in the Handshake Header
field, while the required header value stored on their respective fields.
This commit store all request header key and values into the Header field.
Diffstat (limited to 'lib/websocket')
| -rw-r--r-- | lib/websocket/handshake.go | 32 |
1 files changed, 15 insertions, 17 deletions
diff --git a/lib/websocket/handshake.go b/lib/websocket/handshake.go index 42674f84..2a3f56a3 100644 --- a/lib/websocket/handshake.go +++ b/lib/websocket/handshake.go @@ -350,6 +350,8 @@ func (h *Handshake) parse() (err error) { k, v []byte ) + h.Header = make(http.Header) + for h.start < len(h.raw) { k, v, err = h.parseHeader() if err != nil { @@ -359,8 +361,12 @@ func (h *Handshake) parse() (err error) { break } - switch { - case bytes.Equal(k, []byte(_hdrKeyHost)): + headerKey := string(bytes.TrimSpace(k)) + headerValue := string(bytes.TrimSpace(v)) + h.Header.Set(headerKey, headerValue) + + switch headerKey { + case _hdrKeyHost: if h.headerFlags&_hdrFlagHost == _hdrFlagHost { return ErrInvalidHeaderHost } @@ -370,7 +376,7 @@ func (h *Handshake) parse() (err error) { h.Host = v h.headerFlags |= _hdrFlagHost - case bytes.Equal(k, []byte(_hdrKeyConnection)): + case _hdrKeyConnection: if h.headerFlags&_hdrFlagConn == _hdrFlagConn { return ErrInvalidHeaderConn } @@ -379,7 +385,7 @@ func (h *Handshake) parse() (err error) { } h.headerFlags |= _hdrFlagConn - case bytes.Equal(k, []byte(_hdrKeyUpgrade)): + case _hdrKeyUpgrade: if h.headerFlags&_hdrFlagUpgrade == _hdrFlagUpgrade { return ErrInvalidHeaderUpgrade } @@ -388,7 +394,7 @@ func (h *Handshake) parse() (err error) { } h.headerFlags |= _hdrFlagUpgrade - case bytes.Equal(k, []byte(_hdrKeyWSKey)): + case _hdrKeyWSKey: if h.headerFlags&_hdrFlagWSKey == _hdrFlagWSKey { return ErrInvalidHeaderWSKey } @@ -401,39 +407,31 @@ func (h *Handshake) parse() (err error) { } h.headerFlags |= _hdrFlagWSKey - case bytes.Equal(k, []byte(_hdrKeyWSVersion)): + case _hdrKeyWSVersion: if h.headerFlags&_hdrFlagWSVersion == _hdrFlagWSVersion { return ErrInvalidHeaderWSVersion } if len(v) == 0 { return ErrInvalidHeaderWSVersion } - if !bytes.Equal(v, []byte(_hdrValWSVersion)) { + if headerValue != _hdrValWSVersion { return ErrUnsupportedWSVersion } h.headerFlags |= _hdrFlagWSVersion - case bytes.Equal(k, []byte(_hdrKeyWSExtensions)): + case _hdrKeyWSExtensions: if h.headerFlags&_hdrFlagWSExtensions == _hdrFlagWSExtensions { return ErrInvalidHeaderWSExtensions } h.Extensions = v h.headerFlags |= _hdrFlagWSExtensions - case bytes.Equal(k, []byte(_hdrKeyWSProtocol)): + case _hdrKeyWSProtocol: if h.headerFlags&_hdrFlagWSProtocol == _hdrFlagWSProtocol { return ErrInvalidHeaderWSProtocol } h.Protocol = v h.headerFlags |= _hdrFlagWSProtocol - - default: - if h.Header == nil { - h.Header = make(http.Header) - h.Header.Add(string(k), string(v)) - } else { - h.Header.Add(string(k), string(v)) - } } } |
