diff mbox series

RISC-V: Fix using wrong mode to get reduction insn vlmax

Message ID 20230915111342.1895618-1-lehua.ding@rivai.ai
State New
Headers show
Series RISC-V: Fix using wrong mode to get reduction insn vlmax | expand

Commit Message

Lehua Ding Sept. 15, 2023, 11:13 a.m. UTC
This patch fix using wrong mode when emit vlmax reduction insn. We should
use src operand instead dest operand (which always LMUL=m1) to get the vlmax
length. This patch alse remove dest_mode and mask_mode from insn_expander
constructor, which can be geted by insn_flags.

gcc/ChangeLog:

	* config/riscv/riscv-protos.h (enum insn_flags): Change name.
	(enum insn_type): Ditto.
	* config/riscv/riscv-v.cc (get_mask_mode_from_insn_flags): Removed.
	(emit_vlmax_insn): Adjust.
	(emit_nonvlmax_insn): Adjust.
	(emit_vlmax_insn_lra): Adjust.

gcc/testsuite/ChangeLog:

	* gcc.target/riscv/rvv/vsetvl/wredsum_vlmax.c: New test.

---
 gcc/config/riscv/riscv-protos.h               | 12 +--
 gcc/config/riscv/riscv-v.cc                   | 87 +++++++++----------
 .../riscv/rvv/vsetvl/wredsum_vlmax.c          | 15 ++++
 3 files changed, 60 insertions(+), 54 deletions(-)
 create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/vsetvl/wredsum_vlmax.c

--
2.36.3

Comments

Lehua Ding Sept. 15, 2023, 1 p.m. UTC | #1
Committed, thanks Juzhe.

On 2023/9/15 19:18, juzhe.zhong wrote:
> lgtm
> ---- Replied Message ----
> From	Lehua Ding<lehua.ding@rivai.ai> <mailto:lehua.ding@rivai.ai>
> Date	09/15/2023 19:13
> To	gcc-patches@gcc.gnu.org<gcc-patches@gcc.gnu.org> 
> <mailto:gcc-patches@gcc.gnu.org>
> Cc	juzhe.zhong@rivai.ai<juzhe.zhong@rivai.ai> <mailto:juzhe.zhong@rivai.ai>,
> kito.cheng@gmail.com<kito.cheng@gmail.com> <mailto:kito.cheng@gmail.com>,
> rdapp.gcc@gmail.com<rdapp.gcc@gmail.com> <mailto:rdapp.gcc@gmail.com>,
> palmer@rivosinc.com<palmer@rivosinc.com> <mailto:palmer@rivosinc.com>,
> jeffreyalaw@gmail.com<jeffreyalaw@gmail.com> <mailto:jeffreyalaw@gmail.com>,
> lehua.ding@rivai.ai<lehua.ding@rivai.ai> <mailto:lehua.ding@rivai.ai>
> Subject	[PATCH] RISC-V: Fix using wrong mode to get reduction insn vlmax
>
diff mbox series

Patch

diff --git a/gcc/config/riscv/riscv-protos.h b/gcc/config/riscv/riscv-protos.h
index 44fa36c32ab..cf5ae6b4b70 100644
--- a/gcc/config/riscv/riscv-protos.h
+++ b/gcc/config/riscv/riscv-protos.h
@@ -244,8 +244,8 @@  enum insn_flags : unsigned int
   /* Means INSN need two operands to do the operation.  */
   TERNARY_OP_P = 1 << 13,

-  /* flags for get mask mode from the index number. default from dest operand.  */
-  MASK_MODE_FROM_OP1_P = 1 << 14,
+  /* flags for get vtype mode from the index number. default from dest operand.  */
+  VTYPE_MODE_FROM_OP1_P = 1 << 14,

   /* flags for the floating-point rounding mode.  */
   /* Means INSN has FRM operand and the value is FRM_DYN.  */
