diff mbox series

[2/2] libstdc++: Ensure valid UTF-8 in std::vprint_unicode

Message ID 20231117160320.1513815-2-jwakely@redhat.com
State New
Headers show
Series [1/2] libstdc++: Implement C++23 <print> header [PR107760] | expand

Commit Message

Jonathan Wakely Nov. 17, 2023, 3:54 p.m. UTC
This is a naive implementation of the UTF-8 validation algorithm, which
could definitely be optimized. But it's faster than using
std::codecvt_utf8 and checking the result of that, which is the only
existing code we have to do it in the library.

As the TODO suggests, we could do the UTF-8 to UTF-16 conversion at the
same time. But that is only needed for Windows and as I said in the 1/2
email, the output for Windows seems to be broken currently anyway and I
can't test it properly.

-- >8 --

libstdc++-v3/ChangeLog:

	* include/bits/locale_conv.h (__to_valid_utf8): New function.
	* include/std/ostream (vprint_unicode): Use it.
	* include/std/print (vprint_unicode): Use it.
---
 libstdc++-v3/include/bits/locale_conv.h | 104 ++++++++++++++++++++++++
 libstdc++-v3/include/std/ostream        |  74 +++++++++++------
 libstdc++-v3/include/std/print          |   8 +-
 3 files changed, 160 insertions(+), 26 deletions(-)
diff mbox series

Patch

diff --git a/libstdc++-v3/include/bits/locale_conv.h b/libstdc++-v3/include/bits/locale_conv.h
index 284142a360a..f6ade1d0395 100644
--- a/libstdc++-v3/include/bits/locale_conv.h
+++ b/libstdc++-v3/include/bits/locale_conv.h
@@ -624,6 +624,110 @@  _GLIBCXX_END_NAMESPACE_CXX11
       bool			_M_always_noconv;
     };
 
+#if __cplusplus >= 202002L
+  template<typename _CharT = char>
+  bool
+  __to_valid_utf8(string& __s)
+  {
+    // TODO if _CharT is wchar_t then transcode at the same time.
+
+    unsigned __seen = 0, __needed = 0;
+    unsigned char __lo_bound = 0x80, __hi_bound = 0xBF;
+    size_t __errors = 0;
+
+    auto __q = __s.data(), __eoq = __q + __s.size();
+    while (__q != __eoq)
+      {
+	unsigned char __byte = *__q;
+	if (__needed == 0)
+	  {
+	    if (__byte <= 0x7F)      // 0x00 to 0x7F
+	      {
+		while (++__q != __eoq && (unsigned char)*__q <= 0x7F)
+		  { } // Fast forward to the next non-ASCII character.
+		continue;
+	      }
+	    else if (__byte < 0xC2)
+	      {
+		*__q = 0xFF;
+		++__errors;
+	      }
+	    else if (__byte <= 0xDF) // 0xC2 to 0xDF
+	      {
+		__needed = 1;
+	      }
+	    else if (__byte <= 0xEF) // 0xE0 to 0xEF
+	      {
+		if (__byte == 0xE0)
+		  __lo_bound = 0xA0;
+		else if (__byte == 0xED)
+		  __hi_bound = 0x9F;
+
+		__needed = 2;
+	      }
+	    else if (__byte <= 0xF4) // 0xF0 to 0xF4
+	      {
+		if (__byte == 0xF0)
+		  __lo_bound = 0x90;
+		else if (__byte == 0xF4)
+		  __hi_bound = 0x8F;
+
+		__needed = 3;
+	      }
+	    else
+	      {
+		*__q = 0xFF;
+		++__errors;
+	      }
+	  }
+	else
+	  {
+	    if (__byte < __lo_bound || __byte > __hi_bound)
+	      {
+		*(__q - __seen - 1) = 0xFF;
+		__builtin_memset(__q - __seen, 0xFE, __seen);
+		++__errors;
+		__needed = __seen = 0;
+		__lo_bound = 0x80;
+		__hi_bound = 0xBF;
+		continue; // Reprocess the current character.
+	      }
+
+	    __lo_bound = 0x80;
+	    __hi_bound = 0xBF;
+	    ++__seen;
+	    if (__seen == __needed)
+	      __needed = __seen = 0;
+	  }
+	__q++;
+      }
+
+    if (__needed)
+      {
+	// The string ends with an incomplete multibyte sequence.
+	if (__seen)
+	  __s.resize(__s.size() - __seen);
+	__s.back() = 0xFF;
+	++__errors;
+      }
+
+    if (__errors == 0)
+      return true;
+
+    string __s2;
+    __s2.reserve(__s.size() + __errors * 2);
+    for (unsigned char __byte : __s)
+      {
+	if (__byte == 0xFF)
+	  __s2 += "\uFFFD";
+	else if (__byte != 0xFE)
+	  __s2 += (char)__byte;
+      }
+    __s = std::move(__s2);
+    return false;
+  }
+#endif // C++20
+
   /// @} group locales
 
 _GLIBCXX_END_NAMESPACE_VERSION
