diff mbox series

Handle IFN_COND_MUL in tree-ssa-math-opts.c

Message ID mpt7e7zj0xs.fsf@arm.com
State New
Headers show
Series Handle IFN_COND_MUL in tree-ssa-math-opts.c | expand

Commit Message

Richard Sandiford July 30, 2019, 10:01 a.m. UTC
This patch extends the FMA handling in tree-ssa-math-opts.c so
that it can cope with conditional multiplications as well as
unconditional multiplications.  The addition or subtraction must then
have the same condition as the multiplication (at least for now).

E.g. we can currently fold:

  (IFN_COND_ADD cond (mul x y) z fallback)
    -> (IFN_COND_FMA cond x y z fallback)

This patch also allows:

  (IFN_COND_ADD cond (IFN_COND_MUL cond x y <whatever>) z fallback)
    -> (IFN_COND_FMA cond x y z fallback)

Tested on aarch64-linux-gnu, aarch64_be-elf and x86_64-linux-gnu.
OK to install?

Richard


2019-07-30  Richard Sandiford  <richard.sandiford@arm.com>

gcc/
	* tree-ssa-math-opts.c (convert_mult_to_fma): Add a mul_cond
	parameter.  When nonnull, make sure that the addition or subtraction
	has the same condition.
	(math_opts_dom_walker::after_dom_children): Try convert_mult_to_fma
	for CFN_COND_MUL too.

gcc/testsuite/
	* gcc.dg/vect/vect-cond-arith-7.c: New test.

Comments

Richard Biener July 30, 2019, 10:27 a.m. UTC | #1
On Tue, Jul 30, 2019 at 12:01 PM Richard Sandiford
<richard.sandiford@arm.com> wrote:
>
> This patch extends the FMA handling in tree-ssa-math-opts.c so
> that it can cope with conditional multiplications as well as
> unconditional multiplications.  The addition or subtraction must then
> have the same condition as the multiplication (at least for now).
>
> E.g. we can currently fold:
>
>   (IFN_COND_ADD cond (mul x y) z fallback)
>     -> (IFN_COND_FMA cond x y z fallback)
>
> This patch also allows:
>
>   (IFN_COND_ADD cond (IFN_COND_MUL cond x y <whatever>) z fallback)
>     -> (IFN_COND_FMA cond x y z fallback)
>
> Tested on aarch64-linux-gnu, aarch64_be-elf and x86_64-linux-gnu.
> OK to install?

OK.

