diff mbox series

RISC-V: Add vlse/vsse C/C++ API intrinsics support

Message ID 20230120042541.109466-1-juzhe.zhong@rivai.ai
State New
Headers show
Series RISC-V: Add vlse/vsse C/C++ API intrinsics support | expand

Commit Message

juzhe.zhong@rivai.ai Jan. 20, 2023, 4:25 a.m. UTC
From: Ju-Zhe Zhong <juzhe.zhong@rivai.ai>

gcc/ChangeLog:

        * config/riscv/predicates.md (pmode_reg_or_0_operand): New predicate.
        * config/riscv/riscv-vector-builtins-bases.cc (class loadstore): Add vlse/vsse intrinsic support.
        (BASE): Ditto.
        * config/riscv/riscv-vector-builtins-bases.h: Ditto.
        * config/riscv/riscv-vector-builtins-functions.def (vlse): Ditto.
        (vsse): Ditto.
        * config/riscv/riscv-vector-builtins.cc (function_expander::use_contiguous_load_insn): Ditto.
        * config/riscv/vector.md (@pred_strided_load<mode>): Ditto.
        (@pred_strided_store<mode>): Ditto.

---
 gcc/config/riscv/predicates.md                |  4 +
 .../riscv/riscv-vector-builtins-bases.cc      | 26 +++++-
 .../riscv/riscv-vector-builtins-bases.h       |  2 +
 .../riscv/riscv-vector-builtins-functions.def |  2 +
 gcc/config/riscv/riscv-vector-builtins.cc     | 33 ++++++-
 gcc/config/riscv/vector.md                    | 90 +++++++++++++++++--
 6 files changed, 143 insertions(+), 14 deletions(-)
diff mbox series

Patch

diff --git a/gcc/config/riscv/predicates.md b/gcc/config/riscv/predicates.md
index 5a5a49bf7c0..bae9cfa02dd 100644
--- a/gcc/config/riscv/predicates.md
+++ b/gcc/config/riscv/predicates.md
@@ -286,6 +286,10 @@ 
 	    (match_test "GET_CODE (op) == UNSPEC
 			 && (XINT (op, 1) == UNSPEC_VUNDEF)"))))
 
+(define_special_predicate "pmode_reg_or_0_operand"
+  (ior (match_operand 0 "const_0_operand")
+       (match_operand 0 "pmode_register_operand")))
+
 ;; The scalar operand can be directly broadcast by RVV instructions.
 (define_predicate "direct_broadcast_operand"
   (ior (match_operand 0 "register_operand")
diff --git a/gcc/config/riscv/riscv-vector-builtins-bases.cc b/gcc/config/riscv/riscv-vector-builtins-bases.cc
index 0da4797d272..17a1294cf85 100644
--- a/gcc/config/riscv/riscv-vector-builtins-bases.cc
+++ b/gcc/config/riscv/riscv-vector-builtins-bases.cc
@@ -84,8 +84,8 @@  public:
   }
 };
 
-/* Implements vle.v/vse.v/vlm.v/vsm.v codegen.  */
-template <bool STORE_P>
+/* Implements vle.v/vse.v/vlm.v/vsm.v/vlse.v/vsse.v codegen.  */
+template <bool STORE_P, bool STRIDED_P = false>
 class loadstore : public function_base
 {
   unsigned int call_properties (const function_instance &) const override
@@ -106,9 +106,23 @@  class loadstore : public function_base
   rtx expand (function_expander &e) const override
   {
     if (STORE_P)
-      return e.use_contiguous_store_insn (code_for_pred_store (e.vector_mode ()));
+      {
+	if (STRIDED_P)
+	  return e.use_contiguous_store_insn (
+	    code_for_pred_strided_store (e.vector_mode ()));
+	else
+	  return e.use_contiguous_store_insn (
+	    code_for_pred_store (e.vector_mode ()));
+      }
     else
-      return e.use_contiguous_load_insn (code_for_pred_mov (e.vector_mode ()));
+      {
+	if (STRIDED_P)
+	  return e.use_contiguous_load_insn (
+	    code_for_pred_strided_load (e.vector_mode ()));
+	else
+	  return e.use_contiguous_load_insn (
+	    code_for_pred_mov (e.vector_mode ()));
+      }
   }
 };
 
@@ -118,6 +132,8 @@  static CONSTEXPR const loadstore<false> vle_obj;
 static CONSTEXPR const loadstore<true> vse_obj;
 static CONSTEXPR const loadstore<false> vlm_obj;
 static CONSTEXPR const loadstore<true> vsm_obj;
+static CONSTEXPR const loadstore<false, true> vlse_obj;
+static CONSTEXPR const loadstore<true, true> vsse_obj;
 
 /* Declare the function base NAME, pointing it to an instance
    of class <NAME>_obj.  */
@@ -130,5 +146,7 @@  BASE (vle)
 BASE (vse)
 BASE (vlm)
 BASE (vsm)
+BASE (vlse)
+BASE (vsse)
 
 } // end namespace riscv_vector
