diff mbox series

[v2] target/i386: implement FMA instructions

Message ID 20221020134353.50272-1-pbonzini@redhat.com
State New
Headers show
Series [v2] target/i386: implement FMA instructions | expand

Commit Message

Paolo Bonzini Oct. 20, 2022, 1:43 p.m. UTC
The only issue with FMA instructions is that there are _a lot_ of them (30
opcodes, each of which comes in up to 4 versions depending on VEX.W and
VEX.L; a total of 96 possibilities).  However, they can be implement with
only 6 helpers, two for scalar operations and four for packed operations.
(Scalar versions do not do any merging; they only affect the bottom 32
or 64 bits of the output operand.  Therefore, there is no separate XMM
and YMM of the scalar helpers).

First, we can reduce the number of helpers to one third by passing four
operands (one output and three inputs); the reordering of which operands
go to the multiply and which go to the add is done in emit.c.

Second, the different instructions also dispatch to the same softfloat
function, so the flags for float32_muladd and float64_muladd are passed
in the helper as int arguments, with a little extra complication to
handle FMADDSUB and FMSUBADD.

Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
---
 target/i386/cpu.c                |  5 ++--
 target/i386/ops_sse.h            | 27 +++++++++++++++++
 target/i386/ops_sse_header.h     | 11 +++++++
 target/i386/tcg/decode-new.c.inc | 40 +++++++++++++++++++++++++
 target/i386/tcg/decode-new.h     |  1 +
 target/i386/tcg/emit.c.inc       | 51 ++++++++++++++++++++++++++++++++
 target/i386/tcg/translate.c      |  1 +
 tests/tcg/i386/test-avx.py       |  2 +-
 8 files changed, 135 insertions(+), 3 deletions(-)

Comments

Richard Henderson Oct. 20, 2022, 9:22 p.m. UTC | #1
On 10/20/22 23:43, Paolo Bonzini wrote:
> The only issue with FMA instructions is that there are_a lot_  of them (30
> opcodes, each of which comes in up to 4 versions depending on VEX.W and
> VEX.L; a total of 96 possibilities).  However, they can be implement with
> only 6 helpers, two for scalar operations and four for packed operations.
> (Scalar versions do not do any merging; they only affect the bottom 32
> or 64 bits of the output operand.  Therefore, there is no separate XMM
> and YMM of the scalar helpers).
> 
> First, we can reduce the number of helpers to one third by passing four
> operands (one output and three inputs); the reordering of which operands
> go to the multiply and which go to the add is done in emit.c.
> 
> Second, the different instructions also dispatch to the same softfloat
> function, so the flags for float32_muladd and float64_muladd are passed
> in the helper as int arguments, with a little extra complication to
> handle FMADDSUB and FMSUBADD.
> 
> Signed-off-by: Paolo Bonzini<pbonzini@redhat.com>
> ---
>   target/i386/cpu.c                |  5 ++--
>   target/i386/ops_sse.h            | 27 +++++++++++++++++
>   target/i386/ops_sse_header.h     | 11 +++++++
>   target/i386/tcg/decode-new.c.inc | 40 +++++++++++++++++++++++++
>   target/i386/tcg/decode-new.h     |  1 +
>   target/i386/tcg/emit.c.inc       | 51 ++++++++++++++++++++++++++++++++
>   target/i386/tcg/translate.c      |  1 +
>   tests/tcg/i386/test-avx.py       |  2 +-
>   8 files changed, 135 insertions(+), 3 deletions(-)

Reviewed-by: Richard Henderson <richard.henderson@linaro.org>

r~
diff mbox series

Patch

