diff mbox series

[v2,8/16] middle-end: add Complex Multiply and Accumulate/Subtract and Multiply and Accumulate/Subtract with Conjucate detection

Message ID 20200925142931.GA21805@arm.com
State New
Headers show
Series middle-end Add support for SLP vectorization of complex number instructions. | expand

Commit Message

Tamar Christina Sept. 25, 2020, 2:29 p.m. UTC
Hi All,

This patch adds pattern detections for the following operation:

  Complex FMLA, Conjucate FMLA of the second parameter and FMLS.

    c += a * b, c += a * conj (b), c -= a * b and c -= a * conj (b)

  For the conjucate cases it supports under fast-math that the operands that is
  being conjucated be flipped by flipping the arguments to the optab.  This
  allows it to support c = conj (a) * b and c += conj (a) * b.

  where a, b and c are complex numbers.

Bootstrapped Regtested on aarch64-none-linux-gnu and no issues.

Ok for master?

Thanks,
Tamar

gcc/ChangeLog:

	* doc/md.texi: Document optabs.
	* internal-fn.def (COMPLEX_FMA, COMPLEX_FMA_CONJ, COMPLEX_FMS,
	COMPLEX_FMS_CONJ): New.
	* optabs.def (cmla_optab, cmla_conj_optab, cmls_optab, cmls_conj_optab):
	New.
	* tree-vect-slp-patterns.c (class ComplexFMAPattern): New.
	(slp_patterns): Add ComplexFMAPattern.

--

Comments

Tamar Christina Nov. 3, 2020, 3:06 p.m. UTC | #1
Hi All,

This is a respin of the patch using the new approach.

Bootstrapped Regtested on aarch64-none-linux-gnu and no issues.

Ok for master?

Thanks,
Tamar

gcc/ChangeLog:

	* doc/md.texi: Document optabs.
	* internal-fn.def (COMPLEX_FMA, COMPLEX_FMA_CONJ, COMPLEX_FMS,
	COMPLEX_FMS_CONJ): New.
	* optabs.def (cmla_optab, cmla_conj_optab, cmls_optab, cmls_conj_optab):
	New.
	* tree-vect-slp-patterns.c (class complex_fma_pattern,
	complex_fma_pattern::matches): New.
	(slp_patterns): Add complex_fma_pattern.

> -----Original Message-----
> From: Gcc-patches <gcc-patches-bounces@gcc.gnu.org> On Behalf Of Tamar
> Christina
> Sent: Friday, September 25, 2020 3:30 PM
> To: gcc-patches@gcc.gnu.org
> Cc: nd <nd@arm.com>; rguenther@suse.de; ook@ucw.cz
> Subject: [PATCH v2 8/16]middle-end: add Complex Multiply and
> Accumulate/Subtract and Multiply and Accumulate/Subtract with Conjucate
> detection
> 
> Hi All,
> 
> This patch adds pattern detections for the following operation:
> 
>   Complex FMLA, Conjucate FMLA of the second parameter and FMLS.
> 
>     c += a * b, c += a * conj (b), c -= a * b and c -= a * conj (b)
> 
>   For the conjucate cases it supports under fast-math that the operands that
> is
>   being conjucated be flipped by flipping the arguments to the optab.  This
>   allows it to support c = conj (a) * b and c += conj (a) * b.
> 
>   where a, b and c are complex numbers.
> 
> Bootstrapped Regtested on aarch64-none-linux-gnu and no issues.
> 
> Ok for master?
> 
> Thanks,
> Tamar
> 
> gcc/ChangeLog:
> 
> 	* doc/md.texi: Document optabs.
> 	* internal-fn.def (COMPLEX_FMA, COMPLEX_FMA_CONJ,
> COMPLEX_FMS,
> 	COMPLEX_FMS_CONJ): New.
> 	* optabs.def (cmla_optab, cmla_conj_optab, cmls_optab,
> cmls_conj_optab):
> 	New.
> 	* tree-vect-slp-patterns.c (class ComplexFMAPattern): New.
> 	(slp_patterns): Add ComplexFMAPattern.
> 
> --
diff mbox series

