aboutsummaryrefslogtreecommitdiff
path: root/src/cmd
diff options
context:
space:
mode:
Diffstat (limited to 'src/cmd')
-rw-r--r--src/cmd/compile/internal/test/switch_test.go132
-rw-r--r--src/cmd/compile/internal/test/testdata/ctl_test.go19
-rw-r--r--src/cmd/compile/internal/walk/switch.go327
3 files changed, 478 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/test/switch_test.go b/src/cmd/compile/internal/test/switch_test.go
index 1d12361cbb..0442cdc8fc 100644
--- a/src/cmd/compile/internal/test/switch_test.go
+++ b/src/cmd/compile/internal/test/switch_test.go
@@ -294,3 +294,135 @@ func (r rng) next(predictable bool) rng {
func (r rng) value() uint64 {
return uint64(r)
}
+
+// Benchmarks for switch-to-lookup-table optimization.
+// These use functions that return constants, which is the pattern
+// the lookup table optimization targets.
+
+//go:noinline
+func switchLookup8(x int) int {
+ switch x {
+ case 0:
+ return 1
+ case 1:
+ return 2
+ case 2:
+ return 3
+ case 3:
+ return 5
+ case 4:
+ return 8
+ case 5:
+ return 13
+ case 6:
+ return 21
+ case 7:
+ return 34
+ default:
+ return 0
+ }
+}
+
+//go:noinline
+func switchLookup32(x int) int {
+ switch x {
+ case 0:
+ return 10
+ case 1:
+ return 20
+ case 2:
+ return 30
+ case 3:
+ return 40
+ case 4:
+ return 50
+ case 5:
+ return 60
+ case 6:
+ return 70
+ case 7:
+ return 80
+ case 8:
+ return 90
+ case 9:
+ return 100
+ case 10:
+ return 110
+ case 11:
+ return 120
+ case 12:
+ return 130
+ case 13:
+ return 140
+ case 14:
+ return 150
+ case 15:
+ return 160
+ case 16:
+ return 170
+ case 17:
+ return 180
+ case 18:
+ return 190
+ case 19:
+ return 200
+ case 20:
+ return 210
+ case 21:
+ return 220
+ case 22:
+ return 230
+ case 23:
+ return 240
+ case 24:
+ return 250
+ case 25:
+ return 260
+ case 26:
+ return 270
+ case 27:
+ return 280
+ case 28:
+ return 290
+ case 29:
+ return 300
+ case 30:
+ return 310
+ case 31:
+ return 320
+ default:
+ return 0
+ }
+}
+
+func BenchmarkSwitchLookup8Predictable(b *testing.B) {
+ benchmarkSwitchLookup8(b, true)
+}
+func BenchmarkSwitchLookup8Unpredictable(b *testing.B) {
+ benchmarkSwitchLookup8(b, false)
+}
+func benchmarkSwitchLookup8(b *testing.B, predictable bool) {
+ n := 0
+ rng := newRNG()
+ for i := 0; i < b.N; i++ {
+ rng = rng.next(predictable)
+ n += switchLookup8(int(rng.value() & 7))
+ }
+ sink = n
+}
+
+func BenchmarkSwitchLookup32Predictable(b *testing.B) {
+ benchmarkSwitchLookup32(b, true)
+}
+func BenchmarkSwitchLookup32Unpredictable(b *testing.B) {
+ benchmarkSwitchLookup32(b, false)
+}
+func benchmarkSwitchLookup32(b *testing.B, predictable bool) {
+ n := 0
+ rng := newRNG()
+ for i := 0; i < b.N; i++ {
+ rng = rng.next(predictable)
+ n += switchLookup32(int(rng.value() & 31))
+ }
+ sink = n
+}
diff --git a/src/cmd/compile/internal/test/testdata/ctl_test.go b/src/cmd/compile/internal/test/testdata/ctl_test.go
index 501f79eee1..608bdcf6b9 100644
--- a/src/cmd/compile/internal/test/testdata/ctl_test.go
+++ b/src/cmd/compile/internal/test/testdata/ctl_test.go
@@ -72,6 +72,22 @@ func switch_ssa(a int) int {
return ret
}
+func switch_nonconstcase(n, c int) int {
+ switch n {
+ case 1:
+ return 1
+ case 2:
+ return 2
+ case c:
+ return 9
+ case 3:
+ return 3
+ case 4:
+ return 4
+ }
+ return 0
+}
+
func fallthrough_ssa(a int) int {
ret := 0
switch a {
@@ -107,6 +123,9 @@ func testSwitch(t *testing.T) {
t.Errorf("switch_ssa(i) = %d, wanted %d", got, i)
}
}
+ if got := switch_nonconstcase(3, 3); got != 9 {
+ t.Errorf("switch_nonconstcase(3, 3) = %d, wanted 9", got)
+ }
}
type junk struct {
diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go
index cbe38b54bc..b7f0dcab5b 100644
--- a/src/cmd/compile/internal/walk/switch.go
+++ b/src/cmd/compile/internal/walk/switch.go
@@ -9,6 +9,7 @@ import (
"fmt"
"go/constant"
"go/token"
+ "math"
"math/bits"
"slices"
"sort"
@@ -75,6 +76,8 @@ func walkSwitchExpr(sw *ir.SwitchStmt) {
base.Pos = lno
+ tryLookupTable(sw, cond)
+
s := exprSwitch{
pos: lno,
exprname: cond,
@@ -334,6 +337,330 @@ func (s *exprSwitch) tryJumpTable(cc []exprClause, out *ir.Nodes) bool {
return true
}
+// tryLookupTable attempts to replace constant-returning cases of an integer
+// switch with a static lookup table. Cases whose bodies are a single "return
+// <int constant>" are served from a read-only array, eliminating branching.
+// Remaining cases (non-constant bodies, default) are left in sw.Cases for
+// normal switch compilation.
+//
+// For example:
+//
+// switch x {
+// case 0: return 10
+// case 1: return 20
+// case 2, 3: return 30
+// default: return -1
+// }
+//
+// Becomes:
+//
+// var table = [4]int{10, 20, 30, 30}
+// if uint(x) < 4 { return table[x] }
+// // remaining switch for default
+//
+// Partial optimization also works when some cases have non-constant bodies:
+//
+// switch x {
+// case 1: return 1
+// case 2: return 4
+// case 3: sideEffect(); return 9
+// ...
+// default: return x * x
+// }
+//
+// Becomes:
+//
+// var table = [8]int{1, 4, 0, ...}
+// var mask = [8]uint8{1, 1, 0, ...}
+// if uint(x-1) <= 7 && mask[x-1] != 0 { return table[x-1] }
+// // remaining switch for case 3 + default
+func tryLookupTable(sw *ir.SwitchStmt, cond ir.Node) {
+ const minCases = 4 // need enough cases to justify a table
+
+ if base.Flag.N != 0 {
+ return // optimizations disabled
+ }
+ if !cond.Type().IsInteger() {
+ return
+ }
+ if cond.Type().Size() > int64(types.PtrSize) {
+ return // 64-bit switches on 32-bit archs
+ }
+
+ // Bail out if any case uses fallthrough. Removing cases from the switch
+ // would break fallthrough chains between adjacent cases.
+ // TODO: we could still optimize cases that don't fall through, even if some cases do.
+ for _, ncase := range sw.Cases {
+ if fall, _ := endsInFallthrough(ncase.Body); fall {
+ return
+ }
+ }
+
+ fn := ir.CurFunc
+ if fn == nil || fn.Type().NumResults() != 1 {
+ return // only handle single return value
+ }
+ resultType := fn.Type().Results()[0].Type
+ if !resultType.IsInteger() {
+ // Only handle integer return types for now.
+ // TODO: generalize to other constant types, e.g. strings and bools.
+ return
+ }
+
+ // Classify each case as const-returning or not.
+ // TODO: support more complex bodies, like local variable assignments.
+ // For example:
+ //
+ // var n int
+ // switch x {
+ // case 1: n = 1
+ // case 2: n = 4
+ // case 3: n = 9
+ // case 4: n = 16
+ // }
+ // return n
+ //
+ // Could be optimized to:
+ //
+ // var table = [4]int{1, 4, 9, 16}
+ // var n int
+ // if uint(x-1) < 4 { n = table[x-1] }
+ // return n
+ constSet := make(map[int64]constant.Value) // case value → return constant
+ constCaseSet := make(map[int]bool) // indices of const-returning non-default cases
+ var defaultVal constant.Value
+ var hasConstDefault bool
+ excludeSet := make(map[int64]bool) // case values with non-const bodies
+ minVal, maxVal := int64(math.MaxInt64), int64(math.MinInt64)
+
+ for i, ncase := range sw.Cases {
+ if len(ncase.List) == 0 {
+ // Default case: check if it returns a constant (for gap filling).
+ if isConstIntReturn(ncase) {
+ hasConstDefault = true
+ defaultVal = ncase.Body[0].(*ir.ReturnStmt).Results[0].Val()
+ }
+ continue
+ }
+
+ vals, ok := constIntCaseVals(ncase)
+ if !ok {
+ // Case has a non-constant case expression (e.g. a variable).
+ // Bail out: we can't determine overlap with the table range.
+ // For example:
+ // case 1: return 1 // const → would go to table
+ // case c: return 9 // c is a variable, not a constant
+ // case 3: return 3 // const → would go to table
+ // At runtime, if c==3 then Go evaluates case c before case 3,
+ // returning 9. But if we put cases 1 and 3 in a table, n==3
+ // would return 3 from the table, skipping the case c check.
+ return
+ }
+
+ if !isConstIntReturn(ncase) {
+ // Non-const case body: exclude these values from the table
+ // so the mask redirects them to the normal switch, preserving
+ // Go's top-to-bottom case evaluation order. For example:
+ // case 3: sideEffect(); return 30 → exclude slot 3
+ for _, v := range vals {
+ excludeSet[v] = true
+ }
+ continue // will be handled by normal switch
+ }
+
+ retVal := ncase.Body[0].(*ir.ReturnStmt).Results[0].Val()
+ for _, v := range vals {
+ constSet[v] = retVal
+ minVal = min(minVal, v)
+ maxVal = max(maxVal, v)
+ }
+ constCaseSet[i] = true
+ }
+
+ if len(constSet) < minCases {
+ return
+ }
+
+ tableSize := maxVal - minVal + 1
+ if tableSize <= 0 || !isSwitchDense(int64(len(constSet)), tableSize) {
+ return // too sparse
+ }
+
+ // Build static lookup table and determine which slots are valid.
+ // Also build the bitmask inline if the table is small enough.
+ tabType := types.NewArray(resultType, tableSize)
+ tabName := readonlystaticname(tabType)
+ lsym := tabName.Linksym()
+ elemSize := int(resultType.Size())
+ maxBitmaskSize := int64(types.PtrSize * 8) // 32 or 64
+
+ needMask := false
+ var bitmask uint64
+ validSlots := make([]bool, tableSize)
+ for i := range tableSize {
+ caseVal := minVal + i
+ var v int64
+ switch {
+ case excludeSet[caseVal]:
+ // Non-const case in range: must fall through to normal switch.
+ needMask = true
+ case constSet[caseVal] != nil:
+ v = ir.IntVal(resultType, constSet[caseVal])
+ validSlots[i] = true
+ bitmask |= 1 << uint(i)
+ case hasConstDefault:
+ // Gap filled with default constant value.
+ v = ir.IntVal(resultType, defaultVal)
+ validSlots[i] = true
+ bitmask |= 1 << uint(i)
+ default:
+ // Gap with no const default: must fall through.
+ needMask = true
+ }
+ lsym.WriteInt(base.Ctxt, i*int64(elemSize), elemSize, v)
+ }
+
+ // Build mask if some slots must fall through to normal switch.
+ // When the table fits in a register-width bitmask (≤32 entries on 32-bit,
+ // ≤64 on 64-bit), use the bitmask computed above. For larger tables,
+ // fall back to a byte array.
+ var maskName *ir.Name
+ useBitmask := needMask && tableSize <= maxBitmaskSize
+ if needMask && !useBitmask {
+ maskType := types.NewArray(types.Types[types.TUINT8], tableSize)
+ maskName = readonlystaticname(maskType)
+ maskSym := maskName.Linksym()
+ for i := range tableSize {
+ var v uint8
+ if validSlots[i] {
+ v = 1
+ }
+ maskSym.WriteInt(base.Ctxt, i, 1, int64(v))
+ }
+ }
+
+ // Generate code:
+ // idx := uint(int(cond) - minVal)
+ // if idx <= uint(maxVal-minVal) [&& bitmask>>idx&1 != 0] { return table[idx] }
+ pos := sw.Pos()
+
+ // Widen cond to int to avoid overflow in small integer types.
+ intType := types.Types[types.TINT]
+ wideCond := typecheck.Conv(cond, intType)
+
+ // Compute idx = int(cond) - minVal.
+ var idx ir.Node
+ if minVal != 0 {
+ minLit := ir.NewBasicLit(pos, intType, constant.MakeInt64(minVal))
+ idx = typecheck.Expr(ir.NewBinaryExpr(pos, ir.OSUB, wideCond, minLit))
+ } else {
+ idx = wideCond
+ }
+
+ // Convert to uint for the one-sided bounds check and store in a temp
+ // so the index can be shared across the bounds check, table, and mask.
+ uintType := types.Types[types.TUINT]
+ uidx := typecheck.Conv(idx, uintType)
+ uidx = copyExpr(uidx, uintType, &sw.Compiled)
+
+ // Bounds check: uint(idx) <= uint(maxVal - minVal).
+ rangeLit := ir.NewBasicLit(pos, uintType, constant.MakeUint64(uint64(maxVal-minVal)))
+ boundsCheck := typecheck.Expr(ir.NewBinaryExpr(pos, ir.OLE, uidx, rangeLit))
+ boundsCheck = typecheck.DefaultLit(boundsCheck, nil)
+
+ // Table lookup: table[idx] with bounds elided (already checked).
+ lookup := ir.NewIndexExpr(pos, tabName, uidx)
+ lookup.SetBounded(true)
+ lookup = typecheck.Expr(lookup).(*ir.IndexExpr)
+
+ retStmt := ir.NewReturnStmt(pos, []ir.Node{lookup})
+
+ var ifBody []ir.Node
+ if needMask {
+ var maskCheck ir.Node
+ if useBitmask {
+ // Bitmask check: (bitmask >> idx) & 1 != 0.
+ // Use uintptr so the operation is register-width on all architectures.
+ bitmaskType := types.Types[types.TUINTPTR]
+ bitmaskLit := ir.NewBasicLit(pos, bitmaskType, constant.MakeUint64(bitmask))
+ shifted := typecheck.Expr(ir.NewBinaryExpr(pos, ir.ORSH, bitmaskLit, uidx))
+ one := ir.NewBasicLit(pos, bitmaskType, constant.MakeUint64(1))
+ masked := typecheck.Expr(ir.NewBinaryExpr(pos, ir.OAND, shifted, one))
+ zero := ir.NewBasicLit(pos, bitmaskType, constant.MakeUint64(0))
+ maskCheck = typecheck.Expr(ir.NewBinaryExpr(pos, ir.ONE, masked, zero))
+ } else {
+ // Mask array check: mask[idx] != 0.
+ maskLookup := ir.NewIndexExpr(pos, maskName, uidx)
+ maskLookup.SetBounded(true)
+ maskLookup = typecheck.Expr(maskLookup).(*ir.IndexExpr)
+ zero := ir.NewBasicLit(pos, types.Types[types.TUINT8], constant.MakeInt64(0))
+ maskCheck = typecheck.Expr(ir.NewBinaryExpr(pos, ir.ONE, maskLookup, zero))
+ }
+ maskCheck = typecheck.DefaultLit(maskCheck, nil)
+
+ innerIf := ir.NewIfStmt(pos, maskCheck, []ir.Node{retStmt}, nil)
+ ifBody = []ir.Node{innerIf}
+ } else {
+ ifBody = []ir.Node{retStmt}
+ }
+
+ outerIf := ir.NewIfStmt(pos, boundsCheck, ifBody, nil)
+ sw.Compiled.Append(outerIf)
+
+ // Remove handled const cases from sw.Cases.
+ // Keep default and non-const cases for normal switch processing.
+ newCases := make([]*ir.CaseClause, 0, len(sw.Cases)-len(constCaseSet))
+ for i, ncase := range sw.Cases {
+ if !constCaseSet[i] {
+ newCases = append(newCases, ncase)
+ }
+ }
+ sw.Cases = newCases
+}
+
+// isSwitchDense reports whether a lookup table with tableSize entries
+// for numCases cases is dense enough to be worth building.
+// It requires at least 40% of table slots to be used, matching the
+// density threshold used by LLVM's SimplifyCFG.
+func isSwitchDense(numCases, tableSize int64) bool {
+ const minDensity = 40
+ if tableSize >= math.MaxInt64/100 {
+ return false // avoid multiplication overflow below
+ }
+ return numCases*100 >= tableSize*minDensity
+}
+
+// isConstIntReturn reports whether ncase has a body that is a single
+// return statement returning one integer constant.
+func isConstIntReturn(ncase *ir.CaseClause) bool {
+ if len(ncase.Body) != 1 {
+ return false
+ }
+ ret, ok := ncase.Body[0].(*ir.ReturnStmt)
+ if !ok || len(ret.Results) != 1 {
+ return false
+ }
+ r := ret.Results[0]
+ return r.Op() == ir.OLITERAL && r.Val().Kind() == constant.Int
+}
+
+// constIntCaseVals returns the int64 values of all case expressions in
+// ncase, if they are all integer constants. Returns ok=false if any
+// case expression is not a constant integer.
+func constIntCaseVals(ncase *ir.CaseClause) (vals []int64, ok bool) {
+ for _, n1 := range ncase.List {
+ if n1.Op() != ir.OLITERAL || n1.Val().Kind() != constant.Int {
+ return nil, false
+ }
+ v, fit := constant.Int64Val(n1.Val())
+ if !fit {
+ return nil, false
+ }
+ vals = append(vals, v)
+ }
+ return vals, true
+}
+
func (c *exprClause) test(exprname ir.Node) ir.Node {
// Integer range.
if c.hi != c.lo {