diff mbox series

[x86] Support sdot_prodv*qi with emulation of sdot_prodv*hi.

Message ID 20231129024007.493958-1-hongtao.liu@intel.com
State New
Headers show
Series [x86] Support sdot_prodv*qi with emulation of sdot_prodv*hi. | expand

Commit Message

Liu, Hongtao Nov. 29, 2023, 2:40 a.m. UTC
Currently sdot_prodv*qi is available under TARGET_AVXVNNIINT8, but it
can be emulated by

 vec_unpacks_lo_v32qi
 vec_unpacks_lo_v32qi
 vec_unpacks_hi_v32qi
 vec_unpacks_hi_v32qi
 sdot_prodv16hi
 sdot_prodv16hi
 add3v8si

which is faster than original

  vect_patt_39.11_48 = WIDEN_MULT_LO_EXPR <vect__3.7_44, vect__7.10_47>;
  vect_patt_39.11_49 = WIDEN_MULT_HI_EXPR <vect__3.7_44, vect__7.10_47>;
  vect_patt_38.14_54 = [vec_unpack_lo_expr] vect_patt_39.11_48;
  vect_patt_38.14_55 = [vec_unpack_hi_expr] vect_patt_39.11_48;
  vect_patt_38.14_56 = [vec_unpack_lo_expr] vect_patt_39.11_49;
  vect_patt_38.14_57 = [vec_unpack_hi_expr] vect_patt_39.11_49;
  vect_sum_15.15_59 = vect_patt_38.14_54 + vect_patt_38.14_55;
  vect_sum_15.15_60 = vect_patt_38.14_56 + vect_sum_15.15_59;
  vect_sum_15.15_61 = vect_patt_38.14_57 + vect_sum_15.15_60;

Bootstrapped and regtested on x86_64-pc-linux-gnu{-m32,}.
Ready push to trunk.

gcc/ChangeLog:

	* config/i386/sse.md (sdot_prodv64qi): New expander.
	(sseunpackmodelower): New mode attr.
	(sdot_prod<mode>): Emulate sdot_prodv*qi with sodt_prov*hi
	when TARGET_VNNIINT8 is not available.

gcc/testsuite/ChangeLog:

	* gcc.target/i386/sdotprodint8_emulate.c: New test.
---
 gcc/config/i386/sse.md                        | 87 ++++++++++++++++---
 .../gcc.target/i386/sdotprodint8_emulate.c    | 15 ++++
 2 files changed, 90 insertions(+), 12 deletions(-)
 create mode 100644 gcc/testsuite/gcc.target/i386/sdotprodint8_emulate.c
diff mbox series

Patch

diff --git a/gcc/config/i386/sse.md b/gcc/config/i386/sse.md
index f94a77d0b6d..e29311d83cc 100644
--- a/gcc/config/i386/sse.md
+++ b/gcc/config/i386/sse.md
@@ -1291,6 +1291,11 @@  (define_mode_attr sseunpackmode
    (V32QI "V16HI") (V16HI "V8SI") (V8SI "V4DI")
    (V32HI "V16SI") (V64QI "V32HI") (V16SI "V8DI")])
 
