diff options
Diffstat (limited to 'src/context')
| -rw-r--r-- | src/context/context.go | 10 | ||||
| -rw-r--r-- | src/context/context_test.go | 65 | ||||
| -rw-r--r-- | src/context/withtimeout_test.go | 4 |
3 files changed, 51 insertions, 28 deletions
diff --git a/src/context/context.go b/src/context/context.go index 21dc8676bf..da294b1292 100644 --- a/src/context/context.go +++ b/src/context/context.go @@ -39,6 +39,7 @@ package context import ( "errors" "fmt" + "reflect" "sync" "time" ) @@ -66,7 +67,7 @@ type Context interface { // // // Stream generates values with DoSomething and sends them to out // // until DoSomething returns an error or ctx.Done is closed. - // func Stream(ctx context.Context, out <-chan Value) error { + // func Stream(ctx context.Context, out chan<- Value) error { // for { // v, err := DoSomething(ctx) // if err != nil { @@ -424,7 +425,12 @@ func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { // // Use context Values only for request-scoped data that transits processes and // APIs, not for passing optional parameters to functions. -func WithValue(parent Context, key interface{}, val interface{}) Context { +// +// The provided key must be comparable. +func WithValue(parent Context, key, val interface{}) Context { + if !reflect.TypeOf(key).Comparable() { + panic("key is not comparable") + } return &valueCtx{parent, key, val} } diff --git a/src/context/context_test.go b/src/context/context_test.go index 05345fc5e5..aa26161d2b 100644 --- a/src/context/context_test.go +++ b/src/context/context_test.go @@ -229,55 +229,55 @@ func TestChildFinishesFirst(t *testing.T) { } } -func testDeadline(c Context, wait time.Duration, t *testing.T) { +func testDeadline(c Context, name string, failAfter time.Duration, t *testing.T) { select { - case <-time.After(wait): - t.Fatalf("context should have timed out") + case <-time.After(failAfter): + t.Fatalf("%s: context should have timed out", name) case <-c.Done(): } if e := c.Err(); e != DeadlineExceeded { - t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded) + t.Errorf("%s: c.Err() == %v; want %v", name, e, DeadlineExceeded) } } func TestDeadline(t *testing.T) { - c, _ := WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + c, _ := WithDeadline(Background(), time.Now().Add(50*time.Millisecond)) if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { t.Errorf("c.String() = %q want prefix %q", got, prefix) } - testDeadline(c, 200*time.Millisecond, t) + testDeadline(c, "WithDeadline", time.Second, t) - c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + c, _ = WithDeadline(Background(), time.Now().Add(50*time.Millisecond)) o := otherContext{c} - testDeadline(o, 200*time.Millisecond, t) + testDeadline(o, "WithDeadline+otherContext", time.Second, t) - c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + c, _ = WithDeadline(Background(), time.Now().Add(50*time.Millisecond)) o = otherContext{c} - c, _ = WithDeadline(o, time.Now().Add(300*time.Millisecond)) - testDeadline(c, 200*time.Millisecond, t) + c, _ = WithDeadline(o, time.Now().Add(4*time.Second)) + testDeadline(c, "WithDeadline+otherContext+WithDeadline", 2*time.Second, t) } func TestTimeout(t *testing.T) { - c, _ := WithTimeout(Background(), 100*time.Millisecond) + c, _ := WithTimeout(Background(), 50*time.Millisecond) if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { t.Errorf("c.String() = %q want prefix %q", got, prefix) } - testDeadline(c, 200*time.Millisecond, t) + testDeadline(c, "WithTimeout", time.Second, t) - c, _ = WithTimeout(Background(), 100*time.Millisecond) + c, _ = WithTimeout(Background(), 50*time.Millisecond) o := otherContext{c} - testDeadline(o, 200*time.Millisecond, t) + testDeadline(o, "WithTimeout+otherContext", time.Second, t) - c, _ = WithTimeout(Background(), 100*time.Millisecond) + c, _ = WithTimeout(Background(), 50*time.Millisecond) o = otherContext{c} - c, _ = WithTimeout(o, 300*time.Millisecond) - testDeadline(c, 200*time.Millisecond, t) + c, _ = WithTimeout(o, 3*time.Second) + testDeadline(c, "WithTimeout+otherContext+WithTimeout", 2*time.Second, t) } func TestCanceledTimeout(t *testing.T) { - c, _ := WithTimeout(Background(), 200*time.Millisecond) + c, _ := WithTimeout(Background(), time.Second) o := otherContext{c} - c, cancel := WithTimeout(o, 400*time.Millisecond) + c, cancel := WithTimeout(o, 2*time.Second) cancel() time.Sleep(100 * time.Millisecond) // let cancelation propagate select { @@ -388,9 +388,9 @@ func TestAllocs(t *testing.T) { gccgoLimit: 8, }, { - desc: "WithTimeout(bg, 100*time.Millisecond)", + desc: "WithTimeout(bg, 5*time.Millisecond)", f: func() { - c, cancel := WithTimeout(bg, 100*time.Millisecond) + c, cancel := WithTimeout(bg, 5*time.Millisecond) cancel() <-c.Done() }, @@ -404,7 +404,11 @@ func TestAllocs(t *testing.T) { // TOOD(iant): Remove this when gccgo does do escape analysis. limit = test.gccgoLimit } - if n := testing.AllocsPerRun(100, test.f); n > limit { + numRuns := 100 + if testing.Short() { + numRuns = 10 + } + if n := testing.AllocsPerRun(numRuns, test.f); n > limit { t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit)) } } @@ -536,7 +540,7 @@ func testLayers(t *testing.T, seed int64, testTimeout bool) { if testTimeout { select { case <-ctx.Done(): - case <-time.After(timeout + 100*time.Millisecond): + case <-time.After(timeout + time.Second): errorf("ctx should have timed out") } checkValues("after timeout") @@ -573,3 +577,16 @@ func TestCancelRemoves(t *testing.T) { cancel() checkChildren("after cancelling WithTimeout child", ctx, 0) } + +func TestWithValueChecksKey(t *testing.T) { + panicVal := recoveredValue(func() { WithValue(Background(), []byte("foo"), "bar") }) + if panicVal == nil { + t.Error("expected panic") + } +} + +func recoveredValue(fn func()) (v interface{}) { + defer func() { v = recover() }() + fn() + return +} diff --git a/src/context/withtimeout_test.go b/src/context/withtimeout_test.go index 3ab6fc347f..2aea303bed 100644 --- a/src/context/withtimeout_test.go +++ b/src/context/withtimeout_test.go @@ -13,9 +13,9 @@ import ( func ExampleWithTimeout() { // Pass a context with a timeout to tell a blocking function that it // should abandon its work after the timeout elapses. - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) select { - case <-time.After(200 * time.Millisecond): + case <-time.After(1 * time.Second): fmt.Println("overslept") case <-ctx.Done(): fmt.Println(ctx.Err()) // prints "context deadline exceeded" |