diff --git a/target/i386/cpu.c b/target/i386/cpu.c
index 6292b7e12f..22b681ca37 100644
--- a/target/i386/cpu.c
+++ b/target/i386/cpu.c
@@ -625,10 +625,11 @@  void x86_cpu_vendor_words2str(char *dst, uint32_t vendor1,
           CPUID_EXT_SSE41 | CPUID_EXT_SSE42 | CPUID_EXT_POPCNT | \
           CPUID_EXT_XSAVE | /* CPUID_EXT_OSXSAVE is dynamic */   \
           CPUID_EXT_MOVBE | CPUID_EXT_AES | CPUID_EXT_HYPERVISOR | \
-          CPUID_EXT_RDRAND | CPUID_EXT_AVX | CPUID_EXT_F16C)
+          CPUID_EXT_RDRAND | CPUID_EXT_AVX | CPUID_EXT_F16C | \
+          CPUID_EXT_FMA)
           /* missing:
           CPUID_EXT_DTES64, CPUID_EXT_DSCPL, CPUID_EXT_VMX, CPUID_EXT_SMX,
-          CPUID_EXT_EST, CPUID_EXT_TM2, CPUID_EXT_CID, CPUID_EXT_FMA,
+          CPUID_EXT_EST, CPUID_EXT_TM2, CPUID_EXT_CID,
           CPUID_EXT_XTPR, CPUID_EXT_PDCM, CPUID_EXT_PCID, CPUID_EXT_DCA,
           CPUID_EXT_X2APIC, CPUID_EXT_TSC_DEADLINE_TIMER */
 
diff --git a/target/i386/ops_sse.h b/target/i386/ops_sse.h
index 33c61896ee..3cbc36a59d 100644
--- a/target/i386/ops_sse.h
+++ b/target/i386/ops_sse.h
@@ -2522,6 +2522,33 @@  void helper_vpermd_ymm(Reg *d, Reg *v, Reg *s)
 }
 #endif
 
+/* FMA3 op helpers */
+#if SHIFT == 1
+#define SSE_HELPER_FMAS(name, elem, F)                                         \
+    void name(CPUX86State *env, Reg *d, Reg *a, Reg *b, Reg *c, int flags)     \
+    {                                                                          \
+        d->elem(0) = F(a->elem(0), b->elem(0), c->elem(0), flags, &env->sse_status); \
+    }
+#define SSE_HELPER_FMAP(name, elem, num, F)                                    \
+    void glue(name, SUFFIX)(CPUX86State *env, Reg *d, Reg *a, Reg *b, Reg *c,  \
+                            int flags, int flip)                               \
+    {                                                                          \
+        int i;                                                                 \
+        for (i = 0; i < num; i++) {                                            \
+            d->elem(i) = F(a->elem(i), b->elem(i), c->elem(i), flags, &env->sse_status); \
+            flags ^= flip;                                                     \
+        }                                                                      \
+    }
+
+SSE_HELPER_FMAS(helper_fma4ss,  ZMM_S, float32_muladd)
+SSE_HELPER_FMAS(helper_fma4sd,  ZMM_D, float64_muladd)
+#endif
+
+#if SHIFT >= 1
+SSE_HELPER_FMAP(helper_fma4ps,  ZMM_S, 2 << SHIFT, float32_muladd)
+SSE_HELPER_FMAP(helper_fma4pd,  ZMM_D, 1 << SHIFT, float64_muladd)
+#endif
+
 #undef SSE_HELPER_S
 
 #undef LANE_WIDTH
diff --git a/target/i386/ops_sse_header.h b/target/i386/ops_sse_header.h
index c4c41976c0..8a7b2f4e2f 100644
--- a/target/i386/ops_sse_header.h
+++ b/target/i386/ops_sse_header.h
@@ -359,6 +359,17 @@  DEF_HELPER_3(glue(cvtph2ps, SUFFIX), void, env, Reg, Reg)
 DEF_HELPER_4(glue(cvtps2ph, SUFFIX), void, env, Reg, Reg, int)
 #endif
 
+/* FMA3 helpers */
+#if SHIFT == 1
+DEF_HELPER_6(fma4ss, void, env, Reg, Reg, Reg, Reg, int)
+DEF_HELPER_6(fma4sd, void, env, Reg, Reg, Reg, Reg, int)
+#endif
+
+#if SHIFT >= 1
+DEF_HELPER_7(glue(fma4ps, SUFFIX), void, env, Reg, Reg, Reg, Reg, int, int)
+DEF_HELPER_7(glue(fma4pd, SUFFIX), void, env, Reg, Reg, Reg, Reg, int, int)
+#endif
+
 /* AVX helpers */
 #if SHIFT >= 1
 DEF_HELPER_4(glue(vpermilpd, SUFFIX), void, env, Reg, Reg, Reg)