diff --git a/gcc/config/riscv/riscv-vector-builtins-bases.h b/gcc/config/riscv/riscv-vector-builtins-bases.h
index 28151a8d8d2..d8676e94b28 100644
--- a/gcc/config/riscv/riscv-vector-builtins-bases.h
+++ b/gcc/config/riscv/riscv-vector-builtins-bases.h
@@ -30,6 +30,8 @@  extern const function_base *const vle;
 extern const function_base *const vse;
 extern const function_base *const vlm;
 extern const function_base *const vsm;
+extern const function_base *const vlse;
+extern const function_base *const vsse;
 }
 
 } // end namespace riscv_vector
diff --git a/gcc/config/riscv/riscv-vector-builtins-functions.def b/gcc/config/riscv/riscv-vector-builtins-functions.def
index 63aa8fe32c8..348262928c8 100644
--- a/gcc/config/riscv/riscv-vector-builtins-functions.def
+++ b/gcc/config/riscv/riscv-vector-builtins-functions.def
@@ -44,5 +44,7 @@  DEF_RVV_FUNCTION (vle, loadstore, full_preds, all_v_scalar_const_ptr_ops)
 DEF_RVV_FUNCTION (vse, loadstore, none_m_preds, all_v_scalar_ptr_ops)
 DEF_RVV_FUNCTION (vlm, loadstore, none_preds, b_v_scalar_const_ptr_ops)
 DEF_RVV_FUNCTION (vsm, loadstore, none_preds, b_v_scalar_ptr_ops)
+DEF_RVV_FUNCTION (vlse, loadstore, full_preds, all_v_scalar_const_ptr_ptrdiff_ops)
+DEF_RVV_FUNCTION (vsse, loadstore, none_m_preds, all_v_scalar_ptr_ptrdiff_ops)
 
 #undef DEF_RVV_FUNCTION
diff --git a/gcc/config/riscv/riscv-vector-builtins.cc b/gcc/config/riscv/riscv-vector-builtins.cc
index f95fe0d58d5..b97a2c94550 100644
--- a/gcc/config/riscv/riscv-vector-builtins.cc
+++ b/gcc/config/riscv/riscv-vector-builtins.cc
@@ -167,6 +167,19 @@  static CONSTEXPR const rvv_arg_type_info scalar_ptr_args[]
   = {rvv_arg_type_info (RVV_BASE_scalar_ptr),
      rvv_arg_type_info (RVV_BASE_vector), rvv_arg_type_info_end};
 
+/* A list of args for vector_type func (const scalar_type *, ptrdiff_t)
+ * function.  */
+static CONSTEXPR const rvv_arg_type_info scalar_const_ptr_ptrdiff_args[]
+  = {rvv_arg_type_info (RVV_BASE_scalar_const_ptr),
+     rvv_arg_type_info (RVV_BASE_ptrdiff), rvv_arg_type_info_end};
+
+/* A list of args for void func (scalar_type *, ptrdiff_t, vector_type)
+ * function.  */
+static CONSTEXPR const rvv_arg_type_info scalar_ptr_ptrdiff_args[]
+  = {rvv_arg_type_info (RVV_BASE_scalar_ptr),
+     rvv_arg_type_info (RVV_BASE_ptrdiff), rvv_arg_type_info (RVV_BASE_vector),
+     rvv_arg_type_info_end};
+
 /* A list of none preds that will be registered for intrinsic functions.  */
 static CONSTEXPR const predication_type_index none_preds[]
   = {PRED_TYPE_none, NUM_PRED_TYPES};
@@ -227,6 +240,22 @@  static CONSTEXPR const rvv_op_info b_v_scalar_ptr_ops
      rvv_arg_type_info (RVV_BASE_void), /* Return type */
      scalar_ptr_args /* Args */};
 
