diff mbox series

[COMMITTED,18/22] ada: Implement fast modulo reduction for nonbinary modular multiplication

Message ID 20240621085819.2485987-18-poulhies@adacore.com
State New
Headers show
Series [COMMITTED,01/22] ada: Spurious style error with mutiple square brackets | expand

Commit Message

Marc Poulhiès June 21, 2024, 8:58 a.m. UTC
From: Eric Botcazou <ebotcazou@adacore.com>

This implements modulo reduction for nonbinary modular multiplication with
small moduli by means of the standard division-free algorithm also used in
the optimizer, but with fewer constraints and therefore better results.

For the sake of consistency, it is also used for the 'Mod attribute of the
same modular types and, more generally, for the Mod (and Rem) operators of
unsigned types if the second operand is static and not a power of two.

gcc/ada/

	* gcc-interface/gigi.h (fast_modulo_reduction): Declare.
	* gcc-interface/trans.cc (gnat_to_gnu) <N_Op_Mod>: In the unsigned
	case, call fast_modulo_reduction for {FLOOR,TRUNC}_MOD_EXPR if the
	RHS is a constant and not a power of two, and the precision is not
	larger than the word size.
	* gcc-interface/utils2.cc: Include expmed.h.
	(fast_modulo_reduction): New function.
	(nonbinary_modular_operation): Call fast_modulo_reduction for the
	multiplication if the precision is not larger than the word size.

Tested on x86_64-pc-linux-gnu, committed on master.

---
 gcc/ada/gcc-interface/gigi.h    |   5 ++
 gcc/ada/gcc-interface/trans.cc  |  17 ++++++
 gcc/ada/gcc-interface/utils2.cc | 102 +++++++++++++++++++++++++++++++-
 3 files changed, 121 insertions(+), 3 deletions(-)
diff mbox series

Patch

diff --git a/gcc/ada/gcc-interface/gigi.h b/gcc/ada/gcc-interface/gigi.h
index 6ed74d6879e..40f3f0d3d13 100644
--- a/gcc/ada/gcc-interface/gigi.h
+++ b/gcc/ada/gcc-interface/gigi.h
@@ -1040,6 +1040,11 @@  extern bool simple_constant_p (Entity_Id gnat_entity);
 /* Return the size of TYPE, which must be a positive power of 2.  */
 extern unsigned int resolve_atomic_size (tree type);
 