+(define_mode_attr sseunpackmodelower
+  [(V16QI "v8hi") (V8HI "v4si") (V4SI "v2di")
+   (V32QI "v16hi") (V16HI "v8si") (V8SI "v4di")
+   (V32HI "v16si") (V64QI "v32hi") (V16SI "v8di")])
+
 (define_mode_attr ssepackmode
   [(V8HI "V16QI") (V4SI "V8HI") (V2DI "V4SI")
    (V16HI "V32QI") (V8SI "V16HI") (V4DI "V8SI")
@@ -30742,20 +30747,78 @@  (define_int_attr vpdotprodtype
 
 (define_expand "sdot_prod<mode>"
   [(match_operand:<ssedvecmode> 0 "register_operand")
-   (match_operand:VI1 1 "register_operand")
-   (match_operand:VI1 2 "register_operand")
+   (match_operand:VI1_AVX2 1 "register_operand")
+   (match_operand:VI1_AVX2 2 "register_operand")
    (match_operand:<ssedvecmode> 3 "register_operand")]
-  "TARGET_AVXVNNIINT8"
+  "TARGET_SSE2"
 {
-  operands[1] = lowpart_subreg (<ssedvecmode>mode,
-                                force_reg (<MODE>mode, operands[1]),
-                                <MODE>mode);
-  operands[2] = lowpart_subreg (<ssedvecmode>mode,
-                                force_reg (<MODE>mode, operands[2]),
-                                <MODE>mode);
-  emit_insn (gen_rtx_SET (operands[0], operands[3]));
-  emit_insn (gen_vpdpbssd_<ssedvecmodelower> (operands[0], operands[3],
-				   operands[1], operands[2]));
+  if (TARGET_AVXVNNIINT8)
+    {
+      operands[1] = lowpart_subreg (<ssedvecmode>mode,
+				    force_reg (<MODE>mode, operands[1]),
+				    <MODE>mode);
+      operands[2] = lowpart_subreg (<ssedvecmode>mode,
+				    force_reg (<MODE>mode, operands[2]),
+				    <MODE>mode);
+      emit_insn (gen_rtx_SET (operands[0], operands[3]));
+      emit_insn (gen_vpdpbssd_<ssedvecmodelower> (operands[0], operands[3],
+						  operands[1], operands[2]));
+    }
+  else
+    {
+      /* Emulate with vpdpwssd.  */
+      rtx op1_lo = gen_reg_rtx (<sseunpackmode>mode);
+      rtx op1_hi = gen_reg_rtx (<sseunpackmode>mode);
+      rtx op2_lo = gen_reg_rtx (<sseunpackmode>mode);
+      rtx op2_hi = gen_reg_rtx (<sseunpackmode>mode);
+
+      emit_insn (gen_vec_unpacks_lo_<mode> (op1_lo, operands[1]));
+      emit_insn (gen_vec_unpacks_lo_<mode> (op2_lo, operands[2]));
+      emit_insn (gen_vec_unpacks_hi_<mode> (op1_hi, operands[1]));
+      emit_insn (gen_vec_unpacks_hi_<mode> (op2_hi, operands[2]));
+
+      rtx res1 = gen_reg_rtx (<ssedvecmode>mode);
+      rtx res2 = gen_reg_rtx (<ssedvecmode>mode);
+      rtx sum = gen_reg_rtx (<ssedvecmode>mode);
+
+      emit_move_insn (sum, CONST0_RTX (<ssedvecmode>mode));
+      emit_insn (gen_sdot_prod<sseunpackmodelower> (res1, op1_lo,
+						    op2_lo, sum));
+      emit_insn (gen_sdot_prod<sseunpackmodelower> (res2, op1_hi,
+						    op2_hi, operands[3]));
+      emit_insn (gen_add<ssedvecmodelower>3 (operands[0], res1, res2));
+    }
+
+  DONE;
+})
+
+(define_expand "sdot_prodv64qi"
+  [(match_operand:V16SI 0 "register_operand")
+   (match_operand:V64QI 1 "register_operand")
+   (match_operand:V64QI 2 "register_operand")
+   (match_operand:V16SI 3 "register_operand")]
+  "(TARGET_AVX512VNNI || TARGET_AVX512BW) && TARGET_EVEX512"
+{
+  /* Emulate with vpdpwssd.  */
+  rtx op1_lo = gen_reg_rtx (V32HImode);
+  rtx op1_hi = gen_reg_rtx (V32HImode);
+  rtx op2_lo = gen_reg_rtx (V32HImode);
+  rtx op2_hi = gen_reg_rtx (V32HImode);
+
+  emit_insn (gen_vec_unpacks_lo_v64qi (op1_lo, operands[1]));
+  emit_insn (gen_vec_unpacks_lo_v64qi (op2_lo, operands[2]));
+  emit_insn (gen_vec_unpacks_hi_v64qi (op1_hi, operands[1]));
+  emit_insn (gen_vec_unpacks_hi_v64qi (op2_hi, operands[2]));
+
+  rtx res1 = gen_reg_rtx (V16SImode);
+  rtx res2 = gen_reg_rtx (V16SImode);
+  rtx sum = gen_reg_rtx (V16SImode);
+
+  emit_move_insn (sum, CONST0_RTX (V16SImode));
+  emit_insn (gen_sdot_prodv32hi (res1, op1_lo, op2_lo, sum));
+  emit_insn (gen_sdot_prodv32hi (res2, op1_hi, op2_hi, operands[3]));
+
+  emit_insn (gen_addv16si3 (operands[0], res1, res2));
   DONE;
 })
 
diff --git a/gcc/testsuite/gcc.target/i386/sdotprodint8_emulate.c b/gcc/testsuite/gcc.target/i386/sdotprodint8_emulate.c
new file mode 100644
index 00000000000..ed584606820
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/sdotprodint8_emulate.c
@@ -0,0 +1,15 @@ 
+/* { dg-do compile } */
+/* { dg-options "-mavxvnni -O2 -fdump-tree-optimized" } */
+/* { dg-final { scan-tree-dump-times "DOT_PROD_EXPR" 1 "optimized" } } */
+/* { dg-final { scan-assembler-times "vpdpwssd" 2 } } */
+
+int
+foo (char* a, char* b)
+{
+  int sum = 0;
+  for (int i = 0; i != 16; i++)
+    {
+      sum += a[i] * b[i];
+    }
+  return sum;
+}