diff mbox series

Minor tweaks to code computing modular multiplicative inverse

Message ID 6040019.lOV4Wx5bFT@fomalhaut
State New
Headers show
Series Minor tweaks to code computing modular multiplicative inverse | expand

Commit Message

Eric Botcazou April 29, 2024, 7:14 a.m. UTC
Hi,

this removes the last parameter of choose_multiplier, which is unused, adds
another assertion and more details to the description and various comments.
Likewise to the closely related invert_mod2n, except for the last parameter.

Tested on x86-64/Linux, OK for the mainline?


2024-04-29  Eric Botcazou  <ebotcazou@adacore.com>

	* expmed.h (choose_multiplier): Tweak description and remove last
	parameter.
	* expmed.cc (choose_multiplier): Likewise.  Add assertion for the
	third parameter and adds details to various comments.
	(invert_mod2n): Tweak description and add assertion for the first
	parameter.
	(expand_divmod): Adjust calls to choose_multiplier.
	* tree-vect-generic.cc (expand_vector_divmod): Likewise.
	* tree-vect-patterns.cc (vect_recog_divmod_pattern): Likewise.

Comments

Jeff Law April 29, 2024, 3:15 p.m. UTC | #1
On 4/29/24 1:14 AM, Eric Botcazou wrote:
> Hi,
> 
> this removes the last parameter of choose_multiplier, which is unused, adds
> another assertion and more details to the description and various comments.
> Likewise to the closely related invert_mod2n, except for the last parameter.
> 
> Tested on x86-64/Linux, OK for the mainline?
> 
> 
> 2024-04-29  Eric Botcazou  <ebotcazou@adacore.com>
> 
> 	* expmed.h (choose_multiplier): Tweak description and remove last
> 	parameter.
> 	* expmed.cc (choose_multiplier): Likewise.  Add assertion for the
> 	third parameter and adds details to various comments.
> 	(invert_mod2n): Tweak description and add assertion for the first
> 	parameter.
> 	(expand_divmod): Adjust calls to choose_multiplier.
> 	* tree-vect-generic.cc (expand_vector_divmod): Likewise.
> 	* tree-vect-patterns.cc (vect_recog_divmod_pattern): Likewise.
OK.  Consider waiting to commit though as we want to make it easy to 
cherry pick patches over to the release branch if needed.

Jeff
Eric Botcazou April 29, 2024, 3:52 p.m. UTC | #2
> OK.  Consider waiting to commit though as we want to make it easy to
> cherry pick patches over to the release branch if needed.

Sure.  There are a couple more changes on top of it, but all can wait a bit.
diff mbox series

Patch

diff --git a/gcc/expmed.cc b/gcc/expmed.cc
index 4ec035e4843..60f65c7acc5 100644
--- a/gcc/expmed.cc
+++ b/gcc/expmed.cc
@@ -3689,50 +3689,62 @@  expand_widening_mult (machine_mode mode, rtx op0, rtx op1, rtx target,
 		       unsignedp, OPTAB_LIB_WIDEN);
 }
 
