aboutsummaryrefslogtreecommitdiff
path: root/src/simd
diff options
context:
space:
mode:
authorJunyang Shao <shaojunyang@google.com>2025-11-05 19:25:00 +0000
committerJunyang Shao <shaojunyang@google.com>2025-11-11 13:34:39 -0800
commit86b4fe31d9b7fe4b249a3a8007290305eaa4f16a (patch)
tree23f7b2fed426e41c055c39d864835012e602119c /src/simd
parent771a1dc216ff02dd23c78ada35a207a363690d11 (diff)
downloadgo-86b4fe31d9b7fe4b249a3a8007290305eaa4f16a.tar.xz
[dev.simd] cmd/compile: add masked merging ops and optimizations
This CL generates optimizations for masked variant of AVX512 instructions for patterns: x.Op(y).Merge(z, mask) => OpMasked(z, x, y mask), where OpMasked is resultInArg0. Change-Id: Ife7ccc9ddbf76ae921a085bd6a42b965da9bc179 Reviewed-on: https://go-review.googlesource.com/c/go/+/718160 Reviewed-by: David Chase <drchase@google.com> TryBot-Bypass: Junyang Shao <shaojunyang@google.com>
Diffstat (limited to 'src/simd')
-rw-r--r--src/simd/_gen/simdgen/gen_simdMachineOps.go70
-rw-r--r--src/simd/_gen/simdgen/gen_simdTypes.go4
-rw-r--r--src/simd/_gen/simdgen/gen_simdrules.go37
-rw-r--r--src/simd/_gen/simdgen/gen_simdssa.go33
-rw-r--r--src/simd/_gen/simdgen/gen_utility.go14
-rw-r--r--src/simd/_gen/simdgen/ops/Moves/go.yaml15
-rw-r--r--src/simd/internal/simd_test/simd_test.go19
7 files changed, 159 insertions, 33 deletions
diff --git a/src/simd/_gen/simdgen/gen_simdMachineOps.go b/src/simd/_gen/simdgen/gen_simdMachineOps.go
index 240227b27d..e8cf792d42 100644
--- a/src/simd/_gen/simdgen/gen_simdMachineOps.go
+++ b/src/simd/_gen/simdgen/gen_simdMachineOps.go
@@ -30,6 +30,12 @@ func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vf
{{- range .OpsDataImmLoad}}
{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymValAndOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
{{- end}}
+{{- range .OpsDataMerging }}
+ {name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: false, typ: "{{.Type}}", resultInArg0: true},
+{{- end }}
+{{- range .OpsDataImmMerging }}
+ {name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: false, typ: "{{.Type}}", resultInArg0: true},
+{{- end }}
}
}
`
@@ -51,10 +57,12 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
ResultInArg0 bool
}
type machineOpsData struct {
- OpsData []opData
- OpsDataImm []opData
- OpsDataLoad []opData
- OpsDataImmLoad []opData
+ OpsData []opData
+ OpsDataImm []opData
+ OpsDataLoad []opData
+ OpsDataImmLoad []opData
+ OpsDataMerging []opData
+ OpsDataImmMerging []opData
}
regInfoSet := map[string]bool{
@@ -66,6 +74,8 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
opsDataImm := make([]opData, 0)
opsDataLoad := make([]opData, 0)
opsDataImmLoad := make([]opData, 0)
+ opsDataMerging := make([]opData, 0)
+ opsDataImmMerging := make([]opData, 0)
// Determine the "best" version of an instruction to use
best := make(map[string]Operation)
@@ -98,7 +108,7 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
regInfoMissing := make(map[string]bool, 0)
for _, asm := range mOpOrder {
op := best[asm]
- shapeIn, shapeOut, _, _, gOp := op.shape()
+ shapeIn, shapeOut, maskType, _, gOp := op.shape()
// TODO: all our masked operations are now zeroing, we need to generate machine ops with merging masks, maybe copy
// one here with a name suffix "Merging". The rewrite rules will need them.
@@ -147,11 +157,13 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
resultInArg0 = true
}
var memOpData *opData
+ regInfoMerging := regInfo
+ hasMerging := false
if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
// Right now we only have vbcst case
// Make a full vec memory variant.
- op = rewriteLastVregToMem(op)
- regInfo, err := makeRegInfo(op, VregMemIn)
+ opMem := rewriteLastVregToMem(op)
+ regInfo, err := makeRegInfo(opMem, VregMemIn)
if err != nil {
// Just skip it if it's non nill.
// an error could be triggered by [checkVecAsScalar].
@@ -163,16 +175,51 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
memOpData = &opData{asm + "load", gOp.Asm, len(gOp.In) + 1, regInfo, false, outType, resultInArg0}
}
}
+ hasMerging = gOp.hasMaskedMerging(maskType, shapeOut)
+ if hasMerging && !resultInArg0 {
+ // We have to copy the slice here becasue the sort will be visible from other
+ // aliases when no reslicing is happening.
+ newIn := make([]Operand, len(op.In), len(op.In)+1)
+ copy(newIn, op.In)
+ op.In = newIn
+ op.In = append(op.In, op.Out[0])
+ op.sortOperand()
+ regInfoMerging, err = makeRegInfo(op, NoMem)
+ if err != nil {
+ panic(err)
+ }
+ }
+
if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
opsDataImm = append(opsDataImm, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
if memOpData != nil {
+ if *op.MemFeatures != "vbcst" {
+ panic("simdgen only knows vbcst for mem ops for now")
+ }
opsDataImmLoad = append(opsDataImmLoad, *memOpData)
}
+ if hasMerging {
+ mergingLen := len(gOp.In)
+ if !resultInArg0 {
+ mergingLen++
+ }
+ opsDataImmMerging = append(opsDataImmMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
+ }
} else {
opsData = append(opsData, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
if memOpData != nil {
+ if *op.MemFeatures != "vbcst" {
+ panic("simdgen only knows vbcst for mem ops for now")
+ }
opsDataLoad = append(opsDataLoad, *memOpData)
}
+ if hasMerging {
+ mergingLen := len(gOp.In)
+ if !resultInArg0 {
+ mergingLen++
+ }
+ opsDataMerging = append(opsDataMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
+ }
}
}
if len(regInfoErrs) != 0 {
@@ -193,7 +240,14 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
sort.Slice(opsDataImmLoad, func(i, j int) bool {
return compareNatural(opsDataImmLoad[i].OpName, opsDataImmLoad[j].OpName) < 0
})
- err := t.Execute(buffer, machineOpsData{opsData, opsDataImm, opsDataLoad, opsDataImmLoad})
+ sort.Slice(opsDataMerging, func(i, j int) bool {
+ return compareNatural(opsDataMerging[i].OpName, opsDataMerging[j].OpName) < 0
+ })
+ sort.Slice(opsDataImmMerging, func(i, j int) bool {
+ return compareNatural(opsDataImmMerging[i].OpName, opsDataImmMerging[j].OpName) < 0
+ })
+ err := t.Execute(buffer, machineOpsData{opsData, opsDataImm, opsDataLoad, opsDataImmLoad,
+ opsDataMerging, opsDataImmMerging})
if err != nil {
panic(fmt.Errorf("failed to execute template: %w", err))
}
diff --git a/src/simd/_gen/simdgen/gen_simdTypes.go b/src/simd/_gen/simdgen/gen_simdTypes.go
index efa3ffabeb..c809fcd1de 100644
--- a/src/simd/_gen/simdgen/gen_simdTypes.go
+++ b/src/simd/_gen/simdgen/gen_simdTypes.go
@@ -585,8 +585,8 @@ func writeSIMDFeatures(ops []Operation) *bytes.Buffer {
return buffer
}
-// writeSIMDStubs generates the simd vector intrinsic stubs and writes it to ops_amd64.go and ops_internal_amd64.go
-// within the specified directory.
+// writeSIMDStubs returns two bytes.Buffers containing the declarations for the public
+// and internal-use vector intrinsics.
func writeSIMDStubs(ops []Operation, typeMap simdTypeMap) (f, fI *bytes.Buffer) {
t := templateOf(simdStubsTmpl, "simdStubs")
f = new(bytes.Buffer)
diff --git a/src/simd/_gen/simdgen/gen_simdrules.go b/src/simd/_gen/simdgen/gen_simdrules.go
index 2103678ea9..8dd1707da9 100644
--- a/src/simd/_gen/simdgen/gen_simdrules.go
+++ b/src/simd/_gen/simdgen/gen_simdrules.go
@@ -126,6 +126,9 @@ func writeSIMDRules(ops []Operation) *bytes.Buffer {
buffer := new(bytes.Buffer)
buffer.WriteString(generatedHeader + "\n")
+ // asm -> masked merging rules
+ maskedMergeOpts := make(map[string]string)
+ s2n := map[int]string{8: "B", 16: "W", 32: "D", 64: "Q"}
asmCheck := map[string]bool{}
var allData []tplRuleData
var optData []tplRuleData // for mask peephole optimizations, and other misc
@@ -295,6 +298,33 @@ func writeSIMDRules(ops []Operation) *bytes.Buffer {
memOpData.tplName = "vregMem"
}
memOptData = append(memOptData, memOpData)
+ asmCheck[memOpData.Asm+"load"] = true
+ }
+ }
+ // Generate the masked merging optimization rules
+ if gOp.hasMaskedMerging(maskType, opOutShape) {
+ // TODO: handle customized operand order and special lower.
+ maskElem := gOp.In[len(gOp.In)-1]
+ if maskElem.Bits == nil {
+ panic("mask has no bits")
+ }
+ if maskElem.ElemBits == nil {
+ panic("mask has no elemBits")
+ }
+ if maskElem.Lanes == nil {
+ panic("mask has no lanes")
+ }
+ switch *maskElem.Bits {
+ case 128, 256:
+ // VPBLENDVB cases.
+ noMaskName := machineOpName(NoMask, gOp)
+ maskedMergeOpts[noMaskName] = fmt.Sprintf("(VPBLENDVB%d dst (%s %s) mask) && v.Block.CPUfeatures.hasFeature(CPUavx512) => (%sMerging dst %s (VPMOVVec%dx%dToM <types.TypeMask> mask))\n",
+ *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args, *maskElem.ElemBits, *maskElem.Lanes)
+ case 512:
+ // VPBLENDM[BWDQ] cases.
+ noMaskName := machineOpName(NoMask, gOp)
+ maskedMergeOpts[noMaskName] = fmt.Sprintf("(VPBLENDM%sMasked%d dst (%s %s) mask) => (%sMerging dst %s mask)\n",
+ s2n[*maskElem.ElemBits], *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args)
}
}
@@ -332,6 +362,13 @@ func writeSIMDRules(ops []Operation) *bytes.Buffer {
}
}
+ for asm, rule := range maskedMergeOpts {
+ if !asmCheck[asm] {
+ continue
+ }
+ buffer.WriteString(rule)
+ }
+
for _, data := range memOptData {
if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.Asm, err))
diff --git a/src/simd/_gen/simdgen/gen_simdssa.go b/src/simd/_gen/simdgen/gen_simdssa.go
index 20cfaabfb8..c9d8693aa1 100644
--- a/src/simd/_gen/simdgen/gen_simdssa.go
+++ b/src/simd/_gen/simdgen/gen_simdssa.go
@@ -99,6 +99,7 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer {
"v21ResultInArg0",
"v21ResultInArg0Imm8",
"v31x0AtIn2ResultInArg0",
+ "v2kvResultInArg0",
}
regInfoSet := map[string][]string{}
for _, key := range regInfoKeys {
@@ -107,7 +108,8 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer {
seen := map[string]struct{}{}
allUnseen := make(map[string][]Operation)
- classifyOp := func(op Operation, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error {
+ allUnseenCaseStr := make(map[string][]string)
+ classifyOp := func(op Operation, maskType maskShape, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error {
regShape, err := op.regShape(mem)
if err != nil {
return err
@@ -127,8 +129,31 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer {
}
if _, ok := regInfoSet[regShape]; !ok {
allUnseen[regShape] = append(allUnseen[regShape], op)
+ allUnseenCaseStr[regShape] = append(allUnseenCaseStr[regShape], caseStr)
}
regInfoSet[regShape] = append(regInfoSet[regShape], caseStr)
+ if mem == NoMem && op.hasMaskedMerging(maskType, shapeOut) {
+ regShapeMerging := regShape
+ if shapeOut != OneVregOutAtIn {
+ // We have to copy the slice here becasue the sort will be visible from other
+ // aliases when no reslicing is happening.
+ newIn := make([]Operand, len(op.In), len(op.In)+1)
+ copy(newIn, op.In)
+ op.In = newIn
+ op.In = append(op.In, op.Out[0])
+ op.sortOperand()
+ regShapeMerging, err = op.regShape(mem)
+ regShapeMerging += "ResultInArg0"
+ }
+ if err != nil {
+ return err
+ }
+ if _, ok := regInfoSet[regShapeMerging]; !ok {
+ allUnseen[regShapeMerging] = append(allUnseen[regShapeMerging], op)
+ allUnseenCaseStr[regShapeMerging] = append(allUnseenCaseStr[regShapeMerging], caseStr+"Merging")
+ }
+ regInfoSet[regShapeMerging] = append(regInfoSet[regShapeMerging], caseStr+"Merging")
+ }
return nil
}
for _, op := range ops {
@@ -146,7 +171,7 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer {
isZeroMasking = true
}
}
- if err := classifyOp(op, shapeIn, shapeOut, caseStr, NoMem); err != nil {
+ if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr, NoMem); err != nil {
panic(err)
}
if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
@@ -155,7 +180,7 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer {
// Ignore the error
// an error could be triggered by [checkVecAsScalar].
// TODO: make [checkVecAsScalar] aware of mem ops.
- if err := classifyOp(op, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil {
+ if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil {
if *Verbose {
log.Printf("Seen error: %e", err)
}
@@ -169,7 +194,7 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer {
for k := range allUnseen {
allKeys = append(allKeys, k)
}
- panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v", allUnseen, allKeys))
+ panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v\n, cases: %v\n", allUnseen, allKeys, allUnseenCaseStr))
}
buffer := new(bytes.Buffer)
diff --git a/src/simd/_gen/simdgen/gen_utility.go b/src/simd/_gen/simdgen/gen_utility.go
index 2fb05026c0..c0bc73d5dc 100644
--- a/src/simd/_gen/simdgen/gen_utility.go
+++ b/src/simd/_gen/simdgen/gen_utility.go
@@ -523,10 +523,6 @@ func checkVecAsScalar(op Operation) (idx int, err error) {
}
}
if idx >= 0 {
- if idx != 1 {
- err = fmt.Errorf("simdgen only supports TreatLikeAScalarOfSize at the 2nd arg of the arg list: %s", op)
- return
- }
if sSize != 8 && sSize != 16 && sSize != 32 && sSize != 64 {
err = fmt.Errorf("simdgen does not recognize this uint size: %d, %s", sSize, op)
return
@@ -545,6 +541,10 @@ func rewriteVecAsScalarRegInfo(op Operation, regInfo string) (string, error) {
regInfo = "vfpv"
} else if regInfo == "v2kv" {
regInfo = "vfpkv"
+ } else if regInfo == "v31" {
+ regInfo = "v2fpv"
+ } else if regInfo == "v3kv" {
+ regInfo = "v2fpkv"
} else {
return "", fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regInfo, op)
}
@@ -807,6 +807,12 @@ func reportXEDInconsistency(ops []Operation) error {
return nil
}
+func (o *Operation) hasMaskedMerging(maskType maskShape, outType outShape) bool {
+ // BLEND and VMOVDQU are not user-facing ops so we should filter them out.
+ return o.OperandOrder == nil && o.SpecialLower == nil && maskType == OneMask && outType == OneVregOut &&
+ len(o.InVariant) == 1 && !strings.Contains(o.Asm, "BLEND") && !strings.Contains(o.Asm, "VMOVDQU")
+}
+
func getVbcstData(s string) (feat1Match, feat2Match string) {
_, err := fmt.Sscanf(s, "feat1=%[^;];feat2=%s", &feat1Match, &feat2Match)
if err != nil {
diff --git a/src/simd/_gen/simdgen/ops/Moves/go.yaml b/src/simd/_gen/simdgen/ops/Moves/go.yaml
index 08e857c8ea..a1aefd8406 100644
--- a/src/simd/_gen/simdgen/ops/Moves/go.yaml
+++ b/src/simd/_gen/simdgen/ops/Moves/go.yaml
@@ -299,21 +299,6 @@
out:
- *v
- # For AVX512
-- go: move
- asm: VMOVUP[SD]
- zeroing: true
- in:
- - &v
- go: $t
- class: vreg
- base: float
- inVariant:
- -
- class: mask
- out:
- - *v
-
- go: Expand
asm: "VPEXPAND[BWDQ]|VEXPANDP[SD]"
in:
diff --git a/src/simd/internal/simd_test/simd_test.go b/src/simd/internal/simd_test/simd_test.go
index c64ac0fcfd..f3492170e9 100644
--- a/src/simd/internal/simd_test/simd_test.go
+++ b/src/simd/internal/simd_test/simd_test.go
@@ -1108,3 +1108,22 @@ func TestSelectTernOptInt32x16(t *testing.T) {
}
foo(t2, applyTo3(x, y, z, ft2))
}
+
+func TestMaskedMerge(t *testing.T) {
+ x := simd.LoadInt64x4Slice([]int64{1, 2, 3, 4})
+ y := simd.LoadInt64x4Slice([]int64{5, 6, 1, 1})
+ z := simd.LoadInt64x4Slice([]int64{-1, -2, -3, -4})
+ res := make([]int64, 4)
+ expected := []int64{6, 8, -3, -4}
+ mask := x.Less(y)
+ if simd.HasAVX512() {
+ x.Add(y).Merge(z, mask).StoreSlice(res)
+ } else {
+ x.Add(y).Merge(z, mask).StoreSlice(res)
+ }
+ for i := range 4 {
+ if res[i] != expected[i] {
+ t.Errorf("got %d wanted %d", res[i], expected[i])
+ }
+ }
+}