diff mbox series

[v1,3/3] x86: Fix overflow bug in wcsnlen-sse4_1 and wcsnlen-avx2

Message ID 20210609205257.123944-3-goldstein.w.n@gmail.com
State New
Headers show
Series [v1,1/3] String: Add additional overflow tests for strnlen, memchr, and strncat | expand

Commit Message

Noah Goldstein June 9, 2021, 8:52 p.m. UTC
This commit fixes the bug mentioned in the previous commit.

The previous implementations of wmemchr in these files relied
on maxlen * sizeof(wchar_t) which was not guranteed by the standard.

The new overflow tests added in the previous commit now
pass (As well as all the other tests).

Signed-off-by: Noah Goldstein <goldstein.w.n@gmail.com>
---
Its possible there is a room for a speedup in strnlen-avx2
and strnlen-evex if we check for overflow first and jump to
strlen. This allows for end pointers to be used as opposed
to tracking length which will save some ALU / code size.
 sysdeps/x86_64/multiarch/strlen-avx2.S | 130 ++++++++++++++++++-------
 sysdeps/x86_64/strlen.S                |  14 ++-
 2 files changed, 106 insertions(+), 38 deletions(-)
diff mbox series

Patch

diff --git a/sysdeps/x86_64/multiarch/strlen-avx2.S b/sysdeps/x86_64/multiarch/strlen-avx2.S
index bd2e6ee44a..b282a75613 100644
--- a/sysdeps/x86_64/multiarch/strlen-avx2.S
+++ b/sysdeps/x86_64/multiarch/strlen-avx2.S
@@ -44,21 +44,21 @@ 
 
 # define VEC_SIZE 32
 # define PAGE_SIZE 4096
+# define CHAR_PER_VEC	(VEC_SIZE / CHAR_SIZE)
 
 	.section SECTION(.text),"ax",@progbits
 ENTRY (STRLEN)
 # ifdef USE_AS_STRNLEN
 	/* Check zero length.  */
+#  ifdef __ILP32__
+	/* Clear upper bits.  */
+	and	%RSI_LP, %RSI_LP
+#  else
 	test	%RSI_LP, %RSI_LP
+#  endif
 	jz	L(zero)
 	/* Store max len in R8_LP before adjusting if using WCSLEN.  */
 	mov	%RSI_LP, %R8_LP
-#  ifdef USE_AS_WCSLEN
-	shl	$2, %RSI_LP
-#  elif defined __ILP32__
-	/* Clear the upper 32 bits.  */
-	movl	%esi, %esi
-#  endif
 # endif
 	movl	%edi, %eax
 	movq	%rdi, %rdx
@@ -72,10 +72,10 @@  ENTRY (STRLEN)
 
 	/* Check the first VEC_SIZE bytes.  */
 	VPCMPEQ	(%rdi), %ymm0, %ymm1
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 # ifdef USE_AS_STRNLEN
 	/* If length < VEC_SIZE handle special.  */
-	cmpq	$VEC_SIZE, %rsi
+	cmpq	$CHAR_PER_VEC, %rsi
 	jbe	L(first_vec_x0)
 # endif
 	/* If empty continue to aligned_more. Otherwise return bit
@@ -84,6 +84,7 @@  ENTRY (STRLEN)
 	jz	L(aligned_more)
 	tzcntl	%eax, %eax
 # ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrl	$2, %eax
 # endif
 	VZEROUPPER_RETURN
@@ -97,9 +98,14 @@  L(zero):
 L(first_vec_x0):
 	/* Set bit for max len so that tzcnt will return min of max len
 	   and position of first match.  */
+#  ifdef USE_AS_WCSLEN
+	/* NB: Multiply length by 4 to get byte count.  */
+	sall	$2, %esi
+#  endif
 	btsq	%rsi, %rax
 	tzcntl	%eax, %eax
 #  ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrl	$2, %eax
 #  endif
 	VZEROUPPER_RETURN