-/* Choose a minimal N + 1 bit approximation to 1/D that can be used to
-   replace division by D, and put the least significant N bits of the result
-   in *MULTIPLIER_PTR and return the most significant bit.
+/* Choose a minimal N + 1 bit approximation to 2**K / D that can be used to
+   replace division by D, put the least significant N bits of the result in
+   *MULTIPLIER_PTR, the value K - N in *POST_SHIFT_PTR, and return the most
+   significant bit.
 
    The width of operations is N (should be <= HOST_BITS_PER_WIDE_INT), the
-   needed precision is in PRECISION (should be <= N).
+   needed precision is PRECISION (should be <= N).
 
-   PRECISION should be as small as possible so this function can choose
-   multiplier more freely.
+   PRECISION should be as small as possible so this function can choose the
+   multiplier more freely.  If PRECISION is <= N - 1, the most significant
+   bit returned by the function will be zero.
 
-   The rounded-up logarithm of D is placed in *lgup_ptr.  A shift count that
-   is to be used for a final right shift is placed in *POST_SHIFT_PTR.
-
-   Using this function, x/D will be equal to (x * m) >> (*POST_SHIFT_PTR),
-   where m is the full HOST_BITS_PER_WIDE_INT + 1 bit multiplier.  */
+   Using this function, x / D is equal to (x*m) / 2**N >> (*POST_SHIFT_PTR),
+   where m is the full N + 1 bit multiplier.  */
 
 unsigned HOST_WIDE_INT
 choose_multiplier (unsigned HOST_WIDE_INT d, int n, int precision,
 		   unsigned HOST_WIDE_INT *multiplier_ptr,
-		   int *post_shift_ptr, int *lgup_ptr)
+		   int *post_shift_ptr)
 {
   int lgup, post_shift;
-  int pow, pow2;
+  int pow1, pow2;
 
-  /* lgup = ceil(log2(divisor)); */
+  /* lgup = ceil(log2(d)) */
+  /* Assuming d > 1, we have d >= 2^(lgup-1) + 1 */
   lgup = ceil_log2 (d);
 
   gcc_assert (lgup <= n);
+  gcc_assert (lgup <= precision);
 
-  pow = n + lgup;
+  pow1 = n + lgup;
   pow2 = n + lgup - precision;
 
-  /* mlow = 2^(N + lgup)/d */
-  wide_int val = wi::set_bit_in_zero (pow, HOST_BITS_PER_DOUBLE_INT);
+  /* mlow = 2^(n + lgup)/d */
+  /* Trivially from above we have mlow < 2^(n+1) */
+  wide_int val = wi::set_bit_in_zero (pow1, HOST_BITS_PER_DOUBLE_INT);
   wide_int mlow = wi::udiv_trunc (val, d);
 
-  /* mhigh = (2^(N + lgup) + 2^(N + lgup - precision))/d */
+  /* mhigh = (2^(n + lgup) + 2^(n + lgup - precision))/d */
+  /* From above we have mhigh < 2^(n+1) assuming lgup <= precision */
+  /* From precision <= n, the difference between the numerators of mhigh and
+     mlow is >= 2^lgup >= d.  Therefore the difference of the quotients in
+     the Euclidean division by d is at least 1, so we have mlow < mhigh and
+     the exact value of 2^(n + lgup)/d lies in the interval [mlow; mhigh(.  */
   val |= wi::set_bit_in_zero (pow2, HOST_BITS_PER_DOUBLE_INT);
   wide_int mhigh = wi::udiv_trunc (val, d);
 
-  /* If precision == N, then mlow, mhigh exceed 2^N
-     (but they do not exceed 2^(N+1)).  */
-
   /* Reduce to lowest terms.  */
+  /* If precision <= n - 1, then the difference between the numerators of
+     mhigh and mlow is >= 2^(lgup + 1) >= 2 * 2^lgup >= 2 * d.  Therefore
+     the difference of the quotients in the Euclidean division by d is at
+     least 2, which means that mhigh and mlow differ by at least one bit
+     not in the last place.  The conclusion is that the first iteration of
+     the loop below completes and shifts mhigh and mlow by 1 bit, which in
+     particular means that mhigh < 2^n, that is to say, the most significant
+     bit in the n + 1 bit value is zero.  */
   for (post_shift = lgup; post_shift > 0; post_shift--)
     {
       unsigned HOST_WIDE_INT ml_lo = wi::extract_uhwi (mlow, 1,
@@ -3747,7 +3759,7 @@  choose_multiplier (unsigned HOST_WIDE_INT d, int n, int precision,
     }
 
   *post_shift_ptr = post_shift;
-  *lgup_ptr = lgup;
+
   if (n < HOST_BITS_PER_WIDE_INT)
     {
       unsigned HOST_WIDE_INT mask = (HOST_WIDE_INT_1U << n) - 1;
@@ -3761,31 +3773,32 @@  choose_multiplier (unsigned HOST_WIDE_INT d, int n, int precision,
     }
 }
 
-/* Compute the inverse of X mod 2**n, i.e., find Y such that X * Y is
-   congruent to 1 (mod 2**N).  */
+/* Compute the inverse of X mod 2**N, i.e., find Y such that X * Y is congruent
+   to 1 modulo 2**N, assuming that X is odd.  Bézout's lemma guarantees that Y
+   exists for any given positive N.  */
 
 static unsigned HOST_WIDE_INT
 invert_mod2n (unsigned HOST_WIDE_INT x, int n)
 {
-  /* Solve x*y == 1 (mod 2^n), where x is odd.  Return y.  */
+  gcc_assert ((x & 1) == 1);
 
-  /* The algorithm notes that the choice y = x satisfies
-     x*y == 1 mod 2^3, since x is assumed odd.
-     Each iteration doubles the number of bits of significance in y.  */
+  /* The algorithm notes that the choice Y = Z satisfies X*Y == 1 mod 2^3,
+     since X is odd.  Then each iteration doubles the number of bits of
+     significance in Y.  */
 
-  unsigned HOST_WIDE_INT mask;
+  const unsigned HOST_WIDE_INT mask
+    = (n == HOST_BITS_PER_WIDE_INT
+       ? HOST_WIDE_INT_M1U
+       : (HOST_WIDE_INT_1U << n) - 1);
   unsigned HOST_WIDE_INT y = x;
   int nbit = 3;
 
-  mask = (n == HOST_BITS_PER_WIDE_INT
-	  ? HOST_WIDE_INT_M1U
-	  : (HOST_WIDE_INT_1U << n) - 1);
-
   while (nbit < n)
     {
       y = y * (2 - x*y) & mask;		/* Modulo 2^N */
       nbit *= 2;
     }
+
   return y;
 }
 
@@ -4443,7 +4456,6 @@  expand_divmod (int rem_flag, enum tree_code code, machine_mode mode,
 	      {
 		unsigned HOST_WIDE_INT mh, ml;
 		int pre_shift, post_shift;
-		int dummy;
 		wide_int wd = rtx_mode_t (op1, int_mode);
 		unsigned HOST_WIDE_INT d = wd.to_uhwi ();
 
@@ -4476,10 +4488,9 @@  expand_divmod (int rem_flag, enum tree_code code, machine_mode mode,
 		    else
 		      {
 			/* Find a suitable multiplier and right shift count
-			   instead of multiplying with D.  */
-
+			   instead of directly dividing by D.  */
 			mh = choose_multiplier (d, size, size,
-						&ml, &post_shift, &dummy);
+						&ml, &post_shift);
 
 			/* If the suggested multiplier is more than SIZE bits,
 			   we can do better for even divisors, using an
@@ -4489,7 +4500,7 @@  expand_divmod (int rem_flag, enum tree_code code, machine_mode mode,
 			    pre_shift = ctz_or_zero (d);
 			    mh = choose_multiplier (d >> pre_shift, size,
 						    size - pre_shift,
-						    &ml, &post_shift, &dummy);
+						    &ml, &post_shift);
 			    gcc_assert (!mh);
 			  }
 			else
@@ -4561,7 +4572,7 @@  expand_divmod (int rem_flag, enum tree_code code, machine_mode mode,
 	    else		/* TRUNC_DIV, signed */
 	      {
 		unsigned HOST_WIDE_INT ml;
-		int lgup, post_shift;
+		int post_shift;
 		rtx mlr;
 		HOST_WIDE_INT d = INTVAL (op1);
 		unsigned HOST_WIDE_INT abs_d;
@@ -4657,7 +4668,7 @@  expand_divmod (int rem_flag, enum tree_code code, machine_mode mode,
 		else if (size <= HOST_BITS_PER_WIDE_INT)
 		  {
 		    choose_multiplier (abs_d, size, size - 1,
-				       &ml, &post_shift, &lgup);
+				       &ml, &post_shift);
 		    if (ml < HOST_WIDE_INT_1U << (size - 1))
 		      {
 			rtx t1, t2, t3;
@@ -4748,7 +4759,7 @@  expand_divmod (int rem_flag, enum tree_code code, machine_mode mode,
 	    scalar_int_mode int_mode = as_a <scalar_int_mode> (compute_mode);
 	    int size = GET_MODE_BITSIZE (int_mode);
 	    unsigned HOST_WIDE_INT mh, ml;
-	    int pre_shift, lgup, post_shift;
+	    int pre_shift, post_shift;
 	    HOST_WIDE_INT d = INTVAL (op1);
 
 	    if (d > 0)
@@ -4778,7 +4789,7 @@  expand_divmod (int rem_flag, enum tree_code code, machine_mode mode,
 		    rtx t1, t2, t3, t4;
 
 		    mh = choose_multiplier (d, size, size - 1,
-					    &ml, &post_shift, &lgup);
+					    &ml, &post_shift);
 		    gcc_assert (!mh);
 
 		    if (post_shift < BITS_PER_WORD
diff --git a/gcc/expmed.h b/gcc/expmed.h
index bf3c4097f33..f5375c84f25 100644
--- a/gcc/expmed.h
+++ b/gcc/expmed.h
@@ -694,12 +694,13 @@  extern rtx emit_store_flag_force (rtx, enum rtx_code, rtx, rtx,
 
 extern void canonicalize_comparison (machine_mode, enum rtx_code *, rtx *);
 
-/* Choose a minimal N + 1 bit approximation to 1/D that can be used to
-   replace division by D, and put the least significant N bits of the result
-   in *MULTIPLIER_PTR and return the most significant bit.  */
+/* Choose a minimal N + 1 bit approximation to 2**K / D that can be used to
+   replace division by D, put the least significant N bits of the result in
+   *MULTIPLIER_PTR, the value K - N in *POST_SHIFT_PTR, and return the most
+   significant bit.  */
 extern unsigned HOST_WIDE_INT choose_multiplier (unsigned HOST_WIDE_INT, int,
 						 int, unsigned HOST_WIDE_INT *,
-						 int *, int *);
+						 int *);
 
 #ifdef TREE_CODE
 extern rtx expand_variable_shift (enum tree_code, machine_mode,
diff --git a/gcc/tree-vect-generic.cc b/gcc/tree-vect-generic.cc
index ab640096ca2..ea0069f7a67 100644
--- a/gcc/tree-vect-generic.cc
+++ b/gcc/tree-vect-generic.cc
@@ -551,7 +551,6 @@  expand_vector_divmod (gimple_stmt_iterator *gsi, tree type, tree op0,
   int *shift_temps = post_shifts + nunits;
   unsigned HOST_WIDE_INT *mulc = XALLOCAVEC (unsigned HOST_WIDE_INT, nunits);
   int prec = TYPE_PRECISION (TREE_TYPE (type));
-  int dummy_int;
   unsigned int i;
   signop sign_p = TYPE_SIGN (TREE_TYPE (type));
   unsigned HOST_WIDE_INT mask = GET_MODE_MASK (TYPE_MODE (TREE_TYPE (type)));
@@ -609,11 +608,11 @@  expand_vector_divmod (gimple_stmt_iterator *gsi, tree type, tree op0,
 	      continue;
 	    }
 
-	  /* Find a suitable multiplier and right shift count
-	     instead of multiplying with D.  */
-	  mh = choose_multiplier (d, prec, prec, &ml, &post_shift, &dummy_int);
+	  /* Find a suitable multiplier and right shift count instead of
+	     directly dividing by D.  */
+	  mh = choose_multiplier (d, prec, prec, &ml, &post_shift);
 
-	  /* If the suggested multiplier is more than SIZE bits, we can
+	  /* If the suggested multiplier is more than PREC bits, we can
 	     do better for even divisors, using an initial right shift.  */
 	  if ((mh != 0 && (d & 1) == 0)
 	      || (!has_vector_shift && pre_shift != -1))
@@ -655,7 +654,7 @@  expand_vector_divmod (gimple_stmt_iterator *gsi, tree type, tree op0,
 		    }
 		  mh = choose_multiplier (d >> pre_shift, prec,
 					  prec - pre_shift,
-					  &ml, &post_shift, &dummy_int);
+					  &ml, &post_shift);
 		  gcc_assert (!mh);
 		  pre_shifts[i] = pre_shift;
 		}
@@ -699,7 +698,7 @@  expand_vector_divmod (gimple_stmt_iterator *gsi, tree type, tree op0,
 	    }
 
 	  choose_multiplier (abs_d, prec, prec - 1, &ml,
-			     &post_shift, &dummy_int);
+			     &post_shift);
 	  if (ml >= HOST_WIDE_INT_1U << (prec - 1))
 	    {
 	      this_mode = 4 + (d < 0);
diff --git a/gcc/tree-vect-patterns.cc b/gcc/tree-vect-patterns.cc
index 87c2acff386..8e8de5ea3a5 100644
--- a/gcc/tree-vect-patterns.cc
+++ b/gcc/tree-vect-patterns.cc
@@ -4535,7 +4535,7 @@  vect_recog_divmod_pattern (vec_info *vinfo,
   enum tree_code rhs_code;
   optab optab;
   tree q, cst;
-  int dummy_int, prec;
+  int prec;
 
   if (!is_gimple_assign (last_stmt))
     return NULL;
@@ -4795,17 +4795,17 @@  vect_recog_divmod_pattern (vec_info *vinfo,
 	/* FIXME: Can transform this into oprnd0 >= oprnd1 ? 1 : 0.  */
 	return NULL;
 
-      /* Find a suitable multiplier and right shift count
-	 instead of multiplying with D.  */
-      mh = choose_multiplier (d, prec, prec, &ml, &post_shift, &dummy_int);
+      /* Find a suitable multiplier and right shift count instead of
+	 directly dividing by D.  */
+      mh = choose_multiplier (d, prec, prec, &ml, &post_shift);
 
-      /* If the suggested multiplier is more than SIZE bits, we can do better
+      /* If the suggested multiplier is more than PREC bits, we can do better
 	 for even divisors, using an initial right shift.  */
       if (mh != 0 && (d & 1) == 0)
 	{
 	  pre_shift = ctz_or_zero (d);
 	  mh = choose_multiplier (d >> pre_shift, prec, prec - pre_shift,
-				  &ml, &post_shift, &dummy_int);
+				  &ml, &post_shift);
 	  gcc_assert (!mh);
 	}
       else
@@ -4924,7 +4924,7 @@  vect_recog_divmod_pattern (vec_info *vinfo,
 	/* This case is not handled correctly below.  */
 	return NULL;
 
-      choose_multiplier (abs_d, prec, prec - 1, &ml, &post_shift, &dummy_int);
+      choose_multiplier (abs_d, prec, prec - 1, &ml, &post_shift);
       if (ml >= HOST_WIDE_INT_1U << (prec - 1))
 	{
 	  add = true;