@@ -321,7 +321,7 @@  enum insn_type : unsigned int

   /* For vcpop.m, no merge operand, no tail and mask policy operands.  */
   CPOP_OP = HAS_DEST_P | HAS_MASK_P | USE_ALL_TRUES_MASK_P | UNARY_OP_P
-	    | MASK_MODE_FROM_OP1_P,
+	    | VTYPE_MODE_FROM_OP1_P,

   /* For mask instrunctions, no tail and mask policy operands.  */
   UNARY_MASK_OP = HAS_DEST_P | HAS_MASK_P | USE_ALL_TRUES_MASK_P | HAS_MERGE_P
@@ -336,10 +336,10 @@  enum insn_type : unsigned int
   = HAS_DEST_P | HAS_MERGE_P | TDEFAULT_POLICY_P | BINARY_OP_P,

   /* For vreduce, no mask policy operand. */
-  REDUCE_OP = __NORMAL_OP_TA | BINARY_OP_P | MASK_MODE_FROM_OP1_P,
-  REDUCE_OP_FRM_DYN = REDUCE_OP | FRM_DYN_P | MASK_MODE_FROM_OP1_P,
+  REDUCE_OP = __NORMAL_OP_TA | BINARY_OP_P | VTYPE_MODE_FROM_OP1_P,
+  REDUCE_OP_FRM_DYN = REDUCE_OP | FRM_DYN_P | VTYPE_MODE_FROM_OP1_P,
   REDUCE_OP_M_FRM_DYN
-  = __MASK_OP_TA | BINARY_OP_P | FRM_DYN_P | MASK_MODE_FROM_OP1_P,
+  = __MASK_OP_TA | BINARY_OP_P | FRM_DYN_P | VTYPE_MODE_FROM_OP1_P,

   /* For vmv.s.x/vfmv.s.f.  */
   SCALAR_MOVE_OP = HAS_DEST_P | HAS_MASK_P | USE_ONE_TRUE_MASK_P | HAS_MERGE_P
diff --git a/gcc/config/riscv/riscv-v.cc b/gcc/config/riscv/riscv-v.cc
index 668594b65ed..631840dfafd 100644
--- a/gcc/config/riscv/riscv-v.cc
+++ b/gcc/config/riscv/riscv-v.cc
@@ -72,10 +72,9 @@  template <int MAX_OPERANDS> class insn_expander
 public:
   insn_expander () = delete;

-  insn_expander (unsigned insn_flags, bool vlmax_p, machine_mode dest_mode,
-		  machine_mode mask_mode)
+  insn_expander (unsigned insn_flags, bool vlmax_p)
     : m_insn_flags (insn_flags), m_opno (0), m_vlmax_p (vlmax_p),
-      m_dest_mode (dest_mode), m_mask_mode (mask_mode), m_vl_op (NULL_RTX)
+      m_vl_op (NULL_RTX)
   {
     check_insn_flags ();
   }
@@ -138,13 +137,17 @@  public:
     create_input_operand (&m_ops[m_opno++], x, mode);
     gcc_assert (m_opno <= MAX_OPERANDS);
   }
-  void add_all_one_mask_operand ()
+  void add_all_one_mask_operand (machine_mode mask_mode)
   {
-    add_input_operand (CONSTM1_RTX (m_mask_mode), m_mask_mode);
+    add_input_operand (CONSTM1_RTX (mask_mode), mask_mode);
   }
-  void add_vundef_operand ()
+  void add_first_one_true_mask_operand (machine_mode mask_mode)
   {
-    add_input_operand (RVV_VUNDEF (m_dest_mode), m_dest_mode);
+    add_input_operand (gen_scalar_move_mask (mask_mode), mask_mode);
+  }
+  void add_vundef_operand (machine_mode dest_mode)
+  {
+    add_input_operand (RVV_VUNDEF (dest_mode), dest_mode);
   }
   void add_policy_operand ()
   {
@@ -182,9 +185,17 @@  public:
     add_input_operand (frm_rtx, Pmode);
   }