Patch

diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi
index ddaf1abaccbd44dae11ea902ec38b474aacfb8e1..d8142f745050d963e8d15c7793fae06d9ad02020 100644
--- a/gcc/doc/md.texi
+++ b/gcc/doc/md.texi
@@ -6143,6 +6143,50 @@  rotations @var{m} of 90 or 270.
 
 This pattern is not allowed to @code{FAIL}.
 
+@cindex @code{cmla@var{m}4} instruction pattern
+@item @samp{cmla@var{m}4}
+Perform a vector floating point multiply and accumulate of complex numbers
+in operand 0, operand 1 and operand 2.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmla_conj@var{m}4} instruction pattern
+@item @samp{cmla_conj@var{m}4}
+Perform a vector floating point multiply and accumulate of complex numbers
+in operand 0, operand 1 and the conjucate of operand 2.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmls@var{m}4} instruction pattern
+@item @samp{cmls@var{m}4}
+Perform a vector floating point multiply and subtract of complex numbers
+in operand 0, operand 1 and operand 2.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmls_conj@var{m}4} instruction pattern
+@item @samp{cmls_conj@var{m}4}
+Perform a vector floating point multiply and subtract of complex numbers
+in operand 0, operand 1 and the conjucate of operand 2.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
 @cindex @code{cmul@var{m}4} instruction pattern
 @item @samp{cmul@var{m}4}
 Perform a vector floating point multiplication of complex numbers in operand 0
diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def
index 51bebf8701af262b22d66d19a29a8dafb74db1f0..cc0135cb2c1c14b593181edeaa5f896fa6c4c659 100644
--- a/gcc/internal-fn.def
+++ b/gcc/internal-fn.def
@@ -286,6 +286,10 @@  DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary)
 
 /* Ternary math functions.  */
 DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS, ECF_CONST, cmls, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMS_CONJ, ECF_CONST, cmls_conj, ternary)
 
 /* Unary integer ops.  */
 DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary)
diff --git a/gcc/optabs.def b/gcc/optabs.def
index 9c267d422478d0011f288b1f5f62daabe3989ba7..19db9c00896cd08adfd20a01669990bbbebd79f1 100644
--- a/gcc/optabs.def
+++ b/gcc/optabs.def
@@ -294,6 +294,10 @@  OPTAB_D (cadd90_optab, "cadd90$a3")
 OPTAB_D (cadd270_optab, "cadd270$a3")
 OPTAB_D (cmul_optab, "cmul$a3")
 OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
+OPTAB_D (cmla_optab, "cmla$a4")
+OPTAB_D (cmla_conj_optab, "cmla_conj$a4")
+OPTAB_D (cmls_optab, "cmls$a4")
+OPTAB_D (cmls_conj_optab, "cmls_conj$a4")
 OPTAB_D (cos_optab, "cos$a2")
 OPTAB_D (cosh_optab, "cosh$a2")
 OPTAB_D (exp10_optab, "exp10$a2")
diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
index bef7cc73b21c020e4c0128df5d186a034809b103..d9554aaaf2cce14bb5b9c68e6141ea7f555a35de 100644
--- a/gcc/tree-vect-slp-patterns.c
+++ b/gcc/tree-vect-slp-patterns.c
@@ -916,6 +916,199 @@  class ComplexMulPattern : public ComplexMLAPattern
     }
 };
 
