diff options
| -rw-r--r-- | src/cmd/compile/internal/test/switch_test.go | 132 | ||||
| -rw-r--r-- | src/cmd/compile/internal/test/testdata/ctl_test.go | 19 | ||||
| -rw-r--r-- | src/cmd/compile/internal/walk/switch.go | 327 | ||||
| -rw-r--r-- | test/codegen/switch.go | 32 |
4 files changed, 508 insertions, 2 deletions
diff --git a/src/cmd/compile/internal/test/switch_test.go b/src/cmd/compile/internal/test/switch_test.go index 1d12361cbb..0442cdc8fc 100644 --- a/src/cmd/compile/internal/test/switch_test.go +++ b/src/cmd/compile/internal/test/switch_test.go @@ -294,3 +294,135 @@ func (r rng) next(predictable bool) rng { func (r rng) value() uint64 { return uint64(r) } + +// Benchmarks for switch-to-lookup-table optimization. +// These use functions that return constants, which is the pattern +// the lookup table optimization targets. + +//go:noinline +func switchLookup8(x int) int { + switch x { + case 0: + return 1 + case 1: + return 2 + case 2: + return 3 + case 3: + return 5 + case 4: + return 8 + case 5: + return 13 + case 6: + return 21 + case 7: + return 34 + default: + return 0 + } +} + +//go:noinline +func switchLookup32(x int) int { + switch x { + case 0: + return 10 + case 1: + return 20 + case 2: + return 30 + case 3: + return 40 + case 4: + return 50 + case 5: + return 60 + case 6: + return 70 + case 7: + return 80 + case 8: + return 90 + case 9: + return 100 + case 10: + return 110 + case 11: + return 120 + case 12: + return 130 + case 13: + return 140 + case 14: + return 150 + case 15: + return 160 + case 16: + return 170 + case 17: + return 180 + case 18: + return 190 + case 19: + return 200 + case 20: + return 210 + case 21: + return 220 + case 22: + return 230 + case 23: + return 240 + case 24: + return 250 + case 25: + return 260 + case 26: + return 270 + case 27: + return 280 + case 28: + return 290 + case 29: + return 300 + case 30: + return 310 + case 31: + return 320 + default: + return 0 + } +} + +func BenchmarkSwitchLookup8Predictable(b *testing.B) { + benchmarkSwitchLookup8(b, true) +} +func BenchmarkSwitchLookup8Unpredictable(b *testing.B) { + benchmarkSwitchLookup8(b, false) +} +func benchmarkSwitchLookup8(b *testing.B, predictable bool) { + n := 0 + rng := newRNG() + for i := 0; i < b.N; i++ { + rng = rng.next(predictable) + n += switchLookup8(int(rng.value() & 7)) + } + sink = n +} + +func BenchmarkSwitchLookup32Predictable(b *testing.B) { + benchmarkSwitchLookup32(b, true) +} +func BenchmarkSwitchLookup32Unpredictable(b *testing.B) { + benchmarkSwitchLookup32(b, false) +} +func benchmarkSwitchLookup32(b *testing.B, predictable bool) { + n := 0 + rng := newRNG() + for i := 0; i < b.N; i++ { + rng = rng.next(predictable) + n += switchLookup32(int(rng.value() & 31)) + } + sink = n +} diff --git a/src/cmd/compile/internal/test/testdata/ctl_test.go b/src/cmd/compile/internal/test/testdata/ctl_test.go index 501f79eee1..608bdcf6b9 100644 --- a/src/cmd/compile/internal/test/testdata/ctl_test.go +++ b/src/cmd/compile/internal/test/testdata/ctl_test.go @@ -72,6 +72,22 @@ func switch_ssa(a int) int { return ret } +func switch_nonconstcase(n, c int) int { + switch n { + case 1: + return 1 + case 2: + return 2 + case c: + return 9 + case 3: + return 3 + case 4: + return 4 + } + return 0 +} + func fallthrough_ssa(a int) int { ret := 0 switch a { @@ -107,6 +123,9 @@ func testSwitch(t *testing.T) { t.Errorf("switch_ssa(i) = %d, wanted %d", got, i) } } + if got := switch_nonconstcase(3, 3); got != 9 { + t.Errorf("switch_nonconstcase(3, 3) = %d, wanted 9", got) + } } type junk struct { diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go index cbe38b54bc..b7f0dcab5b 100644 --- a/src/cmd/compile/internal/walk/switch.go +++ b/src/cmd/compile/internal/walk/switch.go @@ -9,6 +9,7 @@ import ( "fmt" "go/constant" "go/token" + "math" "math/bits" "slices" "sort" @@ -75,6 +76,8 @@ func walkSwitchExpr(sw *ir.SwitchStmt) { base.Pos = lno + tryLookupTable(sw, cond) + s := exprSwitch{ pos: lno, exprname: cond, @@ -334,6 +337,330 @@ func (s *exprSwitch) tryJumpTable(cc []exprClause, out *ir.Nodes) bool { return true } +// tryLookupTable attempts to replace constant-returning cases of an integer +// switch with a static lookup table. Cases whose bodies are a single "return +// <int constant>" are served from a read-only array, eliminating branching. +// Remaining cases (non-constant bodies, default) are left in sw.Cases for +// normal switch compilation. +// +// For example: +// +// switch x { +// case 0: return 10 +// case 1: return 20 +// case 2, 3: return 30 +// default: return -1 +// } +// +// Becomes: +// +// var table = [4]int{10, 20, 30, 30} +// if uint(x) < 4 { return table[x] } +// // remaining switch for default +// +// Partial optimization also works when some cases have non-constant bodies: +// +// switch x { +// case 1: return 1 +// case 2: return 4 +// case 3: sideEffect(); return 9 +// ... +// default: return x * x +// } +// +// Becomes: +// +// var table = [8]int{1, 4, 0, ...} +// var mask = [8]uint8{1, 1, 0, ...} +// if uint(x-1) <= 7 && mask[x-1] != 0 { return table[x-1] } +// // remaining switch for case 3 + default +func tryLookupTable(sw *ir.SwitchStmt, cond ir.Node) { + const minCases = 4 // need enough cases to justify a table + + if base.Flag.N != 0 { + return // optimizations disabled + } + if !cond.Type().IsInteger() { + return + } + if cond.Type().Size() > int64(types.PtrSize) { + return // 64-bit switches on 32-bit archs + } + + // Bail out if any case uses fallthrough. Removing cases from the switch + // would break fallthrough chains between adjacent cases. + // TODO: we could still optimize cases that don't fall through, even if some cases do. + for _, ncase := range sw.Cases { + if fall, _ := endsInFallthrough(ncase.Body); fall { + return + } + } + + fn := ir.CurFunc + if fn == nil || fn.Type().NumResults() != 1 { + return // only handle single return value + } + resultType := fn.Type().Results()[0].Type + if !resultType.IsInteger() { + // Only handle integer return types for now. + // TODO: generalize to other constant types, e.g. strings and bools. + return + } + + // Classify each case as const-returning or not. + // TODO: support more complex bodies, like local variable assignments. + // For example: + // + // var n int + // switch x { + // case 1: n = 1 + // case 2: n = 4 + // case 3: n = 9 + // case 4: n = 16 + // } + // return n + // + // Could be optimized to: + // + // var table = [4]int{1, 4, 9, 16} + // var n int + // if uint(x-1) < 4 { n = table[x-1] } + // return n + constSet := make(map[int64]constant.Value) // case value → return constant + constCaseSet := make(map[int]bool) // indices of const-returning non-default cases + var defaultVal constant.Value + var hasConstDefault bool + excludeSet := make(map[int64]bool) // case values with non-const bodies + minVal, maxVal := int64(math.MaxInt64), int64(math.MinInt64) + + for i, ncase := range sw.Cases { + if len(ncase.List) == 0 { + // Default case: check if it returns a constant (for gap filling). + if isConstIntReturn(ncase) { + hasConstDefault = true + defaultVal = ncase.Body[0].(*ir.ReturnStmt).Results[0].Val() + } + continue + } + + vals, ok := constIntCaseVals(ncase) + if !ok { + // Case has a non-constant case expression (e.g. a variable). + // Bail out: we can't determine overlap with the table range. + // For example: + // case 1: return 1 // const → would go to table + // case c: return 9 // c is a variable, not a constant + // case 3: return 3 // const → would go to table + // At runtime, if c==3 then Go evaluates case c before case 3, + // returning 9. But if we put cases 1 and 3 in a table, n==3 + // would return 3 from the table, skipping the case c check. + return + } + + if !isConstIntReturn(ncase) { + // Non-const case body: exclude these values from the table + // so the mask redirects them to the normal switch, preserving + // Go's top-to-bottom case evaluation order. For example: + // case 3: sideEffect(); return 30 → exclude slot 3 + for _, v := range vals { + excludeSet[v] = true + } + continue // will be handled by normal switch + } + + retVal := ncase.Body[0].(*ir.ReturnStmt).Results[0].Val() + for _, v := range vals { + constSet[v] = retVal + minVal = min(minVal, v) + maxVal = max(maxVal, v) + } + constCaseSet[i] = true + } + + if len(constSet) < minCases { + return + } + + tableSize := maxVal - minVal + 1 + if tableSize <= 0 || !isSwitchDense(int64(len(constSet)), tableSize) { + return // too sparse + } + + // Build static lookup table and determine which slots are valid. + // Also build the bitmask inline if the table is small enough. + tabType := types.NewArray(resultType, tableSize) + tabName := readonlystaticname(tabType) + lsym := tabName.Linksym() + elemSize := int(resultType.Size()) + maxBitmaskSize := int64(types.PtrSize * 8) // 32 or 64 + + needMask := false + var bitmask uint64 + validSlots := make([]bool, tableSize) + for i := range tableSize { + caseVal := minVal + i + var v int64 + switch { + case excludeSet[caseVal]: + // Non-const case in range: must fall through to normal switch. + needMask = true + case constSet[caseVal] != nil: + v = ir.IntVal(resultType, constSet[caseVal]) + validSlots[i] = true + bitmask |= 1 << uint(i) + case hasConstDefault: + // Gap filled with default constant value. + v = ir.IntVal(resultType, defaultVal) + validSlots[i] = true + bitmask |= 1 << uint(i) + default: + // Gap with no const default: must fall through. + needMask = true + } + lsym.WriteInt(base.Ctxt, i*int64(elemSize), elemSize, v) + } + + // Build mask if some slots must fall through to normal switch. + // When the table fits in a register-width bitmask (≤32 entries on 32-bit, + // ≤64 on 64-bit), use the bitmask computed above. For larger tables, + // fall back to a byte array. + var maskName *ir.Name + useBitmask := needMask && tableSize <= maxBitmaskSize + if needMask && !useBitmask { + maskType := types.NewArray(types.Types[types.TUINT8], tableSize) + maskName = readonlystaticname(maskType) + maskSym := maskName.Linksym() + for i := range tableSize { + var v uint8 + if validSlots[i] { + v = 1 + } + maskSym.WriteInt(base.Ctxt, i, 1, int64(v)) + } + } + + // Generate code: + // idx := uint(int(cond) - minVal) + // if idx <= uint(maxVal-minVal) [&& bitmask>>idx&1 != 0] { return table[idx] } + pos := sw.Pos() + + // Widen cond to int to avoid overflow in small integer types. + intType := types.Types[types.TINT] + wideCond := typecheck.Conv(cond, intType) + + // Compute idx = int(cond) - minVal. + var idx ir.Node + if minVal != 0 { + minLit := ir.NewBasicLit(pos, intType, constant.MakeInt64(minVal)) + idx = typecheck.Expr(ir.NewBinaryExpr(pos, ir.OSUB, wideCond, minLit)) + } else { + idx = wideCond + } + + // Convert to uint for the one-sided bounds check and store in a temp + // so the index can be shared across the bounds check, table, and mask. + uintType := types.Types[types.TUINT] + uidx := typecheck.Conv(idx, uintType) + uidx = copyExpr(uidx, uintType, &sw.Compiled) + + // Bounds check: uint(idx) <= uint(maxVal - minVal). + rangeLit := ir.NewBasicLit(pos, uintType, constant.MakeUint64(uint64(maxVal-minVal))) + boundsCheck := typecheck.Expr(ir.NewBinaryExpr(pos, ir.OLE, uidx, rangeLit)) + boundsCheck = typecheck.DefaultLit(boundsCheck, nil) + + // Table lookup: table[idx] with bounds elided (already checked). + lookup := ir.NewIndexExpr(pos, tabName, uidx) + lookup.SetBounded(true) + lookup = typecheck.Expr(lookup).(*ir.IndexExpr) + + retStmt := ir.NewReturnStmt(pos, []ir.Node{lookup}) + + var ifBody []ir.Node + if needMask { + var maskCheck ir.Node + if useBitmask { + // Bitmask check: (bitmask >> idx) & 1 != 0. + // Use uintptr so the operation is register-width on all architectures. + bitmaskType := types.Types[types.TUINTPTR] + bitmaskLit := ir.NewBasicLit(pos, bitmaskType, constant.MakeUint64(bitmask)) + shifted := typecheck.Expr(ir.NewBinaryExpr(pos, ir.ORSH, bitmaskLit, uidx)) + one := ir.NewBasicLit(pos, bitmaskType, constant.MakeUint64(1)) + masked := typecheck.Expr(ir.NewBinaryExpr(pos, ir.OAND, shifted, one)) + zero := ir.NewBasicLit(pos, bitmaskType, constant.MakeUint64(0)) + maskCheck = typecheck.Expr(ir.NewBinaryExpr(pos, ir.ONE, masked, zero)) + } else { + // Mask array check: mask[idx] != 0. + maskLookup := ir.NewIndexExpr(pos, maskName, uidx) + maskLookup.SetBounded(true) + maskLookup = typecheck.Expr(maskLookup).(*ir.IndexExpr) + zero := ir.NewBasicLit(pos, types.Types[types.TUINT8], constant.MakeInt64(0)) + maskCheck = typecheck.Expr(ir.NewBinaryExpr(pos, ir.ONE, maskLookup, zero)) + } + maskCheck = typecheck.DefaultLit(maskCheck, nil) + + innerIf := ir.NewIfStmt(pos, maskCheck, []ir.Node{retStmt}, nil) + ifBody = []ir.Node{innerIf} + } else { + ifBody = []ir.Node{retStmt} + } + + outerIf := ir.NewIfStmt(pos, boundsCheck, ifBody, nil) + sw.Compiled.Append(outerIf) + + // Remove handled const cases from sw.Cases. + // Keep default and non-const cases for normal switch processing. + newCases := make([]*ir.CaseClause, 0, len(sw.Cases)-len(constCaseSet)) + for i, ncase := range sw.Cases { + if !constCaseSet[i] { + newCases = append(newCases, ncase) + } + } + sw.Cases = newCases +} + +// isSwitchDense reports whether a lookup table with tableSize entries +// for numCases cases is dense enough to be worth building. +// It requires at least 40% of table slots to be used, matching the +// density threshold used by LLVM's SimplifyCFG. +func isSwitchDense(numCases, tableSize int64) bool { + const minDensity = 40 + if tableSize >= math.MaxInt64/100 { + return false // avoid multiplication overflow below + } + return numCases*100 >= tableSize*minDensity +} + +// isConstIntReturn reports whether ncase has a body that is a single +// return statement returning one integer constant. +func isConstIntReturn(ncase *ir.CaseClause) bool { + if len(ncase.Body) != 1 { + return false + } + ret, ok := ncase.Body[0].(*ir.ReturnStmt) + if !ok || len(ret.Results) != 1 { + return false + } + r := ret.Results[0] + return r.Op() == ir.OLITERAL && r.Val().Kind() == constant.Int +} + +// constIntCaseVals returns the int64 values of all case expressions in +// ncase, if they are all integer constants. Returns ok=false if any +// case expression is not a constant integer. +func constIntCaseVals(ncase *ir.CaseClause) (vals []int64, ok bool) { + for _, n1 := range ncase.List { + if n1.Op() != ir.OLITERAL || n1.Val().Kind() != constant.Int { + return nil, false + } + v, fit := constant.Int64Val(n1.Val()) + if !fit { + return nil, false + } + vals = append(vals, v) + } + return vals, true +} + func (c *exprClause) test(exprname ir.Node) ir.Node { // Integer range. if c.hi != c.lo { diff --git a/test/codegen/switch.go b/test/codegen/switch.go index d59ef4f2eb..07850522ab 100644 --- a/test/codegen/switch.go +++ b/test/codegen/switch.go @@ -21,13 +21,41 @@ func f(x string) int { } } -// use jump tables for 8+ int cases -func square(x int) int { +// use jump tables for 8+ string cases +// Using multiple return values prevent lookup tables. +func squareJump(x int) (int, int) { // amd64:`JMP \(.*\)\(.*\)$` // arm64:`MOVD \(R.*\)\(R.*<<3\)` `JMP \(R.*\)$` // loong64: `ALSLV` `MOVV` `JMP` switch x { case 1: + return 1, 1 + case 2: + return 4, 2 + case 3: + return 9, 3 + case 4: + return 16, 4 + case 5: + return 25, 5 + case 6: + return 36, 6 + case 7: + return 49, 7 + case 8: + return 64, 8 + default: + return x * x, x + } +} + +// use lookup tables for 8+ int cases returning constants +func squareLookup(x int) int { + // amd64:`LEAQ .*\(SB\)` `MOVQ .*\(.*\)\(.*\*8\)` -`JMP \(.*\)\(.*\)$` + // arm64:`MOVD \(R.*\)\(R.*<<3\)` -`JMP \(R.*\)$` + // loong64:`SLLV` `MOVV \(R.*\)\(R.*\)` -`ALSLV` + switch x { + case 1: return 1 case 2: return 4 |