+/* A static operand information for vector_type func (const scalar_type *,
+ * ptrdiff_t) function registration. */
+static CONSTEXPR const rvv_op_info all_v_scalar_const_ptr_ptrdiff_ops
+  = {all_ops,				  /* Types */
+     OP_TYPE_v,				  /* Suffix */
+     rvv_arg_type_info (RVV_BASE_vector), /* Return type */
+     scalar_const_ptr_ptrdiff_args /* Args */};
+
+/* A static operand information for void func (scalar_type *, ptrdiff_t,
+ * vector_type) function registration. */
+static CONSTEXPR const rvv_op_info all_v_scalar_ptr_ptrdiff_ops
+  = {all_ops,				/* Types */
+     OP_TYPE_v,				/* Suffix */
+     rvv_arg_type_info (RVV_BASE_void), /* Return type */
+     scalar_ptr_ptrdiff_args /* Args */};
+
 /* A list of all RVV intrinsic functions.  */
 static function_group_info function_groups[] = {
 #define DEF_RVV_FUNCTION(NAME, SHAPE, PREDS, OPS_INFO)                         \
@@ -920,7 +949,9 @@  function_expander::use_contiguous_load_insn (insn_code icode)
       add_input_operand (Pmode, get_tail_policy_for_pred (pred));
       add_input_operand (Pmode, get_mask_policy_for_pred (pred));
     }
-  add_input_operand (Pmode, get_avl_type_rtx (avl_type::NONVLMAX));
+
+  if (opno != insn_data[icode].n_generator_args)
+    add_input_operand (Pmode, get_avl_type_rtx (avl_type::NONVLMAX));
 
   return generate_insn (icode);
 }
diff --git a/gcc/config/riscv/vector.md b/gcc/config/riscv/vector.md
index e1173f2d5a6..498cf21905b 100644
--- a/gcc/config/riscv/vector.md
+++ b/gcc/config/riscv/vector.md
@@ -33,6 +33,7 @@ 
   UNSPEC_VUNDEF
   UNSPEC_VPREDICATE
   UNSPEC_VLMAX
