diff mbox series

Optimize vec_extract for 256/512-bit vector when index exceeds the lower 128 bits.

Message ID 20210908100203.791504-1-hongtao.liu@intel.com
State New
Headers show
Series Optimize vec_extract for 256/512-bit vector when index exceeds the lower 128 bits. | expand

Commit Message

liuhongt Sept. 8, 2021, 10:02 a.m. UTC
Hi:
  As decribed in PR, valign{d,q} can be used for vector extract one element.
For elements located in the lower 128 bits, only one instruction is needed,
so this patch only optimizes elements located above 128 bits.

The optimization is like:

-	vextracti32x8	$0x1, %zmm0, %ymm0
-	vmovd	%xmm0, %eax
+	valignd	$8, %zmm0, %zmm0, %zmm1
+	vmovd	%xmm1, %eax

-	vextracti32x8	$0x1, %zmm0, %ymm0
-	vextracti128	$0x1, %ymm0, %xmm0
-	vpextrd	$3, %xmm0, %eax
+	valignd	$15, %zmm0, %zmm0, %zmm1
+	vmovd	%xmm1, %eax

-	vextractf64x2	$0x1, %ymm0, %xmm0
+	valignq	$2, %ymm0, %ymm0, %ymm0

-	vextractf64x4	$0x1, %zmm0, %ymm0
-	vextractf64x2	$0x1, %ymm0, %xmm0
-	vunpckhpd	%xmm0, %xmm0, %xmm0
+	valignq	$7, %zmm0, %zmm0, %zmm0

  Bootstrapped and regtested on x86_64-linux-gnu{-m32,}.

gcc/ChangeLog:

	PR target/91103
	* config/i386/sse.md (*vec_extract<mode><ssescalarmodelower>_valign):
	New define_insn.

gcc/testsuite/ChangeLog:

	PR target/91103
	* gcc.target/i386/pr91103-1.c: New test.
	* gcc.target/i386/pr91103-2.c: New test.
---
 gcc/config/i386/sse.md                    | 32 +++++++++
 gcc/testsuite/gcc.target/i386/pr91103-1.c | 37 +++++++++++
 gcc/testsuite/gcc.target/i386/pr91103-2.c | 81 +++++++++++++++++++++++
 3 files changed, 150 insertions(+)
 create mode 100644 gcc/testsuite/gcc.target/i386/pr91103-1.c
 create mode 100644 gcc/testsuite/gcc.target/i386/pr91103-2.c
diff mbox series

Patch

diff --git a/gcc/config/i386/sse.md b/gcc/config/i386/sse.md
index 5785e73241c..57c736ff44a 100644
--- a/gcc/config/i386/sse.md
+++ b/gcc/config/i386/sse.md
@@ -232,6 +232,12 @@  (define_mode_iterator V48_AVX512VL
    V16SF (V8SF "TARGET_AVX512VL") (V4SF "TARGET_AVX512VL")
    V8DF  (V4DF "TARGET_AVX512VL") (V2DF "TARGET_AVX512VL")])
 