@@ -113,14 +119,19 @@  L(first_vec_x1):
 # ifdef USE_AS_STRNLEN
 	/* Use ecx which was computed earlier to compute correct value.
 	 */
+#  ifdef USE_AS_WCSLEN
+	leal	-(VEC_SIZE * 4 + 1)(%rax, %rcx, 4), %eax
+#  else
 	subl	$(VEC_SIZE * 4 + 1), %ecx
 	addl	%ecx, %eax
+#  endif
 # else
 	subl	%edx, %edi
 	incl	%edi
 	addl	%edi, %eax
 # endif
 # ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrl	$2, %eax
 # endif
 	VZEROUPPER_RETURN
@@ -133,14 +144,19 @@  L(first_vec_x2):
 # ifdef USE_AS_STRNLEN
 	/* Use ecx which was computed earlier to compute correct value.
 	 */
+#  ifdef USE_AS_WCSLEN
+	leal	-(VEC_SIZE * 3 + 1)(%rax, %rcx, 4), %eax
+#  else
 	subl	$(VEC_SIZE * 3 + 1), %ecx
 	addl	%ecx, %eax
+#  endif
 # else
 	subl	%edx, %edi
 	addl	$(VEC_SIZE + 1), %edi
 	addl	%edi, %eax
 # endif
 # ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrl	$2, %eax
 # endif
 	VZEROUPPER_RETURN
@@ -153,14 +169,19 @@  L(first_vec_x3):
 # ifdef USE_AS_STRNLEN
 	/* Use ecx which was computed earlier to compute correct value.
 	 */
+#  ifdef USE_AS_WCSLEN
+	leal	-(VEC_SIZE * 2 + 1)(%rax, %rcx, 4), %eax
+#  else
 	subl	$(VEC_SIZE * 2 + 1), %ecx
 	addl	%ecx, %eax
+#  endif
 # else
 	subl	%edx, %edi
 	addl	$(VEC_SIZE * 2 + 1), %edi
 	addl	%edi, %eax
 # endif
 # ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrl	$2, %eax
 # endif
 	VZEROUPPER_RETURN
@@ -173,14 +194,19 @@  L(first_vec_x4):
 # ifdef USE_AS_STRNLEN
 	/* Use ecx which was computed earlier to compute correct value.
 	 */
+#  ifdef USE_AS_WCSLEN
+	leal	-(VEC_SIZE * 1 + 1)(%rax, %rcx, 4), %eax
+#  else
 	subl	$(VEC_SIZE + 1), %ecx
 	addl	%ecx, %eax
+#  endif
 # else
 	subl	%edx, %edi
 	addl	$(VEC_SIZE * 3 + 1), %edi
 	addl	%edi, %eax
 # endif
 # ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrl	$2, %eax
 # endif
 	VZEROUPPER_RETURN
@@ -195,10 +221,14 @@  L(cross_page_continue):
 	/* Check the first 4 * VEC_SIZE.  Only one VEC_SIZE at a time
 	   since data is only aligned to VEC_SIZE.  */
 # ifdef USE_AS_STRNLEN
-	/* + 1 because rdi is aligned to VEC_SIZE - 1. + CHAR_SIZE because
-	   it simplies the logic in last_4x_vec_or_less.  */
+	/* + 1 because rdi is aligned to VEC_SIZE - 1. + CHAR_SIZE
+	   because it simplies the logic in last_4x_vec_or_less.  */
 	leaq	(VEC_SIZE * 4 + CHAR_SIZE + 1)(%rdi), %rcx
 	subq	%rdx, %rcx
+#  ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get the wchar_t count.  */
+	sarl	$2, %ecx
+#  endif
 # endif
 	/* Load first VEC regardless.  */
 	VPCMPEQ	1(%rdi), %ymm0, %ymm1
