aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMeng Zhuo <mzh@golangcn.org>2023-06-28 16:45:07 +0800
committerM Zhuo <mzh@golangcn.org>2023-08-22 12:05:36 +0000
commit63ab68ddc5f1307e552cf27ae7a6f0dfda2bb962 (patch)
tree348c81d74ae80a5da5f118c15cb211b561759b94 /src
parent05f951158278da91a67a2f6380ffbf0c9172f565 (diff)
downloadgo-63ab68ddc5f1307e552cf27ae7a6f0dfda2bb962.tar.xz
cmd/compile: add single-precision FMA code generation for riscv64
This CL adds FMADDS,FMSUBS,FNMADDS,FNMSUBS SSA support for riscv Change-Id: I1e7dd322b46b9e0f4923dbba256303d69ed12066 Reviewed-on: https://go-review.googlesource.com/c/go/+/506616 Reviewed-by: Joel Sing <joel@sing.id.au> Reviewed-by: David Chase <drchase@google.com> TryBot-Result: Gopher Robot <gobot@golang.org> Reviewed-by: Keith Randall <khr@google.com> Run-TryBot: M Zhuo <mzh@golangcn.org>
Diffstat (limited to 'src')
-rw-r--r--src/cmd/compile/internal/riscv64/ssa.go3
-rw-r--r--src/cmd/compile/internal/ssa/_gen/RISCV64.rules9
-rw-r--r--src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go4
-rw-r--r--src/cmd/compile/internal/ssa/opGen.go68
-rw-r--r--src/cmd/compile/internal/ssa/rewriteRISCV64.go256
5 files changed, 336 insertions, 4 deletions
diff --git a/src/cmd/compile/internal/riscv64/ssa.go b/src/cmd/compile/internal/riscv64/ssa.go
index 143e7c525a..f8cf786920 100644
--- a/src/cmd/compile/internal/riscv64/ssa.go
+++ b/src/cmd/compile/internal/riscv64/ssa.go
@@ -332,7 +332,8 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
p2.From.Reg = v.Reg1()
p2.To.Type = obj.TYPE_REG
p2.To.Reg = v.Reg1()
- case ssa.OpRISCV64FMADDD, ssa.OpRISCV64FMSUBD, ssa.OpRISCV64FNMADDD, ssa.OpRISCV64FNMSUBD:
+ case ssa.OpRISCV64FMADDD, ssa.OpRISCV64FMSUBD, ssa.OpRISCV64FNMADDD, ssa.OpRISCV64FNMSUBD,
+ ssa.OpRISCV64FMADDS, ssa.OpRISCV64FMSUBS, ssa.OpRISCV64FNMADDS, ssa.OpRISCV64FNMSUBS:
r := v.Reg()
r1 := v.Args[0].Reg()
r2 := v.Args[1].Reg()
diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules
index ac68dfed76..e0bf00d45d 100644
--- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules
+++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules
@@ -780,9 +780,10 @@
(Select0 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MULHU x y)
(Select1 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MUL x y)
-(FADDD a (FMULD x y)) && a.Block.Func.useFMA(v) => (FMADDD x y a)
-(FSUBD a (FMULD x y)) && a.Block.Func.useFMA(v) => (FNMSUBD x y a)
-(FSUBD (FMULD x y) a) && a.Block.Func.useFMA(v) => (FMSUBD x y a)
+(FADD(S|D) a (FMUL(S|D) x y)) && a.Block.Func.useFMA(v) => (FMADD(S|D) x y a)
+(FSUB(S|D) a (FMUL(S|D) x y)) && a.Block.Func.useFMA(v) => (FNMSUB(S|D) x y a)
+(FSUB(S|D) (FMUL(S|D) x y) a) && a.Block.Func.useFMA(v) => (FMSUB(S|D) x y a)
+
// Merge negation into fused multiply-add and multiply-subtract.
//
// Key:
@@ -793,5 +794,7 @@
// D B
//
// Note: multiplication commutativity handled by rule generator.
+(F(MADD|NMADD|MSUB|NMSUB)S neg:(FNEGS x) y z) && neg.Uses == 1 => (F(NMSUB|MSUB|NMADD|MADD)S x y z)
+(F(MADD|NMADD|MSUB|NMSUB)S x y neg:(FNEGS z)) && neg.Uses == 1 => (F(MSUB|NMSUB|MADD|NMADD)S x y z)
(F(MADD|NMADD|MSUB|NMSUB)D neg:(FNEGD x) y z) && neg.Uses == 1 => (F(NMSUB|MSUB|NMADD|MADD)D x y z)
(F(MADD|NMADD|MSUB|NMSUB)D x y neg:(FNEGD z)) && neg.Uses == 1 => (F(MSUB|NMSUB|MADD|NMADD)D x y z)
diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go
index 69f2950a88..317e9150c9 100644
--- a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go
+++ b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go
@@ -411,6 +411,10 @@ func init() {
{name: "FSUBS", argLength: 2, reg: fp21, asm: "FSUBS", commutative: false, typ: "Float32"}, // arg0 - arg1
{name: "FMULS", argLength: 2, reg: fp21, asm: "FMULS", commutative: true, typ: "Float32"}, // arg0 * arg1
{name: "FDIVS", argLength: 2, reg: fp21, asm: "FDIVS", commutative: false, typ: "Float32"}, // arg0 / arg1
+ {name: "FMADDS", argLength: 3, reg: fp31, asm: "FMADDS", commutative: true, typ: "Float32"}, // (arg0 * arg1) + arg2
+ {name: "FMSUBS", argLength: 3, reg: fp31, asm: "FMSUBS", commutative: true, typ: "Float32"}, // (arg0 * arg1) - arg2
+ {name: "FNMADDS", argLength: 3, reg: fp31, asm: "FNMADDS", commutative: true, typ: "Float32"}, // -(arg0 * arg1) + arg2
+ {name: "FNMSUBS", argLength: 3, reg: fp31, asm: "FNMSUBS", commutative: true, typ: "Float32"}, // -(arg0 * arg1) - arg2
{name: "FSQRTS", argLength: 1, reg: fp11, asm: "FSQRTS", typ: "Float32"}, // sqrt(arg0)
{name: "FNEGS", argLength: 1, reg: fp11, asm: "FNEGS", typ: "Float32"}, // -arg0
{name: "FMVSX", argLength: 1, reg: gpfp, asm: "FMVSX", typ: "Float32"}, // reinterpret arg0 as float
diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go
index 12d8214ae1..11a6138357 100644
--- a/src/cmd/compile/internal/ssa/opGen.go
+++ b/src/cmd/compile/internal/ssa/opGen.go
@@ -2436,6 +2436,10 @@ const (
OpRISCV64FSUBS
OpRISCV64FMULS
OpRISCV64FDIVS
+ OpRISCV64FMADDS
+ OpRISCV64FMSUBS
+ OpRISCV64FNMADDS
+ OpRISCV64FNMSUBS
OpRISCV64FSQRTS
OpRISCV64FNEGS
OpRISCV64FMVSX
@@ -32674,6 +32678,70 @@ var opcodeTable = [...]opInfo{
},
},
{
+ name: "FMADDS",
+ argLen: 3,
+ commutative: true,
+ asm: riscv.AFMADDS,
+ reg: regInfo{
+ inputs: []inputInfo{
+ {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ },
+ outputs: []outputInfo{
+ {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ },
+ },
+ },
+ {
+ name: "FMSUBS",
+ argLen: 3,
+ commutative: true,
+ asm: riscv.AFMSUBS,
+ reg: regInfo{
+ inputs: []inputInfo{
+ {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ },
+ outputs: []outputInfo{
+ {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ },
+ },
+ },
+ {
+ name: "FNMADDS",
+ argLen: 3,
+ commutative: true,
+ asm: riscv.AFNMADDS,
+ reg: regInfo{
+ inputs: []inputInfo{
+ {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ },
+ outputs: []outputInfo{
+ {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ },
+ },
+ },
+ {
+ name: "FNMSUBS",
+ argLen: 3,
+ commutative: true,
+ asm: riscv.AFNMSUBS,
+ reg: regInfo{
+ inputs: []inputInfo{
+ {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ },
+ outputs: []outputInfo{
+ {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+ },
+ },
+ },
+ {
name: "FSQRTS",
argLen: 1,
asm: riscv.AFSQRTS,
diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go
index 17af023db3..0ad6433bf4 100644
--- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go
+++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go
@@ -442,16 +442,28 @@ func rewriteValueRISCV64(v *Value) bool {
return rewriteValueRISCV64_OpRISCV64ANDI(v)
case OpRISCV64FADDD:
return rewriteValueRISCV64_OpRISCV64FADDD(v)
+ case OpRISCV64FADDS:
+ return rewriteValueRISCV64_OpRISCV64FADDS(v)
case OpRISCV64FMADDD:
return rewriteValueRISCV64_OpRISCV64FMADDD(v)
+ case OpRISCV64FMADDS:
+ return rewriteValueRISCV64_OpRISCV64FMADDS(v)
case OpRISCV64FMSUBD:
return rewriteValueRISCV64_OpRISCV64FMSUBD(v)
+ case OpRISCV64FMSUBS:
+ return rewriteValueRISCV64_OpRISCV64FMSUBS(v)
case OpRISCV64FNMADDD:
return rewriteValueRISCV64_OpRISCV64FNMADDD(v)
+ case OpRISCV64FNMADDS:
+ return rewriteValueRISCV64_OpRISCV64FNMADDS(v)
case OpRISCV64FNMSUBD:
return rewriteValueRISCV64_OpRISCV64FNMSUBD(v)
+ case OpRISCV64FNMSUBS:
+ return rewriteValueRISCV64_OpRISCV64FNMSUBS(v)
case OpRISCV64FSUBD:
return rewriteValueRISCV64_OpRISCV64FSUBD(v)
+ case OpRISCV64FSUBS:
+ return rewriteValueRISCV64_OpRISCV64FSUBS(v)
case OpRISCV64MOVBUload:
return rewriteValueRISCV64_OpRISCV64MOVBUload(v)
case OpRISCV64MOVBUreg:
@@ -3364,6 +3376,31 @@ func rewriteValueRISCV64_OpRISCV64FADDD(v *Value) bool {
}
return false
}
+func rewriteValueRISCV64_OpRISCV64FADDS(v *Value) bool {
+ v_1 := v.Args[1]
+ v_0 := v.Args[0]
+ // match: (FADDS a (FMULS x y))
+ // cond: a.Block.Func.useFMA(v)
+ // result: (FMADDS x y a)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ a := v_0
+ if v_1.Op != OpRISCV64FMULS {
+ continue
+ }
+ y := v_1.Args[1]
+ x := v_1.Args[0]
+ if !(a.Block.Func.useFMA(v)) {
+ continue
+ }
+ v.reset(OpRISCV64FMADDS)
+ v.AddArg3(x, y, a)
+ return true
+ }
+ break
+ }
+ return false
+}
func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool {
v_2 := v.Args[2]
v_1 := v.Args[1]
@@ -3409,6 +3446,51 @@ func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool {
}
return false
}
+func rewriteValueRISCV64_OpRISCV64FMADDS(v *Value) bool {
+ v_2 := v.Args[2]
+ v_1 := v.Args[1]
+ v_0 := v.Args[0]
+ // match: (FMADDS neg:(FNEGS x) y z)
+ // cond: neg.Uses == 1
+ // result: (FNMSUBS x y z)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ neg := v_0
+ if neg.Op != OpRISCV64FNEGS {
+ continue
+ }
+ x := neg.Args[0]
+ y := v_1
+ z := v_2
+ if !(neg.Uses == 1) {
+ continue
+ }
+ v.reset(OpRISCV64FNMSUBS)
+ v.AddArg3(x, y, z)
+ return true
+ }
+ break
+ }
+ // match: (FMADDS x y neg:(FNEGS z))
+ // cond: neg.Uses == 1
+ // result: (FMSUBS x y z)
+ for {
+ x := v_0
+ y := v_1
+ neg := v_2
+ if neg.Op != OpRISCV64FNEGS {
+ break
+ }
+ z := neg.Args[0]
+ if !(neg.Uses == 1) {
+ break
+ }
+ v.reset(OpRISCV64FMSUBS)
+ v.AddArg3(x, y, z)
+ return true
+ }
+ return false
+}
func rewriteValueRISCV64_OpRISCV64FMSUBD(v *Value) bool {
v_2 := v.Args[2]
v_1 := v.Args[1]
@@ -3454,6 +3536,51 @@ func rewriteValueRISCV64_OpRISCV64FMSUBD(v *Value) bool {
}
return false
}
+func rewriteValueRISCV64_OpRISCV64FMSUBS(v *Value) bool {
+ v_2 := v.Args[2]
+ v_1 := v.Args[1]
+ v_0 := v.Args[0]
+ // match: (FMSUBS neg:(FNEGS x) y z)
+ // cond: neg.Uses == 1
+ // result: (FNMADDS x y z)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ neg := v_0
+ if neg.Op != OpRISCV64FNEGS {
+ continue
+ }
+ x := neg.Args[0]
+ y := v_1
+ z := v_2
+ if !(neg.Uses == 1) {
+ continue
+ }
+ v.reset(OpRISCV64FNMADDS)
+ v.AddArg3(x, y, z)
+ return true
+ }
+ break
+ }
+ // match: (FMSUBS x y neg:(FNEGS z))
+ // cond: neg.Uses == 1
+ // result: (FMADDS x y z)
+ for {
+ x := v_0
+ y := v_1
+ neg := v_2
+ if neg.Op != OpRISCV64FNEGS {
+ break
+ }
+ z := neg.Args[0]
+ if !(neg.Uses == 1) {
+ break
+ }
+ v.reset(OpRISCV64FMADDS)
+ v.AddArg3(x, y, z)
+ return true
+ }
+ return false
+}
func rewriteValueRISCV64_OpRISCV64FNMADDD(v *Value) bool {
v_2 := v.Args[2]
v_1 := v.Args[1]
@@ -3499,6 +3626,51 @@ func rewriteValueRISCV64_OpRISCV64FNMADDD(v *Value) bool {
}
return false
}
+func rewriteValueRISCV64_OpRISCV64FNMADDS(v *Value) bool {
+ v_2 := v.Args[2]
+ v_1 := v.Args[1]
+ v_0 := v.Args[0]
+ // match: (FNMADDS neg:(FNEGS x) y z)
+ // cond: neg.Uses == 1
+ // result: (FMSUBS x y z)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ neg := v_0
+ if neg.Op != OpRISCV64FNEGS {
+ continue
+ }
+ x := neg.Args[0]
+ y := v_1
+ z := v_2
+ if !(neg.Uses == 1) {
+ continue
+ }
+ v.reset(OpRISCV64FMSUBS)
+ v.AddArg3(x, y, z)
+ return true
+ }
+ break
+ }
+ // match: (FNMADDS x y neg:(FNEGS z))
+ // cond: neg.Uses == 1
+ // result: (FNMSUBS x y z)
+ for {
+ x := v_0
+ y := v_1
+ neg := v_2
+ if neg.Op != OpRISCV64FNEGS {
+ break
+ }
+ z := neg.Args[0]
+ if !(neg.Uses == 1) {
+ break
+ }
+ v.reset(OpRISCV64FNMSUBS)
+ v.AddArg3(x, y, z)
+ return true
+ }
+ return false
+}
func rewriteValueRISCV64_OpRISCV64FNMSUBD(v *Value) bool {
v_2 := v.Args[2]
v_1 := v.Args[1]
@@ -3544,6 +3716,51 @@ func rewriteValueRISCV64_OpRISCV64FNMSUBD(v *Value) bool {
}
return false
}
+func rewriteValueRISCV64_OpRISCV64FNMSUBS(v *Value) bool {
+ v_2 := v.Args[2]
+ v_1 := v.Args[1]
+ v_0 := v.Args[0]
+ // match: (FNMSUBS neg:(FNEGS x) y z)
+ // cond: neg.Uses == 1
+ // result: (FMADDS x y z)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ neg := v_0
+ if neg.Op != OpRISCV64FNEGS {
+ continue
+ }
+ x := neg.Args[0]
+ y := v_1
+ z := v_2
+ if !(neg.Uses == 1) {
+ continue
+ }
+ v.reset(OpRISCV64FMADDS)
+ v.AddArg3(x, y, z)
+ return true
+ }
+ break
+ }
+ // match: (FNMSUBS x y neg:(FNEGS z))
+ // cond: neg.Uses == 1
+ // result: (FNMADDS x y z)
+ for {
+ x := v_0
+ y := v_1
+ neg := v_2
+ if neg.Op != OpRISCV64FNEGS {
+ break
+ }
+ z := neg.Args[0]
+ if !(neg.Uses == 1) {
+ break
+ }
+ v.reset(OpRISCV64FNMADDS)
+ v.AddArg3(x, y, z)
+ return true
+ }
+ return false
+}
func rewriteValueRISCV64_OpRISCV64FSUBD(v *Value) bool {
v_1 := v.Args[1]
v_0 := v.Args[0]
@@ -3583,6 +3800,45 @@ func rewriteValueRISCV64_OpRISCV64FSUBD(v *Value) bool {
}
return false
}
+func rewriteValueRISCV64_OpRISCV64FSUBS(v *Value) bool {
+ v_1 := v.Args[1]
+ v_0 := v.Args[0]
+ // match: (FSUBS a (FMULS x y))
+ // cond: a.Block.Func.useFMA(v)
+ // result: (FNMSUBS x y a)
+ for {
+ a := v_0
+ if v_1.Op != OpRISCV64FMULS {
+ break
+ }
+ y := v_1.Args[1]
+ x := v_1.Args[0]
+ if !(a.Block.Func.useFMA(v)) {
+ break
+ }
+ v.reset(OpRISCV64FNMSUBS)
+ v.AddArg3(x, y, a)
+ return true
+ }
+ // match: (FSUBS (FMULS x y) a)
+ // cond: a.Block.Func.useFMA(v)
+ // result: (FMSUBS x y a)
+ for {
+ if v_0.Op != OpRISCV64FMULS {
+ break
+ }
+ y := v_0.Args[1]
+ x := v_0.Args[0]
+ a := v_1
+ if !(a.Block.Func.useFMA(v)) {
+ break
+ }
+ v.reset(OpRISCV64FMSUBS)
+ v.AddArg3(x, y, a)
+ return true
+ }
+ return false
+}
func rewriteValueRISCV64_OpRISCV64MOVBUload(v *Value) bool {
v_1 := v.Args[1]
v_0 := v.Args[0]