diff mbox series

[2/3] stdlib: Handle various corner cases in the fallback heapsort for qsort

Message ID 832cc3db076991ad9b4e2f2c1f85133d335181a3.1700246487.git.fweimer@redhat.com
State New
Headers show
Series Various qsort fixes | expand

Commit Message

Florian Weimer Nov. 17, 2023, 6:44 p.m. UTC
The previous implementation did not consistently apply the rule that
the child nodes of node K are at 2 * K + 1 and 2 * K + 2, or
that the parent node is at (K - 1) / 2.

Add an internal test that targets the heapsort implementation
directly.

Reported-by: Stepan Golosunov <stepan@golosunov.pp.ru>
---
 stdlib/Makefile     |   1 +
 stdlib/qsort.c      |  55 ++++++++++++------
 stdlib/tst-qsort4.c | 134 ++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 173 insertions(+), 17 deletions(-)
 create mode 100644 stdlib/tst-qsort4.c

Comments

Adhemerval Zanella Netto Nov. 20, 2023, 8:02 p.m. UTC | #1
On 17/11/23 15:44, Florian Weimer wrote:
> The previous implementation did not consistently apply the rule that
> the child nodes of node K are at 2 * K + 1 and 2 * K + 2, or
> that the parent node is at (K - 1) / 2.
> 
> Add an internal test that targets the heapsort implementation
> directly.
> 
> Reported-by: Stepan Golosunov <stepan@golosunov.pp.ru>

LGTM, two minor suggestions below.

Reviewed-by: Adhemerval Zanella  <adhemerval.zanella@linaro.org>

> ---
>  stdlib/Makefile     |   1 +
>  stdlib/qsort.c      |  55 ++++++++++++------
>  stdlib/tst-qsort4.c | 134 ++++++++++++++++++++++++++++++++++++++++++++
>  3 files changed, 173 insertions(+), 17 deletions(-)
>  create mode 100644 stdlib/tst-qsort4.c
> 
> diff --git a/stdlib/Makefile b/stdlib/Makefile
> index 6af606136e..48688f6a27 100644
> --- a/stdlib/Makefile
> +++ b/stdlib/Makefile
> @@ -261,6 +261,7 @@ tests := \
>    # tests
>  
>  tests-internal := \
> +  tst-qsort4 \
>    tst-strtod1i \
>    tst-strtod3 \
>    tst-strtod4 \
> diff --git a/stdlib/qsort.c b/stdlib/qsort.c
> index 6d0c4447ec..a2f9e916ef 100644
> --- a/stdlib/qsort.c
> +++ b/stdlib/qsort.c
> @@ -125,29 +125,44 @@ pop (stack_node *top, char **lo, char **hi, size_t *depth)
>    return top;
>  }
>  
> -/* NB: N is inclusive bound for BASE.  */
> +/* Establish the heap condition at index K, that is, the key at K will
> +   not be less than either of its children, at 2 * K + 1 and 2 * K + 2
> +   (if they exist).  N is the last valid index. */
>  static inline void
>  siftdown (void *base, size_t size, size_t k, size_t n,
>  	  enum swap_type_t swap_type, __compar_d_fn_t cmp, void *arg)
>  {
> -  while (k <= n / 2)
> +  /* There can only be a heap condition violation if there are
> +     children.  */
> +  while (2 * k + 1 <= n)
>      {
> -      size_t j = 2 * k;
> +      /* Left child.  */
> +      size_t j = 2 * k + 1;
> +      /* If the right child is larger, use it.  */
>        if (j < n && cmp (base + (j * size), base + ((j + 1) * size), arg) < 0)
>  	j++;
>  
> +      /* If k is already >= to its children, we are done.  */
>        if (j == k || cmp (base + (k * size), base + (j * size), arg) >= 0)
>  	break;
>  
> +      /* Heal the violation.  */
>        do_swap (base + (size * j), base + (k * size), size, swap_type);
> +
> +      /* Swapping with j may have introduced a violation at j.  Fix
> +	 it in the next loop iteration.  */
>        k = j;
>      }
>  }
>  

Ok.

