diff mbox series

[v3,4/4] x86: Optimize {str|wcs}rchr-evex

Message ID 20220422015230.3241772-4-goldstein.w.n@gmail.com
State New
Headers show
Series [v3,1/4] benchtests: Improve bench-strrchr | expand

Commit Message

Noah Goldstein April 22, 2022, 1:52 a.m. UTC
The new code unrolls the main loop slightly without adding too much
overhead and minimizes the comparisons for the search CHAR.

Geometric Mean of all benchmarks New / Old: 0.755
See email for all results.

Full xcheck passes on x86_64 with and without multiarch enabled.
---
 sysdeps/x86_64/multiarch/strrchr-evex.S | 471 +++++++++++++++---------
 1 file changed, 290 insertions(+), 181 deletions(-)

Comments

H.J. Lu April 22, 2022, 7:04 p.m. UTC | #1
On Thu, Apr 21, 2022 at 6:52 PM Noah Goldstein <goldstein.w.n@gmail.com> wrote:
>
> The new code unrolls the main loop slightly without adding too much
> overhead and minimizes the comparisons for the search CHAR.
>
> Geometric Mean of all benchmarks New / Old: 0.755
> See email for all results.
>
> Full xcheck passes on x86_64 with and without multiarch enabled.
> ---
>  sysdeps/x86_64/multiarch/strrchr-evex.S | 471 +++++++++++++++---------
>  1 file changed, 290 insertions(+), 181 deletions(-)
>
> diff --git a/sysdeps/x86_64/multiarch/strrchr-evex.S b/sysdeps/x86_64/multiarch/strrchr-evex.S
> index adeddaed32..8014c285b3 100644
> --- a/sysdeps/x86_64/multiarch/strrchr-evex.S
> +++ b/sysdeps/x86_64/multiarch/strrchr-evex.S
> @@ -24,242 +24,351 @@
>  #  define STRRCHR      __strrchr_evex
>  # endif
>
> -# define VMOVU         vmovdqu64
> -# define VMOVA         vmovdqa64
> +# define VMOVU vmovdqu64
> +# define VMOVA vmovdqa64
>
>  # ifdef USE_AS_WCSRCHR
> +#  define SHIFT_REG    esi
> +
> +#  define kunpck       kunpckbw
> +#  define kmov_2x      kmovd
> +#  define maskz_2x     ecx
> +#  define maskm_2x     eax
> +#  define CHAR_SIZE    4
> +#  define VPMIN        vpminud
> +#  define VPTESTN      vptestnmd
>  #  define VPBROADCAST  vpbroadcastd
> -#  define VPCMP                vpcmpd
> -#  define SHIFT_REG    r8d
> +#  define VPCMP        vpcmpd
>  # else
> +#  define SHIFT_REG    edi
> +
> +#  define kunpck       kunpckdq
> +#  define kmov_2x      kmovq
> +#  define maskz_2x     rcx
> +#  define maskm_2x     rax
> +
> +#  define CHAR_SIZE    1
> +#  define VPMIN        vpminub
> +#  define VPTESTN      vptestnmb
>  #  define VPBROADCAST  vpbroadcastb
> -#  define VPCMP                vpcmpb
> -#  define SHIFT_REG    ecx
> +#  define VPCMP        vpcmpb
>  # endif
>
>  # define XMMZERO       xmm16
>  # define YMMZERO       ymm16
>  # define YMMMATCH      ymm17
> -# define YMM1          ymm18
> +# define YMMSAVE       ymm18
> +
> +# define YMM1  ymm19
> +# define YMM2  ymm20
> +# define YMM3  ymm21
> +# define YMM4  ymm22
> +# define YMM5  ymm23
> +# define YMM6  ymm24
> +# define YMM7  ymm25
> +# define YMM8  ymm26
>
> -# define VEC_SIZE      32
>
> -       .section .text.evex,"ax",@progbits
> -ENTRY (STRRCHR)
> -       movl    %edi, %ecx
> +# define VEC_SIZE      32
> +# define PAGE_SIZE     4096
> +       .section .text.evex, "ax", @progbits
> +ENTRY(STRRCHR)
> +       movl    %edi, %eax
>         /* Broadcast CHAR to YMMMATCH.  */
>         VPBROADCAST %esi, %YMMMATCH
>
> -       vpxorq  %XMMZERO, %XMMZERO, %XMMZERO
> -
> -       /* Check if we may cross page boundary with one vector load.  */
> -       andl    $(2 * VEC_SIZE - 1), %ecx
> -       cmpl    $VEC_SIZE, %ecx
> -       ja      L(cros_page_boundary)
> +       andl    $(PAGE_SIZE - 1), %eax
> +       cmpl    $(PAGE_SIZE - VEC_SIZE), %eax
> +       jg      L(cross_page_boundary)
>
> +L(page_cross_continue):
>         VMOVU   (%rdi), %YMM1
> -
> -       /* Each bit in K0 represents a null byte in YMM1.  */
> -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> -       /* Each bit in K1 represents a CHAR in YMM1.  */
> -       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> +       /* k0 has a 1 for each zero CHAR in YMM1.  */
> +       VPTESTN %YMM1, %YMM1, %k0
>         kmovd   %k0, %ecx
> -       kmovd   %k1, %eax
> -
> -       addq    $VEC_SIZE, %rdi
> -
> -       testl   %eax, %eax
> -       jnz     L(first_vec)
> -
>         testl   %ecx, %ecx
> -       jnz     L(return_null)
> -
> -       andq    $-VEC_SIZE, %rdi
> -       xorl    %edx, %edx
> -       jmp     L(aligned_loop)
> -
> -       .p2align 4
> -L(first_vec):
> -       /* Check if there is a null byte.  */
> -       testl   %ecx, %ecx
> -       jnz     L(char_and_nul_in_first_vec)
> -
> -       /* Remember the match and keep searching.  */
> -       movl    %eax, %edx
> -       movq    %rdi, %rsi
> -       andq    $-VEC_SIZE, %rdi
> -       jmp     L(aligned_loop)
> -
> -       .p2align 4
> -L(cros_page_boundary):
> -       andl    $(VEC_SIZE - 1), %ecx
> -       andq    $-VEC_SIZE, %rdi
> +       jz      L(aligned_more)
> +       /* fallthrough: zero CHAR in first VEC.  */
>
> +       /* K1 has a 1 for each search CHAR match in YMM1.  */
> +       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> +       kmovd   %k1, %eax
> +       /* Build mask up until first zero CHAR (used to mask of
> +          potential search CHAR matches past the end of the string).
> +        */
> +       blsmskl %ecx, %ecx
> +       andl    %ecx, %eax
> +       jz      L(ret0)
> +       /* Get last match (the `andl` removed any out of bounds
> +          matches).  */
> +       bsrl    %eax, %eax
>  # ifdef USE_AS_WCSRCHR
> -       /* NB: Divide shift count by 4 since each bit in K1 represent 4
> -          bytes.  */
> -       movl    %ecx, %SHIFT_REG
> -       sarl    $2, %SHIFT_REG
> +       leaq    (%rdi, %rax, CHAR_SIZE), %rax
> +# else
> +       addq    %rdi, %rax
>  # endif
> +L(ret0):
> +       ret
>
> -       VMOVA   (%rdi), %YMM1
> -
> -       /* Each bit in K0 represents a null byte in YMM1.  */
> -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> -       /* Each bit in K1 represents a CHAR in YMM1.  */
> +       /* Returns for first vec x1/x2/x3 have hard coded backward
> +          search path for earlier matches.  */
> +       .p2align 4,, 6
> +L(first_vec_x1):
> +       VPCMP   $0, %YMMMATCH, %YMM2, %k1
> +       kmovd   %k1, %eax
> +       blsmskl %ecx, %ecx
> +       /* eax non-zero if search CHAR in range.  */
> +       andl    %ecx, %eax
> +       jnz     L(first_vec_x1_return)
> +
> +       /* fallthrough: no match in YMM2 then need to check for earlier
> +          matches (in YMM1).  */
> +       .p2align 4,, 4
> +L(first_vec_x0_test):
>         VPCMP   $0, %YMMMATCH, %YMM1, %k1
> -       kmovd   %k0, %edx
>         kmovd   %k1, %eax
> -
> -       shrxl   %SHIFT_REG, %edx, %edx
> -       shrxl   %SHIFT_REG, %eax, %eax
> -       addq    $VEC_SIZE, %rdi
> -
> -       /* Check if there is a CHAR.  */
>         testl   %eax, %eax
> -       jnz     L(found_char)
> -
> -       testl   %edx, %edx
> -       jnz     L(return_null)
> -
> -       jmp     L(aligned_loop)
> -
> -       .p2align 4
> -L(found_char):
> -       testl   %edx, %edx
> -       jnz     L(char_and_nul)
> -
> -       /* Remember the match and keep searching.  */
> -       movl    %eax, %edx
> -       leaq    (%rdi, %rcx), %rsi
> +       jz      L(ret1)
> +       bsrl    %eax, %eax
> +# ifdef USE_AS_WCSRCHR
> +       leaq    (%rsi, %rax, CHAR_SIZE), %rax
> +# else
> +       addq    %rsi, %rax
> +# endif
> +L(ret1):
> +       ret
>
> -       .p2align 4
> -L(aligned_loop):
> -       VMOVA   (%rdi), %YMM1
> -       addq    $VEC_SIZE, %rdi
> +       .p2align 4,, 10
> +L(first_vec_x1_or_x2):
> +       VPCMP   $0, %YMM3, %YMMMATCH, %k3
> +       VPCMP   $0, %YMM2, %YMMMATCH, %k2
> +       /* K2 and K3 have 1 for any search CHAR match. Test if any
> +          matches between either of them. Otherwise check YMM1.  */
> +       kortestd %k2, %k3
> +       jz      L(first_vec_x0_test)
> +
> +       /* Guranteed that YMM2 and YMM3 are within range so merge the
> +          two bitmasks then get last result.  */
> +       kunpck  %k2, %k3, %k3
> +       kmovq   %k3, %rax
> +       bsrq    %rax, %rax
> +       leaq    (VEC_SIZE)(%r8, %rax, CHAR_SIZE), %rax
> +       ret
>
> -       /* Each bit in K0 represents a null byte in YMM1.  */
> -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> -       /* Each bit in K1 represents a CHAR in YMM1.  */
> -       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> -       kmovd   %k0, %ecx
> +       .p2align 4,, 6
> +L(first_vec_x3):
> +       VPCMP   $0, %YMMMATCH, %YMM4, %k1
>         kmovd   %k1, %eax
> -       orl     %eax, %ecx
> -       jnz     L(char_nor_null)
> +       blsmskl %ecx, %ecx
> +       /* If no search CHAR match in range check YMM1/YMM2/YMM3.  */
> +       andl    %ecx, %eax
> +       jz      L(first_vec_x1_or_x2)
> +       bsrl    %eax, %eax
> +       leaq    (VEC_SIZE * 3)(%rdi, %rax, CHAR_SIZE), %rax
> +       ret
>
> -       VMOVA   (%rdi), %YMM1
> -       add     $VEC_SIZE, %rdi
> +       .p2align 4,, 6
> +L(first_vec_x0_x1_test):
> +       VPCMP   $0, %YMMMATCH, %YMM2, %k1
> +       kmovd   %k1, %eax
> +       /* Check YMM2 for last match first. If no match try YMM1.  */
> +       testl   %eax, %eax
> +       jz      L(first_vec_x0_test)
> +       .p2align 4,, 4
> +L(first_vec_x1_return):
> +       bsrl    %eax, %eax
> +       leaq    (VEC_SIZE)(%rdi, %rax, CHAR_SIZE), %rax
> +       ret
>
> -       /* Each bit in K0 represents a null byte in YMM1.  */
> -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> -       /* Each bit in K1 represents a CHAR in YMM1.  */
> -       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> -       kmovd   %k0, %ecx
> +       .p2align 4,, 10
> +L(first_vec_x2):
> +       VPCMP   $0, %YMMMATCH, %YMM3, %k1
>         kmovd   %k1, %eax
> -       orl     %eax, %ecx
> -       jnz     L(char_nor_null)
> +       blsmskl %ecx, %ecx
> +       /* Check YMM3 for last match first. If no match try YMM2/YMM1.
> +        */
> +       andl    %ecx, %eax
> +       jz      L(first_vec_x0_x1_test)
> +       bsrl    %eax, %eax
> +       leaq    (VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
> +       ret
>
> -       VMOVA   (%rdi), %YMM1
> -       addq    $VEC_SIZE, %rdi
>
> -       /* Each bit in K0 represents a null byte in YMM1.  */
> -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> -       /* Each bit in K1 represents a CHAR in YMM1.  */
> -       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> +       .p2align 4
> +L(aligned_more):
> +       /* Need to keep original pointer incase YMM1 has last match.  */
> +       movq    %rdi, %rsi
> +       andq    $-VEC_SIZE, %rdi
> +       VMOVU   VEC_SIZE(%rdi), %YMM2
> +       VPTESTN %YMM2, %YMM2, %k0
>         kmovd   %k0, %ecx
> -       kmovd   %k1, %eax
> -       orl     %eax, %ecx
> -       jnz     L(char_nor_null)
> +       testl   %ecx, %ecx
> +       jnz     L(first_vec_x1)
>
> -       VMOVA   (%rdi), %YMM1
> -       addq    $VEC_SIZE, %rdi
> +       VMOVU   (VEC_SIZE * 2)(%rdi), %YMM3
> +       VPTESTN %YMM3, %YMM3, %k0
> +       kmovd   %k0, %ecx
> +       testl   %ecx, %ecx
> +       jnz     L(first_vec_x2)
>
> -       /* Each bit in K0 represents a null byte in YMM1.  */
> -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> -       /* Each bit in K1 represents a CHAR in YMM1.  */
> -       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> +       VMOVU   (VEC_SIZE * 3)(%rdi), %YMM4
> +       VPTESTN %YMM4, %YMM4, %k0
>         kmovd   %k0, %ecx
> -       kmovd   %k1, %eax
> -       orl     %eax, %ecx
> -       jz      L(aligned_loop)
> +       movq    %rdi, %r8
> +       testl   %ecx, %ecx
> +       jnz     L(first_vec_x3)
>
> +       andq    $-(VEC_SIZE * 2), %rdi
>         .p2align 4
> -L(char_nor_null):
> -       /* Find a CHAR or a null byte in a loop.  */
> +L(first_aligned_loop):
> +       /* Preserve YMM1, YMM2, YMM3, and YMM4 until we can gurantee
> +          they don't store a match.  */
> +       VMOVA   (VEC_SIZE * 4)(%rdi), %YMM5
> +       VMOVA   (VEC_SIZE * 5)(%rdi), %YMM6
> +
> +       VPCMP   $0, %YMM5, %YMMMATCH, %k2
> +       vpxord  %YMM6, %YMMMATCH, %YMM7
> +
> +       VPMIN   %YMM5, %YMM6, %YMM8
> +       VPMIN   %YMM8, %YMM7, %YMM7
> +
> +       VPTESTN %YMM7, %YMM7, %k1
> +       subq    $(VEC_SIZE * -2), %rdi
> +       kortestd %k1, %k2
> +       jz      L(first_aligned_loop)
> +
> +       VPCMP   $0, %YMM6, %YMMMATCH, %k3
> +       VPTESTN %YMM8, %YMM8, %k1
> +       ktestd  %k1, %k1
> +       jz      L(second_aligned_loop_prep)
> +
> +       kortestd %k2, %k3
> +       jnz     L(return_first_aligned_loop)
> +
> +       .p2align 4,, 6
> +L(first_vec_x1_or_x2_or_x3):
> +       VPCMP   $0, %YMM4, %YMMMATCH, %k4
> +       kmovd   %k4, %eax
>         testl   %eax, %eax
> -       jnz     L(match)
> -L(return_value):
> -       testl   %edx, %edx
> -       jz      L(return_null)
> -       movl    %edx, %eax
> -       movq    %rsi, %rdi
> +       jz      L(first_vec_x1_or_x2)
>         bsrl    %eax, %eax
> -# ifdef USE_AS_WCSRCHR
> -       /* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
> -       leaq    -VEC_SIZE(%rdi, %rax, 4), %rax
> -# else
> -       leaq    -VEC_SIZE(%rdi, %rax), %rax
> -# endif
> +       leaq    (VEC_SIZE * 3)(%r8, %rax, CHAR_SIZE), %rax
>         ret
>
> -       .p2align 4
> -L(match):
> -       /* Find a CHAR.  Check if there is a null byte.  */
> -       kmovd   %k0, %ecx
> -       testl   %ecx, %ecx
> -       jnz     L(find_nul)
> +       .p2align 4,, 8
> +L(return_first_aligned_loop):
> +       VPTESTN %YMM5, %YMM5, %k0
> +       kunpck  %k0, %k1, %k0
> +       kmov_2x %k0, %maskz_2x
> +
> +       blsmsk  %maskz_2x, %maskz_2x
> +       kunpck  %k2, %k3, %k3
> +       kmov_2x %k3, %maskm_2x
> +       and     %maskz_2x, %maskm_2x
> +       jz      L(first_vec_x1_or_x2_or_x3)
>
> -       /* Remember the match and keep searching.  */
> -       movl    %eax, %edx
> +       bsr     %maskm_2x, %maskm_2x
> +       leaq    (VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
> +       ret
> +
> +       .p2align 4
> +       /* We can throw away the work done for the first 4x checks here
> +          as we have a later match. This is the 'fast' path persay.
> +        */
> +L(second_aligned_loop_prep):
> +L(second_aligned_loop_set_furthest_match):
>         movq    %rdi, %rsi
> -       jmp     L(aligned_loop)
> +       kunpck  %k2, %k3, %k4
>
>         .p2align 4
> -L(find_nul):
> -       /* Mask out any matching bits after the null byte.  */
> -       movl    %ecx, %r8d
> -       subl    $1, %r8d
> -       xorl    %ecx, %r8d
> -       andl    %r8d, %eax
> -       testl   %eax, %eax
> -       /* If there is no CHAR here, return the remembered one.  */
> -       jz      L(return_value)
> -       bsrl    %eax, %eax
> +L(second_aligned_loop):
> +       VMOVU   (VEC_SIZE * 4)(%rdi), %YMM1
> +       VMOVU   (VEC_SIZE * 5)(%rdi), %YMM2
> +
> +       VPCMP   $0, %YMM1, %YMMMATCH, %k2
> +       vpxord  %YMM2, %YMMMATCH, %YMM3
> +
> +       VPMIN   %YMM1, %YMM2, %YMM4
> +       VPMIN   %YMM3, %YMM4, %YMM3
> +
> +       VPTESTN %YMM3, %YMM3, %k1
> +       subq    $(VEC_SIZE * -2), %rdi
> +       kortestd %k1, %k2
> +       jz      L(second_aligned_loop)
> +
> +       VPCMP   $0, %YMM2, %YMMMATCH, %k3
> +       VPTESTN %YMM4, %YMM4, %k1
> +       ktestd  %k1, %k1
> +       jz      L(second_aligned_loop_set_furthest_match)
> +
> +       kortestd %k2, %k3
> +       /* branch here because there is a significant advantage interms
> +          of output dependency chance in using edx.  */
> +       jnz     L(return_new_match)
> +L(return_old_match):
> +       kmovq   %k4, %rax
> +       bsrq    %rax, %rax
> +       leaq    (VEC_SIZE * 2)(%rsi, %rax, CHAR_SIZE), %rax
> +       ret
> +
> +L(return_new_match):
> +       VPTESTN %YMM1, %YMM1, %k0
> +       kunpck  %k0, %k1, %k0
> +       kmov_2x %k0, %maskz_2x
> +
> +       blsmsk  %maskz_2x, %maskz_2x
> +       kunpck  %k2, %k3, %k3
> +       kmov_2x %k3, %maskm_2x
> +       and     %maskz_2x, %maskm_2x
> +       jz      L(return_old_match)
> +
> +       bsr     %maskm_2x, %maskm_2x
> +       leaq    (VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
> +       ret
> +
> +L(cross_page_boundary):
> +       /* eax contains all the page offset bits of src (rdi). `xor rdi,
> +          rax` sets pointer will all page offset bits cleared so
> +          offset of (PAGE_SIZE - VEC_SIZE) will get last aligned VEC
> +          before page cross (guranteed to be safe to read). Doing this
> +          as opposed to `movq %rdi, %rax; andq $-VEC_SIZE, %rax` saves
> +          a bit of code size.  */
> +       xorq    %rdi, %rax
> +       VMOVU   (PAGE_SIZE - VEC_SIZE)(%rax), %YMM1
> +       VPTESTN %YMM1, %YMM1, %k0
> +       kmovd   %k0, %ecx
> +
> +       /* Shift out zero CHAR matches that are before the begining of
> +          src (rdi).  */
>  # ifdef USE_AS_WCSRCHR
> -       /* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
> -       leaq    -VEC_SIZE(%rdi, %rax, 4), %rax
> -# else
> -       leaq    -VEC_SIZE(%rdi, %rax), %rax
> +       movl    %edi, %esi
> +       andl    $(VEC_SIZE - 1), %esi
> +       shrl    $2, %esi
>  # endif
> -       ret
> +       shrxl   %SHIFT_REG, %ecx, %ecx
>
> -       .p2align 4
> -L(char_and_nul):
> -       /* Find both a CHAR and a null byte.  */
> -       addq    %rcx, %rdi
> -       movl    %edx, %ecx
> -L(char_and_nul_in_first_vec):
> -       /* Mask out any matching bits after the null byte.  */
> -       movl    %ecx, %r8d
> -       subl    $1, %r8d
> -       xorl    %ecx, %r8d
> -       andl    %r8d, %eax
> -       testl   %eax, %eax
> -       /* Return null pointer if the null byte comes first.  */
> -       jz      L(return_null)
> +       testl   %ecx, %ecx
> +       jz      L(page_cross_continue)
> +
> +       /* Found zero CHAR so need to test for search CHAR.  */
> +       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> +       kmovd   %k1, %eax
> +       /* Shift out search CHAR matches that are before the begining of
> +          src (rdi).  */
> +       shrxl   %SHIFT_REG, %eax, %eax
> +
> +       /* Check if any search CHAR match in range.  */
> +       blsmskl %ecx, %ecx
> +       andl    %ecx, %eax
> +       jz      L(ret3)
>         bsrl    %eax, %eax
>  # ifdef USE_AS_WCSRCHR
> -       /* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
> -       leaq    -VEC_SIZE(%rdi, %rax, 4), %rax
> +       leaq    (%rdi, %rax, CHAR_SIZE), %rax
>  # else
> -       leaq    -VEC_SIZE(%rdi, %rax), %rax
> +       addq    %rdi, %rax
>  # endif
> +L(ret3):
>         ret
>
> -       .p2align 4
> -L(return_null):
> -       xorl    %eax, %eax
> -       ret
> -
> -END (STRRCHR)
> +END(STRRCHR)
>  #endif
> --
> 2.25.1
>

LGTM.

Reviewed-by: H.J. Lu <hjl.tools@gmail.com>

Thanks.
Sunil Pandey May 12, 2022, 8:16 p.m. UTC | #2
On Fri, Apr 22, 2022 at 12:08 PM H.J. Lu via Libc-alpha
<libc-alpha@sourceware.org> wrote:
>
> On Thu, Apr 21, 2022 at 6:52 PM Noah Goldstein <goldstein.w.n@gmail.com> wrote:
> >
> > The new code unrolls the main loop slightly without adding too much
> > overhead and minimizes the comparisons for the search CHAR.
> >
> > Geometric Mean of all benchmarks New / Old: 0.755
> > See email for all results.
> >
> > Full xcheck passes on x86_64 with and without multiarch enabled.
> > ---
> >  sysdeps/x86_64/multiarch/strrchr-evex.S | 471 +++++++++++++++---------
> >  1 file changed, 290 insertions(+), 181 deletions(-)
> >
> > diff --git a/sysdeps/x86_64/multiarch/strrchr-evex.S b/sysdeps/x86_64/multiarch/strrchr-evex.S
> > index adeddaed32..8014c285b3 100644
> > --- a/sysdeps/x86_64/multiarch/strrchr-evex.S
> > +++ b/sysdeps/x86_64/multiarch/strrchr-evex.S
> > @@ -24,242 +24,351 @@
> >  #  define STRRCHR      __strrchr_evex
> >  # endif
> >
> > -# define VMOVU         vmovdqu64
> > -# define VMOVA         vmovdqa64
> > +# define VMOVU vmovdqu64
> > +# define VMOVA vmovdqa64
> >
> >  # ifdef USE_AS_WCSRCHR
> > +#  define SHIFT_REG    esi
> > +
> > +#  define kunpck       kunpckbw
> > +#  define kmov_2x      kmovd
> > +#  define maskz_2x     ecx
> > +#  define maskm_2x     eax
> > +#  define CHAR_SIZE    4
> > +#  define VPMIN        vpminud
> > +#  define VPTESTN      vptestnmd
> >  #  define VPBROADCAST  vpbroadcastd
> > -#  define VPCMP                vpcmpd
> > -#  define SHIFT_REG    r8d
> > +#  define VPCMP        vpcmpd
> >  # else
> > +#  define SHIFT_REG    edi
> > +
> > +#  define kunpck       kunpckdq
> > +#  define kmov_2x      kmovq
> > +#  define maskz_2x     rcx
> > +#  define maskm_2x     rax
> > +
> > +#  define CHAR_SIZE    1
> > +#  define VPMIN        vpminub
> > +#  define VPTESTN      vptestnmb
> >  #  define VPBROADCAST  vpbroadcastb
> > -#  define VPCMP                vpcmpb
> > -#  define SHIFT_REG    ecx
> > +#  define VPCMP        vpcmpb
> >  # endif
> >
> >  # define XMMZERO       xmm16
> >  # define YMMZERO       ymm16
> >  # define YMMMATCH      ymm17
> > -# define YMM1          ymm18
> > +# define YMMSAVE       ymm18
> > +
> > +# define YMM1  ymm19
> > +# define YMM2  ymm20
> > +# define YMM3  ymm21
> > +# define YMM4  ymm22
> > +# define YMM5  ymm23
> > +# define YMM6  ymm24
> > +# define YMM7  ymm25
> > +# define YMM8  ymm26
> >
> > -# define VEC_SIZE      32
> >
> > -       .section .text.evex,"ax",@progbits
> > -ENTRY (STRRCHR)
> > -       movl    %edi, %ecx
> > +# define VEC_SIZE      32
> > +# define PAGE_SIZE     4096
> > +       .section .text.evex, "ax", @progbits
> > +ENTRY(STRRCHR)
> > +       movl    %edi, %eax
> >         /* Broadcast CHAR to YMMMATCH.  */
> >         VPBROADCAST %esi, %YMMMATCH
> >
> > -       vpxorq  %XMMZERO, %XMMZERO, %XMMZERO
> > -
> > -       /* Check if we may cross page boundary with one vector load.  */
> > -       andl    $(2 * VEC_SIZE - 1), %ecx
> > -       cmpl    $VEC_SIZE, %ecx
> > -       ja      L(cros_page_boundary)
> > +       andl    $(PAGE_SIZE - 1), %eax
> > +       cmpl    $(PAGE_SIZE - VEC_SIZE), %eax
> > +       jg      L(cross_page_boundary)
> >
> > +L(page_cross_continue):
> >         VMOVU   (%rdi), %YMM1
> > -
> > -       /* Each bit in K0 represents a null byte in YMM1.  */
> > -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> > -       /* Each bit in K1 represents a CHAR in YMM1.  */
> > -       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> > +       /* k0 has a 1 for each zero CHAR in YMM1.  */
> > +       VPTESTN %YMM1, %YMM1, %k0
> >         kmovd   %k0, %ecx
> > -       kmovd   %k1, %eax
> > -
> > -       addq    $VEC_SIZE, %rdi
> > -
> > -       testl   %eax, %eax
> > -       jnz     L(first_vec)
> > -
> >         testl   %ecx, %ecx
> > -       jnz     L(return_null)
> > -
> > -       andq    $-VEC_SIZE, %rdi
> > -       xorl    %edx, %edx
> > -       jmp     L(aligned_loop)
> > -
> > -       .p2align 4
> > -L(first_vec):
> > -       /* Check if there is a null byte.  */
> > -       testl   %ecx, %ecx
> > -       jnz     L(char_and_nul_in_first_vec)
> > -
> > -       /* Remember the match and keep searching.  */
> > -       movl    %eax, %edx
> > -       movq    %rdi, %rsi
> > -       andq    $-VEC_SIZE, %rdi
> > -       jmp     L(aligned_loop)
> > -
> > -       .p2align 4
> > -L(cros_page_boundary):
> > -       andl    $(VEC_SIZE - 1), %ecx
> > -       andq    $-VEC_SIZE, %rdi
> > +       jz      L(aligned_more)
> > +       /* fallthrough: zero CHAR in first VEC.  */
> >
> > +       /* K1 has a 1 for each search CHAR match in YMM1.  */
> > +       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> > +       kmovd   %k1, %eax
> > +       /* Build mask up until first zero CHAR (used to mask of
> > +          potential search CHAR matches past the end of the string).
> > +        */
> > +       blsmskl %ecx, %ecx
> > +       andl    %ecx, %eax
> > +       jz      L(ret0)
> > +       /* Get last match (the `andl` removed any out of bounds
> > +          matches).  */
> > +       bsrl    %eax, %eax
> >  # ifdef USE_AS_WCSRCHR
> > -       /* NB: Divide shift count by 4 since each bit in K1 represent 4
> > -          bytes.  */
> > -       movl    %ecx, %SHIFT_REG
> > -       sarl    $2, %SHIFT_REG
> > +       leaq    (%rdi, %rax, CHAR_SIZE), %rax
> > +# else
> > +       addq    %rdi, %rax
> >  # endif
> > +L(ret0):
> > +       ret
> >
> > -       VMOVA   (%rdi), %YMM1
> > -
> > -       /* Each bit in K0 represents a null byte in YMM1.  */
> > -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> > -       /* Each bit in K1 represents a CHAR in YMM1.  */
> > +       /* Returns for first vec x1/x2/x3 have hard coded backward
> > +          search path for earlier matches.  */
> > +       .p2align 4,, 6
> > +L(first_vec_x1):
> > +       VPCMP   $0, %YMMMATCH, %YMM2, %k1
> > +       kmovd   %k1, %eax
> > +       blsmskl %ecx, %ecx
> > +       /* eax non-zero if search CHAR in range.  */
> > +       andl    %ecx, %eax
> > +       jnz     L(first_vec_x1_return)
> > +
> > +       /* fallthrough: no match in YMM2 then need to check for earlier
> > +          matches (in YMM1).  */
> > +       .p2align 4,, 4
> > +L(first_vec_x0_test):
> >         VPCMP   $0, %YMMMATCH, %YMM1, %k1
> > -       kmovd   %k0, %edx
> >         kmovd   %k1, %eax
> > -
> > -       shrxl   %SHIFT_REG, %edx, %edx
> > -       shrxl   %SHIFT_REG, %eax, %eax
> > -       addq    $VEC_SIZE, %rdi
> > -
> > -       /* Check if there is a CHAR.  */
> >         testl   %eax, %eax
> > -       jnz     L(found_char)
> > -
> > -       testl   %edx, %edx
> > -       jnz     L(return_null)
> > -
> > -       jmp     L(aligned_loop)
> > -
> > -       .p2align 4
> > -L(found_char):
> > -       testl   %edx, %edx
> > -       jnz     L(char_and_nul)
> > -
> > -       /* Remember the match and keep searching.  */
> > -       movl    %eax, %edx
> > -       leaq    (%rdi, %rcx), %rsi
> > +       jz      L(ret1)
> > +       bsrl    %eax, %eax
> > +# ifdef USE_AS_WCSRCHR
> > +       leaq    (%rsi, %rax, CHAR_SIZE), %rax
> > +# else
> > +       addq    %rsi, %rax
> > +# endif
> > +L(ret1):
> > +       ret
> >
> > -       .p2align 4
> > -L(aligned_loop):
> > -       VMOVA   (%rdi), %YMM1
> > -       addq    $VEC_SIZE, %rdi
> > +       .p2align 4,, 10
> > +L(first_vec_x1_or_x2):
> > +       VPCMP   $0, %YMM3, %YMMMATCH, %k3
> > +       VPCMP   $0, %YMM2, %YMMMATCH, %k2
> > +       /* K2 and K3 have 1 for any search CHAR match. Test if any
> > +          matches between either of them. Otherwise check YMM1.  */
> > +       kortestd %k2, %k3
> > +       jz      L(first_vec_x0_test)
> > +
> > +       /* Guranteed that YMM2 and YMM3 are within range so merge the
> > +          two bitmasks then get last result.  */
> > +       kunpck  %k2, %k3, %k3
> > +       kmovq   %k3, %rax
> > +       bsrq    %rax, %rax
> > +       leaq    (VEC_SIZE)(%r8, %rax, CHAR_SIZE), %rax
> > +       ret
> >
> > -       /* Each bit in K0 represents a null byte in YMM1.  */
> > -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> > -       /* Each bit in K1 represents a CHAR in YMM1.  */
> > -       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> > -       kmovd   %k0, %ecx
> > +       .p2align 4,, 6
> > +L(first_vec_x3):
> > +       VPCMP   $0, %YMMMATCH, %YMM4, %k1
> >         kmovd   %k1, %eax
> > -       orl     %eax, %ecx
> > -       jnz     L(char_nor_null)
> > +       blsmskl %ecx, %ecx
> > +       /* If no search CHAR match in range check YMM1/YMM2/YMM3.  */
> > +       andl    %ecx, %eax
> > +       jz      L(first_vec_x1_or_x2)
> > +       bsrl    %eax, %eax
> > +       leaq    (VEC_SIZE * 3)(%rdi, %rax, CHAR_SIZE), %rax
> > +       ret
> >
> > -       VMOVA   (%rdi), %YMM1
> > -       add     $VEC_SIZE, %rdi
> > +       .p2align 4,, 6
> > +L(first_vec_x0_x1_test):
> > +       VPCMP   $0, %YMMMATCH, %YMM2, %k1
> > +       kmovd   %k1, %eax
> > +       /* Check YMM2 for last match first. If no match try YMM1.  */
> > +       testl   %eax, %eax
> > +       jz      L(first_vec_x0_test)
> > +       .p2align 4,, 4
> > +L(first_vec_x1_return):
> > +       bsrl    %eax, %eax
> > +       leaq    (VEC_SIZE)(%rdi, %rax, CHAR_SIZE), %rax
> > +       ret
> >
> > -       /* Each bit in K0 represents a null byte in YMM1.  */
> > -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> > -       /* Each bit in K1 represents a CHAR in YMM1.  */
> > -       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> > -       kmovd   %k0, %ecx
> > +       .p2align 4,, 10
> > +L(first_vec_x2):
> > +       VPCMP   $0, %YMMMATCH, %YMM3, %k1
> >         kmovd   %k1, %eax
> > -       orl     %eax, %ecx
> > -       jnz     L(char_nor_null)
> > +       blsmskl %ecx, %ecx
> > +       /* Check YMM3 for last match first. If no match try YMM2/YMM1.
> > +        */
> > +       andl    %ecx, %eax
> > +       jz      L(first_vec_x0_x1_test)
> > +       bsrl    %eax, %eax
> > +       leaq    (VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
> > +       ret
> >
> > -       VMOVA   (%rdi), %YMM1
> > -       addq    $VEC_SIZE, %rdi
> >
> > -       /* Each bit in K0 represents a null byte in YMM1.  */
> > -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> > -       /* Each bit in K1 represents a CHAR in YMM1.  */
> > -       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> > +       .p2align 4
> > +L(aligned_more):
> > +       /* Need to keep original pointer incase YMM1 has last match.  */
> > +       movq    %rdi, %rsi
> > +       andq    $-VEC_SIZE, %rdi
> > +       VMOVU   VEC_SIZE(%rdi), %YMM2
> > +       VPTESTN %YMM2, %YMM2, %k0
> >         kmovd   %k0, %ecx
> > -       kmovd   %k1, %eax
> > -       orl     %eax, %ecx
> > -       jnz     L(char_nor_null)
> > +       testl   %ecx, %ecx
> > +       jnz     L(first_vec_x1)
> >
> > -       VMOVA   (%rdi), %YMM1
> > -       addq    $VEC_SIZE, %rdi
> > +       VMOVU   (VEC_SIZE * 2)(%rdi), %YMM3
> > +       VPTESTN %YMM3, %YMM3, %k0
> > +       kmovd   %k0, %ecx
> > +       testl   %ecx, %ecx
> > +       jnz     L(first_vec_x2)
> >
> > -       /* Each bit in K0 represents a null byte in YMM1.  */
> > -       VPCMP   $0, %YMMZERO, %YMM1, %k0
> > -       /* Each bit in K1 represents a CHAR in YMM1.  */
> > -       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> > +       VMOVU   (VEC_SIZE * 3)(%rdi), %YMM4
> > +       VPTESTN %YMM4, %YMM4, %k0
> >         kmovd   %k0, %ecx
> > -       kmovd   %k1, %eax
> > -       orl     %eax, %ecx
> > -       jz      L(aligned_loop)
> > +       movq    %rdi, %r8
> > +       testl   %ecx, %ecx
> > +       jnz     L(first_vec_x3)
> >
> > +       andq    $-(VEC_SIZE * 2), %rdi
> >         .p2align 4
> > -L(char_nor_null):
> > -       /* Find a CHAR or a null byte in a loop.  */
> > +L(first_aligned_loop):
> > +       /* Preserve YMM1, YMM2, YMM3, and YMM4 until we can gurantee
> > +          they don't store a match.  */
> > +       VMOVA   (VEC_SIZE * 4)(%rdi), %YMM5
> > +       VMOVA   (VEC_SIZE * 5)(%rdi), %YMM6
> > +
> > +       VPCMP   $0, %YMM5, %YMMMATCH, %k2
> > +       vpxord  %YMM6, %YMMMATCH, %YMM7
> > +
> > +       VPMIN   %YMM5, %YMM6, %YMM8
> > +       VPMIN   %YMM8, %YMM7, %YMM7
> > +
> > +       VPTESTN %YMM7, %YMM7, %k1
> > +       subq    $(VEC_SIZE * -2), %rdi
> > +       kortestd %k1, %k2
> > +       jz      L(first_aligned_loop)
> > +
> > +       VPCMP   $0, %YMM6, %YMMMATCH, %k3
> > +       VPTESTN %YMM8, %YMM8, %k1
> > +       ktestd  %k1, %k1
> > +       jz      L(second_aligned_loop_prep)
> > +
> > +       kortestd %k2, %k3
> > +       jnz     L(return_first_aligned_loop)
> > +
> > +       .p2align 4,, 6
> > +L(first_vec_x1_or_x2_or_x3):
> > +       VPCMP   $0, %YMM4, %YMMMATCH, %k4
> > +       kmovd   %k4, %eax
> >         testl   %eax, %eax
> > -       jnz     L(match)
> > -L(return_value):
> > -       testl   %edx, %edx
> > -       jz      L(return_null)
> > -       movl    %edx, %eax
> > -       movq    %rsi, %rdi
> > +       jz      L(first_vec_x1_or_x2)
> >         bsrl    %eax, %eax
> > -# ifdef USE_AS_WCSRCHR
> > -       /* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
> > -       leaq    -VEC_SIZE(%rdi, %rax, 4), %rax
> > -# else
> > -       leaq    -VEC_SIZE(%rdi, %rax), %rax
> > -# endif
> > +       leaq    (VEC_SIZE * 3)(%r8, %rax, CHAR_SIZE), %rax
> >         ret
> >
> > -       .p2align 4
> > -L(match):
> > -       /* Find a CHAR.  Check if there is a null byte.  */
> > -       kmovd   %k0, %ecx
> > -       testl   %ecx, %ecx
> > -       jnz     L(find_nul)
> > +       .p2align 4,, 8
> > +L(return_first_aligned_loop):
> > +       VPTESTN %YMM5, %YMM5, %k0
> > +       kunpck  %k0, %k1, %k0
> > +       kmov_2x %k0, %maskz_2x
> > +
> > +       blsmsk  %maskz_2x, %maskz_2x
> > +       kunpck  %k2, %k3, %k3
> > +       kmov_2x %k3, %maskm_2x
> > +       and     %maskz_2x, %maskm_2x
> > +       jz      L(first_vec_x1_or_x2_or_x3)
> >
> > -       /* Remember the match and keep searching.  */
> > -       movl    %eax, %edx
> > +       bsr     %maskm_2x, %maskm_2x
> > +       leaq    (VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
> > +       ret
> > +
> > +       .p2align 4
> > +       /* We can throw away the work done for the first 4x checks here
> > +          as we have a later match. This is the 'fast' path persay.
> > +        */
> > +L(second_aligned_loop_prep):
> > +L(second_aligned_loop_set_furthest_match):
> >         movq    %rdi, %rsi
> > -       jmp     L(aligned_loop)
> > +       kunpck  %k2, %k3, %k4
> >
> >         .p2align 4
> > -L(find_nul):
> > -       /* Mask out any matching bits after the null byte.  */
> > -       movl    %ecx, %r8d
> > -       subl    $1, %r8d
> > -       xorl    %ecx, %r8d
> > -       andl    %r8d, %eax
> > -       testl   %eax, %eax
> > -       /* If there is no CHAR here, return the remembered one.  */
> > -       jz      L(return_value)
> > -       bsrl    %eax, %eax
> > +L(second_aligned_loop):
> > +       VMOVU   (VEC_SIZE * 4)(%rdi), %YMM1
> > +       VMOVU   (VEC_SIZE * 5)(%rdi), %YMM2
> > +
> > +       VPCMP   $0, %YMM1, %YMMMATCH, %k2
> > +       vpxord  %YMM2, %YMMMATCH, %YMM3
> > +
> > +       VPMIN   %YMM1, %YMM2, %YMM4
> > +       VPMIN   %YMM3, %YMM4, %YMM3
> > +
> > +       VPTESTN %YMM3, %YMM3, %k1
> > +       subq    $(VEC_SIZE * -2), %rdi
> > +       kortestd %k1, %k2
> > +       jz      L(second_aligned_loop)
> > +
> > +       VPCMP   $0, %YMM2, %YMMMATCH, %k3
> > +       VPTESTN %YMM4, %YMM4, %k1
> > +       ktestd  %k1, %k1
> > +       jz      L(second_aligned_loop_set_furthest_match)
> > +
> > +       kortestd %k2, %k3
> > +       /* branch here because there is a significant advantage interms
> > +          of output dependency chance in using edx.  */
> > +       jnz     L(return_new_match)
> > +L(return_old_match):
> > +       kmovq   %k4, %rax
> > +       bsrq    %rax, %rax
> > +       leaq    (VEC_SIZE * 2)(%rsi, %rax, CHAR_SIZE), %rax
> > +       ret
> > +
> > +L(return_new_match):
> > +       VPTESTN %YMM1, %YMM1, %k0
> > +       kunpck  %k0, %k1, %k0
> > +       kmov_2x %k0, %maskz_2x
> > +
> > +       blsmsk  %maskz_2x, %maskz_2x
> > +       kunpck  %k2, %k3, %k3
> > +       kmov_2x %k3, %maskm_2x
> > +       and     %maskz_2x, %maskm_2x
> > +       jz      L(return_old_match)
> > +
> > +       bsr     %maskm_2x, %maskm_2x
> > +       leaq    (VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
> > +       ret
> > +
> > +L(cross_page_boundary):
> > +       /* eax contains all the page offset bits of src (rdi). `xor rdi,
> > +          rax` sets pointer will all page offset bits cleared so
> > +          offset of (PAGE_SIZE - VEC_SIZE) will get last aligned VEC
> > +          before page cross (guranteed to be safe to read). Doing this
> > +          as opposed to `movq %rdi, %rax; andq $-VEC_SIZE, %rax` saves
> > +          a bit of code size.  */
> > +       xorq    %rdi, %rax
> > +       VMOVU   (PAGE_SIZE - VEC_SIZE)(%rax), %YMM1
> > +       VPTESTN %YMM1, %YMM1, %k0
> > +       kmovd   %k0, %ecx
> > +
> > +       /* Shift out zero CHAR matches that are before the begining of
> > +          src (rdi).  */
> >  # ifdef USE_AS_WCSRCHR
> > -       /* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
> > -       leaq    -VEC_SIZE(%rdi, %rax, 4), %rax
> > -# else
> > -       leaq    -VEC_SIZE(%rdi, %rax), %rax
> > +       movl    %edi, %esi
> > +       andl    $(VEC_SIZE - 1), %esi
> > +       shrl    $2, %esi
> >  # endif
> > -       ret
> > +       shrxl   %SHIFT_REG, %ecx, %ecx
> >
> > -       .p2align 4
> > -L(char_and_nul):
> > -       /* Find both a CHAR and a null byte.  */
> > -       addq    %rcx, %rdi
> > -       movl    %edx, %ecx
> > -L(char_and_nul_in_first_vec):
> > -       /* Mask out any matching bits after the null byte.  */
> > -       movl    %ecx, %r8d
> > -       subl    $1, %r8d
> > -       xorl    %ecx, %r8d
> > -       andl    %r8d, %eax
> > -       testl   %eax, %eax
> > -       /* Return null pointer if the null byte comes first.  */
> > -       jz      L(return_null)
> > +       testl   %ecx, %ecx
> > +       jz      L(page_cross_continue)
> > +
> > +       /* Found zero CHAR so need to test for search CHAR.  */
> > +       VPCMP   $0, %YMMMATCH, %YMM1, %k1
> > +       kmovd   %k1, %eax
> > +       /* Shift out search CHAR matches that are before the begining of
> > +          src (rdi).  */
> > +       shrxl   %SHIFT_REG, %eax, %eax
> > +
> > +       /* Check if any search CHAR match in range.  */
> > +       blsmskl %ecx, %ecx
> > +       andl    %ecx, %eax
> > +       jz      L(ret3)
> >         bsrl    %eax, %eax
> >  # ifdef USE_AS_WCSRCHR
> > -       /* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
> > -       leaq    -VEC_SIZE(%rdi, %rax, 4), %rax
> > +       leaq    (%rdi, %rax, CHAR_SIZE), %rax
> >  # else
> > -       leaq    -VEC_SIZE(%rdi, %rax), %rax
> > +       addq    %rdi, %rax
> >  # endif
> > +L(ret3):
> >         ret
> >
> > -       .p2align 4
> > -L(return_null):
> > -       xorl    %eax, %eax
> > -       ret
> > -
> > -END (STRRCHR)
> > +END(STRRCHR)
> >  #endif
> > --
> > 2.25.1
> >
>
> LGTM.
>
> Reviewed-by: H.J. Lu <hjl.tools@gmail.com>
>
> Thanks.
>
> --
> H.J.

I would like to backport this patch to release branches.
Any comments or objections?

--Sunil
diff mbox series

Patch

diff --git a/sysdeps/x86_64/multiarch/strrchr-evex.S b/sysdeps/x86_64/multiarch/strrchr-evex.S
index adeddaed32..8014c285b3 100644
--- a/sysdeps/x86_64/multiarch/strrchr-evex.S
+++ b/sysdeps/x86_64/multiarch/strrchr-evex.S
@@ -24,242 +24,351 @@ 
 #  define STRRCHR	__strrchr_evex
 # endif
 
-# define VMOVU		vmovdqu64
-# define VMOVA		vmovdqa64
+# define VMOVU	vmovdqu64
+# define VMOVA	vmovdqa64
 
 # ifdef USE_AS_WCSRCHR
+#  define SHIFT_REG	esi
+
+#  define kunpck	kunpckbw
+#  define kmov_2x	kmovd
+#  define maskz_2x	ecx
+#  define maskm_2x	eax
+#  define CHAR_SIZE	4
+#  define VPMIN	vpminud
+#  define VPTESTN	vptestnmd
 #  define VPBROADCAST	vpbroadcastd
-#  define VPCMP		vpcmpd
-#  define SHIFT_REG	r8d
+#  define VPCMP	vpcmpd
 # else
+#  define SHIFT_REG	edi
+
+#  define kunpck	kunpckdq
+#  define kmov_2x	kmovq
+#  define maskz_2x	rcx
+#  define maskm_2x	rax
+
+#  define CHAR_SIZE	1
+#  define VPMIN	vpminub
+#  define VPTESTN	vptestnmb
 #  define VPBROADCAST	vpbroadcastb
-#  define VPCMP		vpcmpb
-#  define SHIFT_REG	ecx
+#  define VPCMP	vpcmpb
 # endif
 
 # define XMMZERO	xmm16
 # define YMMZERO	ymm16
 # define YMMMATCH	ymm17
-# define YMM1		ymm18
+# define YMMSAVE	ymm18
+
+# define YMM1	ymm19
+# define YMM2	ymm20
+# define YMM3	ymm21
+# define YMM4	ymm22
+# define YMM5	ymm23
+# define YMM6	ymm24
+# define YMM7	ymm25
+# define YMM8	ymm26
 
-# define VEC_SIZE	32
 
-	.section .text.evex,"ax",@progbits
-ENTRY (STRRCHR)
-	movl	%edi, %ecx
+# define VEC_SIZE	32
+# define PAGE_SIZE	4096
+	.section .text.evex, "ax", @progbits
+ENTRY(STRRCHR)
+	movl	%edi, %eax
 	/* Broadcast CHAR to YMMMATCH.  */
 	VPBROADCAST %esi, %YMMMATCH
 
-	vpxorq	%XMMZERO, %XMMZERO, %XMMZERO
-
-	/* Check if we may cross page boundary with one vector load.  */
-	andl	$(2 * VEC_SIZE - 1), %ecx
-	cmpl	$VEC_SIZE, %ecx
-	ja	L(cros_page_boundary)
+	andl	$(PAGE_SIZE - 1), %eax
+	cmpl	$(PAGE_SIZE - VEC_SIZE), %eax
+	jg	L(cross_page_boundary)
 
+L(page_cross_continue):
 	VMOVU	(%rdi), %YMM1
-
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
-	VPCMP	$0, %YMMMATCH, %YMM1, %k1
+	/* k0 has a 1 for each zero CHAR in YMM1.  */
+	VPTESTN	%YMM1, %YMM1, %k0
 	kmovd	%k0, %ecx
-	kmovd	%k1, %eax
-
-	addq	$VEC_SIZE, %rdi
-
-	testl	%eax, %eax
-	jnz	L(first_vec)
-
 	testl	%ecx, %ecx
-	jnz	L(return_null)
-
-	andq	$-VEC_SIZE, %rdi
-	xorl	%edx, %edx
-	jmp	L(aligned_loop)
-
-	.p2align 4
-L(first_vec):
-	/* Check if there is a null byte.  */
-	testl	%ecx, %ecx
-	jnz	L(char_and_nul_in_first_vec)
-
-	/* Remember the match and keep searching.  */
-	movl	%eax, %edx
-	movq	%rdi, %rsi
-	andq	$-VEC_SIZE, %rdi
-	jmp	L(aligned_loop)
-
-	.p2align 4
-L(cros_page_boundary):
-	andl	$(VEC_SIZE - 1), %ecx
-	andq	$-VEC_SIZE, %rdi
+	jz	L(aligned_more)
+	/* fallthrough: zero CHAR in first VEC.  */
 
+	/* K1 has a 1 for each search CHAR match in YMM1.  */
+	VPCMP	$0, %YMMMATCH, %YMM1, %k1
+	kmovd	%k1, %eax
+	/* Build mask up until first zero CHAR (used to mask of
+	   potential search CHAR matches past the end of the string).
+	 */
+	blsmskl	%ecx, %ecx
+	andl	%ecx, %eax
+	jz	L(ret0)
+	/* Get last match (the `andl` removed any out of bounds
+	   matches).  */
+	bsrl	%eax, %eax
 # ifdef USE_AS_WCSRCHR
-	/* NB: Divide shift count by 4 since each bit in K1 represent 4
-	   bytes.  */
-	movl	%ecx, %SHIFT_REG
-	sarl	$2, %SHIFT_REG
+	leaq	(%rdi, %rax, CHAR_SIZE), %rax
+# else
+	addq	%rdi, %rax
 # endif
+L(ret0):
+	ret
 
-	VMOVA	(%rdi), %YMM1
-
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
+	/* Returns for first vec x1/x2/x3 have hard coded backward
+	   search path for earlier matches.  */
+	.p2align 4,, 6
+L(first_vec_x1):
+	VPCMP	$0, %YMMMATCH, %YMM2, %k1
+	kmovd	%k1, %eax
+	blsmskl	%ecx, %ecx
+	/* eax non-zero if search CHAR in range.  */
+	andl	%ecx, %eax
+	jnz	L(first_vec_x1_return)
+
+	/* fallthrough: no match in YMM2 then need to check for earlier
+	   matches (in YMM1).  */
+	.p2align 4,, 4
+L(first_vec_x0_test):
 	VPCMP	$0, %YMMMATCH, %YMM1, %k1
-	kmovd	%k0, %edx
 	kmovd	%k1, %eax
-
-	shrxl	%SHIFT_REG, %edx, %edx
-	shrxl	%SHIFT_REG, %eax, %eax
-	addq	$VEC_SIZE, %rdi
-
-	/* Check if there is a CHAR.  */
 	testl	%eax, %eax
-	jnz	L(found_char)
-
-	testl	%edx, %edx
-	jnz	L(return_null)
-
-	jmp	L(aligned_loop)
-
-	.p2align 4
-L(found_char):
-	testl	%edx, %edx
-	jnz	L(char_and_nul)
-
-	/* Remember the match and keep searching.  */
-	movl	%eax, %edx
-	leaq	(%rdi, %rcx), %rsi
+	jz	L(ret1)
+	bsrl	%eax, %eax
+# ifdef USE_AS_WCSRCHR
+	leaq	(%rsi, %rax, CHAR_SIZE), %rax
+# else
+	addq	%rsi, %rax
+# endif
+L(ret1):
+	ret
 
-	.p2align 4
-L(aligned_loop):
-	VMOVA	(%rdi), %YMM1
-	addq	$VEC_SIZE, %rdi
+	.p2align 4,, 10
+L(first_vec_x1_or_x2):
+	VPCMP	$0, %YMM3, %YMMMATCH, %k3
+	VPCMP	$0, %YMM2, %YMMMATCH, %k2
+	/* K2 and K3 have 1 for any search CHAR match. Test if any
+	   matches between either of them. Otherwise check YMM1.  */
+	kortestd %k2, %k3
+	jz	L(first_vec_x0_test)
+
+	/* Guranteed that YMM2 and YMM3 are within range so merge the
+	   two bitmasks then get last result.  */
+	kunpck	%k2, %k3, %k3
+	kmovq	%k3, %rax
+	bsrq	%rax, %rax
+	leaq	(VEC_SIZE)(%r8, %rax, CHAR_SIZE), %rax
+	ret
 
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
-	VPCMP	$0, %YMMMATCH, %YMM1, %k1
-	kmovd	%k0, %ecx
+	.p2align 4,, 6
+L(first_vec_x3):
+	VPCMP	$0, %YMMMATCH, %YMM4, %k1
 	kmovd	%k1, %eax
-	orl	%eax, %ecx
-	jnz	L(char_nor_null)
+	blsmskl	%ecx, %ecx
+	/* If no search CHAR match in range check YMM1/YMM2/YMM3.  */
+	andl	%ecx, %eax
+	jz	L(first_vec_x1_or_x2)
+	bsrl	%eax, %eax
+	leaq	(VEC_SIZE * 3)(%rdi, %rax, CHAR_SIZE), %rax
+	ret
 
-	VMOVA	(%rdi), %YMM1
-	add	$VEC_SIZE, %rdi
+	.p2align 4,, 6
+L(first_vec_x0_x1_test):
+	VPCMP	$0, %YMMMATCH, %YMM2, %k1
+	kmovd	%k1, %eax
+	/* Check YMM2 for last match first. If no match try YMM1.  */
+	testl	%eax, %eax
+	jz	L(first_vec_x0_test)
+	.p2align 4,, 4
+L(first_vec_x1_return):
+	bsrl	%eax, %eax
+	leaq	(VEC_SIZE)(%rdi, %rax, CHAR_SIZE), %rax
+	ret
 
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
-	VPCMP	$0, %YMMMATCH, %YMM1, %k1
-	kmovd	%k0, %ecx
+	.p2align 4,, 10
+L(first_vec_x2):
+	VPCMP	$0, %YMMMATCH, %YMM3, %k1
 	kmovd	%k1, %eax
-	orl	%eax, %ecx
-	jnz	L(char_nor_null)
+	blsmskl	%ecx, %ecx
+	/* Check YMM3 for last match first. If no match try YMM2/YMM1.
+	 */
+	andl	%ecx, %eax
+	jz	L(first_vec_x0_x1_test)
+	bsrl	%eax, %eax
+	leaq	(VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
+	ret
 
-	VMOVA	(%rdi), %YMM1
-	addq	$VEC_SIZE, %rdi
 
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
-	VPCMP	$0, %YMMMATCH, %YMM1, %k1
+	.p2align 4
+L(aligned_more):
+	/* Need to keep original pointer incase YMM1 has last match.  */
+	movq	%rdi, %rsi
+	andq	$-VEC_SIZE, %rdi
+	VMOVU	VEC_SIZE(%rdi), %YMM2
+	VPTESTN	%YMM2, %YMM2, %k0
 	kmovd	%k0, %ecx
-	kmovd	%k1, %eax
-	orl	%eax, %ecx
-	jnz	L(char_nor_null)
+	testl	%ecx, %ecx
+	jnz	L(first_vec_x1)
 
-	VMOVA	(%rdi), %YMM1
-	addq	$VEC_SIZE, %rdi
+	VMOVU	(VEC_SIZE * 2)(%rdi), %YMM3
+	VPTESTN	%YMM3, %YMM3, %k0
+	kmovd	%k0, %ecx
+	testl	%ecx, %ecx
+	jnz	L(first_vec_x2)
 
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
-	VPCMP	$0, %YMMMATCH, %YMM1, %k1
+	VMOVU	(VEC_SIZE * 3)(%rdi), %YMM4
+	VPTESTN	%YMM4, %YMM4, %k0
 	kmovd	%k0, %ecx
-	kmovd	%k1, %eax
-	orl	%eax, %ecx
-	jz	L(aligned_loop)
+	movq	%rdi, %r8
+	testl	%ecx, %ecx
+	jnz	L(first_vec_x3)
 
+	andq	$-(VEC_SIZE * 2), %rdi
 	.p2align 4
-L(char_nor_null):
-	/* Find a CHAR or a null byte in a loop.  */
+L(first_aligned_loop):
+	/* Preserve YMM1, YMM2, YMM3, and YMM4 until we can gurantee
+	   they don't store a match.  */
+	VMOVA	(VEC_SIZE * 4)(%rdi), %YMM5
+	VMOVA	(VEC_SIZE * 5)(%rdi), %YMM6
+
+	VPCMP	$0, %YMM5, %YMMMATCH, %k2
+	vpxord	%YMM6, %YMMMATCH, %YMM7
+
+	VPMIN	%YMM5, %YMM6, %YMM8
+	VPMIN	%YMM8, %YMM7, %YMM7
+
+	VPTESTN	%YMM7, %YMM7, %k1
+	subq	$(VEC_SIZE * -2), %rdi
+	kortestd %k1, %k2
+	jz	L(first_aligned_loop)
+
+	VPCMP	$0, %YMM6, %YMMMATCH, %k3
+	VPTESTN	%YMM8, %YMM8, %k1
+	ktestd	%k1, %k1
+	jz	L(second_aligned_loop_prep)
+
+	kortestd %k2, %k3
+	jnz	L(return_first_aligned_loop)
+
+	.p2align 4,, 6
+L(first_vec_x1_or_x2_or_x3):
+	VPCMP	$0, %YMM4, %YMMMATCH, %k4
+	kmovd	%k4, %eax
 	testl	%eax, %eax
-	jnz	L(match)
-L(return_value):
-	testl	%edx, %edx
-	jz	L(return_null)
-	movl	%edx, %eax
-	movq	%rsi, %rdi
+	jz	L(first_vec_x1_or_x2)
 	bsrl	%eax, %eax
-# ifdef USE_AS_WCSRCHR
-	/* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
-	leaq	-VEC_SIZE(%rdi, %rax, 4), %rax
-# else
-	leaq	-VEC_SIZE(%rdi, %rax), %rax
-# endif
+	leaq	(VEC_SIZE * 3)(%r8, %rax, CHAR_SIZE), %rax
 	ret
 
-	.p2align 4
-L(match):
-	/* Find a CHAR.  Check if there is a null byte.  */
-	kmovd	%k0, %ecx
-	testl	%ecx, %ecx
-	jnz	L(find_nul)
+	.p2align 4,, 8
+L(return_first_aligned_loop):
+	VPTESTN	%YMM5, %YMM5, %k0
+	kunpck	%k0, %k1, %k0
+	kmov_2x	%k0, %maskz_2x
+
+	blsmsk	%maskz_2x, %maskz_2x
+	kunpck	%k2, %k3, %k3
+	kmov_2x	%k3, %maskm_2x
+	and	%maskz_2x, %maskm_2x
+	jz	L(first_vec_x1_or_x2_or_x3)
 
-	/* Remember the match and keep searching.  */
-	movl	%eax, %edx
+	bsr	%maskm_2x, %maskm_2x
+	leaq	(VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
+	ret
+
+	.p2align 4
+	/* We can throw away the work done for the first 4x checks here
+	   as we have a later match. This is the 'fast' path persay.
+	 */
+L(second_aligned_loop_prep):
+L(second_aligned_loop_set_furthest_match):
 	movq	%rdi, %rsi
-	jmp	L(aligned_loop)
+	kunpck	%k2, %k3, %k4
 
 	.p2align 4
-L(find_nul):
-	/* Mask out any matching bits after the null byte.  */
-	movl	%ecx, %r8d
-	subl	$1, %r8d
-	xorl	%ecx, %r8d
-	andl	%r8d, %eax
-	testl	%eax, %eax
-	/* If there is no CHAR here, return the remembered one.  */
-	jz	L(return_value)
-	bsrl	%eax, %eax
+L(second_aligned_loop):
+	VMOVU	(VEC_SIZE * 4)(%rdi), %YMM1
+	VMOVU	(VEC_SIZE * 5)(%rdi), %YMM2
+
+	VPCMP	$0, %YMM1, %YMMMATCH, %k2
+	vpxord	%YMM2, %YMMMATCH, %YMM3
+
+	VPMIN	%YMM1, %YMM2, %YMM4
+	VPMIN	%YMM3, %YMM4, %YMM3
+
+	VPTESTN	%YMM3, %YMM3, %k1
+	subq	$(VEC_SIZE * -2), %rdi
+	kortestd %k1, %k2
+	jz	L(second_aligned_loop)
+
+	VPCMP	$0, %YMM2, %YMMMATCH, %k3
+	VPTESTN	%YMM4, %YMM4, %k1
+	ktestd	%k1, %k1
+	jz	L(second_aligned_loop_set_furthest_match)
+
+	kortestd %k2, %k3
+	/* branch here because there is a significant advantage interms
+	   of output dependency chance in using edx.  */
+	jnz	L(return_new_match)
+L(return_old_match):
+	kmovq	%k4, %rax
+	bsrq	%rax, %rax
+	leaq	(VEC_SIZE * 2)(%rsi, %rax, CHAR_SIZE), %rax
+	ret
+
+L(return_new_match):
+	VPTESTN	%YMM1, %YMM1, %k0
+	kunpck	%k0, %k1, %k0
+	kmov_2x	%k0, %maskz_2x
+
+	blsmsk	%maskz_2x, %maskz_2x
+	kunpck	%k2, %k3, %k3
+	kmov_2x	%k3, %maskm_2x
+	and	%maskz_2x, %maskm_2x
+	jz	L(return_old_match)
+
+	bsr	%maskm_2x, %maskm_2x
+	leaq	(VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
+	ret
+
+L(cross_page_boundary):
+	/* eax contains all the page offset bits of src (rdi). `xor rdi,
+	   rax` sets pointer will all page offset bits cleared so
+	   offset of (PAGE_SIZE - VEC_SIZE) will get last aligned VEC
+	   before page cross (guranteed to be safe to read). Doing this
+	   as opposed to `movq %rdi, %rax; andq $-VEC_SIZE, %rax` saves
+	   a bit of code size.  */
+	xorq	%rdi, %rax
+	VMOVU	(PAGE_SIZE - VEC_SIZE)(%rax), %YMM1
+	VPTESTN	%YMM1, %YMM1, %k0
+	kmovd	%k0, %ecx
+
+	/* Shift out zero CHAR matches that are before the begining of
+	   src (rdi).  */
 # ifdef USE_AS_WCSRCHR
-	/* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
-	leaq	-VEC_SIZE(%rdi, %rax, 4), %rax
-# else
-	leaq	-VEC_SIZE(%rdi, %rax), %rax
+	movl	%edi, %esi
+	andl	$(VEC_SIZE - 1), %esi
+	shrl	$2, %esi
 # endif
-	ret
+	shrxl	%SHIFT_REG, %ecx, %ecx
 
-	.p2align 4
-L(char_and_nul):
-	/* Find both a CHAR and a null byte.  */
-	addq	%rcx, %rdi
-	movl	%edx, %ecx
-L(char_and_nul_in_first_vec):
-	/* Mask out any matching bits after the null byte.  */
-	movl	%ecx, %r8d
-	subl	$1, %r8d
-	xorl	%ecx, %r8d
-	andl	%r8d, %eax
-	testl	%eax, %eax
-	/* Return null pointer if the null byte comes first.  */
-	jz	L(return_null)
+	testl	%ecx, %ecx
+	jz	L(page_cross_continue)
+
+	/* Found zero CHAR so need to test for search CHAR.  */
+	VPCMP	$0, %YMMMATCH, %YMM1, %k1
+	kmovd	%k1, %eax
+	/* Shift out search CHAR matches that are before the begining of
+	   src (rdi).  */
+	shrxl	%SHIFT_REG, %eax, %eax
+
+	/* Check if any search CHAR match in range.  */
+	blsmskl	%ecx, %ecx
+	andl	%ecx, %eax
+	jz	L(ret3)
 	bsrl	%eax, %eax
 # ifdef USE_AS_WCSRCHR
-	/* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
-	leaq	-VEC_SIZE(%rdi, %rax, 4), %rax
+	leaq	(%rdi, %rax, CHAR_SIZE), %rax
 # else
-	leaq	-VEC_SIZE(%rdi, %rax), %rax
+	addq	%rdi, %rax
 # endif
+L(ret3):
 	ret
 
-	.p2align 4
-L(return_null):
-	xorl	%eax, %eax
-	ret
-
-END (STRRCHR)
+END(STRRCHR)
 #endif