-  void add_oprand (rtx *ops, int opno)
+  /* Return the vtype mode based on insn_flags.
+     vtype mode mean the mode vsetvl insn set. */
+  machine_mode
+  get_vtype_mode (rtx *ops)
   {
-
+    machine_mode vtype_mode;
+    if (m_insn_flags & VTYPE_MODE_FROM_OP1_P)
+      vtype_mode = GET_MODE (ops[1]);
+    else
+      vtype_mode = GET_MODE (ops[0]);
+    return vtype_mode;
   }

   void emit_insn (enum insn_code icode, rtx *ops)
@@ -194,18 +205,22 @@  public:
     /* It's true if any operand is memory operand.  */
     bool any_mem_p = false;

+    machine_mode vtype_mode = get_vtype_mode (ops);
+    machine_mode mask_mode = get_mask_mode (vtype_mode);
+
     /* Add dest operand.  */
     if (m_insn_flags & HAS_DEST_P)
       {
-	any_mem_p |= MEM_P (ops[opno]);
-	add_output_operand (ops[opno++], m_dest_mode);
+	rtx op = ops[opno++];
+	any_mem_p |= MEM_P (op);
+	add_output_operand (op, GET_MODE (op));
       }

     /* Add mask operand.  */
     if (m_insn_flags & USE_ONE_TRUE_MASK_P)
-      add_input_operand (gen_scalar_move_mask (m_mask_mode), m_mask_mode);
+      add_first_one_true_mask_operand (mask_mode);
     else if (m_insn_flags & USE_ALL_TRUES_MASK_P)
-      add_all_one_mask_operand ();
+      add_all_one_mask_operand (mask_mode);
     else if (m_insn_flags & HAS_MASK_P)
       {
 	machine_mode mode = insn_data[(int) icode].operand[m_opno].mode;
@@ -215,7 +230,8 @@  public:

     /* Add merge operand.  */
     if (m_insn_flags & USE_VUNDEF_MERGE_P)
-      add_vundef_operand ();
+      /* Same as dest operand.  */
+      add_vundef_operand (GET_MODE (ops[0]));
     else if (m_insn_flags & HAS_MERGE_P)
       {
 	machine_mode mode = insn_data[(int) icode].operand[m_opno].mode;
@@ -256,31 +272,30 @@  public:

     /* Add vl operand.  */
     rtx len = m_vl_op;
-    machine_mode mode = VECTOR_MODE_P (m_dest_mode) ? m_dest_mode : m_mask_mode;
     if (m_vlmax_p)
       {
-	if (riscv_v_ext_vls_mode_p (mode))
+	if (riscv_v_ext_vls_mode_p (vtype_mode))
 	  {
 	    /* VLS modes always set VSETVL by
 	       "vsetvl zero, rs1/imm".  */
-	    poly_uint64 nunits = GET_MODE_NUNITS (mode);
+	    poly_uint64 nunits = GET_MODE_NUNITS (vtype_mode);
 	    len = gen_int_mode (nunits, Pmode);
 	    if (!satisfies_constraint_K (len))
 	      len = force_reg (Pmode, len);
 	    m_vlmax_p = false;
 	  }
-	else if (const_vlmax_p (mode))
+	else if (const_vlmax_p (vtype_mode))
 	  {
 	    /* Optimize VLS-VLMAX code gen, we can use vsetivli instead of
 	       the vsetvli to obtain the value of vlmax.  */
-	    poly_uint64 nunits = GET_MODE_NUNITS (mode);
+	    poly_uint64 nunits = GET_MODE_NUNITS (vtype_mode);
 	    len = gen_int_mode (nunits, Pmode);
 	    m_vlmax_p = false;
 	  }
 	else if (can_create_pseudo_p ())
 	  {
 	    len = gen_reg_rtx (Pmode);
-	    emit_vlmax_vsetvl (mode, len);
+	    emit_vlmax_vsetvl (vtype_mode, len);
 	  }
       }

@@ -313,38 +328,20 @@  public:
   }

 private:
-  int m_insn_flags;
+  unsigned m_insn_flags;
   int m_opno;
   bool m_vlmax_p;
-  machine_mode m_dest_mode;
-  machine_mode m_mask_mode;
   rtx m_vl_op;
   expand_operand m_ops[MAX_OPERANDS];
 };