> +/* Establish the heap condition for the indices 0 to N (inclusive).  */
>  static inline void
>  heapify (void *base, size_t size, size_t n, enum swap_type_t swap_type,
>  	 __compar_d_fn_t cmp, void *arg)
>  {
> +  /* If n is odd, k = n / 2 has a left child at n, so this is the
> +     largest index that can have a heap condition violation regarding
> +     its children.  */
>    size_t k = n / 2;
>    while (1)
>      {
> @@ -157,32 +172,38 @@ heapify (void *base, size_t size, size_t n, enum swap_type_t swap_type,
>      }
>  }
>  
> -/* A non-recursive heapsort, used on introsort implementation as a fallback
> -   routine with worst-case performance of O(nlog n) and worst-case space
> -   complexity of O(1).  It sorts the array starting at BASE and ending at
> -   END, with each element of SIZE bytes.  The SWAP_TYPE is the callback
> -   function used to swap elements, and CMP is the function used to compare
> -   elements.   */
> +/* A non-recursive heapsort, used on introsort implementation as a
> +   fallback routine with worst-case performance of O(nlog n) and
> +   worst-case space complexity of O(1).  It sorts the array starting
> +   at BASE and ending at END (inclusive), with each element of SIZE
> +   bytes.  The SWAP_TYPE is the callback function used to swap
> +   elements, and CMP is the function used to compare elements.  */
>  static void
>  heapsort_r (void *base, void *end, size_t size, enum swap_type_t swap_type,
>  	    __compar_d_fn_t cmp, void *arg)
>  {
> -  const size_t count = ((uintptr_t) end - (uintptr_t) base) / size;
> -
> -  if (count < 2)
> +  size_t n = ((uintptr_t) end - (uintptr_t) base) / size;
> +  if (n <= 1)
> +    /* Handled by insertion sort.  */
>      return;
>  
> -  size_t n = count - 1;
> -
>    /* Build the binary heap, largest value at the base[0].  */
>    heapify (base, size, n, swap_type, cmp, arg);
>  
> -  /* On each iteration base[0:n] is the binary heap, while base[n:count]
> -     is sorted.  */
> -  while (n > 0)
> +  while (true)
>      {
> +      /* Indices 0 .. n contain the binary heap.  Extract the largest
> +	 element put it into the final position in the array.  */
>        do_swap (base, base + (n * size), size, swap_type);
> +
> +      /* The heap is now one element shorter.  */
>        n--;
> +      if (n == 0)
> +	break;
> +
> +      /* By swapping in elements 0 and the previous value of n (now at
> +	 n + 1), we likely introduced a heap condition violation.  Fix
> +	 it for the reduced heap.  */
>        siftdown (base, size, 0, n, swap_type, cmp, arg);
>      }
>  }

Ok.

> diff --git a/stdlib/tst-qsort4.c b/stdlib/tst-qsort4.c
> new file mode 100644
> index 0000000000..d5b8d05a91
> --- /dev/null
> +++ b/stdlib/tst-qsort4.c
> @@ -0,0 +1,134 @@
> +/* Test the heapsort implementation behind qsort.
> +   Copyright (C) 2023 Free Software Foundation, Inc.
> +   This file is part of the GNU C Library.
> +
> +   The GNU C Library is free software; you can redistribute it and/or
> +   modify it under the terms of the GNU Lesser General Public
> +   License as published by the Free Software Foundation; either
> +   version 2.1 of the License, or (at your option) any later version.
> +
> +   The GNU C Library is distributed in the hope that it will be useful,
> +   but WITHOUT ANY WARRANTY; without even the implied warranty of
> +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
> +   Lesser General Public License for more details.
> +
> +   You should have received a copy of the GNU Lesser General Public
> +   License along with the GNU C Library; if not, see
> +   <http://www.gnu.org/licenses/>.  */
> +
> +#include "qsort.c"
> +
> +#include <stdio.h>
> +#include <support/check.h>
> +#include <support/support.h>
> +
> +static int
> +cmp (const void *a1, const void *b1, void *closure)
> +{
> +  const signed char *a = a1;
> +  const signed char *b = b1;
> +  return *a - *b;

Ok, the current test inputs won't trigger underflow. Maybe to make it future
proof with '(*a > *b) - (*a < *b)'.

> +}
> +
> +/* Wrapper around heapsort_r that set ups the required variables.  */
> +static void
> +heapsort_wrapper (void *const pbase, size_t total_elems, size_t size,
> +                  __compar_d_fn_t cmp, void *arg)
> +{
> +  char *base_ptr = (char *) pbase;
> +  char *lo = base_ptr;
> +  char *hi = &lo[size * (total_elems - 1)];
> +
> +  if (total_elems <= 1)
> +    /* Avoid lossage with unsigned arithmetic below.  */
> +    return;
> +
> +  enum swap_type_t swap_type;
> +  if (is_aligned (pbase, size, 8))
> +    swap_type = SWAP_WORDS_64;
> +  else if (is_aligned (pbase, size, 4))
> +    swap_type = SWAP_WORDS_32;
> +  else
> +    swap_type = SWAP_BYTES;
> +  heapsort_r (lo, hi, size, swap_type, cmp, arg);
> +}
> +
> +static void
> +check_one_sort (signed char *array, int length)
> +{
> +  signed char *copy = xmalloc (length);
> +  memcpy (copy, array, length);
> +  heapsort_wrapper (copy, length, 1, cmp, NULL);
> +
> +  /* Verify that the result is sorted.  */
> +  for (int i = 1; i < length; ++i)
> +    if (copy[i] < copy[i - 1])
> +      {
> +        support_record_failure ();
> +        printf ("error: sorting failure for length %d at offset %d\n",
> +                length, i - 1);
> +        printf ("input:");
> +        for (int i = 0; i < length; ++i)
> +          printf (" %d", array[i]);
> +        printf ("\noutput:");
> +        for (int i = 0; i < length; ++i)
> +          printf (" %d", copy[i]);
> +        putchar ('\n');
> +        break;
> +      }
> +
> +  /* Verify that no elements went away or were added.  */
> +  {
> +    int expected_counts[256];

Maybe use UCHAR_MAX+1 here.

> +    for (int i = 0; i < length; ++i)
> +      ++expected_counts[array[i] & 0xff];
> +    int actual_counts[256];
> +    for (int i = 0; i < length; ++i)
> +      ++actual_counts[copy[i] & 0xff];
> +    for (int i = 0; i < 256; ++i)
> +      TEST_COMPARE (expected_counts[i], expected_counts[i]);
> +  }
> +
> +  free (copy);
> +}
> +
> +/* Enumerate all possible combinations of LENGTH elements.  */
> +static void
> +check_combinations (int length, signed char *start, int offset)
> +{
> +  if (offset == length)
> +    check_one_sort (start, length);
> +  else
> +    for (int i = 0; i < length; ++i)
> +      {
> +        start[offset] = i;
> +        check_combinations(length, start, offset + 1);
> +      }
> +}
> +
> +static int
> +do_test (void)
> +{
> +  /* A random permutation of 20 values.  */
> +  check_one_sort ((signed char[20]) {5, 12, 16, 10, 14, 11, 9, 13, 8, 15,
> +                                     0, 17, 3, 7, 1, 18, 2, 19, 4, 6}, 20);
> +
> +
> +  /* A permutation that appeared during adversial testing for the

s/adversial/adversarial

> +     quicksort pass.  */
> +  check_one_sort ((signed char[16]) {15, 3, 4, 2, 1, 0, 8, 7, 6, 5, 14,
> +                                     13, 12, 11, 10, 9}, 16);
> +
> +  /* Array lengths 2 and less are not handled by heapsort_r and
> +     deferred to insertion sort.  */
> +  for (int i = 3; i <= 8; ++i)
> +    {
> +      signed char *buf = xmalloc (i);
> +      check_combinations (i, buf, 0);
> +      free (buf);
> +    }
> +
> +  return 0;
> +}
> +
> +#include <support/test-driver.c>
diff mbox series

Patch

diff --git a/stdlib/Makefile b/stdlib/Makefile
index 6af606136e..48688f6a27 100644
--- a/stdlib/Makefile
+++ b/stdlib/Makefile
@@ -261,6 +261,7 @@  tests := \
   # tests
 
 tests-internal := \
+  tst-qsort4 \
   tst-strtod1i \
   tst-strtod3 \
   tst-strtod4 \
diff --git a/stdlib/qsort.c b/stdlib/qsort.c
index 6d0c4447ec..a2f9e916ef 100644
--- a/stdlib/qsort.c
+++ b/stdlib/qsort.c
@@ -125,29 +125,44 @@  pop (stack_node *top, char **lo, char **hi, size_t *depth)
   return top;
 }
 