diff --git a/libstdc++-v3/include/std/ostream b/libstdc++-v3/include/std/ostream
index e81c39a7c80..760aaa206da 100644
--- a/libstdc++-v3/include/std/ostream
+++ b/libstdc++-v3/include/std/ostream
@@ -917,42 +917,68 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
   inline void
   vprint_unicode(ostream& __os, string_view __fmt, format_args __args)
   {
-    // TODO: diagnose invalid UTF-8 code units
-#ifdef _WIN32
-    int __fd_for_console(std::streambuf*);
-    void __write_utf16_to_console(int, string);
-
-    // If stream refers to a terminal convert to UTF-16 and use WriteConsoleW.
-    if (int __fd = __fd_for_console(__os.rdbuf()); __fd >= 0)
+    ostream::sentry __cerb(__os);
+    if (__cerb)
       {
-	ostream::sentry __cerb(__os);
-	if (__cerb)
+	string __out = std::vformat(__fmt, __args);
+	std::__to_valid_utf8(__out);
+
+#ifdef _WIN32
+	int __fd_for_console(std::streambuf*);
+	void __write_utf16_to_console(int, string);
+
+	// If stream refers to a terminal output UTF-16 using WriteConsoleW.
+	if (int __fd = __fd_for_console(__os.rdbuf()); __fd >= 0)
 	  {
-	    string __out = std::vformat(__fmt, __args);
 	    ios_base::iostate __err = ios_base::goodbit;
 	    __try
-	      {
-		if (__os.rdbuf()->pubsync() == -1)
-		  __err = ios::badbit;
-		else if (__write_utf16_to_console(__fd, __out))
-		  __err = ios::badbit;
-	      }
+	    {
+	      if (__os.rdbuf()->pubsync() == -1)
+		__err = ios::badbit;
+	      else if (__write_utf16_to_console(__fd, __out))
+		__err = ios::badbit;
+	    }
 	    __catch(const __cxxabiv1::__forced_unwind&)
-	      {
-		__os._M_setstate(ios_base::badbit);
-		__throw_exception_again;
-	      }
+	    {
+	      __os._M_setstate(ios_base::badbit);
+	      __throw_exception_again;
+	    }
 	    __catch(...)
-	      { __os._M_setstate(ios_base::badbit); }
+	    { __os._M_setstate(ios_base::badbit); }
 
 	    if (__err)
 	      __os.setstate(__err);
+	    return;
 	  }
-      }
 #endif
-    std::vprint_nonunicode(__os, __fmt, __args);
-  }
 
+	__try
+	  {
+	    const streamsize __w = __os.width();
+	    const streamsize __n = __out.size();
+	    if (__w > __n)
+	      {
+		const bool __left
+		  = (__os.flags() & ios_base::adjustfield) == ios_base::left;
+		if (!__left)
+		  std::__ostream_fill(__os, __w - __n);
+		if (__os.good())
+		  std::__ostream_write(__os, __out.data(), __n);
+		if (__left && __os.good())
+		  std::__ostream_fill(__os, __w - __n);
+	      }
+	    else
+	      std::__ostream_write(__os, __out.data(), __n);
+	  }
+	__catch(const __cxxabiv1::__forced_unwind&)
+	  {
+	    __os._M_setstate(ios_base::badbit);
+	    __throw_exception_again;
+	  }
+	__catch(...)
+	  { __os._M_setstate(ios_base::badbit); }
+      }
+  }
 
   template<typename... _Args>
     inline void
diff --git a/libstdc++-v3/include/std/print b/libstdc++-v3/include/std/print
index 75e78841247..096b97b1ef7 100644
--- a/libstdc++-v3/include/std/print
+++ b/libstdc++-v3/include/std/print
@@ -62,7 +62,9 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
   inline void
   vprint_unicode(FILE* __stream, string_view __fmt, format_args __args)
   {
-    // TODO: diagnose invalid UTF-8 code units
+    string __out = std::vformat(__fmt, __args);
+    std::__to_valid_utf8(__out);
+
 #ifdef _WIN32
     int __fd_for_console(FILE*);
     void __write_utf16_to_console(int, string);
@@ -82,7 +84,9 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	_GLIBCXX_THROW_OR_ABORT(system_error(__e, "std::vprint_unicode"));
       }
 #endif
-    std::vprint_nonunicode(__stream, __fmt, __args);
+
+    if (std::fwrite(__out.data(), 1, __out.size(), __stream) != __out.size())
+      __throw_system_error(EIO);
   }
 
   template<typename... _Args>