diff --git a/target/i386/tcg/decode-new.c.inc b/target/i386/tcg/decode-new.c.inc
index 8baee9018a..e4878b967f 100644
--- a/target/i386/tcg/decode-new.c.inc
+++ b/target/i386/tcg/decode-new.c.inc
@@ -376,6 +376,16 @@  static const X86OpEntry opcodes_0F38_00toEF[240] = {
     [0x92] = X86_OP_ENTRY3(VPGATHERD, V,x,  H,x,  M,d,  vex12 cpuid(AVX2) p_66), /* vgatherdps/d */
     [0x93] = X86_OP_ENTRY3(VPGATHERQ, V,x,  H,x,  M,q,  vex12 cpuid(AVX2) p_66), /* vgatherqps/d */
 
+    /* Should be exception type 2 but they do not have legacy SSE equivalents? */
+    [0x96] = X86_OP_ENTRY3(VFMADDSUB132Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0x97] = X86_OP_ENTRY3(VFMSUBADD132Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+
+    [0xa6] = X86_OP_ENTRY3(VFMADDSUB213Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xa7] = X86_OP_ENTRY3(VFMSUBADD213Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+
+    [0xb6] = X86_OP_ENTRY3(VFMADDSUB231Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xb7] = X86_OP_ENTRY3(VFMSUBADD231Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+
     [0x08] = X86_OP_ENTRY3(PSIGNB,    V,x,        H,x,  W,x,  vex4 cpuid(SSSE3) mmx avx2_256 p_00_66),
     [0x09] = X86_OP_ENTRY3(PSIGNW,    V,x,        H,x,  W,x,  vex4 cpuid(SSSE3) mmx avx2_256 p_00_66),
     [0x0a] = X86_OP_ENTRY3(PSIGND,    V,x,        H,x,  W,x,  vex4 cpuid(SSSE3) mmx avx2_256 p_00_66),
@@ -421,6 +431,34 @@  static const X86OpEntry opcodes_0F38_00toEF[240] = {
     [0x8c] = X86_OP_ENTRY3(VPMASKMOV,    V,x,  H,x, WM,x, vex6 cpuid(AVX2) p_66),
     [0x8e] = X86_OP_ENTRY3(VPMASKMOV_st, M,x,  V,x, H,x,  vex6 cpuid(AVX2) p_66),
 
+    /* Should be exception type 2 or 3 but they do not have legacy SSE equivalents? */
+    [0x98] = X86_OP_ENTRY3(VFMADD132Px,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0x99] = X86_OP_ENTRY3(VFMADD132Sx,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0x9a] = X86_OP_ENTRY3(VFMSUB132Px,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0x9b] = X86_OP_ENTRY3(VFMSUB132Sx,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0x9c] = X86_OP_ENTRY3(VFNMADD132Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0x9d] = X86_OP_ENTRY3(VFNMADD132Sx, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0x9e] = X86_OP_ENTRY3(VFNMSUB132Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0x9f] = X86_OP_ENTRY3(VFNMSUB132Sx, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+
+    [0xa8] = X86_OP_ENTRY3(VFMADD213Px,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xa9] = X86_OP_ENTRY3(VFMADD213Sx,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xaa] = X86_OP_ENTRY3(VFMSUB213Px,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xab] = X86_OP_ENTRY3(VFMSUB213Sx,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xac] = X86_OP_ENTRY3(VFNMADD213Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xad] = X86_OP_ENTRY3(VFNMADD213Sx, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xae] = X86_OP_ENTRY3(VFNMSUB213Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xaf] = X86_OP_ENTRY3(VFNMSUB213Sx, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+
+    [0xb8] = X86_OP_ENTRY3(VFMADD231Px,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xb9] = X86_OP_ENTRY3(VFMADD231Sx,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xba] = X86_OP_ENTRY3(VFMSUB231Px,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xbb] = X86_OP_ENTRY3(VFMSUB231Sx,  V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xbc] = X86_OP_ENTRY3(VFNMADD231Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xbd] = X86_OP_ENTRY3(VFNMADD231Sx, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xbe] = X86_OP_ENTRY3(VFNMSUB231Px, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+    [0xbf] = X86_OP_ENTRY3(VFNMSUB231Sx, V,x,  H,x, W,x,  vex6 cpuid(FMA) p_66),
+
     [0xdb] = X86_OP_ENTRY3(VAESIMC,     V,dq, None,None, W,dq, vex4 cpuid(AES) p_66),
     [0xdc] = X86_OP_ENTRY3(VAESENC,     V,x,  H,x,       W,x,  vex4 cpuid(AES) p_66),
     [0xdd] = X86_OP_ENTRY3(VAESENCLAST, V,x,  H,x,       W,x,  vex4 cpuid(AES) p_66),
@@ -1350,6 +1388,8 @@  static bool has_cpuid_feature(DisasContext *s, X86CPUIDFeature cpuid)
         return true;
     case X86_FEAT_F16C:
         return (s->cpuid_ext_features & CPUID_EXT_F16C);
+    case X86_FEAT_FMA:
+        return (s->cpuid_ext_features & CPUID_EXT_FMA);
     case X86_FEAT_MOVBE:
         return (s->cpuid_ext_features & CPUID_EXT_MOVBE);
     case X86_FEAT_PCLMULQDQ:
diff --git a/target/i386/tcg/decode-new.h b/target/i386/tcg/decode-new.h
index 0ef54628ee..cb6b8bcf67 100644
--- a/target/i386/tcg/decode-new.h
+++ b/target/i386/tcg/decode-new.h
@@ -105,6 +105,7 @@  typedef enum X86CPUIDFeature {
     X86_FEAT_BMI1,
     X86_FEAT_BMI2,
     X86_FEAT_F16C,
+    X86_FEAT_FMA,
     X86_FEAT_MOVBE,
     X86_FEAT_PCLMULQDQ,
     X86_FEAT_SSE,
diff --git a/target/i386/tcg/emit.c.inc b/target/i386/tcg/emit.c.inc
index 9334f0939d..7037ff91c6 100644
--- a/target/i386/tcg/emit.c.inc
+++ b/target/i386/tcg/emit.c.inc
@@ -39,6 +39,11 @@  typedef void (*SSEFunc_0_eppt)(TCGv_ptr env, TCGv_ptr reg_a, TCGv_ptr reg_b,
                                TCGv val);
 typedef void (*SSEFunc_0_epppti)(TCGv_ptr env, TCGv_ptr reg_a, TCGv_ptr reg_b,
                                  TCGv_ptr reg_c, TCGv a0, TCGv_i32 scale);
+typedef void (*SSEFunc_0_eppppi)(TCGv_ptr env, TCGv_ptr reg_a, TCGv_ptr reg_b,
+                                  TCGv_ptr reg_c, TCGv_ptr reg_d, TCGv_i32 flags);
+typedef void (*SSEFunc_0_eppppii)(TCGv_ptr env, TCGv_ptr reg_a, TCGv_ptr reg_b,
+                                  TCGv_ptr reg_c, TCGv_ptr reg_d, TCGv_i32 even,
+                                  TCGv_i32 odd);
 
 static inline TCGv_i32 tcg_constant8u_i32(uint8_t val)
 {
@@ -491,6 +496,52 @@  FP_SSE(VMIN, min)
 FP_SSE(VDIV, div)
 FP_SSE(VMAX, max)
 
+#define FMA_SSE_PACKED(uname, ptr0, ptr1, ptr2, even, odd)                         \
+static void gen_##uname##Px(DisasContext *s, CPUX86State *env, X86DecodedInsn *decode) \
+{                                                                                  \
+    SSEFunc_0_eppppii xmm = s->vex_w ? gen_helper_fma4pd_xmm : gen_helper_fma4ps_xmm; \
+    SSEFunc_0_eppppii ymm = s->vex_w ? gen_helper_fma4pd_ymm : gen_helper_fma4ps_ymm; \
+    SSEFunc_0_eppppii fn = s->vex_l ? ymm : xmm;                                   \
+                                                                                   \
+    fn(cpu_env, OP_PTR0, ptr0, ptr1, ptr2,                                         \
+       tcg_constant_i32(even),                                                     \
+       tcg_constant_i32((even) ^ (odd)));                                          \
+}
+
+#define FMA_SSE(uname, ptr0, ptr1, ptr2, flags)                                    \
+FMA_SSE_PACKED(uname, ptr0, ptr1, ptr2, flags, flags)                              \
+static void gen_##uname##Sx(DisasContext *s, CPUX86State *env, X86DecodedInsn *decode) \
+{                                                                                  \
+    SSEFunc_0_eppppi fn = s->vex_w ? gen_helper_fma4sd : gen_helper_fma4ss;        \
+                                                                                   \
+    fn(cpu_env, OP_PTR0, ptr0, ptr1, ptr2,                                         \
+       tcg_constant_i32(flags));                                                   \
+}                                                                                  \
+
+FMA_SSE(VFMADD231,  OP_PTR1, OP_PTR2, OP_PTR0, 0)
+FMA_SSE(VFMADD213,  OP_PTR1, OP_PTR0, OP_PTR2, 0)
+FMA_SSE(VFMADD132,  OP_PTR0, OP_PTR2, OP_PTR1, 0)
+
+FMA_SSE(VFNMADD231, OP_PTR1, OP_PTR2, OP_PTR0, float_muladd_negate_product)
+FMA_SSE(VFNMADD213, OP_PTR1, OP_PTR0, OP_PTR2, float_muladd_negate_product)
+FMA_SSE(VFNMADD132, OP_PTR0, OP_PTR2, OP_PTR1, float_muladd_negate_product)
+
+FMA_SSE(VFMSUB231,  OP_PTR1, OP_PTR2, OP_PTR0, float_muladd_negate_c)
+FMA_SSE(VFMSUB213,  OP_PTR1, OP_PTR0, OP_PTR2, float_muladd_negate_c)
+FMA_SSE(VFMSUB132,  OP_PTR0, OP_PTR2, OP_PTR1, float_muladd_negate_c)
+
+FMA_SSE(VFNMSUB231, OP_PTR1, OP_PTR2, OP_PTR0, float_muladd_negate_c|float_muladd_negate_product)
+FMA_SSE(VFNMSUB213, OP_PTR1, OP_PTR0, OP_PTR2, float_muladd_negate_c|float_muladd_negate_product)
+FMA_SSE(VFNMSUB132, OP_PTR0, OP_PTR2, OP_PTR1, float_muladd_negate_c|float_muladd_negate_product)
+
+FMA_SSE_PACKED(VFMADDSUB231, OP_PTR1, OP_PTR2, OP_PTR0, float_muladd_negate_c, 0)
+FMA_SSE_PACKED(VFMADDSUB213, OP_PTR1, OP_PTR0, OP_PTR2, float_muladd_negate_c, 0)
+FMA_SSE_PACKED(VFMADDSUB132, OP_PTR0, OP_PTR2, OP_PTR1, float_muladd_negate_c, 0)
+
+FMA_SSE_PACKED(VFMSUBADD231, OP_PTR1, OP_PTR2, OP_PTR0, 0, float_muladd_negate_c)
+FMA_SSE_PACKED(VFMSUBADD213, OP_PTR1, OP_PTR0, OP_PTR2, 0, float_muladd_negate_c)
+FMA_SSE_PACKED(VFMSUBADD132, OP_PTR0, OP_PTR2, OP_PTR1, 0, float_muladd_negate_c)
+
 #define FP_UNPACK_SSE(uname, lname)                                                \
 static void gen_##uname(DisasContext *s, CPUX86State *env, X86DecodedInsn *decode) \
 {                                                                                  \
diff --git a/target/i386/tcg/translate.c b/target/i386/tcg/translate.c
index e19d5c1c64..85be2e58c2 100644
--- a/target/i386/tcg/translate.c
+++ b/target/i386/tcg/translate.c
@@ -26,6 +26,7 @@ 
 #include "tcg/tcg-op-gvec.h"
 #include "exec/cpu_ldst.h"
 #include "exec/translator.h"
+#include "fpu/softfloat.h"
 
 #include "exec/helper-proto.h"
 #include "exec/helper-gen.h"
diff --git a/tests/tcg/i386/test-avx.py b/tests/tcg/i386/test-avx.py
index ebb1d99c5e..d9ca00a49e 100755
--- a/tests/tcg/i386/test-avx.py
+++ b/tests/tcg/i386/test-avx.py
@@ -9,7 +9,7 @@ 
 archs = [
     "SSE", "SSE2", "SSE3", "SSSE3", "SSE4_1", "SSE4_2",
     "AES", "AVX", "AVX2", "AES+AVX", "VAES+AVX",
-    "F16C",
+    "F16C", "FMA",
 ]
 
 ignore = set(["FISTTP",