+(define_mode_iterator V48_256_512_AVX512VL
+  [V16SI (V8SI "TARGET_AVX512VL")
+   V8DI  (V4DI "TARGET_AVX512VL")
+   V16SF (V8SF "TARGET_AVX512VL")
+   V8DF  (V4DF "TARGET_AVX512VL")])
+
 ;; 1,2 byte AVX-512{BW,VL} vector modes. Supposed TARGET_AVX512BW baseline.
 (define_mode_iterator VI12_AVX512VL
   [V64QI (V16QI "TARGET_AVX512VL") (V32QI "TARGET_AVX512VL")
@@ -786,6 +792,15 @@  (define_mode_attr sseinsnmode
    (V4SF "V4SF") (V2DF "V2DF")
    (TI "TI")])
 
+(define_mode_attr sseintvecinsnmode
+  [(V64QI "XI") (V32HI "XI") (V16SI "XI") (V8DI "XI") (V4TI "XI")
+   (V32QI "OI") (V16HI "OI") (V8SI "OI") (V4DI "OI") (V2TI "OI")
+   (V16QI "TI") (V8HI "TI") (V4SI "TI") (V2DI "TI") (V1TI "TI")
+   (V16SF "XI") (V8DF "XI")
+   (V8SF "OI") (V4DF "OI")
+   (V4SF "TI") (V2DF "TI")
+   (TI "TI")])
+
 ;; SSE constant -1 constraint
 (define_mode_attr sseconstm1
   [(V64QI "BC") (V32HI "BC") (V16SI "BC") (V8DI "BC") (V4TI "BC")
@@ -10326,6 +10341,23 @@  (define_insn "<mask_codefor><avx512>_align<mode><mask_name>"
   [(set_attr "prefix" "evex")
    (set_attr "mode" "<sseinsnmode>")])
 
+(define_mode_attr vec_extract_imm_predicate
+  [(V16SF "const_0_to_15_operand") (V8SF "const_0_to_7_operand")
+   (V16SI "const_0_to_15_operand") (V8SI "const_0_to_7_operand")
+   (V8DF "const_0_to_7_operand") (V4DF "const_0_to_3_operand")
+   (V8DI "const_0_to_7_operand") (V4DI "const_0_to_3_operand")])
+
+(define_insn "*vec_extract<mode><ssescalarmodelower>_valign"
+  [(set (match_operand:<ssescalarmode> 0 "register_operand" "=v")
+	(vec_select:<ssescalarmode>
+	  (match_operand:V48_256_512_AVX512VL 1 "register_operand" "v")
+	  (parallel [(match_operand 2 "<vec_extract_imm_predicate>")])))]
+  "TARGET_AVX512F
+   && INTVAL(operands[2]) >= 16 / GET_MODE_SIZE (<ssescalarmode>mode)"
+  "valign<ternlogsuffix>\t{%2, %1, %1, %<xtg_mode>0|%<xtg_mode>0, %1, %1, %2}";
+  [(set_attr "prefix" "evex")
+   (set_attr "mode" "<sseintvecinsnmode>")])
+
 (define_expand "avx512f_shufps512_mask"
   [(match_operand:V16SF 0 "register_operand")
    (match_operand:V16SF 1 "register_operand")
diff --git a/gcc/testsuite/gcc.target/i386/pr91103-1.c b/gcc/testsuite/gcc.target/i386/pr91103-1.c
new file mode 100644
index 00000000000..11caaa8bd1b
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/pr91103-1.c
@@ -0,0 +1,37 @@ 
+/* { dg-do compile } */
+/* { dg-options "-mavx512vl -O2" } */
+/* { dg-final { scan-assembler-times "valign\[dq\]" 16 } } */
+
+typedef float v8sf __attribute__((vector_size(32)));
+typedef float v16sf __attribute__((vector_size(64)));
+typedef int v8si __attribute__((vector_size(32)));
+typedef int v16si __attribute__((vector_size(64)));
+typedef double v4df __attribute__((vector_size(32)));
+typedef double v8df __attribute__((vector_size(64)));
+typedef long long v4di __attribute__((vector_size(32)));
+typedef long long v8di __attribute__((vector_size(64)));
+
+#define EXTRACT(V,S,IDX)			\
+  S						\
+  __attribute__((noipa))			\
+  foo_##V##_##IDX (V v)				\
+  {						\
+    return v[IDX];				\
+  }						\
+
+EXTRACT (v8sf, float, 4);
+EXTRACT (v8sf, float, 7);
+EXTRACT (v8si, int, 4);
+EXTRACT (v8si, int, 7);
+EXTRACT (v16sf, float, 8);
+EXTRACT (v16sf, float, 15);
+EXTRACT (v16si, int, 8);
+EXTRACT (v16si, int, 15);
+EXTRACT (v4df, double, 2);
+EXTRACT (v4df, double, 3);
+EXTRACT (v4di, long long, 2);
+EXTRACT (v4di, long long, 3);
+EXTRACT (v8df, double, 4);
+EXTRACT (v8df, double, 7);
+EXTRACT (v8di, long long, 4);
+EXTRACT (v8di, long long, 7);
diff --git a/gcc/testsuite/gcc.target/i386/pr91103-2.c b/gcc/testsuite/gcc.target/i386/pr91103-2.c
new file mode 100644
index 00000000000..010e4775723
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/pr91103-2.c
@@ -0,0 +1,81 @@ 
+/* { dg-do run } */
+/* { dg-options "-O2 -mavx512vl" } */
+/* { dg-require-effective-target avx512vl } */
+
+#define AVX512VL
+
+#ifndef CHECK
+#define CHECK "avx512f-helper.h"
+#endif
+
+#include CHECK
+#include "pr91103-1.c"
+
+#define RUNCHECK(U,V,S,IDX)			\
+  do						\
+    {						\
+      S tmp = foo_##V##_##IDX ((V)U.x);		\
+      if (tmp != U.a[IDX])			\
+	abort();				\
+    }						\
+  while (0)
+
+void
+test_256 (void)
+{
+  union512i_d di1;
+  union256i_d di2;
+  union512i_q q1;
+  union256i_q q2;
+  union512 f1;
+  union256 f2;
+  union512d d1;
+  union256d d2;
+  int sign = 1;
+
+  int i = 0;
+  for (i = 0; i < 16; i++)
+    {
+      di1.a[i] = 30 * (i - 30) * sign;
+      f1.a[i] = 56.78 * (i - 30) * sign;
+      sign = -sign;
+    }
+
+  for (i = 0; i != 8; i++)
+    {
+      di2.a[i] = 15 * (i + 40) * sign;
+      f2.a[i] = 90.12 * (i + 40) * sign;
+      q1.a[i] = 15 * (i + 40) * sign;
+      d1.a[i] = 90.12 * (i + 40) * sign;
+      sign = -sign;
+    }
+
+  for (i = 0; i != 4; i++)
+    {
+      q2.a[i] = 15 * (i + 40) * sign;
+      d2.a[i] = 90.12 * (i + 40) * sign;
+      sign = -sign;
+    }
+
+RUNCHECK (f2, v8sf, float, 4);
+RUNCHECK (f2, v8sf, float, 7);
+RUNCHECK (di2, v8si, int, 4);
+RUNCHECK (di2, v8si, int, 7);
+RUNCHECK (f1, v16sf, float, 8);
+RUNCHECK (f1, v16sf, float, 15);
+RUNCHECK (di1, v16si, int, 8);
+RUNCHECK (di1, v16si, int, 15);
+RUNCHECK (d2, v4df, double, 2);
+RUNCHECK (d2, v4df, double, 3);
+RUNCHECK (q2, v4di, long long, 2);
+RUNCHECK (q2, v4di, long long, 3);
+RUNCHECK (d1, v8df, double, 4);
+RUNCHECK (d1, v8df, double, 7);
+RUNCHECK (q1, v8di, long long, 4);
+RUNCHECK (q1, v8di, long long, 7);
+}
+
+void
+test_128()
+{
+}