aboutsummaryrefslogtreecommitdiff
path: root/lib/websocket
diff options
context:
space:
mode:
authorShulhan <ms@kilabit.info>2021-06-21 18:10:05 +0700
committerShulhan <ms@kilabit.info>2021-06-21 18:11:07 +0700
commitd996cc8a4e629a1044b7d9372ee976e3ee3206f7 (patch)
tree95b71bb273e4820f448bdc06dcc9740095182b87 /lib/websocket
parent6993647b706d488886f9b0665e73d2fc2c251cb2 (diff)
downloadpakakeh.go-d996cc8a4e629a1044b7d9372ee976e3ee3206f7.tar.xz
websocket: store all the handshake headers to Handshake.Header
Previously only non-required headers are stored in the Handshake Header field, while the required header value stored on their respective fields. This commit store all request header key and values into the Header field.
Diffstat (limited to 'lib/websocket')
-rw-r--r--lib/websocket/handshake.go32
1 files changed, 15 insertions, 17 deletions
diff --git a/lib/websocket/handshake.go b/lib/websocket/handshake.go
index 42674f84..2a3f56a3 100644
--- a/lib/websocket/handshake.go
+++ b/lib/websocket/handshake.go
@@ -350,6 +350,8 @@ func (h *Handshake) parse() (err error) {
k, v []byte
)
+ h.Header = make(http.Header)
+
for h.start < len(h.raw) {
k, v, err = h.parseHeader()
if err != nil {
@@ -359,8 +361,12 @@ func (h *Handshake) parse() (err error) {
break
}
- switch {
- case bytes.Equal(k, []byte(_hdrKeyHost)):
+ headerKey := string(bytes.TrimSpace(k))
+ headerValue := string(bytes.TrimSpace(v))
+ h.Header.Set(headerKey, headerValue)
+
+ switch headerKey {
+ case _hdrKeyHost:
if h.headerFlags&_hdrFlagHost == _hdrFlagHost {
return ErrInvalidHeaderHost
}
@@ -370,7 +376,7 @@ func (h *Handshake) parse() (err error) {
h.Host = v
h.headerFlags |= _hdrFlagHost
- case bytes.Equal(k, []byte(_hdrKeyConnection)):
+ case _hdrKeyConnection:
if h.headerFlags&_hdrFlagConn == _hdrFlagConn {
return ErrInvalidHeaderConn
}
@@ -379,7 +385,7 @@ func (h *Handshake) parse() (err error) {
}
h.headerFlags |= _hdrFlagConn
- case bytes.Equal(k, []byte(_hdrKeyUpgrade)):
+ case _hdrKeyUpgrade:
if h.headerFlags&_hdrFlagUpgrade == _hdrFlagUpgrade {
return ErrInvalidHeaderUpgrade
}
@@ -388,7 +394,7 @@ func (h *Handshake) parse() (err error) {
}
h.headerFlags |= _hdrFlagUpgrade
- case bytes.Equal(k, []byte(_hdrKeyWSKey)):
+ case _hdrKeyWSKey:
if h.headerFlags&_hdrFlagWSKey == _hdrFlagWSKey {
return ErrInvalidHeaderWSKey
}
@@ -401,39 +407,31 @@ func (h *Handshake) parse() (err error) {
}
h.headerFlags |= _hdrFlagWSKey
- case bytes.Equal(k, []byte(_hdrKeyWSVersion)):
+ case _hdrKeyWSVersion:
if h.headerFlags&_hdrFlagWSVersion == _hdrFlagWSVersion {
return ErrInvalidHeaderWSVersion
}
if len(v) == 0 {
return ErrInvalidHeaderWSVersion
}
- if !bytes.Equal(v, []byte(_hdrValWSVersion)) {
+ if headerValue != _hdrValWSVersion {
return ErrUnsupportedWSVersion
}
h.headerFlags |= _hdrFlagWSVersion
- case bytes.Equal(k, []byte(_hdrKeyWSExtensions)):
+ case _hdrKeyWSExtensions:
if h.headerFlags&_hdrFlagWSExtensions == _hdrFlagWSExtensions {
return ErrInvalidHeaderWSExtensions
}
h.Extensions = v
h.headerFlags |= _hdrFlagWSExtensions
- case bytes.Equal(k, []byte(_hdrKeyWSProtocol)):
+ case _hdrKeyWSProtocol:
if h.headerFlags&_hdrFlagWSProtocol == _hdrFlagWSProtocol {
return ErrInvalidHeaderWSProtocol
}
h.Protocol = v
h.headerFlags |= _hdrFlagWSProtocol
-
- default:
- if h.Header == nil {
- h.Header = make(http.Header)
- h.Header.Add(string(k), string(v))
- } else {
- h.Header.Add(string(k), string(v))
- }
}
}