diff mbox series

[v2,7/16] middle-end: Add Complex Multiplication and Multiplication with Conjucate detection

Message ID 20200925142914.GA19264@arm.com
State New
Headers show
Series middle-end Add support for SLP vectorization of complex number instructions. | expand

Commit Message

Tamar Christina Sept. 25, 2020, 2:29 p.m. UTC
Hi All,

This patch adds pattern detections for the following operation:

  Complex multiplication and Conjucate Complex multiplication of the second
     parameter.

    c = a * b and c = a * conj (b)

  For the conjucate cases it supports under fast-math that the operands that is
  being conjucated be flipped by flipping the arguments to the optab.  This
  allows it to support c = conj (a) * b and c += conj (a) * b.

  where a, b and c are complex numbers.

and provides a shared class for anything needing to recognize complex MLA
patterns.

Bootstrapped Regtested on aarch64-none-linux-gnu and no issues.

Ok for master?

Thanks,
Tamar

gcc/ChangeLog:

	* doc/md.texi: Document optabs.
	* internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New.
	* optabs.def (cmul_optab, cmul_conj_optab): New,
	* tree-vect-slp-patterns.c (class ComplexMLAPattern,
	class ComplexMulPattern): New.
	(slp_patterns): Add ComplexMulPattern.

--

Comments

Tamar Christina Nov. 3, 2020, 3:06 p.m. UTC | #1
Hi All,

This is a respin of this patch using the new approach.

Thanks,
Tamar

gcc/ChangeLog:

	* doc/md.texi: Document optabs.
	* internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New.
	* optabs.def (cmul_optab, cmul_conj_optab): New,
	* tree-vect-slp-patterns.c (vect_build_perm_groups,
	(vect_can_combine_node_p, vect_slp_make_combine_linear,
	vect_match_call_complex_mla, vect_slp_matches_complex_mul,
	class complex_mul_pattern, complex_mul_pattern::matches,
	complex_mul_pattern::validate_p,
	complex_operations_pattern::matches): Add complex_mul_pattern.


> -----Original Message-----
> From: Gcc-patches <gcc-patches-bounces@gcc.gnu.org> On Behalf Of Tamar
> Christina
> Sent: Friday, September 25, 2020 3:29 PM
> To: gcc-patches@gcc.gnu.org
> Cc: nd <nd@arm.com>; rguenther@suse.de; ook@ucw.cz
> Subject: [PATCH v2 7/16]middle-end: Add Complex Multiplication and
> Multiplication with Conjucate detection
> 
> Hi All,
> 
> This patch adds pattern detections for the following operation:
> 
>   Complex multiplication and Conjucate Complex multiplication of the second
>      parameter.
> 
>     c = a * b and c = a * conj (b)
> 
>   For the conjucate cases it supports under fast-math that the operands that
> is
>   being conjucated be flipped by flipping the arguments to the optab.  This
>   allows it to support c = conj (a) * b and c += conj (a) * b.
> 
>   where a, b and c are complex numbers.
> 
> and provides a shared class for anything needing to recognize complex MLA
> patterns.
> 
> Bootstrapped Regtested on aarch64-none-linux-gnu and no issues.
> 
> Ok for master?
> 
> Thanks,
> Tamar
> 
> gcc/ChangeLog:
> 
> 	* doc/md.texi: Document optabs.
> 	* internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New.
> 	* optabs.def (cmul_optab, cmul_conj_optab): New,
> 	* tree-vect-slp-patterns.c (class ComplexMLAPattern,
> 	class ComplexMulPattern): New.
> 	(slp_patterns): Add ComplexMulPattern.
> 
> --
diff mbox series

Patch

diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi
index 71e226505b2619d10982b59a4ebbed73a70f29be..ddaf1abaccbd44dae11ea902ec38b474aacfb8e1 100644
--- a/gcc/doc/md.texi
+++ b/gcc/doc/md.texi
@@ -6143,6 +6143,28 @@  rotations @var{m} of 90 or 270.
 
 This pattern is not allowed to @code{FAIL}.
 
+@cindex @code{cmul@var{m}4} instruction pattern
+@item @samp{cmul@var{m}4}
+Perform a vector floating point multiplication of complex numbers in operand 0
+and operand 1.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmul_conj@var{m}4} instruction pattern
+@item @samp{cmul_conj@var{m}4}
+Perform a vector floating point multiplication of complex numbers in operand 0
+and the conjucate of operand 1.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
 @cindex @code{ffs@var{m}2} instruction pattern
 @item @samp{ffs@var{m}2}
 Store into operand 0 one plus the index of the least significant 1-bit
diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def
index 956a65a338c157b51de7e78a3fb005b5af78ef31..51bebf8701af262b22d66d19a29a8dafb74db1f0 100644
--- a/gcc/internal-fn.def
+++ b/gcc/internal-fn.def
@@ -277,6 +277,9 @@  DEF_INTERNAL_FLT_FLOATN_FN (FMAX, ECF_CONST, fmax, binary)
 DEF_INTERNAL_OPTAB_FN (XORSIGN, ECF_CONST, xorsign, binary)
 DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT90, ECF_CONST, cadd90, binary)
 DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT270, ECF_CONST, cadd270, binary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL, ECF_CONST, cmul, binary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL_CONJ, ECF_CONST, cmul_conj, binary)
+
 
 /* FP scales.  */
 DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary)
diff --git a/gcc/optabs.def b/gcc/optabs.def
index 2bb0bf857977035bf562a77f5f6848e80edf936d..9c267d422478d0011f288b1f5f62daabe3989ba7 100644
--- a/gcc/optabs.def
+++ b/gcc/optabs.def
@@ -292,6 +292,8 @@  OPTAB_D (copysign_optab, "copysign$F$a3")
 OPTAB_D (xorsign_optab, "xorsign$F$a3")
 OPTAB_D (cadd90_optab, "cadd90$a3")
 OPTAB_D (cadd270_optab, "cadd270$a3")
+OPTAB_D (cmul_optab, "cmul$a3")
+OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
 OPTAB_D (cos_optab, "cos$a2")
 OPTAB_D (cosh_optab, "cosh$a2")
 OPTAB_D (exp10_optab, "exp10$a2")
diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
index b2b0ac62e9a69145470f41d2bac736dd970be735..bef7cc73b21c020e4c0128df5d186a034809b103 100644
--- a/gcc/tree-vect-slp-patterns.c
+++ b/gcc/tree-vect-slp-patterns.c
@@ -743,6 +743,179 @@  class ComplexAddPattern : public ComplexPattern
     }
 };
 
