diff mbox series

tree-optimization/101267 - fix SLP vect with masked operations

Message ID 62q8s5q8-s764-852q-6qo9-598r17s49s9@fhfr.qr
State New
Headers show
Series tree-optimization/101267 - fix SLP vect with masked operations | expand

Commit Message

Richard Biener June 30, 2021, 10:56 a.m. UTC
This fixes the missed handling of external/constant mask SLP
operations, for the testcase in particular masked loads.  The
patch adjusts the vect_check_scalar_mask API to reflect the
required vect_is_simple_use SLP compatible API plus adjusts
for the special handling of masked loads in SLP discovery.

The issue is likely latent.

Lightly tested as fixing the 521.wrf_r build and being clean
on vect.exp and i386.exp on x86_64.

Full bootstrap and regtest running on x86_64-unknown-linux-gnu,
I'll push it unless I hear otherwise.

I'm quite sure that SLP masked operations test coverage is weak though.
Maybe somebody can throw it at SVE[2] which should expose more
masking (but eventually not SLP - I don't know about the state of
SLP and masking with respect to SVE)

Thanks,
Richard.

2021-06-30  Richard Biener  <rguenther@suse.de>

	PR tree-optimization/101267
	* tree-vect-stmts.c (vect_check_scalar_mask): Adjust
	API and use SLP compatible interface of vect_is_simple_use.
	Reject not vectorized SLP defs for callers that do not support
	that.
	(vect_check_store_rhs): Handle masked stores and pass down
	the appropriate operator index.
	(vectorizable_call): Adjust.
	(vectorizable_store): Likewise.
	(vectorizable_load): Likewise.  Handle SLP pecularity of
	masked loads.
	(vect_is_simple_use): Remove special-casing of masked stores.

	* gfortran.dg/pr101267.f90: New testcase.
---
 gcc/testsuite/gfortran.dg/pr101267.f90 | 23 +++++++
 gcc/tree-vect-stmts.c                  | 92 +++++++++++++++-----------
 2 files changed, 77 insertions(+), 38 deletions(-)
 create mode 100644 gcc/testsuite/gfortran.dg/pr101267.f90
diff mbox series

Patch

diff --git a/gcc/testsuite/gfortran.dg/pr101267.f90 b/gcc/testsuite/gfortran.dg/pr101267.f90
new file mode 100644
index 00000000000..12723cf9c22
--- /dev/null
+++ b/gcc/testsuite/gfortran.dg/pr101267.f90
@@ -0,0 +1,23 @@ 
+! { dg-do compile }
+! { dg-options "-Ofast" }
+! { dg-additional-options "-march=znver2" { target x86_64-*-* i?86-*-* } }
+   SUBROUTINE sfddagd( regime, znt,ite ,jte )
+   REAL, DIMENSION( ime, IN) :: regime, znt
+   REAL, DIMENSION( ite, jte) :: wndcor_u 
+   LOGICAL wrf_dm_on_monitor
+   IF( int4 == 1 ) THEN
+     DO j=jts,jtf
+      DO i=itsu,itf
+       reg =   regime(i,  j) 
+       IF( reg > 10.0 ) THEN
+         znt0 = znt(i-1,  j) + znt(i,  j) 
+         IF( znt0 <= 0.2) THEN
+           wndcor_u(i,j) = 0.2
+         ENDIF
+       ENDIF
+      ENDDO
+     ENDDO
+     IF ( wrf_dm_on_monitor()) THEN
+     ENDIF
+   ENDIF
+   END
diff --git a/gcc/tree-vect-stmts.c b/gcc/tree-vect-stmts.c
index 4ee11b2041a..e590f34d75d 100644
--- a/gcc/tree-vect-stmts.c
+++ b/gcc/tree-vect-stmts.c
@@ -2439,17 +2439,31 @@  get_load_store_type (vec_info  *vinfo, stmt_vec_info stmt_info,
   return true;
 }
 
-/* Return true if boolean argument MASK is suitable for vectorizing
-   conditional operation STMT_INFO.  When returning true, store the type
-   of the definition in *MASK_DT_OUT and the type of the vectorized mask
-   in *MASK_VECTYPE_OUT.  */
+/* Return true if boolean argument at MASK_INDEX is suitable for vectorizing
+   conditional operation STMT_INFO.  When returning true, store the mask
+   in *MASK, the type of its definition in *MASK_DT_OUT, the type of the
+   vectorized mask in *MASK_VECTYPE_OUT and the SLP node corresponding
+   to the mask in *MASK_NODE if MASK_NODE is not NULL.  */
 
 static bool