> Richard
>
>
> 2019-07-30  Richard Sandiford  <richard.sandiford@arm.com>
>
> gcc/
>         * tree-ssa-math-opts.c (convert_mult_to_fma): Add a mul_cond
>         parameter.  When nonnull, make sure that the addition or subtraction
>         has the same condition.
>         (math_opts_dom_walker::after_dom_children): Try convert_mult_to_fma
>         for CFN_COND_MUL too.
>
> gcc/testsuite/
>         * gcc.dg/vect/vect-cond-arith-7.c: New test.
>
> Index: gcc/tree-ssa-math-opts.c
> ===================================================================
> --- gcc/tree-ssa-math-opts.c    2019-07-30 10:51:22.000000000 +0100
> +++ gcc/tree-ssa-math-opts.c    2019-07-30 10:51:51.827405171 +0100
> @@ -3044,6 +3044,8 @@ last_fma_candidate_feeds_initial_phi (fm
>  /* Combine the multiplication at MUL_STMT with operands MULOP1 and MULOP2
>     with uses in additions and subtractions to form fused multiply-add
>     operations.  Returns true if successful and MUL_STMT should be removed.
> +   If MUL_COND is nonnull, the multiplication in MUL_STMT is conditional
> +   on MUL_COND, otherwise it is unconditional.
>
>     If STATE indicates that we are deferring FMA transformation, that means
>     that we do not produce FMAs for basic blocks which look like:
> @@ -3060,7 +3062,7 @@ last_fma_candidate_feeds_initial_phi (fm
>
>  static bool
>  convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
> -                    fma_deferring_state *state)
> +                    fma_deferring_state *state, tree mul_cond = NULL_TREE)
>  {
>    tree mul_result = gimple_get_lhs (mul_stmt);
>    tree type = TREE_TYPE (mul_result);
> @@ -3174,6 +3176,9 @@ convert_mult_to_fma (gimple *mul_stmt, t
>           return false;
>         }
>
> +      if (mul_cond && cond != mul_cond)
> +       return false;
> +
>        if (cond)
>         {
>           if (cond == result || else_value == result)
> @@ -3785,38 +3790,48 @@ math_opts_dom_walker::after_dom_children
>         }
>        else if (is_gimple_call (stmt))
>         {
> -         tree fndecl = gimple_call_fndecl (stmt);
> -         if (fndecl && gimple_call_builtin_p (stmt, BUILT_IN_NORMAL))
> +         switch (gimple_call_combined_fn (stmt))
>             {
> -             switch (DECL_FUNCTION_CODE (fndecl))
> +           CASE_CFN_POW:
> +             if (gimple_call_lhs (stmt)
> +                 && TREE_CODE (gimple_call_arg (stmt, 1)) == REAL_CST
> +                 && real_equal (&TREE_REAL_CST (gimple_call_arg (stmt, 1)),
> +                                &dconst2)
> +                 && convert_mult_to_fma (stmt,
> +                                         gimple_call_arg (stmt, 0),
> +                                         gimple_call_arg (stmt, 0),
> +                                         &fma_state))
>                 {
> -               case BUILT_IN_POWF:
> -               case BUILT_IN_POW:
> -               case BUILT_IN_POWL:
> -                 if (gimple_call_lhs (stmt)
> -                     && TREE_CODE (gimple_call_arg (stmt, 1)) == REAL_CST
> -                     && real_equal
> -                     (&TREE_REAL_CST (gimple_call_arg (stmt, 1)),
> -                      &dconst2)
> -                     && convert_mult_to_fma (stmt,
> -                                             gimple_call_arg (stmt, 0),
> -                                             gimple_call_arg (stmt, 0),
> -                                             &fma_state))
> -                   {
> -                     unlink_stmt_vdef (stmt);
> -                     if (gsi_remove (&gsi, true)
> -                         && gimple_purge_dead_eh_edges (bb))
> -                       *m_cfg_changed_p = true;
> -                     release_defs (stmt);
> -                     continue;
> -                   }
> -                 break;
> +                 unlink_stmt_vdef (stmt);
> +                 if (gsi_remove (&gsi, true)
> +                     && gimple_purge_dead_eh_edges (bb))
> +                   *m_cfg_changed_p = true;
> +                 release_defs (stmt);
> +                 continue;
> +               }
> +             break;
>
> -               default:;
> +           case CFN_COND_MUL:
> +             if (convert_mult_to_fma (stmt,
> +                                      gimple_call_arg (stmt, 1),
> +                                      gimple_call_arg (stmt, 2),
> +                                      &fma_state,
> +                                      gimple_call_arg (stmt, 0)))
> +
> +               {
> +                 gsi_remove (&gsi, true);
> +                 release_defs (stmt);
> +                 continue;
>                 }
> +             break;
> +
> +           case CFN_LAST:
> +             cancel_fma_deferring (&fma_state);
> +             break;
> +
> +           default:
> +             break;
>             }
> -         else
> -           cancel_fma_deferring (&fma_state);
>         }
>        gsi_next (&gsi);
>      }
> Index: gcc/testsuite/gcc.dg/vect/vect-cond-arith-7.c
> ===================================================================
> --- /dev/null   2019-07-30 08:53:31.317691683 +0100
> +++ gcc/testsuite/gcc.dg/vect/vect-cond-arith-7.c       2019-07-30 10:51:51.823405201 +0100
> @@ -0,0 +1,60 @@
> +/* { dg-require-effective-target scalar_all_fma } */
> +/* { dg-additional-options "-fdump-tree-optimized -ffp-contract=fast" } */
> +
> +#include "tree-vect.h"
> +
> +#define N (VECTOR_BITS * 11 / 64 + 3)
> +
> +#define DEF(INV)                                       \
> +  void __attribute__ ((noipa))                         \
> +  f_##INV (double *restrict a, double *restrict b,     \
> +          double *restrict c, double *restrict d)      \
> +  {                                                    \
> +    for (int i = 0; i < N; ++i)                                \
> +      {                                                        \
> +       double mb = (INV & 1 ? -b[i] : b[i]);           \
> +       double mc = c[i];                               \
> +       double md = (INV & 2 ? -d[i] : d[i]);           \
> +       a[i] = b[i] < 10 ? mb * mc + md : 10.0;         \
> +      }                                                        \
> +  }
> +
> +#define TEST(INV)                                      \
> +  {                                                    \
> +    f_##INV (a, b, c, d);                              \
> +    for (int i = 0; i < N; ++i)                                \
> +      {                                                        \
> +       double mb = (INV & 1 ? -b[i] : b[i]);           \
> +       double mc = c[i];                               \
> +       double md = (INV & 2 ? -d[i] : d[i]);           \
> +       double fma = __builtin_fma (mb, mc, md);        \
> +       if (a[i] != (i % 17 < 10 ? fma : 10.0))         \
> +         __builtin_abort ();                           \
> +       asm volatile ("" ::: "memory");                 \
> +      }                                                        \
> +  }
> +
> +#define FOR_EACH_INV(T) \
> +  T (0) T (1) T (2) T (3)
> +
> +FOR_EACH_INV (DEF)
> +
> +int
> +main (void)
> +{
> +  double a[N], b[N], c[N], d[N];
> +  for (int i = 0; i < N; ++i)
> +    {
> +      b[i] = i % 17;
> +      c[i] = i % 9 + 11;
> +      d[i] = i % 13 + 14;
> +      asm volatile ("" ::: "memory");
> +    }
> +  FOR_EACH_INV (TEST)
> +  return 0;
> +}
> +
> +/* { dg-final { scan-tree-dump-times { = \.COND_FMA } 1 "optimized" { target vect_double_cond_arith } } } */
> +/* { dg-final { scan-tree-dump-times { = \.COND_FMS } 1 "optimized" { target vect_double_cond_arith } } } */
> +/* { dg-final { scan-tree-dump-times { = \.COND_FNMA } 1 "optimized" { target vect_double_cond_arith } } } */
> +/* { dg-final { scan-tree-dump-times { = \.COND_FNMS } 1 "optimized" { target vect_double_cond_arith } } } */
diff mbox series