@@ -207,34 +237,38 @@  L(cross_page_continue):
 	subq	%rcx, %rsi
 	jb	L(last_4x_vec_or_less)
 # endif
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 	testl	%eax, %eax
 	jnz	L(first_vec_x1)
 
 	VPCMPEQ	(VEC_SIZE + 1)(%rdi), %ymm0, %ymm1
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 	testl	%eax, %eax
 	jnz	L(first_vec_x2)
 
 	VPCMPEQ	(VEC_SIZE * 2 + 1)(%rdi), %ymm0, %ymm1
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 	testl	%eax, %eax
 	jnz	L(first_vec_x3)
 
 	VPCMPEQ	(VEC_SIZE * 3 + 1)(%rdi), %ymm0, %ymm1
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 	testl	%eax, %eax
 	jnz	L(first_vec_x4)
 
 	/* Align data to VEC_SIZE * 4 - 1.  */
 # ifdef USE_AS_STRNLEN
 	/* Before adjusting length check if at last VEC_SIZE * 4.  */
-	cmpq	$(VEC_SIZE * 4 - 1), %rsi
+	cmpq	$(CHAR_PER_VEC * 4 - 1), %rsi
 	jbe	L(last_4x_vec_or_less_load)
 	incq	%rdi
 	movl	%edi, %ecx
 	orq	$(VEC_SIZE * 4 - 1), %rdi
 	andl	$(VEC_SIZE * 4 - 1), %ecx
+#  ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get the wchar_t count.  */
+	sarl	$2, %ecx
+#  endif
 	/* Readjust length.  */
 	addq	%rcx, %rsi
 # else
@@ -246,13 +280,13 @@  L(cross_page_continue):
 L(loop_4x_vec):
 # ifdef USE_AS_STRNLEN
 	/* Break if at end of length.  */
-	subq	$(VEC_SIZE * 4), %rsi
+	subq	$(CHAR_PER_VEC * 4), %rsi
 	jb	L(last_4x_vec_or_less_cmpeq)
 # endif
-	/* Save some code size by microfusing VPMINU with the load. Since
-	   the matches in ymm2/ymm4 can only be returned if there where no
-	   matches in ymm1/ymm3 respectively there is no issue with overlap.
-	 */
+	/* Save some code size by microfusing VPMINU with the load.
+	   Since the matches in ymm2/ymm4 can only be returned if there
+	   where no matches in ymm1/ymm3 respectively there is no issue
+	   with overlap.  */
 	vmovdqa	1(%rdi), %ymm1
 	VPMINU	(VEC_SIZE + 1)(%rdi), %ymm1, %ymm2
 	vmovdqa	(VEC_SIZE * 2 + 1)(%rdi), %ymm3
@@ -260,7 +294,7 @@  L(loop_4x_vec):
 
 	VPMINU	%ymm2, %ymm4, %ymm5
 	VPCMPEQ	%ymm5, %ymm0, %ymm5
-	vpmovmskb	%ymm5, %ecx
+	vpmovmskb %ymm5, %ecx
 
 	subq	$-(VEC_SIZE * 4), %rdi
 	testl	%ecx, %ecx
@@ -268,27 +302,28 @@  L(loop_4x_vec):
 
 
 	VPCMPEQ	%ymm1, %ymm0, %ymm1
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 	subq	%rdx, %rdi
 	testl	%eax, %eax
 	jnz	L(last_vec_return_x0)
 
 	VPCMPEQ	%ymm2, %ymm0, %ymm2
-	vpmovmskb	%ymm2, %eax
+	vpmovmskb %ymm2, %eax
 	testl	%eax, %eax
 	jnz	L(last_vec_return_x1)
 
 	/* Combine last 2 VEC.  */
 	VPCMPEQ	%ymm3, %ymm0, %ymm3
-	vpmovmskb	%ymm3, %eax
-	/* rcx has combined result from all 4 VEC. It will only be used if
-	   the first 3 other VEC all did not contain a match.  */
+	vpmovmskb %ymm3, %eax
+	/* rcx has combined result from all 4 VEC. It will only be used
+	   if the first 3 other VEC all did not contain a match.  */
 	salq	$32, %rcx
 	orq	%rcx, %rax
 	tzcntq	%rax, %rax
 	subq	$(VEC_SIZE * 2 - 1), %rdi
 	addq	%rdi, %rax
 # ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrq	$2, %rax
 # endif
 	VZEROUPPER_RETURN