-vect_check_scalar_mask (vec_info *vinfo, stmt_vec_info stmt_info, tree mask,
-			vect_def_type *mask_dt_out,
-			tree *mask_vectype_out)
+vect_check_scalar_mask (vec_info *vinfo, stmt_vec_info stmt_info,
+			slp_tree slp_node, unsigned mask_index,
+			tree *mask, slp_tree *mask_node,
+			vect_def_type *mask_dt_out, tree *mask_vectype_out)
 {
-  if (!VECT_SCALAR_BOOLEAN_TYPE_P (TREE_TYPE (mask)))
+  enum vect_def_type mask_dt;
+  tree mask_vectype;
+  slp_tree mask_node_1;
+  if (!vect_is_simple_use (vinfo, stmt_info, slp_node, mask_index,
+			   mask, &mask_node_1, &mask_dt, &mask_vectype))
+    {
+      if (dump_enabled_p ())
+	dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location,
+			 "mask use not simple.\n");
+      return false;
+    }
+
+  if (!VECT_SCALAR_BOOLEAN_TYPE_P (TREE_TYPE (*mask)))
     {
       if (dump_enabled_p ())
 	dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location,
@@ -2457,7 +2471,7 @@  vect_check_scalar_mask (vec_info *vinfo, stmt_vec_info stmt_info, tree mask,
       return false;
     }
 
-  if (TREE_CODE (mask) != SSA_NAME)
+  if (TREE_CODE (*mask) != SSA_NAME)
     {
       if (dump_enabled_p ())
 	dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location,
@@ -2465,13 +2479,15 @@  vect_check_scalar_mask (vec_info *vinfo, stmt_vec_info stmt_info, tree mask,
       return false;
     }
 
-  enum vect_def_type mask_dt;
-  tree mask_vectype;
-  if (!vect_is_simple_use (mask, vinfo, &mask_dt, &mask_vectype))
+  /* If the caller is not prepared for adjusting an external/constant
+     SLP mask vector type fail.  */
+  if (slp_node
+      && !mask_node
+      && SLP_TREE_DEF_TYPE (mask_node_1) != vect_internal_def)
     {
       if (dump_enabled_p ())
 	dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location,
-			 "mask use not simple.\n");
+			 "SLP mask argument is not vectorized.\n");
       return false;
     }
 
@@ -2501,6 +2517,8 @@  vect_check_scalar_mask (vec_info *vinfo, stmt_vec_info stmt_info, tree mask,
 
   *mask_dt_out = mask_dt;
   *mask_vectype_out = mask_vectype;
+  if (mask_node)
+    *mask_node = mask_node_1;
   return true;
 }
 
@@ -2525,10 +2543,18 @@  vect_check_store_rhs (vec_info *vinfo, stmt_vec_info stmt_info,
       return false;
     }
 
+  unsigned op_no = 0;
+  if (gcall *call = dyn_cast <gcall *> (stmt_info->stmt))
+    {
+      if (gimple_call_internal_p (call)
+	  && internal_store_fn_p (gimple_call_internal_fn (call)))
+	op_no = internal_fn_stored_value_index (gimple_call_internal_fn (call));
+    }
+
   enum vect_def_type rhs_dt;
   tree rhs_vectype;
   slp_tree slp_op;
-  if (!vect_is_simple_use (vinfo, stmt_info, slp_node, 0,
+  if (!vect_is_simple_use (vinfo, stmt_info, slp_node, op_no,
 			   &rhs, &slp_op, &rhs_dt, &rhs_vectype))
     {
       if (dump_enabled_p ())
@@ -3163,9 +3189,8 @@  vectorizable_call (vec_info *vinfo,
     {
       if ((int) i == mask_opno)
 	{
-	  op = gimple_call_arg (stmt, i);
-	  if (!vect_check_scalar_mask (vinfo,
-				       stmt_info, op, &dt[i], &vectypes[i]))
+	  if (!vect_check_scalar_mask (vinfo, stmt_info, slp_node, mask_opno,
+				       &op, &slp_op[i], &dt[i], &vectypes[i]))
 	    return false;
 	  continue;
 	}
@@ -7213,13 +7238,10 @@  vectorizable_store (vec_info *vinfo,
 	}
 
       int mask_index = internal_fn_mask_index (ifn);
-      if (mask_index >= 0)
-	{
-	  mask = gimple_call_arg (call, mask_index);
-	  if (!vect_check_scalar_mask (vinfo, stmt_info, mask, &mask_dt,
-				       &mask_vectype))
-	    return false;
-	}
+      if (mask_index >= 0
+	  && !vect_check_scalar_mask (vinfo, stmt_info, slp_node, mask_index,
+				      &mask, NULL, &mask_dt, &mask_vectype))
+	return false;
     }
 
   op = vect_get_store_rhs (stmt_info);
@@ -8494,13 +8516,13 @@  vectorizable_load (vec_info *vinfo,
 	return false;
 
       int mask_index = internal_fn_mask_index (ifn);
-      if (mask_index >= 0)
-	{
-	  mask = gimple_call_arg (call, mask_index);
-	  if (!vect_check_scalar_mask (vinfo, stmt_info, mask, &mask_dt,
-				       &mask_vectype))
-	    return false;
-	}
+      if (mask_index >= 0
+	  && !vect_check_scalar_mask (vinfo, stmt_info, slp_node,
+				      /* ??? For SLP we only have operands for
+					 the mask operand.  */
+				      slp_node ? 0 : mask_index,
+				      &mask, NULL, &mask_dt, &mask_vectype))
+	return false;
     }
 
   tree vectype = STMT_VINFO_VECTYPE (stmt_info);
@@ -11484,13 +11506,7 @@  vect_is_simple_use (vec_info *vinfo, stmt_vec_info stmt, slp_tree slp_node,
 	    *op = gimple_op (ass, operand + 1);
 	}
       else if (gcall *call = dyn_cast <gcall *> (stmt->stmt))
-	{
-	  if (gimple_call_internal_p (call)
-	      && internal_store_fn_p (gimple_call_internal_fn (call)))
-	    operand = internal_fn_stored_value_index (gimple_call_internal_fn
-									(call));
-	  *op = gimple_call_arg (call, operand);
-	}
+	*op = gimple_call_arg (call, operand);
       else
 	gcc_unreachable ();
       return vect_is_simple_use (*op, vinfo, dt, vectype, def_stmt_info_out);