aboutsummaryrefslogtreecommitdiff
path: root/src/net/http
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@golang.org>2020-10-06 10:53:11 -0700
committerBrad Fitzpatrick <bradfitz@golang.org>2020-10-06 22:02:30 +0000
commit930fa890c9b6a75700bda3dc4043de81350749ea (patch)
treecd91c9ffd31d6440991b10b446ecf8e9ffcab1cf /src/net/http
parentdb428ad7b61ed757671162054252b4326045e96c (diff)
downloadgo-930fa890c9b6a75700bda3dc4043de81350749ea.tar.xz
net/http: add Transport.GetProxyConnectHeader
Fixes golang/go#41048 Change-Id: I38e01605bffb6f85100c098051b0c416dd77f261 Reviewed-on: https://go-review.googlesource.com/c/go/+/259917 Trust: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Go Bot <gobot@golang.org> Reviewed-by: Damien Neil <dneil@google.com>
Diffstat (limited to 'src/net/http')
-rw-r--r--src/net/http/transport.go23
-rw-r--r--src/net/http/transport_test.go52
2 files changed, 74 insertions, 1 deletions
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index b97c4268b5..4546166430 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -240,8 +240,18 @@ type Transport struct {
// ProxyConnectHeader optionally specifies headers to send to
// proxies during CONNECT requests.
+ // To set the header dynamically, see GetProxyConnectHeader.
ProxyConnectHeader Header
+ // GetProxyConnectHeader optionally specifies a func to return
+ // headers to send to proxyURL during a CONNECT request to the
+ // ip:port target.
+ // If it returns an error, the Transport's RoundTrip fails with
+ // that error. It can return (nil, nil) to not add headers.
+ // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is
+ // ignored.
+ GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error)
+
// MaxResponseHeaderBytes specifies a limit on how many
// response bytes are allowed in the server's response
// header.
@@ -313,6 +323,7 @@ func (t *Transport) Clone() *Transport {
ResponseHeaderTimeout: t.ResponseHeaderTimeout,
ExpectContinueTimeout: t.ExpectContinueTimeout,
ProxyConnectHeader: t.ProxyConnectHeader.Clone(),
+ GetProxyConnectHeader: t.GetProxyConnectHeader,
MaxResponseHeaderBytes: t.MaxResponseHeaderBytes,
ForceAttemptHTTP2: t.ForceAttemptHTTP2,
WriteBufferSize: t.WriteBufferSize,
@@ -1623,7 +1634,17 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
}
case cm.targetScheme == "https":
conn := pconn.conn
- hdr := t.ProxyConnectHeader
+ var hdr Header
+ if t.GetProxyConnectHeader != nil {
+ var err error
+ hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr)
+ if err != nil {
+ conn.Close()
+ return nil, err
+ }
+ } else {
+ hdr = t.ProxyConnectHeader
+ }
if hdr == nil {
hdr = make(Header)
}
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index f4b7623630..a1c9e822b4 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -5174,6 +5174,57 @@ func TestTransportProxyConnectHeader(t *testing.T) {
}
}
+func TestTransportProxyGetConnectHeader(t *testing.T) {
+ defer afterTest(t)
+ reqc := make(chan *Request, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", r.Method)
+ }
+ reqc <- r
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ c.Close()
+ }))
+ defer ts.Close()
+
+ c := ts.Client()
+ c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
+ return url.Parse(ts.URL)
+ }
+ // These should be ignored:
+ c.Transport.(*Transport).ProxyConnectHeader = Header{
+ "User-Agent": {"foo"},
+ "Other": {"bar"},
+ }
+ c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
+ return Header{
+ "User-Agent": {"foo2"},
+ "Other": {"bar2"},
+ }, nil
+ }
+
+ res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
+ if err == nil {
+ res.Body.Close()
+ t.Errorf("unexpected success")
+ }
+ select {
+ case <-time.After(3 * time.Second):
+ t.Fatal("timeout")
+ case r := <-reqc:
+ if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
+ t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
+ }
+ if got, want := r.Header.Get("Other"), "bar2"; got != want {
+ t.Errorf("CONNECT request Other = %q; want %q", got, want)
+ }
+ }
+}
+
var errFakeRoundTrip = errors.New("fake roundtrip")
type funcRoundTripper func()
@@ -5842,6 +5893,7 @@ func TestTransportClone(t *testing.T) {
ResponseHeaderTimeout: time.Second,
ExpectContinueTimeout: time.Second,
ProxyConnectHeader: Header{},
+ GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
MaxResponseHeaderBytes: 1,
ForceAttemptHTTP2: true,
TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{