diff options
| author | David Chase <drchase@google.com> | 2025-09-20 16:52:07 -0400 |
|---|---|---|
| committer | David Chase <drchase@google.com> | 2025-09-26 13:11:24 -0700 |
| commit | ea3b2ecd2878a694f9f42011eccb1312feb82bca (patch) | |
| tree | 88b9d9075fadcb19dc676e070232b74d3e6b1d0e /src/cmd/compile | |
| parent | 25c36b95d1523f22d4c46ec237acc03e00540e0a (diff) | |
| download | go-ea3b2ecd2878a694f9f42011eccb1312feb82bca.tar.xz | |
[dev.simd] cmd/compile, simd: add 64-bit select-from-pair methods
these are in the same style as the 32-bit select-from-pair,
including the grouped variant. This does not quite capture
the full awesome power of VSHUFPD where it can select
differently in each group; that will be some other method,
that is more complex.
Change-Id: I807ddd7c1256103b5b0d7c5d60bd70b185e3aaf0
Reviewed-on: https://go-review.googlesource.com/c/go/+/705695
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Junyang Shao <shaojunyang@google.com>
Diffstat (limited to 'src/cmd/compile')
| -rw-r--r-- | src/cmd/compile/internal/ssagen/intrinsics.go | 140 |
1 files changed, 102 insertions, 38 deletions
diff --git a/src/cmd/compile/internal/ssagen/intrinsics.go b/src/cmd/compile/internal/ssagen/intrinsics.go index 4c5cd9ef2c..6561cbe9a2 100644 --- a/src/cmd/compile/internal/ssagen/intrinsics.go +++ b/src/cmd/compile/internal/ssagen/intrinsics.go @@ -1632,12 +1632,12 @@ func initIntrinsics(cfg *intrinsicBuildConfig) { addF(simdPackage, "Uint32x8.IsZero", opLen1(ssa.OpIsZeroVec, types.Types[types.TBOOL]), sys.AMD64) addF(simdPackage, "Uint64x4.IsZero", opLen1(ssa.OpIsZeroVec, types.Types[types.TBOOL]), sys.AMD64) - sfp := func(method string, hwop ssa.Op, vectype *types.Type) { + sfp4 := func(method string, hwop ssa.Op, vectype *types.Type) { addF("simd", method, func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { x, a, b, c, d, y := args[0], args[1], args[2], args[3], args[4], args[5] if a.Op == ssa.OpConst8 && b.Op == ssa.OpConst8 && c.Op == ssa.OpConst8 && d.Op == ssa.OpConst8 { - return selectFromPair(x, a, b, c, d, y, s, hwop, vectype) + return select4FromPair(x, a, b, c, d, y, s, hwop, vectype) } else { return s.callResult(n, callNormal) } @@ -1645,25 +1645,64 @@ func initIntrinsics(cfg *intrinsicBuildConfig) { sys.AMD64) } - sfp("Int32x4.SelectFromPair", ssa.OpconcatSelectedConstantInt32x4, types.TypeVec128) - sfp("Uint32x4.SelectFromPair", ssa.OpconcatSelectedConstantUint32x4, types.TypeVec128) - sfp("Float32x4.SelectFromPair", ssa.OpconcatSelectedConstantFloat32x4, types.TypeVec128) + sfp4("Int32x4.SelectFromPair", ssa.OpconcatSelectedConstantInt32x4, types.TypeVec128) + sfp4("Uint32x4.SelectFromPair", ssa.OpconcatSelectedConstantUint32x4, types.TypeVec128) + sfp4("Float32x4.SelectFromPair", ssa.OpconcatSelectedConstantFloat32x4, types.TypeVec128) - sfp("Int32x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedInt32x8, types.TypeVec256) - sfp("Uint32x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedUint32x8, types.TypeVec256) - sfp("Float32x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedFloat32x8, types.TypeVec256) + sfp4("Int32x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedInt32x8, types.TypeVec256) + sfp4("Uint32x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedUint32x8, types.TypeVec256) + sfp4("Float32x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedFloat32x8, types.TypeVec256) - sfp("Int32x16.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedInt32x16, types.TypeVec512) - sfp("Uint32x16.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedUint32x16, types.TypeVec512) - sfp("Float32x16.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedFloat32x16, types.TypeVec512) + sfp4("Int32x16.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedInt32x16, types.TypeVec512) + sfp4("Uint32x16.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedUint32x16, types.TypeVec512) + sfp4("Float32x16.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedFloat32x16, types.TypeVec512) + + sfp2 := func(method string, hwop ssa.Op, vectype *types.Type, cscimm func(i, j uint8) int64) { + addF("simd", method, + func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { + x, a, b, y := args[0], args[1], args[2], args[3] + if a.Op == ssa.OpConst8 && b.Op == ssa.OpConst8 { + return select2FromPair(x, a, b, y, s, hwop, vectype, cscimm) + } else { + return s.callResult(n, callNormal) + } + }, + sys.AMD64) + } + + sfp2("Uint64x2.SelectFromPair", ssa.OpconcatSelectedConstantUint64x2, types.TypeVec128, cscimm2) + sfp2("Int64x2.SelectFromPair", ssa.OpconcatSelectedConstantInt64x2, types.TypeVec128, cscimm2) + sfp2("Float64x2.SelectFromPair", ssa.OpconcatSelectedConstantFloat64x2, types.TypeVec128, cscimm2) + + sfp2("Uint64x4.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedUint64x4, types.TypeVec256, cscimm2g2) + sfp2("Int64x4.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedInt64x4, types.TypeVec256, cscimm2g2) + sfp2("Float64x4.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedFloat64x4, types.TypeVec256, cscimm2g2) + + sfp2("Uint64x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedUint64x8, types.TypeVec512, cscimm2g4) + sfp2("Int64x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedInt64x8, types.TypeVec512, cscimm2g4) + sfp2("Float64x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedFloat64x8, types.TypeVec512, cscimm2g4) } } -func cscimm(a, b, c, d uint8) int64 { +func cscimm4(a, b, c, d uint8) int64 { return se(a + b<<2 + c<<4 + d<<6) } +func cscimm2(a, b uint8) int64 { + return se(a + b<<1) +} + +func cscimm2g2(a, b uint8) int64 { + g := cscimm2(a, b) + return int64(int8(g + g<<2)) +} + +func cscimm2g4(a, b uint8) int64 { + g := cscimm2g2(a, b) + return int64(int8(g + g<<4)) +} + const ( _LLLL = iota _HLLL @@ -1683,7 +1722,32 @@ const ( _HHHH ) -func selectFromPair(x, _a, _b, _c, _d, y *ssa.Value, s *state, op ssa.Op, t *types.Type) *ssa.Value { +const ( + _LL = iota + _HL + _LH + _HH +) + +func select2FromPair(x, _a, _b, y *ssa.Value, s *state, op ssa.Op, t *types.Type, csc func(a, b uint8) int64) *ssa.Value { + a, b := uint8(_a.AuxInt8()), uint8(_b.AuxInt8()) + pattern := (a&2)>>1 + (b & 2) + a, b = a&1, b&1 + + switch pattern { + case _LL: + return s.newValue2I(op, t, csc(a, b), x, x) + case _HH: + return s.newValue2I(op, t, csc(a, b), y, y) + case _LH: + return s.newValue2I(op, t, csc(a, b), x, y) + case _HL: + return s.newValue2I(op, t, csc(a, b), y, x) + } + panic("The preceding switch should have been exhaustive") +} + +func select4FromPair(x, _a, _b, _c, _d, y *ssa.Value, s *state, op ssa.Op, t *types.Type) *ssa.Value { a, b, c, d := uint8(_a.AuxInt8()), uint8(_b.AuxInt8()), uint8(_c.AuxInt8()), uint8(_d.AuxInt8()) pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 @@ -1692,54 +1756,54 @@ func selectFromPair(x, _a, _b, _c, _d, y *ssa.Value, s *state, op ssa.Op, t *typ switch pattern { case _LLLL: // TODO DETECT 0,1,2,3, 0,0,0,0 - return s.newValue2I(op, t, cscimm(a, b, c, d), x, x) + return s.newValue2I(op, t, cscimm4(a, b, c, d), x, x) case _HHHH: // TODO DETECT 0,1,2,3, 0,0,0,0 - return s.newValue2I(op, t, cscimm(a, b, c, d), y, y) + return s.newValue2I(op, t, cscimm4(a, b, c, d), y, y) case _LLHH: - return s.newValue2I(op, t, cscimm(a, b, c, d), x, y) + return s.newValue2I(op, t, cscimm4(a, b, c, d), x, y) case _HHLL: - return s.newValue2I(op, t, cscimm(a, b, c, d), y, x) + return s.newValue2I(op, t, cscimm4(a, b, c, d), y, x) case _HLLL: - z := s.newValue2I(op, t, cscimm(a, a, b, b), y, x) - return s.newValue2I(op, t, cscimm(0, 2, c, d), z, x) + z := s.newValue2I(op, t, cscimm4(a, a, b, b), y, x) + return s.newValue2I(op, t, cscimm4(0, 2, c, d), z, x) case _LHLL: - z := s.newValue2I(op, t, cscimm(a, a, b, b), x, y) - return s.newValue2I(op, t, cscimm(0, 2, c, d), z, x) + z := s.newValue2I(op, t, cscimm4(a, a, b, b), x, y) + return s.newValue2I(op, t, cscimm4(0, 2, c, d), z, x) case _HLHH: - z := s.newValue2I(op, t, cscimm(a, a, b, b), y, x) - return s.newValue2I(op, t, cscimm(0, 2, c, d), z, y) + z := s.newValue2I(op, t, cscimm4(a, a, b, b), y, x) + return s.newValue2I(op, t, cscimm4(0, 2, c, d), z, y) case _LHHH: - z := s.newValue2I(op, t, cscimm(a, a, b, b), x, y) - return s.newValue2I(op, t, cscimm(0, 2, c, d), z, y) + z := s.newValue2I(op, t, cscimm4(a, a, b, b), x, y) + return s.newValue2I(op, t, cscimm4(0, 2, c, d), z, y) case _LLLH: - z := s.newValue2I(op, t, cscimm(c, c, d, d), x, y) - return s.newValue2I(op, t, cscimm(a, b, 0, 2), x, z) + z := s.newValue2I(op, t, cscimm4(c, c, d, d), x, y) + return s.newValue2I(op, t, cscimm4(a, b, 0, 2), x, z) case _LLHL: - z := s.newValue2I(op, t, cscimm(c, c, d, d), y, x) - return s.newValue2I(op, t, cscimm(a, b, 0, 2), x, z) + z := s.newValue2I(op, t, cscimm4(c, c, d, d), y, x) + return s.newValue2I(op, t, cscimm4(a, b, 0, 2), x, z) case _HHLH: - z := s.newValue2I(op, t, cscimm(c, c, d, d), x, y) - return s.newValue2I(op, t, cscimm(a, b, 0, 2), y, z) + z := s.newValue2I(op, t, cscimm4(c, c, d, d), x, y) + return s.newValue2I(op, t, cscimm4(a, b, 0, 2), y, z) case _HHHL: - z := s.newValue2I(op, t, cscimm(c, c, d, d), y, x) - return s.newValue2I(op, t, cscimm(a, b, 0, 2), y, z) + z := s.newValue2I(op, t, cscimm4(c, c, d, d), y, x) + return s.newValue2I(op, t, cscimm4(a, b, 0, 2), y, z) case _LHLH: - z := s.newValue2I(op, t, cscimm(a, c, b, d), x, y) + z := s.newValue2I(op, t, cscimm4(a, c, b, d), x, y) return s.newValue2I(op, t, se(0b11_01_10_00), z, z) case _HLHL: - z := s.newValue2I(op, t, cscimm(b, d, a, c), x, y) + z := s.newValue2I(op, t, cscimm4(b, d, a, c), x, y) return s.newValue2I(op, t, se(0b01_11_00_10), z, z) case _HLLH: - z := s.newValue2I(op, t, cscimm(b, c, a, d), x, y) + z := s.newValue2I(op, t, cscimm4(b, c, a, d), x, y) return s.newValue2I(op, t, se(0b11_01_00_10), z, z) case _LHHL: - z := s.newValue2I(op, t, cscimm(a, d, b, c), x, y) + z := s.newValue2I(op, t, cscimm4(a, d, b, c), x, y) return s.newValue2I(op, t, se(0b01_11_10_00), z, z) } panic("The preceding switch should have been exhaustive") @@ -1906,7 +1970,7 @@ func opLen2Imm8_II(op ssa.Op, t *types.Type, _ int) func(s *state, n *ir.CallExp return func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { if args[1].Op == ssa.OpConst8 && args[2].Op == ssa.OpConst8 && args[1].AuxInt & ^3 == 0 && args[2].AuxInt & ^3 == 0 { i1, i2 := args[1].AuxInt, args[2].AuxInt - return s.newValue2I(op, t, i1+i2<<4, args[0], args[3]) + return s.newValue2I(op, t, int64(int8(i1+i2<<4)), args[0], args[3]) } four := s.constInt64(types.Types[types.TUINT8], 4) shifted := s.newValue2(ssa.OpLsh8x8, types.Types[types.TUINT8], args[2], four) |