@@ -297,15 +332,19 @@  L(loop_4x_vec):
 # ifdef USE_AS_STRNLEN
 	.p2align 4
 L(last_4x_vec_or_less_load):
-	/* Depending on entry adjust rdi / prepare first VEC in ymm1.  */
+	/* Depending on entry adjust rdi / prepare first VEC in ymm1.
+	 */
 	subq	$-(VEC_SIZE * 4), %rdi
 L(last_4x_vec_or_less_cmpeq):
 	VPCMPEQ	1(%rdi), %ymm0, %ymm1
 L(last_4x_vec_or_less):
-
-	vpmovmskb	%ymm1, %eax
-	/* If remaining length > VEC_SIZE * 2. This works if esi is off by
-	   VEC_SIZE * 4.  */
+#  ifdef USE_AS_WCSLEN
+	/* NB: Multiply length by 4 to get byte count.  */
+	sall	$2, %esi
+#  endif
+	vpmovmskb %ymm1, %eax
+	/* If remaining length > VEC_SIZE * 2. This works if esi is off
+	   by VEC_SIZE * 4.  */
 	testl	$(VEC_SIZE * 2), %esi
 	jnz	L(last_4x_vec)
 
@@ -320,7 +359,7 @@  L(last_4x_vec_or_less):
 	jb	L(max)
 
 	VPCMPEQ	(VEC_SIZE + 1)(%rdi), %ymm0, %ymm1
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 	tzcntl	%eax, %eax
 	/* Check the end of data.  */
 	cmpl	%eax, %esi
@@ -329,6 +368,7 @@  L(last_4x_vec_or_less):
 	addl	$(VEC_SIZE + 1), %eax
 	addq	%rdi, %rax
 #  ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrq	$2, %rax
 #  endif
 	VZEROUPPER_RETURN
@@ -340,6 +380,7 @@  L(last_vec_return_x0):
 	subq	$(VEC_SIZE * 4 - 1), %rdi
 	addq	%rdi, %rax
 # ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrq	$2, %rax
 # endif
 	VZEROUPPER_RETURN
@@ -350,6 +391,7 @@  L(last_vec_return_x1):
 	subq	$(VEC_SIZE * 3 - 1), %rdi
 	addq	%rdi, %rax
 # ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrq	$2, %rax
 # endif
 	VZEROUPPER_RETURN
@@ -366,6 +408,7 @@  L(last_vec_x1_check):
 	incl	%eax
 	addq	%rdi, %rax
 #  ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrq	$2, %rax
 #  endif
 	VZEROUPPER_RETURN
@@ -381,14 +424,14 @@  L(last_4x_vec):
 	jnz	L(last_vec_x1)
 
 	VPCMPEQ	(VEC_SIZE + 1)(%rdi), %ymm0, %ymm1
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 	testl	%eax, %eax
 	jnz	L(last_vec_x2)
 
 	/* Normalize length.  */
 	andl	$(VEC_SIZE * 4 - 1), %esi
 	VPCMPEQ	(VEC_SIZE * 2 + 1)(%rdi), %ymm0, %ymm1
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 	testl	%eax, %eax
 	jnz	L(last_vec_x3)
 
@@ -396,7 +439,7 @@  L(last_4x_vec):
 	jb	L(max)
 
 	VPCMPEQ	(VEC_SIZE * 3 + 1)(%rdi), %ymm0, %ymm1
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 	tzcntl	%eax, %eax
 	/* Check the end of data.  */
 	cmpl	%eax, %esi
@@ -405,6 +448,7 @@  L(last_4x_vec):
 	addl	$(VEC_SIZE * 3 + 1), %eax
 	addq	%rdi, %rax
 #  ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrq	$2, %rax
 #  endif
 	VZEROUPPER_RETURN
