diff options
Diffstat (limited to 'src/net/http')
| -rw-r--r-- | src/net/http/serve_test.go | 47 | ||||
| -rw-r--r-- | src/net/http/server.go | 12 |
2 files changed, 59 insertions, 0 deletions
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index c603c201d5..5d2a29a6fc 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -1645,6 +1645,53 @@ func testTLSServer(t *testing.T, mode testMode) { } } +type fakeConnectionStateConn struct { + net.Conn +} + +func (fcsc *fakeConnectionStateConn) ConnectionState() tls.ConnectionState { + return tls.ConnectionState{ + ServerName: "example.com", + } +} + +func TestTLSServerWithoutTLSConn(t *testing.T) { + //set up + pr, pw := net.Pipe() + c := make(chan int) + listener := &oneConnListener{&fakeConnectionStateConn{pr}} + server := &Server{ + Handler: HandlerFunc(func(writer ResponseWriter, request *Request) { + if request.TLS == nil { + t.Fatal("request.TLS is nil, expected not nil") + } + if request.TLS.ServerName != "example.com" { + t.Fatalf("request.TLS.ServerName is %s, expected %s", request.TLS.ServerName, "example.com") + } + writer.Header().Set("X-TLS-ServerName", "example.com") + }), + } + + // write request and read response + go func() { + req, _ := NewRequest(MethodGet, "https://example.com", nil) + req.Write(pw) + + resp, _ := ReadResponse(bufio.NewReader(pw), req) + if hdr := resp.Header.Get("X-TLS-ServerName"); hdr != "example.com" { + t.Errorf("response header X-TLS-ServerName is %s, expected %s", hdr, "example.com") + } + close(c) + pw.Close() + }() + + server.Serve(listener) + + // oneConnListener returns error after one accept, wait util response is read + <-c + pr.Close() +} + func TestServeTLS(t *testing.T) { CondSkipHTTP2(t) // Not parallel: uses global test hooks. diff --git a/src/net/http/server.go b/src/net/http/server.go index 49a9d30207..f2bedb7d6a 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -1924,6 +1924,10 @@ func isCommonNetReadError(err error) bool { return false } +type connectionStater interface { + ConnectionState() tls.ConnectionState +} + // Serve a new connection. func (c *conn) serve(ctx context.Context) { if ra := c.rwc.RemoteAddr(); ra != nil { @@ -1996,6 +2000,14 @@ func (c *conn) serve(ctx context.Context) { // HTTP/1.x from here on. + // Set Request.TLS if the conn is not a *tls.Conn, but implements ConnectionState. + if c.tlsState == nil { + if tc, ok := c.rwc.(connectionStater); ok { + c.tlsState = new(tls.ConnectionState) + *c.tlsState = tc.ConnectionState() + } + } + ctx, cancelCtx := context.WithCancel(ctx) c.cancelCtx = cancelCtx defer cancelCtx() |