+class ComplexMLAPattern : public ComplexPattern
+{
+  protected:
+    ComplexMLAPattern (slp_tree node, vec_info *vinfo)
+      : ComplexPattern (node, vinfo)
+    { }
+
+  protected:
+    /* Helper function of vect_match_call_complex_mla that looks up the
+       definition of LHS_0 and LHS_1 by finding the statements starting in
+       position BASE + IDX in child ROOT of NODE and tries to match the
+       definition against pair ops.
+
+       If the match is successful then ARGS will contain the operands matched
+       and the complex_operation_t type is returned.  If match is not successful
+       then CMPLX_NONE is returned and ARGS is left unmodified.  */
+
+    complex_operation_t
+    vect_match_call_complex_mla_1 (slp_tree node, slp_tree *res, int root,
+				   int base, int idx, vec<stmt_vec_info> *args)
+    {
+      gcc_assert (base >= 0 && idx >= 0 && node != NULL);
+
+      if ((unsigned)root >= SLP_TREE_CHILDREN (node).length ())
+	return CMPLX_NONE;
+
+      slp_tree data = SLP_TREE_CHILDREN (node)[root];
+
+      /* If it's a VEC_PERM_EXPR we need to look one deeper.  */
+      if (node->code == VEC_PERM_EXPR)
+	data = SLP_TREE_CHILDREN (data)[root];
+
+      int lhs_0 = base + idx;
+      int lhs_1 = base + idx + 1;
+
+      vec<stmt_vec_info> stmts = SLP_TREE_SCALAR_STMTS (data);
+      if (stmts.length () < (unsigned)lhs_1)
+	return CMPLX_NONE;
+
+      gimple *stmt_0 = STMT_VINFO_STMT (stmts[lhs_0]);
+      gimple *stmt_1 = STMT_VINFO_STMT (stmts[lhs_1]);
+
+      if (gimple_expr_type (stmt_0) != gimple_expr_type (stmt_1))
+	return CMPLX_NONE;
+
+      if (res)
+	*res = data;
+
+      return vect_detect_pair_op (base, data, args);
+    }
+};
+
+class ComplexMulPattern : public ComplexMLAPattern
+{
+  protected:
+    ComplexMulPattern (slp_tree node, vec_info *vinfo)
+      : ComplexMLAPattern (node, vinfo)
+    {
+      this->m_arity = 2;
+      this->m_num_args = 2;
+      this->m_vects.create (0);
+      this->m_defs.create (0);
+    }
+
+  public:
+    ~ComplexMulPattern ()
+    {
+      this->m_vects.release ();
+      this->m_defs.release ();
+    }
+
+    static VectPattern* create (slp_tree node, vec_info *vinfo)
+    {
+       return new ComplexMulPattern (node, vinfo);
+    }
+
+    const char* get_name ()
+    {
+      return "Complex Multiplication";
+    }
+
+
+    /* Pattern matcher for trying to match complex multiply pattern in SLP tree
+       using N statements STMT_0 and STMT_0 as the root statements by finding
+       the statements starting in position IDX in NODE.  If the operation
+       matches then IFN is set to the operation it matched and the arguments to
+       the two replacement statements are put in VECTS.
+
+       If no match is found then IFN is set to IFN_LAST and VECTS is unchanged.
+
+       This function matches the patterns shaped as:
+
+	 double ax = (b[i+1] * a[i]);
+	 double bx = (a[i+1] * b[i]);
+
+	 c[i] = c[i] - ax;
+	 c[i+1] = c[i+1] + bx;
+
+       If a match occurred then TRUE is returned, else FALSE.  */
+
+    bool
+    matches (stmt_vec_info *stmts, int idx)
+    {
+      this->m_last_ifn = IFN_LAST;
+      this->m_vects.truncate (0);
+      this->m_vects.create (6);
+      int base = idx - (this->m_arity - 1);
+      this->m_last_idx = idx;
+      this->m_stmt_info = stmts[0];
+
+      complex_operation_t op1 = vect_detect_pair_op (base, this->m_node, NULL);
+
+      if (op1 != MINUS_PLUS)
+	return false;
+
+      slp_tree sub1a, sub1b, sub2;
+      /* Now operand1+3 must lead to another expression.  */
+      auto_vec<stmt_vec_info> args0;
+      complex_operation_t op2
+	= vect_match_call_complex_mla_1 (this->m_node, &sub1a, 0, base, 0,
+					 &args0);
+
+      if (op2 != MULT_MULT)
+	return false;
+
+      /* Now operand2+4 must lead to another expression.  */
+      auto_vec<stmt_vec_info> args1;
+      complex_operation_t op3
+	= vect_match_call_complex_mla_1 (this->m_node, &sub1b, 1, base, 0,
+					 &args1);
+
+      if (op3 != MULT_MULT)
+	return false;
+
+      /* Now operand2+4 may lead to another expression.  */
+      auto_vec<stmt_vec_info> args2;
+      complex_operation_t op4
+	= vect_match_call_complex_mla_1 (sub1b, &sub2, 1, base, 0, &args2);
+
+      if (op4 != CMPLX_NONE && op4 != NEG_NEG)
+	return false;
+
+      if (op4 == CMPLX_NONE)
+	{
+	  this->m_last_ifn = IFN_COMPLEX_MUL;
+	  /* Correct the arguments after matching.  */
+	  std::swap (args0[2], args1[0]);
+	}
+      else if (op4 == NEG_NEG)
+	{
+	  this->m_last_ifn = IFN_COMPLEX_MUL_CONJ;
+	  /* Check if the conjucate is on the first or second parameter.  */
+	  if (args1[1] == args1[3] && args0[1] == args0[3])
+	    {
+	      this->m_vects.quick_push (args0[3]);
+	      this->m_vects.quick_push (args0[0]);
+	      this->m_vects.quick_push (args2[0]);
+	      this->m_vects.quick_push (args0[2]);
+	    }
+	  else
+	    {
+	      /* Correct the arguments after matching.  */
+	      std::swap (args0[2], args2[0]);
+	    }
+        }
+
+      if (this->m_vects.length () == 0)
+	this->m_vects.splice (args0);
+
+      return this->m_last_ifn != IFN_LAST && store_results ();
+    }
+};
+
 #define SLP_PATTERN(x) &x::create
 VectPatternDecl slp_patterns[]
 {
@@ -750,6 +923,7 @@  VectPatternDecl slp_patterns[]
      order patterns from the largest to the smallest.  Especially if they
      overlap in what they can detect.  */
 
+  SLP_PATTERN (ComplexMulPattern),
   SLP_PATTERN (ComplexAddPattern),
 };
 #undef SLP_PATTERN