Patch

Index: gcc/tree-ssa-math-opts.c
===================================================================
--- gcc/tree-ssa-math-opts.c	2019-07-30 10:51:22.000000000 +0100
+++ gcc/tree-ssa-math-opts.c	2019-07-30 10:51:51.827405171 +0100
@@ -3044,6 +3044,8 @@  last_fma_candidate_feeds_initial_phi (fm
 /* Combine the multiplication at MUL_STMT with operands MULOP1 and MULOP2
    with uses in additions and subtractions to form fused multiply-add
    operations.  Returns true if successful and MUL_STMT should be removed.
+   If MUL_COND is nonnull, the multiplication in MUL_STMT is conditional
+   on MUL_COND, otherwise it is unconditional.
 
    If STATE indicates that we are deferring FMA transformation, that means
    that we do not produce FMAs for basic blocks which look like:
@@ -3060,7 +3062,7 @@  last_fma_candidate_feeds_initial_phi (fm
 
 static bool
 convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
-		     fma_deferring_state *state)
+		     fma_deferring_state *state, tree mul_cond = NULL_TREE)
 {
   tree mul_result = gimple_get_lhs (mul_stmt);
   tree type = TREE_TYPE (mul_result);
@@ -3174,6 +3176,9 @@  convert_mult_to_fma (gimple *mul_stmt, t
 	  return false;
 	}
 
+      if (mul_cond && cond != mul_cond)
+	return false;
+
       if (cond)
 	{
 	  if (cond == result || else_value == result)
@@ -3785,38 +3790,48 @@  math_opts_dom_walker::after_dom_children
 	}
       else if (is_gimple_call (stmt))
 	{
-	  tree fndecl = gimple_call_fndecl (stmt);
-	  if (fndecl && gimple_call_builtin_p (stmt, BUILT_IN_NORMAL))
+	  switch (gimple_call_combined_fn (stmt))
 	    {
-	      switch (DECL_FUNCTION_CODE (fndecl))
+	    CASE_CFN_POW:
+	      if (gimple_call_lhs (stmt)
+		  && TREE_CODE (gimple_call_arg (stmt, 1)) == REAL_CST
+		  && real_equal (&TREE_REAL_CST (gimple_call_arg (stmt, 1)),
+				 &dconst2)
+		  && convert_mult_to_fma (stmt,
+					  gimple_call_arg (stmt, 0),
+					  gimple_call_arg (stmt, 0),
+					  &fma_state))
 		{
-		case BUILT_IN_POWF:
-		case BUILT_IN_POW:
-		case BUILT_IN_POWL:
-		  if (gimple_call_lhs (stmt)
-		      && TREE_CODE (gimple_call_arg (stmt, 1)) == REAL_CST
-		      && real_equal
-		      (&TREE_REAL_CST (gimple_call_arg (stmt, 1)),
-		       &dconst2)
-		      && convert_mult_to_fma (stmt,
-					      gimple_call_arg (stmt, 0),
-					      gimple_call_arg (stmt, 0),
-					      &fma_state))
-		    {
-		      unlink_stmt_vdef (stmt);
-		      if (gsi_remove (&gsi, true)
-			  && gimple_purge_dead_eh_edges (bb))
-			*m_cfg_changed_p = true;
-		      release_defs (stmt);
-		      continue;
-		    }
-		  break;
+		  unlink_stmt_vdef (stmt);
+		  if (gsi_remove (&gsi, true)
+		      && gimple_purge_dead_eh_edges (bb))
+		    *m_cfg_changed_p = true;
+		  release_defs (stmt);
+		  continue;
+		}
+	      break;
 
-		default:;
+	    case CFN_COND_MUL:
+	      if (convert_mult_to_fma (stmt,
+				       gimple_call_arg (stmt, 1),
+				       gimple_call_arg (stmt, 2),
+				       &fma_state,
+				       gimple_call_arg (stmt, 0)))
+
+		{
+		  gsi_remove (&gsi, true);
+		  release_defs (stmt);
+		  continue;
 		}
+	      break;
+
+	    case CFN_LAST:
+	      cancel_fma_deferring (&fma_state);
+	      break;
+
+	    default:
+	      break;
 	    }
-	  else
-	    cancel_fma_deferring (&fma_state);
 	}
       gsi_next (&gsi);
     }
Index: gcc/testsuite/gcc.dg/vect/vect-cond-arith-7.c
===================================================================
--- /dev/null	2019-07-30 08:53:31.317691683 +0100
+++ gcc/testsuite/gcc.dg/vect/vect-cond-arith-7.c	2019-07-30 10:51:51.823405201 +0100
@@ -0,0 +1,60 @@ 
+/* { dg-require-effective-target scalar_all_fma } */
+/* { dg-additional-options "-fdump-tree-optimized -ffp-contract=fast" } */
+
+#include "tree-vect.h"
+
+#define N (VECTOR_BITS * 11 / 64 + 3)
+
+#define DEF(INV)					\
+  void __attribute__ ((noipa))				\
+  f_##INV (double *restrict a, double *restrict b,	\
+	   double *restrict c, double *restrict d)	\
+  {							\
+    for (int i = 0; i < N; ++i)				\
+      {							\
+	double mb = (INV & 1 ? -b[i] : b[i]);		\
+	double mc = c[i];				\
+	double md = (INV & 2 ? -d[i] : d[i]);		\
+	a[i] = b[i] < 10 ? mb * mc + md : 10.0;		\
+      }							\
+  }
+
+#define TEST(INV)					\
+  {							\
+    f_##INV (a, b, c, d);				\
+    for (int i = 0; i < N; ++i)				\
+      {							\
+	double mb = (INV & 1 ? -b[i] : b[i]);		\
+	double mc = c[i];				\
+	double md = (INV & 2 ? -d[i] : d[i]);		\
+	double fma = __builtin_fma (mb, mc, md);	\
+	if (a[i] != (i % 17 < 10 ? fma : 10.0))		\
+	  __builtin_abort ();				\
+	asm volatile ("" ::: "memory");			\
+      }							\
+  }
+
+#define FOR_EACH_INV(T) \
+  T (0) T (1) T (2) T (3)
+
+FOR_EACH_INV (DEF)
+
+int
+main (void)
+{
+  double a[N], b[N], c[N], d[N];
+  for (int i = 0; i < N; ++i)
+    {
+      b[i] = i % 17;
+      c[i] = i % 9 + 11;
+      d[i] = i % 13 + 14;
+      asm volatile ("" ::: "memory");
+    }
+  FOR_EACH_INV (TEST)
+  return 0;
+}
+
+/* { dg-final { scan-tree-dump-times { = \.COND_FMA } 1 "optimized" { target vect_double_cond_arith } } } */
+/* { dg-final { scan-tree-dump-times { = \.COND_FMS } 1 "optimized" { target vect_double_cond_arith } } } */
+/* { dg-final { scan-tree-dump-times { = \.COND_FNMA } 1 "optimized" { target vect_double_cond_arith } } } */
+/* { dg-final { scan-tree-dump-times { = \.COND_FNMS } 1 "optimized" { target vect_double_cond_arith } } } */