-/* NB: N is inclusive bound for BASE.  */
+/* Establish the heap condition at index K, that is, the key at K will
+   not be less than either of its children, at 2 * K + 1 and 2 * K + 2
+   (if they exist).  N is the last valid index. */
 static inline void
 siftdown (void *base, size_t size, size_t k, size_t n,
 	  enum swap_type_t swap_type, __compar_d_fn_t cmp, void *arg)
 {
-  while (k <= n / 2)
+  /* There can only be a heap condition violation if there are
+     children.  */
+  while (2 * k + 1 <= n)
     {
-      size_t j = 2 * k;
+      /* Left child.  */
+      size_t j = 2 * k + 1;
+      /* If the right child is larger, use it.  */
       if (j < n && cmp (base + (j * size), base + ((j + 1) * size), arg) < 0)
 	j++;
 
+      /* If k is already >= to its children, we are done.  */
       if (j == k || cmp (base + (k * size), base + (j * size), arg) >= 0)
 	break;
 
+      /* Heal the violation.  */
       do_swap (base + (size * j), base + (k * size), size, swap_type);
+
+      /* Swapping with j may have introduced a violation at j.  Fix
+	 it in the next loop iteration.  */
       k = j;
     }
 }
 
+/* Establish the heap condition for the indices 0 to N (inclusive).  */
 static inline void
 heapify (void *base, size_t size, size_t n, enum swap_type_t swap_type,
 	 __compar_d_fn_t cmp, void *arg)
 {
+  /* If n is odd, k = n / 2 has a left child at n, so this is the
+     largest index that can have a heap condition violation regarding
+     its children.  */
   size_t k = n / 2;
   while (1)
     {
@@ -157,32 +172,38 @@  heapify (void *base, size_t size, size_t n, enum swap_type_t swap_type,
     }
 }
 
-/* A non-recursive heapsort, used on introsort implementation as a fallback
-   routine with worst-case performance of O(nlog n) and worst-case space
-   complexity of O(1).  It sorts the array starting at BASE and ending at
-   END, with each element of SIZE bytes.  The SWAP_TYPE is the callback
-   function used to swap elements, and CMP is the function used to compare
-   elements.   */
+/* A non-recursive heapsort, used on introsort implementation as a
+   fallback routine with worst-case performance of O(nlog n) and
+   worst-case space complexity of O(1).  It sorts the array starting
+   at BASE and ending at END (inclusive), with each element of SIZE
+   bytes.  The SWAP_TYPE is the callback function used to swap
+   elements, and CMP is the function used to compare elements.  */
 static void
 heapsort_r (void *base, void *end, size_t size, enum swap_type_t swap_type,
 	    __compar_d_fn_t cmp, void *arg)
 {
-  const size_t count = ((uintptr_t) end - (uintptr_t) base) / size;
-
-  if (count < 2)
+  size_t n = ((uintptr_t) end - (uintptr_t) base) / size;
+  if (n <= 1)
+    /* Handled by insertion sort.  */
     return;
 
-  size_t n = count - 1;
-
   /* Build the binary heap, largest value at the base[0].  */
   heapify (base, size, n, swap_type, cmp, arg);
 
-  /* On each iteration base[0:n] is the binary heap, while base[n:count]
-     is sorted.  */
-  while (n > 0)
+  while (true)
     {
+      /* Indices 0 .. n contain the binary heap.  Extract the largest
+	 element put it into the final position in the array.  */
       do_swap (base, base + (n * size), size, swap_type);
+
+      /* The heap is now one element shorter.  */
       n--;
+      if (n == 0)
+	break;
+
+      /* By swapping in elements 0 and the previous value of n (now at
+	 n + 1), we likely introduced a heap condition violation.  Fix
+	 it for the reduced heap.  */
       siftdown (base, size, 0, n, swap_type, cmp, arg);
     }
 }
diff --git a/stdlib/tst-qsort4.c b/stdlib/tst-qsort4.c
new file mode 100644
index 0000000000..d5b8d05a91
--- /dev/null
+++ b/stdlib/tst-qsort4.c
@@ -0,0 +1,134 @@ 
+/* Test the heapsort implementation behind qsort.
+   Copyright (C) 2023 Free Software Foundation, Inc.
+   This file is part of the GNU C Library.
+
+   The GNU C Library is free software; you can redistribute it and/or
+   modify it under the terms of the GNU Lesser General Public
+   License as published by the Free Software Foundation; either
+   version 2.1 of the License, or (at your option) any later version.
+
+   The GNU C Library is distributed in the hope that it will be useful,
+   but WITHOUT ANY WARRANTY; without even the implied warranty of
+   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+   Lesser General Public License for more details.
+
+   You should have received a copy of the GNU Lesser General Public
+   License along with the GNU C Library; if not, see
+   <http://www.gnu.org/licenses/>.  */
+
+#include "qsort.c"
+
+#include <stdio.h>
+#include <support/check.h>
+#include <support/support.h>
+
+static int
+cmp (const void *a1, const void *b1, void *closure)
+{
+  const signed char *a = a1;
+  const signed char *b = b1;
+  return *a - *b;
+}
+
+/* Wrapper around heapsort_r that set ups the required variables.  */
+static void
+heapsort_wrapper (void *const pbase, size_t total_elems, size_t size,
+                  __compar_d_fn_t cmp, void *arg)
+{
+  char *base_ptr = (char *) pbase;
+  char *lo = base_ptr;
+  char *hi = &lo[size * (total_elems - 1)];
+
+  if (total_elems <= 1)
+    /* Avoid lossage with unsigned arithmetic below.  */
+    return;
+
+  enum swap_type_t swap_type;
+  if (is_aligned (pbase, size, 8))
+    swap_type = SWAP_WORDS_64;
+  else if (is_aligned (pbase, size, 4))
+    swap_type = SWAP_WORDS_32;
+  else
+    swap_type = SWAP_BYTES;
+  heapsort_r (lo, hi, size, swap_type, cmp, arg);
+}
+
+static void
+check_one_sort (signed char *array, int length)
+{
+  signed char *copy = xmalloc (length);
+  memcpy (copy, array, length);
+  heapsort_wrapper (copy, length, 1, cmp, NULL);
+
+  /* Verify that the result is sorted.  */
+  for (int i = 1; i < length; ++i)
+    if (copy[i] < copy[i - 1])
+      {
+        support_record_failure ();
+        printf ("error: sorting failure for length %d at offset %d\n",
+                length, i - 1);
+        printf ("input:");
+        for (int i = 0; i < length; ++i)
+          printf (" %d", array[i]);
+        printf ("\noutput:");
+        for (int i = 0; i < length; ++i)
+          printf (" %d", copy[i]);
+        putchar ('\n');
+        break;
+      }
+
+  /* Verify that no elements went away or were added.  */
+  {
+    int expected_counts[256];
+    for (int i = 0; i < length; ++i)
+      ++expected_counts[array[i] & 0xff];
+    int actual_counts[256];
+    for (int i = 0; i < length; ++i)
+      ++actual_counts[copy[i] & 0xff];
+    for (int i = 0; i < 256; ++i)
+      TEST_COMPARE (expected_counts[i], expected_counts[i]);
+  }
+
+  free (copy);
+}
+
+/* Enumerate all possible combinations of LENGTH elements.  */
+static void
+check_combinations (int length, signed char *start, int offset)
+{
+  if (offset == length)
+    check_one_sort (start, length);
+  else
+    for (int i = 0; i < length; ++i)
+      {
+        start[offset] = i;
+        check_combinations(length, start, offset + 1);
+      }
+}
+
+static int
+do_test (void)
+{
+  /* A random permutation of 20 values.  */
+  check_one_sort ((signed char[20]) {5, 12, 16, 10, 14, 11, 9, 13, 8, 15,
+                                     0, 17, 3, 7, 1, 18, 2, 19, 4, 6}, 20);
+
+
+  /* A permutation that appeared during adversial testing for the
+     quicksort pass.  */
+  check_one_sort ((signed char[16]) {15, 3, 4, 2, 1, 0, 8, 7, 6, 5, 14,
+                                     13, 12, 11, 10, 9}, 16);
+
+  /* Array lengths 2 and less are not handled by heapsort_r and
+     deferred to insertion sort.  */
+  for (int i = 3; i <= 8; ++i)
+    {
+      signed char *buf = xmalloc (i);
+      check_combinations (i, buf, 0);
+      free (buf);
+    }
+
+  return 0;
+}
+
+#include <support/test-driver.c>