aboutsummaryrefslogtreecommitdiff
path: root/src/cmd
diff options
context:
space:
mode:
authorMichael Munday <mndygolang+git@gmail.com>2025-08-26 21:17:36 +0100
committerGopher Robot <gobot@golang.org>2025-11-14 10:59:56 -0800
commit0a569528ea355099af864f7612c3fa1973df30e4 (patch)
tree189f828573801d88d072ebd988ae1d553d2a8afa /src/cmd
parent1e5e6663e958dcc9579fb38ffcd8a1999d75128d (diff)
downloadgo-0a569528ea355099af864f7612c3fa1973df30e4.tar.xz
cmd/compile: optimize comparisons with single bit difference
Optimize comparisons with constants that only differ by 1 bit (i.e. a power of 2). For example: x == 4 || x == 6 -> x|2 == 6 x != 1 && x != 5 -> x|4 != 5 Change-Id: Ic61719e5118446d21cf15652d9da22f7d95b2a15 Reviewed-on: https://go-review.googlesource.com/c/go/+/719420 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Junyang Shao <shaojunyang@google.com> Auto-Submit: Keith Randall <khr@golang.org> Reviewed-by: Keith Randall <khr@golang.org> Reviewed-by: Keith Randall <khr@google.com>
Diffstat (limited to 'src/cmd')
-rw-r--r--src/cmd/compile/internal/ssa/_gen/generic.rules6
-rw-r--r--src/cmd/compile/internal/ssa/fuse.go8
-rw-r--r--src/cmd/compile/internal/ssa/fuse_comparisons.go45
-rw-r--r--src/cmd/compile/internal/ssa/rewritegeneric.go352
4 files changed, 410 insertions, 1 deletions
diff --git a/src/cmd/compile/internal/ssa/_gen/generic.rules b/src/cmd/compile/internal/ssa/_gen/generic.rules
index 7e3aba1e5e..6efead03ad 100644
--- a/src/cmd/compile/internal/ssa/_gen/generic.rules
+++ b/src/cmd/compile/internal/ssa/_gen/generic.rules
@@ -337,6 +337,12 @@
(OrB ((Less|Leq)16U (Const16 [c]) x) (Leq16U x (Const16 [d]))) && uint16(c) >= uint16(d+1) && uint16(d+1) > uint16(d) => ((Less|Leq)16U (Const16 <x.Type> [c-d-1]) (Sub16 <x.Type> x (Const16 <x.Type> [d+1])))
(OrB ((Less|Leq)8U (Const8 [c]) x) (Leq8U x (Const8 [d]))) && uint8(c) >= uint8(d+1) && uint8(d+1) > uint8(d) => ((Less|Leq)8U (Const8 <x.Type> [c-d-1]) (Sub8 <x.Type> x (Const8 <x.Type> [d+1])))
+// single bit difference: ( x != c && x != d ) -> ( x|(c^d) != c )
+(AndB (Neq(64|32|16|8) x cv:(Const(64|32|16|8) [c])) (Neq(64|32|16|8) x (Const(64|32|16|8) [d]))) && c|d == c && oneBit(c^d) => (Neq(64|32|16|8) (Or(64|32|16|8) <x.Type> x (Const(64|32|16|8) <x.Type> [c^d])) cv)
+
+// single bit difference: ( x == c || x == d ) -> ( x|(c^d) == c )
+(OrB (Eq(64|32|16|8) x cv:(Const(64|32|16|8) [c])) (Eq(64|32|16|8) x (Const(64|32|16|8) [d]))) && c|d == c && oneBit(c^d) => (Eq(64|32|16|8) (Or(64|32|16|8) <x.Type> x (Const(64|32|16|8) <x.Type> [c^d])) cv)
+
// NaN check: ( x != x || x (>|>=|<|<=) c ) -> ( !(c (>=|>|<=|<) x) )
(OrB (Neq64F x x) ((Less|Leq)64F x y:(Const64F [c]))) => (Not ((Leq|Less)64F y x))
(OrB (Neq64F x x) ((Less|Leq)64F y:(Const64F [c]) x)) => (Not ((Leq|Less)64F x y))
diff --git a/src/cmd/compile/internal/ssa/fuse.go b/src/cmd/compile/internal/ssa/fuse.go
index 0cee91b532..e95064c1df 100644
--- a/src/cmd/compile/internal/ssa/fuse.go
+++ b/src/cmd/compile/internal/ssa/fuse.go
@@ -10,7 +10,9 @@ import (
)
// fuseEarly runs fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeNanCheck).
-func fuseEarly(f *Func) { fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeNanCheck) }
+func fuseEarly(f *Func) {
+ fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeSingleBitDifference|fuseTypeNanCheck)
+}
// fuseLate runs fuse(f, fuseTypePlain|fuseTypeIf|fuseTypeBranchRedirect).
func fuseLate(f *Func) { fuse(f, fuseTypePlain|fuseTypeIf|fuseTypeBranchRedirect) }
@@ -21,6 +23,7 @@ const (
fuseTypePlain fuseType = 1 << iota
fuseTypeIf
fuseTypeIntInRange
+ fuseTypeSingleBitDifference
fuseTypeNanCheck
fuseTypeBranchRedirect
fuseTypeShortCircuit
@@ -41,6 +44,9 @@ func fuse(f *Func, typ fuseType) {
if typ&fuseTypeIntInRange != 0 {
changed = fuseIntInRange(b) || changed
}
+ if typ&fuseTypeSingleBitDifference != 0 {
+ changed = fuseSingleBitDifference(b) || changed
+ }
if typ&fuseTypeNanCheck != 0 {
changed = fuseNanCheck(b) || changed
}
diff --git a/src/cmd/compile/internal/ssa/fuse_comparisons.go b/src/cmd/compile/internal/ssa/fuse_comparisons.go
index b6eb8fcb90..898c034485 100644
--- a/src/cmd/compile/internal/ssa/fuse_comparisons.go
+++ b/src/cmd/compile/internal/ssa/fuse_comparisons.go
@@ -19,6 +19,14 @@ func fuseNanCheck(b *Block) bool {
return fuseComparisons(b, canOptNanCheck)
}
+// fuseSingleBitDifference replaces the short-circuit operators between equality checks with
+// constants that only differ by a single bit. For example, it would convert
+// `if x == 4 || x == 6 { ... }` into `if (x == 4) | (x == 6) { ... }`. Rewrite rules can
+// then optimize these using a bitwise operation, in this case generating `if x|2 == 6 { ... }`.
+func fuseSingleBitDifference(b *Block) bool {
+ return fuseComparisons(b, canOptSingleBitDifference)
+}
+
// fuseComparisons looks for control graphs that match this pattern:
//
// p - predecessor
@@ -229,3 +237,40 @@ func canOptNanCheck(x, y *Value, op Op) bool {
}
return false
}
+
+// canOptSingleBitDifference returns true if x op y matches either:
+//
+// v == c || v == d
+// v != c && v != d
+//
+// Where c and d are constant values that differ by a single bit.
+func canOptSingleBitDifference(x, y *Value, op Op) bool {
+ if x.Op != y.Op {
+ return false
+ }
+ switch x.Op {
+ case OpEq64, OpEq32, OpEq16, OpEq8:
+ if op != OpOrB {
+ return false
+ }
+ case OpNeq64, OpNeq32, OpNeq16, OpNeq8:
+ if op != OpAndB {
+ return false
+ }
+ default:
+ return false
+ }
+
+ xi := getConstIntArgIndex(x)
+ if xi < 0 {
+ return false
+ }
+ yi := getConstIntArgIndex(y)
+ if yi < 0 {
+ return false
+ }
+ if x.Args[xi^1] != y.Args[yi^1] {
+ return false
+ }
+ return oneBit(x.Args[xi].AuxInt ^ y.Args[yi].AuxInt)
+}
diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go
index fd5139c0bb..2428f17947 100644
--- a/src/cmd/compile/internal/ssa/rewritegeneric.go
+++ b/src/cmd/compile/internal/ssa/rewritegeneric.go
@@ -5332,6 +5332,182 @@ func rewriteValuegeneric_OpAndB(v *Value) bool {
}
break
}
+ // match: (AndB (Neq64 x cv:(Const64 [c])) (Neq64 x (Const64 [d])))
+ // cond: c|d == c && oneBit(c^d)
+ // result: (Neq64 (Or64 <x.Type> x (Const64 <x.Type> [c^d])) cv)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ if v_0.Op != OpNeq64 {
+ continue
+ }
+ _ = v_0.Args[1]
+ v_0_0 := v_0.Args[0]
+ v_0_1 := v_0.Args[1]
+ for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
+ x := v_0_0
+ cv := v_0_1
+ if cv.Op != OpConst64 {
+ continue
+ }
+ c := auxIntToInt64(cv.AuxInt)
+ if v_1.Op != OpNeq64 {
+ continue
+ }
+ _ = v_1.Args[1]
+ v_1_0 := v_1.Args[0]
+ v_1_1 := v_1.Args[1]
+ for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
+ if x != v_1_0 || v_1_1.Op != OpConst64 {
+ continue
+ }
+ d := auxIntToInt64(v_1_1.AuxInt)
+ if !(c|d == c && oneBit(c^d)) {
+ continue
+ }
+ v.reset(OpNeq64)
+ v0 := b.NewValue0(v.Pos, OpOr64, x.Type)
+ v1 := b.NewValue0(v.Pos, OpConst64, x.Type)
+ v1.AuxInt = int64ToAuxInt(c ^ d)
+ v0.AddArg2(x, v1)
+ v.AddArg2(v0, cv)
+ return true
+ }
+ }
+ }
+ break
+ }
+ // match: (AndB (Neq32 x cv:(Const32 [c])) (Neq32 x (Const32 [d])))
+ // cond: c|d == c && oneBit(c^d)
+ // result: (Neq32 (Or32 <x.Type> x (Const32 <x.Type> [c^d])) cv)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ if v_0.Op != OpNeq32 {
+ continue
+ }
+ _ = v_0.Args[1]
+ v_0_0 := v_0.Args[0]
+ v_0_1 := v_0.Args[1]
+ for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
+ x := v_0_0
+ cv := v_0_1
+ if cv.Op != OpConst32 {
+ continue
+ }
+ c := auxIntToInt32(cv.AuxInt)
+ if v_1.Op != OpNeq32 {
+ continue
+ }
+ _ = v_1.Args[1]
+ v_1_0 := v_1.Args[0]
+ v_1_1 := v_1.Args[1]
+ for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
+ if x != v_1_0 || v_1_1.Op != OpConst32 {
+ continue
+ }
+ d := auxIntToInt32(v_1_1.AuxInt)
+ if !(c|d == c && oneBit(c^d)) {
+ continue
+ }
+ v.reset(OpNeq32)
+ v0 := b.NewValue0(v.Pos, OpOr32, x.Type)
+ v1 := b.NewValue0(v.Pos, OpConst32, x.Type)
+ v1.AuxInt = int32ToAuxInt(c ^ d)
+ v0.AddArg2(x, v1)
+ v.AddArg2(v0, cv)
+ return true
+ }
+ }
+ }
+ break
+ }
+ // match: (AndB (Neq16 x cv:(Const16 [c])) (Neq16 x (Const16 [d])))
+ // cond: c|d == c && oneBit(c^d)
+ // result: (Neq16 (Or16 <x.Type> x (Const16 <x.Type> [c^d])) cv)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ if v_0.Op != OpNeq16 {
+ continue
+ }
+ _ = v_0.Args[1]
+ v_0_0 := v_0.Args[0]
+ v_0_1 := v_0.Args[1]
+ for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
+ x := v_0_0
+ cv := v_0_1
+ if cv.Op != OpConst16 {
+ continue
+ }
+ c := auxIntToInt16(cv.AuxInt)
+ if v_1.Op != OpNeq16 {
+ continue
+ }
+ _ = v_1.Args[1]
+ v_1_0 := v_1.Args[0]
+ v_1_1 := v_1.Args[1]
+ for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
+ if x != v_1_0 || v_1_1.Op != OpConst16 {
+ continue
+ }
+ d := auxIntToInt16(v_1_1.AuxInt)
+ if !(c|d == c && oneBit(c^d)) {
+ continue
+ }
+ v.reset(OpNeq16)
+ v0 := b.NewValue0(v.Pos, OpOr16, x.Type)
+ v1 := b.NewValue0(v.Pos, OpConst16, x.Type)
+ v1.AuxInt = int16ToAuxInt(c ^ d)
+ v0.AddArg2(x, v1)
+ v.AddArg2(v0, cv)
+ return true
+ }
+ }
+ }
+ break
+ }
+ // match: (AndB (Neq8 x cv:(Const8 [c])) (Neq8 x (Const8 [d])))
+ // cond: c|d == c && oneBit(c^d)
+ // result: (Neq8 (Or8 <x.Type> x (Const8 <x.Type> [c^d])) cv)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ if v_0.Op != OpNeq8 {
+ continue
+ }
+ _ = v_0.Args[1]
+ v_0_0 := v_0.Args[0]
+ v_0_1 := v_0.Args[1]
+ for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
+ x := v_0_0
+ cv := v_0_1
+ if cv.Op != OpConst8 {
+ continue
+ }
+ c := auxIntToInt8(cv.AuxInt)
+ if v_1.Op != OpNeq8 {
+ continue
+ }
+ _ = v_1.Args[1]
+ v_1_0 := v_1.Args[0]
+ v_1_1 := v_1.Args[1]
+ for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
+ if x != v_1_0 || v_1_1.Op != OpConst8 {
+ continue
+ }
+ d := auxIntToInt8(v_1_1.AuxInt)
+ if !(c|d == c && oneBit(c^d)) {
+ continue
+ }
+ v.reset(OpNeq8)
+ v0 := b.NewValue0(v.Pos, OpOr8, x.Type)
+ v1 := b.NewValue0(v.Pos, OpConst8, x.Type)
+ v1.AuxInt = int8ToAuxInt(c ^ d)
+ v0.AddArg2(x, v1)
+ v.AddArg2(v0, cv)
+ return true
+ }
+ }
+ }
+ break
+ }
return false
}
func rewriteValuegeneric_OpArraySelect(v *Value) bool {
@@ -23242,6 +23418,182 @@ func rewriteValuegeneric_OpOrB(v *Value) bool {
}
break
}
+ // match: (OrB (Eq64 x cv:(Const64 [c])) (Eq64 x (Const64 [d])))
+ // cond: c|d == c && oneBit(c^d)
+ // result: (Eq64 (Or64 <x.Type> x (Const64 <x.Type> [c^d])) cv)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ if v_0.Op != OpEq64 {
+ continue
+ }
+ _ = v_0.Args[1]
+ v_0_0 := v_0.Args[0]
+ v_0_1 := v_0.Args[1]
+ for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
+ x := v_0_0
+ cv := v_0_1
+ if cv.Op != OpConst64 {
+ continue
+ }
+ c := auxIntToInt64(cv.AuxInt)
+ if v_1.Op != OpEq64 {
+ continue
+ }
+ _ = v_1.Args[1]
+ v_1_0 := v_1.Args[0]
+ v_1_1 := v_1.Args[1]
+ for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
+ if x != v_1_0 || v_1_1.Op != OpConst64 {
+ continue
+ }
+ d := auxIntToInt64(v_1_1.AuxInt)
+ if !(c|d == c && oneBit(c^d)) {
+ continue
+ }
+ v.reset(OpEq64)
+ v0 := b.NewValue0(v.Pos, OpOr64, x.Type)
+ v1 := b.NewValue0(v.Pos, OpConst64, x.Type)
+ v1.AuxInt = int64ToAuxInt(c ^ d)
+ v0.AddArg2(x, v1)
+ v.AddArg2(v0, cv)
+ return true
+ }
+ }
+ }
+ break
+ }
+ // match: (OrB (Eq32 x cv:(Const32 [c])) (Eq32 x (Const32 [d])))
+ // cond: c|d == c && oneBit(c^d)
+ // result: (Eq32 (Or32 <x.Type> x (Const32 <x.Type> [c^d])) cv)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ if v_0.Op != OpEq32 {
+ continue
+ }
+ _ = v_0.Args[1]
+ v_0_0 := v_0.Args[0]
+ v_0_1 := v_0.Args[1]
+ for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
+ x := v_0_0
+ cv := v_0_1
+ if cv.Op != OpConst32 {
+ continue
+ }
+ c := auxIntToInt32(cv.AuxInt)
+ if v_1.Op != OpEq32 {
+ continue
+ }
+ _ = v_1.Args[1]
+ v_1_0 := v_1.Args[0]
+ v_1_1 := v_1.Args[1]
+ for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
+ if x != v_1_0 || v_1_1.Op != OpConst32 {
+ continue
+ }
+ d := auxIntToInt32(v_1_1.AuxInt)
+ if !(c|d == c && oneBit(c^d)) {
+ continue
+ }
+ v.reset(OpEq32)
+ v0 := b.NewValue0(v.Pos, OpOr32, x.Type)
+ v1 := b.NewValue0(v.Pos, OpConst32, x.Type)
+ v1.AuxInt = int32ToAuxInt(c ^ d)
+ v0.AddArg2(x, v1)
+ v.AddArg2(v0, cv)
+ return true
+ }
+ }
+ }
+ break
+ }
+ // match: (OrB (Eq16 x cv:(Const16 [c])) (Eq16 x (Const16 [d])))
+ // cond: c|d == c && oneBit(c^d)
+ // result: (Eq16 (Or16 <x.Type> x (Const16 <x.Type> [c^d])) cv)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ if v_0.Op != OpEq16 {
+ continue
+ }
+ _ = v_0.Args[1]
+ v_0_0 := v_0.Args[0]
+ v_0_1 := v_0.Args[1]
+ for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
+ x := v_0_0
+ cv := v_0_1
+ if cv.Op != OpConst16 {
+ continue
+ }
+ c := auxIntToInt16(cv.AuxInt)
+ if v_1.Op != OpEq16 {
+ continue
+ }
+ _ = v_1.Args[1]
+ v_1_0 := v_1.Args[0]
+ v_1_1 := v_1.Args[1]
+ for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
+ if x != v_1_0 || v_1_1.Op != OpConst16 {
+ continue
+ }
+ d := auxIntToInt16(v_1_1.AuxInt)
+ if !(c|d == c && oneBit(c^d)) {
+ continue
+ }
+ v.reset(OpEq16)
+ v0 := b.NewValue0(v.Pos, OpOr16, x.Type)
+ v1 := b.NewValue0(v.Pos, OpConst16, x.Type)
+ v1.AuxInt = int16ToAuxInt(c ^ d)
+ v0.AddArg2(x, v1)
+ v.AddArg2(v0, cv)
+ return true
+ }
+ }
+ }
+ break
+ }
+ // match: (OrB (Eq8 x cv:(Const8 [c])) (Eq8 x (Const8 [d])))
+ // cond: c|d == c && oneBit(c^d)
+ // result: (Eq8 (Or8 <x.Type> x (Const8 <x.Type> [c^d])) cv)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ if v_0.Op != OpEq8 {
+ continue
+ }
+ _ = v_0.Args[1]
+ v_0_0 := v_0.Args[0]
+ v_0_1 := v_0.Args[1]
+ for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
+ x := v_0_0
+ cv := v_0_1
+ if cv.Op != OpConst8 {
+ continue
+ }
+ c := auxIntToInt8(cv.AuxInt)
+ if v_1.Op != OpEq8 {
+ continue
+ }
+ _ = v_1.Args[1]
+ v_1_0 := v_1.Args[0]
+ v_1_1 := v_1.Args[1]
+ for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
+ if x != v_1_0 || v_1_1.Op != OpConst8 {
+ continue
+ }
+ d := auxIntToInt8(v_1_1.AuxInt)
+ if !(c|d == c && oneBit(c^d)) {
+ continue
+ }
+ v.reset(OpEq8)
+ v0 := b.NewValue0(v.Pos, OpOr8, x.Type)
+ v1 := b.NewValue0(v.Pos, OpConst8, x.Type)
+ v1.AuxInt = int8ToAuxInt(c ^ d)
+ v0.AddArg2(x, v1)
+ v.AddArg2(v0, cv)
+ return true
+ }
+ }
+ }
+ break
+ }
// match: (OrB (Neq64F x x) (Less64F x y:(Const64F [c])))
// result: (Not (Leq64F y x))
for {