diff options
| author | Junyang Shao <shaojunyang@google.com> | 2025-11-05 19:25:00 +0000 |
|---|---|---|
| committer | Junyang Shao <shaojunyang@google.com> | 2025-11-11 13:34:39 -0800 |
| commit | 86b4fe31d9b7fe4b249a3a8007290305eaa4f16a (patch) | |
| tree | 23f7b2fed426e41c055c39d864835012e602119c /src/simd | |
| parent | 771a1dc216ff02dd23c78ada35a207a363690d11 (diff) | |
| download | go-86b4fe31d9b7fe4b249a3a8007290305eaa4f16a.tar.xz | |
[dev.simd] cmd/compile: add masked merging ops and optimizations
This CL generates optimizations for masked variant of AVX512
instructions for patterns:
x.Op(y).Merge(z, mask) => OpMasked(z, x, y mask), where OpMasked is
resultInArg0.
Change-Id: Ife7ccc9ddbf76ae921a085bd6a42b965da9bc179
Reviewed-on: https://go-review.googlesource.com/c/go/+/718160
Reviewed-by: David Chase <drchase@google.com>
TryBot-Bypass: Junyang Shao <shaojunyang@google.com>
Diffstat (limited to 'src/simd')
| -rw-r--r-- | src/simd/_gen/simdgen/gen_simdMachineOps.go | 70 | ||||
| -rw-r--r-- | src/simd/_gen/simdgen/gen_simdTypes.go | 4 | ||||
| -rw-r--r-- | src/simd/_gen/simdgen/gen_simdrules.go | 37 | ||||
| -rw-r--r-- | src/simd/_gen/simdgen/gen_simdssa.go | 33 | ||||
| -rw-r--r-- | src/simd/_gen/simdgen/gen_utility.go | 14 | ||||
| -rw-r--r-- | src/simd/_gen/simdgen/ops/Moves/go.yaml | 15 | ||||
| -rw-r--r-- | src/simd/internal/simd_test/simd_test.go | 19 |
7 files changed, 159 insertions, 33 deletions
diff --git a/src/simd/_gen/simdgen/gen_simdMachineOps.go b/src/simd/_gen/simdgen/gen_simdMachineOps.go index 240227b27d..e8cf792d42 100644 --- a/src/simd/_gen/simdgen/gen_simdMachineOps.go +++ b/src/simd/_gen/simdgen/gen_simdMachineOps.go @@ -30,6 +30,12 @@ func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vf {{- range .OpsDataImmLoad}} {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymValAndOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}}, {{- end}} +{{- range .OpsDataMerging }} + {name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: false, typ: "{{.Type}}", resultInArg0: true}, +{{- end }} +{{- range .OpsDataImmMerging }} + {name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: false, typ: "{{.Type}}", resultInArg0: true}, +{{- end }} } } ` @@ -51,10 +57,12 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer { ResultInArg0 bool } type machineOpsData struct { - OpsData []opData - OpsDataImm []opData - OpsDataLoad []opData - OpsDataImmLoad []opData + OpsData []opData + OpsDataImm []opData + OpsDataLoad []opData + OpsDataImmLoad []opData + OpsDataMerging []opData + OpsDataImmMerging []opData } regInfoSet := map[string]bool{ @@ -66,6 +74,8 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer { opsDataImm := make([]opData, 0) opsDataLoad := make([]opData, 0) opsDataImmLoad := make([]opData, 0) + opsDataMerging := make([]opData, 0) + opsDataImmMerging := make([]opData, 0) // Determine the "best" version of an instruction to use best := make(map[string]Operation) @@ -98,7 +108,7 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer { regInfoMissing := make(map[string]bool, 0) for _, asm := range mOpOrder { op := best[asm] - shapeIn, shapeOut, _, _, gOp := op.shape() + shapeIn, shapeOut, maskType, _, gOp := op.shape() // TODO: all our masked operations are now zeroing, we need to generate machine ops with merging masks, maybe copy // one here with a name suffix "Merging". The rewrite rules will need them. @@ -147,11 +157,13 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer { resultInArg0 = true } var memOpData *opData + regInfoMerging := regInfo + hasMerging := false if op.MemFeatures != nil && *op.MemFeatures == "vbcst" { // Right now we only have vbcst case // Make a full vec memory variant. - op = rewriteLastVregToMem(op) - regInfo, err := makeRegInfo(op, VregMemIn) + opMem := rewriteLastVregToMem(op) + regInfo, err := makeRegInfo(opMem, VregMemIn) if err != nil { // Just skip it if it's non nill. // an error could be triggered by [checkVecAsScalar]. @@ -163,16 +175,51 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer { memOpData = &opData{asm + "load", gOp.Asm, len(gOp.In) + 1, regInfo, false, outType, resultInArg0} } } + hasMerging = gOp.hasMaskedMerging(maskType, shapeOut) + if hasMerging && !resultInArg0 { + // We have to copy the slice here becasue the sort will be visible from other + // aliases when no reslicing is happening. + newIn := make([]Operand, len(op.In), len(op.In)+1) + copy(newIn, op.In) + op.In = newIn + op.In = append(op.In, op.Out[0]) + op.sortOperand() + regInfoMerging, err = makeRegInfo(op, NoMem) + if err != nil { + panic(err) + } + } + if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn { opsDataImm = append(opsDataImm, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0}) if memOpData != nil { + if *op.MemFeatures != "vbcst" { + panic("simdgen only knows vbcst for mem ops for now") + } opsDataImmLoad = append(opsDataImmLoad, *memOpData) } + if hasMerging { + mergingLen := len(gOp.In) + if !resultInArg0 { + mergingLen++ + } + opsDataImmMerging = append(opsDataImmMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0}) + } } else { opsData = append(opsData, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0}) if memOpData != nil { + if *op.MemFeatures != "vbcst" { + panic("simdgen only knows vbcst for mem ops for now") + } opsDataLoad = append(opsDataLoad, *memOpData) } + if hasMerging { + mergingLen := len(gOp.In) + if !resultInArg0 { + mergingLen++ + } + opsDataMerging = append(opsDataMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0}) + } } } if len(regInfoErrs) != 0 { @@ -193,7 +240,14 @@ func writeSIMDMachineOps(ops []Operation) *bytes.Buffer { sort.Slice(opsDataImmLoad, func(i, j int) bool { return compareNatural(opsDataImmLoad[i].OpName, opsDataImmLoad[j].OpName) < 0 }) - err := t.Execute(buffer, machineOpsData{opsData, opsDataImm, opsDataLoad, opsDataImmLoad}) + sort.Slice(opsDataMerging, func(i, j int) bool { + return compareNatural(opsDataMerging[i].OpName, opsDataMerging[j].OpName) < 0 + }) + sort.Slice(opsDataImmMerging, func(i, j int) bool { + return compareNatural(opsDataImmMerging[i].OpName, opsDataImmMerging[j].OpName) < 0 + }) + err := t.Execute(buffer, machineOpsData{opsData, opsDataImm, opsDataLoad, opsDataImmLoad, + opsDataMerging, opsDataImmMerging}) if err != nil { panic(fmt.Errorf("failed to execute template: %w", err)) } diff --git a/src/simd/_gen/simdgen/gen_simdTypes.go b/src/simd/_gen/simdgen/gen_simdTypes.go index efa3ffabeb..c809fcd1de 100644 --- a/src/simd/_gen/simdgen/gen_simdTypes.go +++ b/src/simd/_gen/simdgen/gen_simdTypes.go @@ -585,8 +585,8 @@ func writeSIMDFeatures(ops []Operation) *bytes.Buffer { return buffer } -// writeSIMDStubs generates the simd vector intrinsic stubs and writes it to ops_amd64.go and ops_internal_amd64.go -// within the specified directory. +// writeSIMDStubs returns two bytes.Buffers containing the declarations for the public +// and internal-use vector intrinsics. func writeSIMDStubs(ops []Operation, typeMap simdTypeMap) (f, fI *bytes.Buffer) { t := templateOf(simdStubsTmpl, "simdStubs") f = new(bytes.Buffer) diff --git a/src/simd/_gen/simdgen/gen_simdrules.go b/src/simd/_gen/simdgen/gen_simdrules.go index 2103678ea9..8dd1707da9 100644 --- a/src/simd/_gen/simdgen/gen_simdrules.go +++ b/src/simd/_gen/simdgen/gen_simdrules.go @@ -126,6 +126,9 @@ func writeSIMDRules(ops []Operation) *bytes.Buffer { buffer := new(bytes.Buffer) buffer.WriteString(generatedHeader + "\n") + // asm -> masked merging rules + maskedMergeOpts := make(map[string]string) + s2n := map[int]string{8: "B", 16: "W", 32: "D", 64: "Q"} asmCheck := map[string]bool{} var allData []tplRuleData var optData []tplRuleData // for mask peephole optimizations, and other misc @@ -295,6 +298,33 @@ func writeSIMDRules(ops []Operation) *bytes.Buffer { memOpData.tplName = "vregMem" } memOptData = append(memOptData, memOpData) + asmCheck[memOpData.Asm+"load"] = true + } + } + // Generate the masked merging optimization rules + if gOp.hasMaskedMerging(maskType, opOutShape) { + // TODO: handle customized operand order and special lower. + maskElem := gOp.In[len(gOp.In)-1] + if maskElem.Bits == nil { + panic("mask has no bits") + } + if maskElem.ElemBits == nil { + panic("mask has no elemBits") + } + if maskElem.Lanes == nil { + panic("mask has no lanes") + } + switch *maskElem.Bits { + case 128, 256: + // VPBLENDVB cases. + noMaskName := machineOpName(NoMask, gOp) + maskedMergeOpts[noMaskName] = fmt.Sprintf("(VPBLENDVB%d dst (%s %s) mask) && v.Block.CPUfeatures.hasFeature(CPUavx512) => (%sMerging dst %s (VPMOVVec%dx%dToM <types.TypeMask> mask))\n", + *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args, *maskElem.ElemBits, *maskElem.Lanes) + case 512: + // VPBLENDM[BWDQ] cases. + noMaskName := machineOpName(NoMask, gOp) + maskedMergeOpts[noMaskName] = fmt.Sprintf("(VPBLENDM%sMasked%d dst (%s %s) mask) => (%sMerging dst %s mask)\n", + s2n[*maskElem.ElemBits], *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args) } } @@ -332,6 +362,13 @@ func writeSIMDRules(ops []Operation) *bytes.Buffer { } } + for asm, rule := range maskedMergeOpts { + if !asmCheck[asm] { + continue + } + buffer.WriteString(rule) + } + for _, data := range memOptData { if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil { panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.Asm, err)) diff --git a/src/simd/_gen/simdgen/gen_simdssa.go b/src/simd/_gen/simdgen/gen_simdssa.go index 20cfaabfb8..c9d8693aa1 100644 --- a/src/simd/_gen/simdgen/gen_simdssa.go +++ b/src/simd/_gen/simdgen/gen_simdssa.go @@ -99,6 +99,7 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer { "v21ResultInArg0", "v21ResultInArg0Imm8", "v31x0AtIn2ResultInArg0", + "v2kvResultInArg0", } regInfoSet := map[string][]string{} for _, key := range regInfoKeys { @@ -107,7 +108,8 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer { seen := map[string]struct{}{} allUnseen := make(map[string][]Operation) - classifyOp := func(op Operation, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error { + allUnseenCaseStr := make(map[string][]string) + classifyOp := func(op Operation, maskType maskShape, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error { regShape, err := op.regShape(mem) if err != nil { return err @@ -127,8 +129,31 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer { } if _, ok := regInfoSet[regShape]; !ok { allUnseen[regShape] = append(allUnseen[regShape], op) + allUnseenCaseStr[regShape] = append(allUnseenCaseStr[regShape], caseStr) } regInfoSet[regShape] = append(regInfoSet[regShape], caseStr) + if mem == NoMem && op.hasMaskedMerging(maskType, shapeOut) { + regShapeMerging := regShape + if shapeOut != OneVregOutAtIn { + // We have to copy the slice here becasue the sort will be visible from other + // aliases when no reslicing is happening. + newIn := make([]Operand, len(op.In), len(op.In)+1) + copy(newIn, op.In) + op.In = newIn + op.In = append(op.In, op.Out[0]) + op.sortOperand() + regShapeMerging, err = op.regShape(mem) + regShapeMerging += "ResultInArg0" + } + if err != nil { + return err + } + if _, ok := regInfoSet[regShapeMerging]; !ok { + allUnseen[regShapeMerging] = append(allUnseen[regShapeMerging], op) + allUnseenCaseStr[regShapeMerging] = append(allUnseenCaseStr[regShapeMerging], caseStr+"Merging") + } + regInfoSet[regShapeMerging] = append(regInfoSet[regShapeMerging], caseStr+"Merging") + } return nil } for _, op := range ops { @@ -146,7 +171,7 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer { isZeroMasking = true } } - if err := classifyOp(op, shapeIn, shapeOut, caseStr, NoMem); err != nil { + if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr, NoMem); err != nil { panic(err) } if op.MemFeatures != nil && *op.MemFeatures == "vbcst" { @@ -155,7 +180,7 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer { // Ignore the error // an error could be triggered by [checkVecAsScalar]. // TODO: make [checkVecAsScalar] aware of mem ops. - if err := classifyOp(op, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil { + if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil { if *Verbose { log.Printf("Seen error: %e", err) } @@ -169,7 +194,7 @@ func writeSIMDSSA(ops []Operation) *bytes.Buffer { for k := range allUnseen { allKeys = append(allKeys, k) } - panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v", allUnseen, allKeys)) + panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v\n, cases: %v\n", allUnseen, allKeys, allUnseenCaseStr)) } buffer := new(bytes.Buffer) diff --git a/src/simd/_gen/simdgen/gen_utility.go b/src/simd/_gen/simdgen/gen_utility.go index 2fb05026c0..c0bc73d5dc 100644 --- a/src/simd/_gen/simdgen/gen_utility.go +++ b/src/simd/_gen/simdgen/gen_utility.go @@ -523,10 +523,6 @@ func checkVecAsScalar(op Operation) (idx int, err error) { } } if idx >= 0 { - if idx != 1 { - err = fmt.Errorf("simdgen only supports TreatLikeAScalarOfSize at the 2nd arg of the arg list: %s", op) - return - } if sSize != 8 && sSize != 16 && sSize != 32 && sSize != 64 { err = fmt.Errorf("simdgen does not recognize this uint size: %d, %s", sSize, op) return @@ -545,6 +541,10 @@ func rewriteVecAsScalarRegInfo(op Operation, regInfo string) (string, error) { regInfo = "vfpv" } else if regInfo == "v2kv" { regInfo = "vfpkv" + } else if regInfo == "v31" { + regInfo = "v2fpv" + } else if regInfo == "v3kv" { + regInfo = "v2fpkv" } else { return "", fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regInfo, op) } @@ -807,6 +807,12 @@ func reportXEDInconsistency(ops []Operation) error { return nil } +func (o *Operation) hasMaskedMerging(maskType maskShape, outType outShape) bool { + // BLEND and VMOVDQU are not user-facing ops so we should filter them out. + return o.OperandOrder == nil && o.SpecialLower == nil && maskType == OneMask && outType == OneVregOut && + len(o.InVariant) == 1 && !strings.Contains(o.Asm, "BLEND") && !strings.Contains(o.Asm, "VMOVDQU") +} + func getVbcstData(s string) (feat1Match, feat2Match string) { _, err := fmt.Sscanf(s, "feat1=%[^;];feat2=%s", &feat1Match, &feat2Match) if err != nil { diff --git a/src/simd/_gen/simdgen/ops/Moves/go.yaml b/src/simd/_gen/simdgen/ops/Moves/go.yaml index 08e857c8ea..a1aefd8406 100644 --- a/src/simd/_gen/simdgen/ops/Moves/go.yaml +++ b/src/simd/_gen/simdgen/ops/Moves/go.yaml @@ -299,21 +299,6 @@ out: - *v - # For AVX512 -- go: move - asm: VMOVUP[SD] - zeroing: true - in: - - &v - go: $t - class: vreg - base: float - inVariant: - - - class: mask - out: - - *v - - go: Expand asm: "VPEXPAND[BWDQ]|VEXPANDP[SD]" in: diff --git a/src/simd/internal/simd_test/simd_test.go b/src/simd/internal/simd_test/simd_test.go index c64ac0fcfd..f3492170e9 100644 --- a/src/simd/internal/simd_test/simd_test.go +++ b/src/simd/internal/simd_test/simd_test.go @@ -1108,3 +1108,22 @@ func TestSelectTernOptInt32x16(t *testing.T) { } foo(t2, applyTo3(x, y, z, ft2)) } + +func TestMaskedMerge(t *testing.T) { + x := simd.LoadInt64x4Slice([]int64{1, 2, 3, 4}) + y := simd.LoadInt64x4Slice([]int64{5, 6, 1, 1}) + z := simd.LoadInt64x4Slice([]int64{-1, -2, -3, -4}) + res := make([]int64, 4) + expected := []int64{6, 8, -3, -4} + mask := x.Less(y) + if simd.HasAVX512() { + x.Add(y).Merge(z, mask).StoreSlice(res) + } else { + x.Add(y).Merge(z, mask).StoreSlice(res) + } + for i := range 4 { + if res[i] != expected[i] { + t.Errorf("got %d wanted %d", res[i], expected[i]) + } + } +} |
