From f6b47110952ea1c19cbdc040489c83f306c36e73 Mon Sep 17 00:00:00 2001 From: David Chase Date: Thu, 9 Oct 2025 15:12:47 -0400 Subject: [dev.simd] cmd/compile, simd: add rewrite to convert logical expression trees into TERNLOG instructions includes tests of both rewrite application and rewrite correctness Change-Id: I7983ccf87a8408af95bb6c447cb22f01beda9f61 Reviewed-on: https://go-review.googlesource.com/c/go/+/710697 LUCI-TryBot-Result: Go LUCI Reviewed-by: Junyang Shao --- src/simd/genfiles.go | 155 +++++++++++++++++++++++++++++++ src/simd/internal/simd_test/simd_test.go | 78 ++++++++++++++++ 2 files changed, 233 insertions(+) (limited to 'src/simd') diff --git a/src/simd/genfiles.go b/src/simd/genfiles.go index 80234ac9f8..be23b127c8 100644 --- a/src/simd/genfiles.go +++ b/src/simd/genfiles.go @@ -254,6 +254,15 @@ package simd `, s) } +func ssaPrologue(s string, out io.Writer) { + fmt.Fprintf(out, + `// Code generated by '%s'; DO NOT EDIT. + +package ssa + +`, s) +} + func unsafePrologue(s string, out io.Writer) { fmt.Fprintf(out, `// Code generated by '%s'; DO NOT EDIT. @@ -806,6 +815,7 @@ func (x {{.VType}}) String() string { `) const TD = "internal/simd_test/" +const SSA = "../cmd/compile/internal/ssa/" func main() { sl := flag.String("sl", "slice_gen_amd64.go", "file name for slice operations") @@ -867,6 +877,115 @@ func main() { if *cmh != "" { one(*cmh, curryTestPrologue("simd methods that compare two operands under a mask"), compareMaskedTemplate) } + + nonTemplateRewrites(SSA+"tern_helpers.go", ssaPrologue, classifyBooleanSIMD, ternOpForLogical) + +} + +func ternOpForLogical(out io.Writer) { + fmt.Fprintf(out, ` +func ternOpForLogical(op Op) Op { + switch op { +`) + + intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) { + wt, ct := w, c + if wt < 32 { + wt = 32 + ct = (w * c) / wt + } + fmt.Fprintf(out, "case OpAndInt%[1]dx%[2]d, OpOrInt%[1]dx%[2]d, OpXorInt%[1]dx%[2]d,OpAndNotInt%[1]dx%[2]d: return OpternInt%dx%d\n", w, c, wt, ct) + fmt.Fprintf(out, "case OpAndUint%[1]dx%[2]d, OpOrUint%[1]dx%[2]d, OpXorUint%[1]dx%[2]d,OpAndNotUint%[1]dx%[2]d: return OpternUint%dx%d\n", w, c, wt, ct) + }, out) + + fmt.Fprintf(out, ` + } + return op +} +`) + +} + +func classifyBooleanSIMD(out io.Writer) { + fmt.Fprintf(out, ` +type SIMDLogicalOP uint8 +const ( + // boolean simd operations, for reducing expression to VPTERNLOG* instructions + // sloInterior is set for non-root nodes in logical-op expression trees. + sloInterior SIMDLogicalOP = 1 + sloNone SIMDLogicalOP = 2 * iota + sloAnd + sloOr + sloAndNot + sloXor + sloNot +) +func classifyBooleanSIMD(v *Value) SIMDLogicalOP { + switch v.Op { + case `) + intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) { + op := "And" + if seq > 0 { + fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c) + } else { + fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c) + } + seq++ + }, out) + + fmt.Fprintf(out, `: + return sloAnd + + case `) + intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) { + op := "Or" + if seq > 0 { + fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c) + } else { + fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c) + } + seq++ + }, out) + + fmt.Fprintf(out, `: + return sloOr + + case `) + intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) { + op := "AndNot" + if seq > 0 { + fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c) + } else { + fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c) + } + seq++ + }, out) + + fmt.Fprintf(out, `: + return sloAndNot +`) + + // "Not" is encoded as x.Xor(x.Equal(x).AsInt8x16()) + // i.e. xor.Args[0] == x, xor.Args[1].Op == As... + // but AsInt8x16 is a pun/passthrough. + + intShapes.forAllShapes( + func(seq int, t, upperT string, w, c int, out io.Writer) { + fmt.Fprintf(out, "case OpXor%s%dx%d: ", upperT, w, c) + fmt.Fprintf(out, ` + if y := v.Args[1]; y.Op == OpEqual%s%dx%d && + y.Args[0] == y.Args[1] { + return sloNot + } + `, upperT, w, c) + fmt.Fprintf(out, "return sloXor\n") + }, out) + + fmt.Fprintf(out, ` + } + return sloNone +} +`) } // numberLines takes a slice of bytes, and returns a string where each line @@ -881,6 +1000,42 @@ func numberLines(data []byte) string { return buf.String() } +func nonTemplateRewrites(filename string, prologue func(s string, out io.Writer), rewrites ...func(out io.Writer)) { + if filename == "" { + return + } + + ofile := os.Stdout + + if filename != "-" { + var err error + ofile, err = os.Create(filename) + if err != nil { + fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err) + os.Exit(1) + } + } + + out := new(bytes.Buffer) + + prologue("go run genfiles.go", out) + for _, rewrite := range rewrites { + rewrite(out) + } + + b, err := format.Source(out.Bytes()) + if err != nil { + fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err) + fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes())) + fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err) + os.Exit(1) + } else { + ofile.Write(b) + ofile.Close() + } + +} + func one(filename string, prologue func(s string, out io.Writer), sats ...shapeAndTemplate) { if filename == "" { return diff --git a/src/simd/internal/simd_test/simd_test.go b/src/simd/internal/simd_test/simd_test.go index 295f7bf9ce..c64ac0fcfd 100644 --- a/src/simd/internal/simd_test/simd_test.go +++ b/src/simd/internal/simd_test/simd_test.go @@ -1030,3 +1030,81 @@ func TestString(t *testing.T) { t.Logf("y=%s", y) t.Logf("z=%s", z) } + +// a returns an slice of 16 int32 +func a() []int32 { + return make([]int32, 16, 16) +} + +// applyTo3 returns a 16-element slice of the results of +// applying f to the respective elements of vectors x, y, and z. +func applyTo3(x, y, z simd.Int32x16, f func(x, y, z int32) int32) []int32 { + ax, ay, az := a(), a(), a() + x.StoreSlice(ax) + y.StoreSlice(ay) + z.StoreSlice(az) + + r := a() + for i := range r { + r[i] = f(ax[i], ay[i], az[i]) + } + return r +} + +// applyTo3 returns a 16-element slice of the results of +// applying f to the respective elements of vectors x, y, z, and w. +func applyTo4(x, y, z, w simd.Int32x16, f func(x, y, z, w int32) int32) []int32 { + ax, ay, az, aw := a(), a(), a(), a() + x.StoreSlice(ax) + y.StoreSlice(ay) + z.StoreSlice(az) + w.StoreSlice(aw) + + r := make([]int32, len(ax), len(ax)) + for i := range r { + r[i] = f(ax[i], ay[i], az[i], aw[i]) + } + return r +} + +func TestSelectTernOptInt32x16(t *testing.T) { + if !simd.HasAVX512() { + t.Skip("Test requires HasAVX512, not available on this hardware") + return + } + ax := []int32{0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1} + ay := []int32{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1} + az := []int32{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1} + aw := []int32{0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1} + am := []int32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + + x := simd.LoadInt32x16Slice(ax) + y := simd.LoadInt32x16Slice(ay) + z := simd.LoadInt32x16Slice(az) + w := simd.LoadInt32x16Slice(aw) + m := simd.LoadInt32x16Slice(am) + + foo := func(v simd.Int32x16, s []int32) { + r := make([]int32, 16, 16) + v.StoreSlice(r) + checkSlices[int32](t, r, s) + } + + t0 := w.Xor(y).Xor(z) + ft0 := func(w, y, z int32) int32 { + return w ^ y ^ z + } + foo(t0, applyTo3(w, y, z, ft0)) + + t1 := m.And(w.Xor(y).Xor(z.Not())) + ft1 := func(m, w, y, z int32) int32 { + return m & (w ^ y ^ ^z) + } + foo(t1, applyTo4(m, w, y, z, ft1)) + + t2 := x.Xor(y).Xor(z).And(x.Xor(y).Xor(z.Not())) + ft2 := func(x, y, z int32) int32 { + return (x ^ y ^ z) & (x ^ y ^ ^z) + } + foo(t2, applyTo3(x, y, z, ft2)) +} -- cgit v1.3