aboutsummaryrefslogtreecommitdiff
path: root/src/simd
diff options
context:
space:
mode:
authorDavid Chase <drchase@google.com>2025-10-09 15:12:47 -0400
committerDavid Chase <drchase@google.com>2025-10-24 11:05:14 -0700
commitf6b47110952ea1c19cbdc040489c83f306c36e73 (patch)
tree3908c2f5819808800047aa08e886857979f76ca2 /src/simd
parentcf7c1a4cbb917b6c5d80d1d9443a40cb7720db75 (diff)
downloadgo-f6b47110952ea1c19cbdc040489c83f306c36e73.tar.xz
[dev.simd] cmd/compile, simd: add rewrite to convert logical expression trees into TERNLOG instructions
includes tests of both rewrite application and rewrite correctness Change-Id: I7983ccf87a8408af95bb6c447cb22f01beda9f61 Reviewed-on: https://go-review.googlesource.com/c/go/+/710697 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Junyang Shao <shaojunyang@google.com>
Diffstat (limited to 'src/simd')
-rw-r--r--src/simd/genfiles.go155
-rw-r--r--src/simd/internal/simd_test/simd_test.go78
2 files changed, 233 insertions, 0 deletions
diff --git a/src/simd/genfiles.go b/src/simd/genfiles.go
index 80234ac9f8..be23b127c8 100644
--- a/src/simd/genfiles.go
+++ b/src/simd/genfiles.go
@@ -254,6 +254,15 @@ package simd
`, s)
}
+func ssaPrologue(s string, out io.Writer) {
+ fmt.Fprintf(out,
+ `// Code generated by '%s'; DO NOT EDIT.
+
+package ssa
+
+`, s)
+}
+
func unsafePrologue(s string, out io.Writer) {
fmt.Fprintf(out,
`// Code generated by '%s'; DO NOT EDIT.
@@ -806,6 +815,7 @@ func (x {{.VType}}) String() string {
`)
const TD = "internal/simd_test/"
+const SSA = "../cmd/compile/internal/ssa/"
func main() {
sl := flag.String("sl", "slice_gen_amd64.go", "file name for slice operations")
@@ -867,6 +877,115 @@ func main() {
if *cmh != "" {
one(*cmh, curryTestPrologue("simd methods that compare two operands under a mask"), compareMaskedTemplate)
}
+
+ nonTemplateRewrites(SSA+"tern_helpers.go", ssaPrologue, classifyBooleanSIMD, ternOpForLogical)
+
+}
+
+func ternOpForLogical(out io.Writer) {
+ fmt.Fprintf(out, `
+func ternOpForLogical(op Op) Op {
+ switch op {
+`)
+
+ intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
+ wt, ct := w, c
+ if wt < 32 {
+ wt = 32
+ ct = (w * c) / wt
+ }
+ fmt.Fprintf(out, "case OpAndInt%[1]dx%[2]d, OpOrInt%[1]dx%[2]d, OpXorInt%[1]dx%[2]d,OpAndNotInt%[1]dx%[2]d: return OpternInt%dx%d\n", w, c, wt, ct)
+ fmt.Fprintf(out, "case OpAndUint%[1]dx%[2]d, OpOrUint%[1]dx%[2]d, OpXorUint%[1]dx%[2]d,OpAndNotUint%[1]dx%[2]d: return OpternUint%dx%d\n", w, c, wt, ct)
+ }, out)
+
+ fmt.Fprintf(out, `
+ }
+ return op
+}
+`)
+
+}
+
+func classifyBooleanSIMD(out io.Writer) {
+ fmt.Fprintf(out, `
+type SIMDLogicalOP uint8
+const (
+ // boolean simd operations, for reducing expression to VPTERNLOG* instructions
+ // sloInterior is set for non-root nodes in logical-op expression trees.
+ sloInterior SIMDLogicalOP = 1
+ sloNone SIMDLogicalOP = 2 * iota
+ sloAnd
+ sloOr
+ sloAndNot
+ sloXor
+ sloNot
+)
+func classifyBooleanSIMD(v *Value) SIMDLogicalOP {
+ switch v.Op {
+ case `)
+ intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
+ op := "And"
+ if seq > 0 {
+ fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
+ } else {
+ fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
+ }
+ seq++
+ }, out)
+
+ fmt.Fprintf(out, `:
+ return sloAnd
+
+ case `)
+ intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
+ op := "Or"
+ if seq > 0 {
+ fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
+ } else {
+ fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
+ }
+ seq++
+ }, out)
+
+ fmt.Fprintf(out, `:
+ return sloOr
+
+ case `)
+ intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
+ op := "AndNot"
+ if seq > 0 {
+ fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
+ } else {
+ fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
+ }
+ seq++
+ }, out)
+
+ fmt.Fprintf(out, `:
+ return sloAndNot
+`)
+
+ // "Not" is encoded as x.Xor(x.Equal(x).AsInt8x16())
+ // i.e. xor.Args[0] == x, xor.Args[1].Op == As...
+ // but AsInt8x16 is a pun/passthrough.
+
+ intShapes.forAllShapes(
+ func(seq int, t, upperT string, w, c int, out io.Writer) {
+ fmt.Fprintf(out, "case OpXor%s%dx%d: ", upperT, w, c)
+ fmt.Fprintf(out, `
+ if y := v.Args[1]; y.Op == OpEqual%s%dx%d &&
+ y.Args[0] == y.Args[1] {
+ return sloNot
+ }
+ `, upperT, w, c)
+ fmt.Fprintf(out, "return sloXor\n")
+ }, out)
+
+ fmt.Fprintf(out, `
+ }
+ return sloNone
+}
+`)
}
// numberLines takes a slice of bytes, and returns a string where each line
@@ -881,6 +1000,42 @@ func numberLines(data []byte) string {
return buf.String()
}
+func nonTemplateRewrites(filename string, prologue func(s string, out io.Writer), rewrites ...func(out io.Writer)) {
+ if filename == "" {
+ return
+ }
+
+ ofile := os.Stdout
+
+ if filename != "-" {
+ var err error
+ ofile, err = os.Create(filename)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err)
+ os.Exit(1)
+ }
+ }
+
+ out := new(bytes.Buffer)
+
+ prologue("go run genfiles.go", out)
+ for _, rewrite := range rewrites {
+ rewrite(out)
+ }
+
+ b, err := format.Source(out.Bytes())
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
+ fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
+ fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
+ os.Exit(1)
+ } else {
+ ofile.Write(b)
+ ofile.Close()
+ }
+
+}
+
func one(filename string, prologue func(s string, out io.Writer), sats ...shapeAndTemplate) {
if filename == "" {
return
diff --git a/src/simd/internal/simd_test/simd_test.go b/src/simd/internal/simd_test/simd_test.go
index 295f7bf9ce..c64ac0fcfd 100644
--- a/src/simd/internal/simd_test/simd_test.go
+++ b/src/simd/internal/simd_test/simd_test.go
@@ -1030,3 +1030,81 @@ func TestString(t *testing.T) {
t.Logf("y=%s", y)
t.Logf("z=%s", z)
}
+
+// a returns an slice of 16 int32
+func a() []int32 {
+ return make([]int32, 16, 16)
+}
+
+// applyTo3 returns a 16-element slice of the results of
+// applying f to the respective elements of vectors x, y, and z.
+func applyTo3(x, y, z simd.Int32x16, f func(x, y, z int32) int32) []int32 {
+ ax, ay, az := a(), a(), a()
+ x.StoreSlice(ax)
+ y.StoreSlice(ay)
+ z.StoreSlice(az)
+
+ r := a()
+ for i := range r {
+ r[i] = f(ax[i], ay[i], az[i])
+ }
+ return r
+}
+
+// applyTo3 returns a 16-element slice of the results of
+// applying f to the respective elements of vectors x, y, z, and w.
+func applyTo4(x, y, z, w simd.Int32x16, f func(x, y, z, w int32) int32) []int32 {
+ ax, ay, az, aw := a(), a(), a(), a()
+ x.StoreSlice(ax)
+ y.StoreSlice(ay)
+ z.StoreSlice(az)
+ w.StoreSlice(aw)
+
+ r := make([]int32, len(ax), len(ax))
+ for i := range r {
+ r[i] = f(ax[i], ay[i], az[i], aw[i])
+ }
+ return r
+}
+
+func TestSelectTernOptInt32x16(t *testing.T) {
+ if !simd.HasAVX512() {
+ t.Skip("Test requires HasAVX512, not available on this hardware")
+ return
+ }
+ ax := []int32{0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}
+ ay := []int32{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1}
+ az := []int32{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}
+ aw := []int32{0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}
+ am := []int32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+
+ x := simd.LoadInt32x16Slice(ax)
+ y := simd.LoadInt32x16Slice(ay)
+ z := simd.LoadInt32x16Slice(az)
+ w := simd.LoadInt32x16Slice(aw)
+ m := simd.LoadInt32x16Slice(am)
+
+ foo := func(v simd.Int32x16, s []int32) {
+ r := make([]int32, 16, 16)
+ v.StoreSlice(r)
+ checkSlices[int32](t, r, s)
+ }
+
+ t0 := w.Xor(y).Xor(z)
+ ft0 := func(w, y, z int32) int32 {
+ return w ^ y ^ z
+ }
+ foo(t0, applyTo3(w, y, z, ft0))
+
+ t1 := m.And(w.Xor(y).Xor(z.Not()))
+ ft1 := func(m, w, y, z int32) int32 {
+ return m & (w ^ y ^ ^z)
+ }
+ foo(t1, applyTo4(m, w, y, z, ft1))
+
+ t2 := x.Xor(y).Xor(z).And(x.Xor(y).Xor(z.Not()))
+ ft2 := func(x, y, z int32) int32 {
+ return (x ^ y ^ z) & (x ^ y ^ ^z)
+ }
+ foo(t2, applyTo3(x, y, z, ft2))
+}