+  UNSPEC_STRIDED
 ])
 
 (define_constants [
@@ -204,28 +205,56 @@ 
 
 ;; The index of operand[] to get the avl op.
 (define_attr "vl_op_idx" ""
-	(cond [(eq_attr "type" "vlde,vste,vimov,vfmov,vldm,vstm,vlds,vmalu")
-	 (const_int 4)]
-	(const_int INVALID_ATTRIBUTE)))
+  (cond [(eq_attr "type" "vlde,vste,vimov,vfmov,vldm,vstm,vmalu,vsts")
+	   (const_int 4)
+
+	 ;; If operands[3] of "vlds" is not vector mode, it is pred_broadcast.
+	 ;; wheras it is pred_strided_load if operands[3] is vector mode.
+         (eq_attr "type" "vlds")
+	   (if_then_else (match_test "VECTOR_MODE_P (GET_MODE (operands[3]))")
+             (const_int 5)
+             (const_int 4))]
+  (const_int INVALID_ATTRIBUTE)))
 
 ;; The tail policy op value.
 (define_attr "ta" ""
-  (cond [(eq_attr "type" "vlde,vimov,vfmov,vlds")
-	   (symbol_ref "riscv_vector::get_ta(operands[5])")]
+  (cond [(eq_attr "type" "vlde,vimov,vfmov")
+	   (symbol_ref "riscv_vector::get_ta(operands[5])")
+
+	 ;; If operands[3] of "vlds" is not vector mode, it is pred_broadcast.
+	 ;; wheras it is pred_strided_load if operands[3] is vector mode.
+	 (eq_attr "type" "vlds")
+	   (if_then_else (match_test "VECTOR_MODE_P (GET_MODE (operands[3]))")
+	     (symbol_ref "riscv_vector::get_ta(operands[6])")
+	     (symbol_ref "riscv_vector::get_ta(operands[5])"))]
 	(const_int INVALID_ATTRIBUTE)))
 
 ;; The mask policy op value.
 (define_attr "ma" ""
-  (cond [(eq_attr "type" "vlde,vlds")
-	   (symbol_ref "riscv_vector::get_ma(operands[6])")]
+  (cond [(eq_attr "type" "vlde")
+	   (symbol_ref "riscv_vector::get_ma(operands[6])")
+
+	 ;; If operands[3] of "vlds" is not vector mode, it is pred_broadcast.
+	 ;; wheras it is pred_strided_load if operands[3] is vector mode.
+	 (eq_attr "type" "vlds")
+	   (if_then_else (match_test "VECTOR_MODE_P (GET_MODE (operands[3]))")
+	     (symbol_ref "riscv_vector::get_ma(operands[7])")
+	     (symbol_ref "riscv_vector::get_ma(operands[6])"))]
 	(const_int INVALID_ATTRIBUTE)))
 
 ;; The avl type value.
 (define_attr "avl_type" ""
-  (cond [(eq_attr "type" "vlde,vlde,vste,vimov,vimov,vimov,vfmov,vlds,vlds")
+  (cond [(eq_attr "type" "vlde,vlde,vste,vimov,vimov,vimov,vfmov")
 	   (symbol_ref "INTVAL (operands[7])")
 	 (eq_attr "type" "vldm,vstm,vimov,vmalu,vmalu")
-	   (symbol_ref "INTVAL (operands[5])")]
+	   (symbol_ref "INTVAL (operands[5])")
+
+	 ;; If operands[3] of "vlds" is not vector mode, it is pred_broadcast.
+	 ;; wheras it is pred_strided_load if operands[3] is vector mode.
+	 (eq_attr "type" "vlds")
+	   (if_then_else (match_test "VECTOR_MODE_P (GET_MODE (operands[3]))")
+	     (const_int INVALID_ATTRIBUTE)
+	     (symbol_ref "INTVAL (operands[7])"))]
 	(const_int INVALID_ATTRIBUTE)))
 
 ;; -----------------------------------------------------------------
@@ -760,3 +789,46 @@ 
    vlse<sew>.v\t%0,%3,zero"
   [(set_attr "type" "vimov,vfmov,vlds,vlds")
    (set_attr "mode" "<MODE>")])
+
+;; -------------------------------------------------------------------------------
+;; ---- Predicated Strided loads/stores
+;; -------------------------------------------------------------------------------
+;; Includes:
+;; - 7.5. Vector Strided Instructions
+;; -------------------------------------------------------------------------------
+
+(define_insn "@pred_strided_load<mode>"
+  [(set (match_operand:V 0 "register_operand"              "=vr,    vr,    vd")
+	(if_then_else:V
+	  (unspec:<VM>
+	    [(match_operand:<VM> 1 "vector_mask_operand" "vmWc1,   Wc1,    vm")
+	     (match_operand 5 "vector_length_operand"    "   rK,    rK,    rK")
+	     (match_operand 6 "const_int_operand"        "    i,     i,     i")
+	     (match_operand 7 "const_int_operand"        "    i,     i,     i")
+	     (reg:SI VL_REGNUM)
+	     (reg:SI VTYPE_REGNUM)] UNSPEC_VPREDICATE)
+	  (unspec:V
+	    [(match_operand:V 3 "memory_operand"         "    m,     m,     m")
+	     (match_operand 4 "pmode_reg_or_0_operand"   "   rJ,    rJ,    rJ")] UNSPEC_STRIDED)
+	  (match_operand:V 2 "vector_merge_operand"      "    0,    vu,    vu")))]
+  "TARGET_VECTOR"
+  "vlse<sew>.v\t%0,%3,%z4%p1"
+  [(set_attr "type" "vlds")
+   (set_attr "mode" "<MODE>")])
+
+(define_insn "@pred_strided_store<mode>"
+  [(set (match_operand:V 0 "memory_operand"                 "+m")
+	(if_then_else:V
+	  (unspec:<VM>
+	    [(match_operand:<VM> 1 "vector_mask_operand" "vmWc1")
+	     (match_operand 4 "vector_length_operand"    "   rK")
+	     (reg:SI VL_REGNUM)
+	     (reg:SI VTYPE_REGNUM)] UNSPEC_VPREDICATE)
+	  (unspec:V
+	    [(match_operand 2 "pmode_reg_or_0_operand"   "   rJ")
+	     (match_operand:V 3 "register_operand"       "   vr")] UNSPEC_STRIDED)
+	  (match_dup 0)))]
+  "TARGET_VECTOR"
+  "vsse<sew>.v\t%3,%0,%z2%p1"
+  [(set_attr "type" "vsts")
+   (set_attr "mode" "<MODE>")])