+class ComplexFMAPattern : public ComplexMLAPattern
+{
+  protected:
+    ComplexFMAPattern (slp_tree node, vec_info *vinfo)
+      : ComplexMLAPattern (node, vinfo)
+    {
+      this->m_arity = 2;
+      this->m_num_args = 3;
+      this->m_vects.create (0);
+      this->m_defs.create (0);
+    }
+
+  public:
+    ~ComplexFMAPattern ()
+    {
+      this->m_vects.release ();
+      this->m_defs.release ();
+    }
+
+    static VectPattern* create (slp_tree node, vec_info *vinfo)
+    {
+       return new ComplexFMAPattern (node, vinfo);
+    }
+
+    const char* get_name ()
+    {
+      return "Complex FM(A|S)";
+    }
+
+    /* Pattern matcher for trying to match complex multiply and accumulate
+       pattern in SLP tree using N statements STMT_0 and STMT_0 as the root
+       statements by finding the statements starting in position IDX in NODE.
+       If the operation matches then IFN is set to the operation it matched and
+       the arguments to the two replacement statements are put in VECTS.
+
+       If no match is found then IFN is set to IFN_LAST and VECTS is unchanged.
+
+       This function matches the patterns shaped as:
+
+	 double ax = (b[i+1] * a[i]) + (b[i] * a[i]);
+	 double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]);
+
+	 c[i] = c[i] - ax;
+	 c[i+1] = c[i+1] + bx;
+
+       If a match occurred then TRUE is returned, else FALSE.  */
+    bool
+    matches (stmt_vec_info *stmts, int idx)
+    {
+      this->m_last_ifn = IFN_LAST;
+      this->m_vects.truncate (0);
+      this->m_vects.create (6);
+      int base = idx - (this->m_arity - 1);
+      this->m_last_idx = idx;
+      slp_tree node = this->m_node;
+      this->m_stmt_info = stmts[0];
+
+
+      /* Find the two components.  Rotation in the complex plane will modify
+	 the operations:
+
+	 * Rotation  0: + +
+	 * Rotation 90: - +
+	 * Rotation 180: - -
+	 * Rotation 270: + -.  */
+      auto_vec<stmt_vec_info> args0;
+      complex_operation_t op1 = vect_detect_pair_op (base, node, &args0);
+
+      if (op1 == CMPLX_NONE)
+	return false;
+
+      slp_tree sub1, sub2a, sub2b, sub3;
+
+      /* Now operand2+4 must lead to another expression.  */
+      auto_vec<stmt_vec_info> args1;
+      complex_operation_t op2
+	= vect_match_call_complex_mla_1 (node, &sub1, 1, base, 0, &args1);
+
+      if (op2 != MINUS_PLUS && op2 != PLUS_MINUS)
+	return false;
+
+      /* Now operand1+3 must lead to another expression.  */
+      auto_vec<stmt_vec_info> args2;
+      complex_operation_t op3
+	= vect_match_call_complex_mla_1 (sub1, &sub2a, 0, base, 0, &args2);
+
+      if (op3 != MULT_MULT)
+	return false;
+
+      /* Now operand2+4 must lead to another expression.  */
+      auto_vec<stmt_vec_info> args3;
+      complex_operation_t op4
+	= vect_match_call_complex_mla_1 (sub1, &sub2b, 1, base, 0, &args3);
+
+      if (op4 != MULT_MULT)
+	return false;
+
+      /* Now operand2+4 may lead to another expression.  */
+      auto_vec<stmt_vec_info> args4;
+      complex_operation_t op5
+	= vect_match_call_complex_mla_1 (sub2b, &sub3, 1, base, 0, &args4);
+
+      /* Or operand1+3 may lead to another expression.  */
+      auto_vec<stmt_vec_info> args5;
+      complex_operation_t op6
+	= vect_match_call_complex_mla_1 (sub2b, &sub3, 0, base, 0, &args5);
+
+      if (op1 == PLUS_MINUS && op2 == MINUS_PLUS)
+	{
+
+	  /* The FMS conjucate has a different layout so check that.  */
+	  if (op5 == CMPLX_NONE && op6 == CMPLX_NONE)
+	    {
+	       op6 = vect_match_call_complex_mla_1 (sub2a, &sub3, 0, base, 0,
+						    &args5);
+	       if (op6 == CMPLX_NONE)
+	         op6 = vect_match_call_complex_mla_1 (sub2a, &sub3, 1, base, 0,
+						      &args5);
+	    }
+	  if (op5 == CMPLX_NONE && op6 != NEG_NEG)
+	    this->m_last_ifn = IFN_COMPLEX_FMS;
+	  else if (op5 == NEG_NEG || op6 == NEG_NEG)
+	    this->m_last_ifn = IFN_COMPLEX_FMS_CONJ;
+	}
+      else if (op1 == PLUS_PLUS && op2 == MINUS_PLUS)
+	{
+	  if (op5 == CMPLX_NONE && op6 != NEG_NEG)
+	    this->m_last_ifn = IFN_COMPLEX_FMA;
+	  else if (op5 == NEG_NEG || op6 == NEG_NEG)
+	    this->m_last_ifn = IFN_COMPLEX_FMA_CONJ;
+	}
+
+      if (this->m_last_ifn == IFN_LAST)
+	return false;
+
+      if (this->m_last_ifn == IFN_COMPLEX_FMA_CONJ)
+	{
+	  /* Check if the conjucate is on the first or second parameter.  */
+	  if (op5 == NEG_NEG)
+	    {
+	      this->m_vects.quick_push (args0[0]);
+	      this->m_vects.quick_push (args2[2]);
+	      this->m_vects.quick_push (args3[2]);
+	      this->m_vects.quick_push (args0[2]);
+	      this->m_vects.quick_push (args4[0]);
+	      this->m_vects.quick_push (args2[3]);
+	    }
+	  else
+	    {
+	      this->m_vects.quick_push (args0[0]);
+	      this->m_vects.quick_push (args2[3]);
+	      this->m_vects.quick_push (args2[0]);
+	      this->m_vects.quick_push (args0[2]);
+	      this->m_vects.quick_push (args5[0]);
+	      this->m_vects.quick_push (args2[2]);
+	    }
+        }
+      else if (this->m_last_ifn == IFN_COMPLEX_FMS_CONJ)
+	{
+	  /* Check if the conjucate is on the first or second parameter.  */
+	  if (op6 == NEG_NEG)
+	    {
+	      this->m_vects.quick_push (args0[0]);
+	      this->m_vects.quick_push (args3[1]);
+	      this->m_vects.quick_push (args2[3]);
+	      this->m_vects.quick_push (args0[2]);
+	      this->m_vects.quick_push (args5[0]);
+	      this->m_vects.quick_push (args2[1]);
+	    }
+	  else
+	    {
+	      this->m_vects.quick_push (args0[0]);
+	      this->m_vects.quick_push (args2[2]);
+	      this->m_vects.quick_push (args3[2]);
+	      this->m_vects.quick_push (args0[2]);
+	      this->m_vects.quick_push (args2[0]);
+	      this->m_vects.quick_push (args5[0]);
+	    }
+        }
+      else
+        {
+          this->m_vects.quick_push (args0[0]);
+          this->m_vects.quick_push (args2[3]);
+          this->m_vects.quick_push (args3[2]);
+          this->m_vects.quick_push (args0[2]);
+          this->m_vects.quick_push (args3[3]);
+          this->m_vects.quick_push (args2[2]);
+        }
+
+      return store_results ();
+    }
+};
+
 #define SLP_PATTERN(x) &x::create
 VectPatternDecl slp_patterns[]
 {
@@ -923,6 +1116,7 @@  VectPatternDecl slp_patterns[]
      order patterns from the largest to the smallest.  Especially if they
      overlap in what they can detect.  */
 
+  SLP_PATTERN (ComplexFMAPattern),
   SLP_PATTERN (ComplexMulPattern),
   SLP_PATTERN (ComplexAddPattern),
 };