@@ -419,6 +463,7 @@  L(last_vec_x1):
 	incl	%eax
 	addq	%rdi, %rax
 #  ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrq	$2, %rax
 #  endif
 	VZEROUPPER_RETURN
@@ -432,6 +477,7 @@  L(last_vec_x2):
 	addl	$(VEC_SIZE + 1), %eax
 	addq	%rdi, %rax
 #  ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrq	$2, %rax
 #  endif
 	VZEROUPPER_RETURN
@@ -447,6 +493,7 @@  L(last_vec_x3):
 	addl	$(VEC_SIZE * 2 + 1), %eax
 	addq	%rdi, %rax
 #  ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
 	shrq	$2, %rax
 #  endif
 	VZEROUPPER_RETURN
@@ -455,13 +502,13 @@  L(max_end):
 	VZEROUPPER_RETURN
 # endif
 
-	/* Cold case for crossing page with first load.	 */
+	/* Cold case for crossing page with first load.  */
 	.p2align 4
 L(cross_page_boundary):
 	/* Align data to VEC_SIZE - 1.  */
 	orq	$(VEC_SIZE - 1), %rdi
 	VPCMPEQ	-(VEC_SIZE - 1)(%rdi), %ymm0, %ymm1
-	vpmovmskb	%ymm1, %eax
+	vpmovmskb %ymm1, %eax
 	/* Remove the leading bytes. sarxl only uses bits [5:0] of COUNT
 	   so no need to manually mod rdx.  */
 	sarxl	%edx, %eax, %eax
@@ -470,6 +517,10 @@  L(cross_page_boundary):
 	jnz	L(cross_page_less_vec)
 	leaq	1(%rdi), %rcx
 	subq	%rdx, %rcx
+#  ifdef USE_AS_WCSLEN
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
+	shrl	$2, %ecx
+#  endif
 	/* Check length.  */
 	cmpq	%rsi, %rcx
 	jb	L(cross_page_continue)
@@ -479,6 +530,7 @@  L(cross_page_boundary):
 	jz	L(cross_page_continue)
 	tzcntl	%eax, %eax
 #  ifdef USE_AS_WCSLEN
+	/* NB: Divide length by 4 to get wchar_t count.  */
 	shrl	$2, %eax
 #  endif
 # endif
@@ -489,6 +541,10 @@  L(return_vzeroupper):
 	.p2align 4
 L(cross_page_less_vec):
 	tzcntl	%eax, %eax
+#  ifdef USE_AS_WCSLEN
+	/* NB: Multiply length by 4 to get byte count.  */
+	sall	$2, %esi
+#  endif
 	cmpq	%rax, %rsi
 	cmovb	%esi, %eax
 #  ifdef USE_AS_WCSLEN
diff --git a/sysdeps/x86_64/strlen.S b/sysdeps/x86_64/strlen.S
index d223ea1700..3fc6734910 100644
--- a/sysdeps/x86_64/strlen.S
+++ b/sysdeps/x86_64/strlen.S
@@ -65,12 +65,24 @@  ENTRY(strlen)
 	ret
 L(n_nonzero):
 # ifdef AS_WCSLEN
-	shl	$2, %RSI_LP
+/* Check for overflow from maxlen * sizeof(wchar_t). If it would
+   overflow the only way this program doesn't have undefined behavior 
+   is if there is a null terminator in valid memory so strlen will 
+   suffice.  */
+	mov	%RSI_LP, %R10_LP
+	sar	$62, %R10_LP
+	test	%R10_LP, %R10_LP
+	jnz	__wcslen_sse2
+	sal	$2, %RSI_LP
 # endif
 
 /* Initialize long lived registers.  */
 
 	add	%RDI_LP, %RSI_LP
+# ifdef AS_WCSLEN
+/* Check for overflow again from s + maxlen * sizeof(wchar_t).  */
+	jbe	__wcslen_sse2
+# endif
 	mov	%RSI_LP, %R10_LP
 	and	$-64, %R10_LP
 	mov	%RSI_LP, %R11_LP