-/* Return the mask mode based on insn_flags */
-static machine_mode
-get_mask_mode_from_insn_flags (unsigned insn_flags, rtx *ops)
-{
-  machine_mode mask_mode;
-  if (insn_flags & MASK_MODE_FROM_OP1_P)
-    mask_mode = get_mask_mode (GET_MODE (ops[1]));
-  else
-    mask_mode = get_mask_mode (GET_MODE (ops[0]));
-  return mask_mode;
-}
-
 /* Emit RVV insn which vl is VLMAX.
    This function can only be used before LRA pass or
    for VLS_AVL_IMM modes.  */
 void
 emit_vlmax_insn (unsigned icode, unsigned insn_flags, rtx *ops)
 {
-  machine_mode dest_mode = GET_MODE (ops[0]);
-  machine_mode mask_mode = get_mask_mode_from_insn_flags (insn_flags, ops);
-
-  insn_expander<RVV_INSN_OPERANDS_MAX> e (insn_flags, true, dest_mode,
-					   mask_mode);
+  insn_expander<RVV_INSN_OPERANDS_MAX> e (insn_flags, true);
   e.emit_insn ((enum insn_code) icode, ops);
 }

@@ -352,10 +349,7 @@  emit_vlmax_insn (unsigned icode, unsigned insn_flags, rtx *ops)
 void
 emit_nonvlmax_insn (unsigned icode, unsigned insn_flags, rtx *ops, rtx vl)
 {
-  machine_mode dest_mode = GET_MODE (ops[0]);
-  machine_mode mask_mode = get_mask_mode_from_insn_flags (insn_flags, ops);
-  insn_expander<RVV_INSN_OPERANDS_MAX> e (insn_flags, false, dest_mode,
-					   mask_mode);
+  insn_expander<RVV_INSN_OPERANDS_MAX> e (insn_flags, false);
   e.set_vl (vl);
   e.emit_insn ((enum insn_code) icode, ops);
 }
@@ -367,10 +361,7 @@  emit_vlmax_insn_lra (unsigned icode, unsigned insn_flags, rtx *ops, rtx vl)
 {
   gcc_assert (!can_create_pseudo_p ());

-  machine_mode dest_mode = GET_MODE (ops[0]);
-  machine_mode mask_mode = get_mask_mode_from_insn_flags (insn_flags, ops);
-  insn_expander<RVV_INSN_OPERANDS_MAX> e (insn_flags, true, dest_mode,
-					   mask_mode);
+  insn_expander<RVV_INSN_OPERANDS_MAX> e (insn_flags, true);
   e.set_vl (vl);
   e.emit_insn ((enum insn_code) icode, ops);
 }
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/vsetvl/wredsum_vlmax.c b/gcc/testsuite/gcc.target/riscv/rvv/vsetvl/wredsum_vlmax.c
new file mode 100644
index 00000000000..6b7c77326ae
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/vsetvl/wredsum_vlmax.c
@@ -0,0 +1,15 @@ 
+/* { dg-do compile } */
+/* { dg-options "-march=rv64gcv_zvl256b --param=riscv-autovec-preference=fixed-vlmax -O3" } */
+
+
+#include <stdint.h>
+
+int16_t foo (int8_t *restrict a)
+{
+    int16_t sum = 0;
+    for (int i = 0; i < 8; i += 1)
+      sum += a[i];
+    return sum;
+}
+
+/* { dg-final { scan-assembler-not {\tvsetivli\tzero,16} } } */