+/* Try to compute the reduction of OP modulo MODULUS in PRECISION bits with a
+   division-free algorithm.  Return NULL_TREE if this is not easily doable.  */
+extern tree fast_modulo_reduction (tree op, tree modulus,
+				   unsigned int precision);
+
 #ifdef __cplusplus
 extern "C" {
 #endif
diff --git a/gcc/ada/gcc-interface/trans.cc b/gcc/ada/gcc-interface/trans.cc
index e68fb3fd776..7c5282602b2 100644
--- a/gcc/ada/gcc-interface/trans.cc
+++ b/gcc/ada/gcc-interface/trans.cc
@@ -7317,6 +7317,23 @@  gnat_to_gnu (Node_Id gnat_node)
 	  gnu_result
 	    = build_binary_op_trapv (code, gnu_type, gnu_lhs, gnu_rhs,
 				     gnat_node);
+
+	  /* For an unsigned modulo operation with nonbinary constant modulus,
+	     we first try to do a reduction by means of a (multiplier, shifter)
+	     pair in the needed precision up to the word size.  But not when
+	     optimizing for size, because it will be longer than a div+mul+sub
+	     sequence.  */
+        else if (!optimize_size
+		 && (code == FLOOR_MOD_EXPR || code == TRUNC_MOD_EXPR)
+		 && TYPE_UNSIGNED (gnu_type)
+		 && TYPE_PRECISION (gnu_type) <= BITS_PER_WORD
+		 && TREE_CODE (gnu_rhs) == INTEGER_CST
+		 && !integer_pow2p (gnu_rhs)
+		 && (gnu_expr
+		     = fast_modulo_reduction (gnu_lhs, gnu_rhs,
+					      TYPE_PRECISION (gnu_type))))
+	  gnu_result = gnu_expr;
+
 	else
 	  {
 	    /* Some operations, e.g. comparisons of arrays, generate complex
diff --git a/gcc/ada/gcc-interface/utils2.cc b/gcc/ada/gcc-interface/utils2.cc
index 70271cf2836..a37eccc4cfb 100644
--- a/gcc/ada/gcc-interface/utils2.cc
+++ b/gcc/ada/gcc-interface/utils2.cc
@@ -33,6 +33,7 @@ 
 #include "tree.h"
 #include "inchash.h"
 #include "builtins.h"
+#include "expmed.h"
 #include "fold-const.h"
 #include "stor-layout.h"
 #include "stringpool.h"
@@ -534,6 +535,91 @@  compare_fat_pointers (location_t loc, tree result_type, tree p1, tree p2)
 					   p1_array_is_null, same_bounds));
 }
 
+/* Try to compute the reduction of OP modulo MODULUS in PRECISION bits with a
+   division-free algorithm.  Return NULL_TREE if this is not easily doable.  */
+
+tree
+fast_modulo_reduction (tree op, tree modulus, unsigned int precision)
+{
+  const tree type = TREE_TYPE (op);
+  const unsigned int type_precision = TYPE_PRECISION (type);
+
+  /* The implementation is host-dependent for the time being.  */
+  if (type_precision <= HOST_BITS_PER_WIDE_INT)
+    {
+      const unsigned HOST_WIDE_INT d = tree_to_uhwi (modulus);
+      unsigned HOST_WIDE_INT ml, mh;
+      int pre_shift, post_shift;
+      tree t;
+
+      /* The trick is to replace the division by d with a multiply-and-shift
+	 sequence parameterized by a (multiplier, shifter) pair computed from
+	 d, the precision of the type and the needed precision:
+
+	   op / d = (op * multiplier) >> shifter
+
+         But choose_multiplier provides a slightly different interface:
+
+           op / d = (op h* multiplier) >> reduced_shifter
+
+         that makes things easier by using a high-part multiplication.  */
+      mh = choose_multiplier (d, type_precision, precision, &ml, &post_shift);
+
+      /* If the suggested multiplier is more than TYPE_PRECISION bits, we can
+	 do better for even divisors, using an initial right shift.  */
+      if (mh != 0 && (d & 1) == 0)
+	{
+	  pre_shift = ctz_or_zero (d);
+	  mh = choose_multiplier (d >> pre_shift, type_precision,
+				  precision - pre_shift, &ml, &post_shift);
+	}
+      else
+	pre_shift = 0;
+
+      /* If the suggested multiplier is still more than TYPE_PRECISION bits,
+	 try again with a larger type up to the word size.  */
+      if (mh != 0)
+	{
+	  if (type_precision < BITS_PER_WORD)
+	    {
+	      const scalar_int_mode m
+		= smallest_int_mode_for_size (type_precision + 1);
+	      tree new_type = gnat_type_for_mode (m, 1);
+	      op = fold_convert (new_type, op);
+	      modulus = fold_convert (new_type, modulus);
+	      t = fast_modulo_reduction (op, modulus, precision);
+	      if (t)
+		return fold_convert (type, t);
+	    }
+
+	  return NULL_TREE;
+	}
+
+      /* This computes op - (op / modulus) * modulus with PRECISION bits.  */
+      op = gnat_protect_expr (op);
+
+      /* t = op >> pre_shift
+	 t = t h* ml
+	 t = t >> post_shift
+	 t = t * modulus  */
+      if (pre_shift)
+	t = fold_build2 (RSHIFT_EXPR, type, op,
+			 build_int_cst (type, pre_shift));
+      else
+	t = op;
+      t = fold_build2 (MULT_HIGHPART_EXPR, type, t, build_int_cst (type, ml));
+      if (post_shift)
+	t = fold_build2 (RSHIFT_EXPR, type, t,
+			 build_int_cst (type, post_shift));
+      t = fold_build2 (MULT_EXPR, type, t, modulus);
+
+      return fold_build2 (MINUS_EXPR, type, op, t);
+    }
+
+  else
+    return NULL_TREE;
+}
+
 /* Compute the result of applying OP_CODE to LHS and RHS, where both are of
    TYPE.  We know that TYPE is a modular type with a nonbinary modulus.  */
 
@@ -543,7 +629,7 @@  nonbinary_modular_operation (enum tree_code op_code, tree type, tree lhs,
 {
   tree modulus = TYPE_MODULUS (type);
   unsigned precision = tree_floor_log2 (modulus) + 1;
-  tree op_type, result;
+  tree op_type, result, fmr;
 
   /* For the logical operations, we only need PRECISION bits.  For addition and
      subtraction, we need one more, and for multiplication twice as many.  */
@@ -576,9 +662,19 @@  nonbinary_modular_operation (enum tree_code op_code, tree type, tree lhs,
   if (op_code == MINUS_EXPR)
     result = fold_build2 (PLUS_EXPR, op_type, result, modulus);
 
-  /* For a multiplication, we have no choice but to use a modulo operation.  */
+  /* For a multiplication, we first try to do a modulo reduction by means of a
+     (multiplier, shifter) pair in the needed precision up to the word size, or
+     else we fall back to a standard modulo operation.  But not when optimizing
+     for size, because it will be longer than a div+mul+sub sequence.  */
   if (op_code == MULT_EXPR)
-    result = fold_build2 (TRUNC_MOD_EXPR, op_type, result, modulus);
+    {
+      if (!optimize_size
+	  && precision <= BITS_PER_WORD
+	  && (fmr = fast_modulo_reduction (result, modulus, precision)))
+	result = fmr;
+      else
+	result = fold_build2 (TRUNC_MOD_EXPR, op_type, result, modulus);
+    }
 
   /* For the other operations, subtract the modulus if we are >= it.  */
   else