Patchwork [U-Boot,v2,2/3] Add safe vsnprintf and snprintf library functions

login
register
mail settings
Submitter Simon Glass
Date Oct. 10, 2011, 7:22 p.m.
Message ID <1318274551-25335-3-git-send-email-sjg@chromium.org>
Download mbox | patch
Permalink /patch/118823/
State New, archived
Headers show

Comments

Simon Glass - Oct. 10, 2011, 7:22 p.m.
From: Sonny Rao <sonnyrao@chromium.org>

From: Sonny Rao <sonnyrao@chromium.org>

These functions are useful in U-Boot because they allow a graceful failure
rather than an unpredictable stack overflow when printf() buffers are
exceeded.

Mostly copied from the Linux kernel. I copied vscnprintf and
scnprintf so we can change printf and vprintf to use the safe
implementation but still return the correct values.

Signed-off-by: Sonny Rao <sonnyrao@chromium.org>
---
 include/common.h |    8 ++-
 lib/vsprintf.c   |  316 ++++++++++++++++++++++++++++++++++++++++++------------
 2 files changed, 256 insertions(+), 68 deletions(-)

Patch

diff --git a/include/common.h b/include/common.h
index eb19a44..4e74855 100644
--- a/include/common.h
+++ b/include/common.h
@@ -699,9 +699,15 @@  unsigned long long	simple_strtoull(const char *cp,char **endp,unsigned int base)
 long	simple_strtol(const char *cp,char **endp,unsigned int base);
 void	panic(const char *fmt, ...)
 		__attribute__ ((format (__printf__, 1, 2), noreturn));
-int	sprintf(char * buf, const char *fmt, ...)
+int	sprintf(char *buf, const char *fmt, ...)
 		__attribute__ ((format (__printf__, 2, 3)));
+int	snprintf(char *buf, size_t size, const char *fmt, ...)
+		__attribute__ ((format (__printf__, 3, 4)));
+int	scnprintf(char *buf, size_t size, const char *fmt, ...)
+		__attribute__ ((format (__printf__, 3, 4)));
 int	vsprintf(char *buf, const char *fmt, va_list args);
+int	vsnprintf(char *buf, size_t size, const char *fmt, va_list args);
+int	vscnprintf(char *buf, size_t size, const char *fmt, va_list args);
 
 /* lib/strmhz.c */
 char *	strmhz(char *buf, unsigned long hz);
diff --git a/lib/vsprintf.c b/lib/vsprintf.c
index 79dead3..bac6f30 100644
--- a/lib/vsprintf.c
+++ b/lib/vsprintf.c
@@ -16,6 +16,7 @@ 
 #include <errno.h>
 
 #include <common.h>
+#include <limits.h>
 #if !defined (CONFIG_PANIC_HANG)
 #include <command.h>
 #endif
@@ -289,7 +290,8 @@  static noinline char* put_dec(char *buf, unsigned NUM_TYPE num)
 #define SMALL	32		/* Must be 32 == 0x20 */
 #define SPECIAL	64		/* 0x */
 
-static char *number(char *buf, unsigned NUM_TYPE num, int base, int size, int precision, int type)
+static char *number(char *buf, char *end, unsigned NUM_TYPE num,
+		int base, int size, int precision, int type)
 {
 	/* we are called with base 8, 10 or 16, only, thus don't need "G..."  */
 	static const char digits[16] = "0123456789ABCDEF"; /* "GHIJKLMNOPQRSTUVWXYZ"; */
@@ -351,37 +353,63 @@  static char *number(char *buf, unsigned NUM_TYPE num, int base, int size, int pr
 		precision = i;
 	/* leading space padding */
 	size -= precision;
-	if (!(type & (ZEROPAD+LEFT)))
-		while(--size >= 0)
-			*buf++ = ' ';
+	if (!(type & (ZEROPAD+LEFT))) {
+		while (--size >= 0) {
+			if (buf < end)
+				*buf = ' ';
+			++buf;
+		}
+	}
 	/* sign */
-	if (sign)
-		*buf++ = sign;
+	if (sign) {
+		if (buf < end)
+			*buf = sign;
+		++buf;
+	}
 	/* "0x" / "0" prefix */
 	if (need_pfx) {
-		*buf++ = '0';
-		if (base == 16)
-			*buf++ = ('X' | locase);
+		if (buf < end)
+			*buf = '0';
+		++buf;
+		if (base == 16) {
+			if (buf < end)
+				*buf = ('X' | locase);
+			++buf;
+		}
 	}
 	/* zero or space padding */
 	if (!(type & LEFT)) {
 		char c = (type & ZEROPAD) ? '0' : ' ';
-		while (--size >= 0)
-			*buf++ = c;
+		while (--size >= 0) {
+			if (buf < end)
+				*buf = c;
+			++buf;
+		}
 	}
 	/* hmm even more zero padding? */
-	while (i <= --precision)
-		*buf++ = '0';
+	while (i <= --precision) {
+		if (buf < end)
+			*buf = '0';
+		++buf;
+	}
 	/* actual digits of result */
-	while (--i >= 0)
-		*buf++ = tmp[i];
+	while (--i >= 0) {
+		if (buf < end)
+			*buf = tmp[i];
+		++buf;
+
+	}
 	/* trailing space padding */
-	while (--size >= 0)
-		*buf++ = ' ';
+	while (--size >= 0) {
+		if (buf < end)
+			*buf = ' ';
+		++buf;
+	}
 	return buf;
 }
 
-static char *string(char *buf, char *s, int field_width, int precision, int flags)
+static char *string(char *buf, char *end, char *s, int field_width,
+		int precision, int flags)
 {
 	int len, i;
 
@@ -391,17 +419,26 @@  static char *string(char *buf, char *s, int field_width, int precision, int flag
 	len = strnlen(s, precision);
 
 	if (!(flags & LEFT))
-		while (len < field_width--)
-			*buf++ = ' ';
-	for (i = 0; i < len; ++i)
-		*buf++ = *s++;
-	while (len < field_width--)
-		*buf++ = ' ';
+		while (len < field_width--) {
+			if (buf < end)
+				*buf = ' ';
+			++buf;
+		}
+	for (i = 0; i < len; ++i) {
+		if (buf < end)
+			*buf = *s++;
+		++buf;
+	}
+	while (len < field_width--) {
+		if (buf < end)
+			*buf = ' ';
+		++buf;
+	}
 	return buf;
 }
 
 #ifdef CONFIG_CMD_NET
-static char *mac_address_string(char *buf, u8 *addr, int field_width,
+static char *mac_address_string(char *buf, char *end, u8 *addr, int field_width,
 				int precision, int flags)
 {
 	char mac_addr[6 * 3]; /* (6 * 2 hex digits), 5 colons and trailing zero */
@@ -415,10 +452,11 @@  static char *mac_address_string(char *buf, u8 *addr, int field_width,
 	}
 	*p = '\0';
 
-	return string(buf, mac_addr, field_width, precision, flags & ~SPECIAL);
+	return string(buf, end, mac_addr, field_width, precision,
+		      flags & ~SPECIAL);
 }
 
-static char *ip6_addr_string(char *buf, u8 *addr, int field_width,
+static char *ip6_addr_string(char *buf, char *end, u8 *addr, int field_width,
 			 int precision, int flags)
 {
 	char ip6_addr[8 * 5]; /* (8 * 4 hex digits), 7 colons and trailing zero */
@@ -433,10 +471,11 @@  static char *ip6_addr_string(char *buf, u8 *addr, int field_width,
 	}
 	*p = '\0';
 
-	return string(buf, ip6_addr, field_width, precision, flags & ~SPECIAL);
+	return string(buf, end, ip6_addr, field_width, precision,
+		      flags & ~SPECIAL);
 }
 
-static char *ip4_addr_string(char *buf, u8 *addr, int field_width,
+static char *ip4_addr_string(char *buf, char *end, u8 *addr, int field_width,
 			 int precision, int flags)
 {
 	char ip4_addr[4 * 4]; /* (4 * 3 decimal digits), 3 dots and trailing zero */
@@ -454,7 +493,8 @@  static char *ip4_addr_string(char *buf, u8 *addr, int field_width,
 	}
 	*p = '\0';
 
-	return string(buf, ip4_addr, field_width, precision, flags & ~SPECIAL);
+	return string(buf, end, ip4_addr, field_width, precision,
+		      flags & ~SPECIAL);
 }
 #endif
 
@@ -476,10 +516,12 @@  static char *ip4_addr_string(char *buf, u8 *addr, int field_width,
  * function pointers are really function descriptors, which contain a
  * pointer to the real address.
  */
-static char *pointer(const char *fmt, char *buf, void *ptr, int field_width, int precision, int flags)
+static char *pointer(const char *fmt, char *buf, char *end, void *ptr,
+		int field_width, int precision, int flags)
 {
 	if (!ptr)
-		return string(buf, "(null)", field_width, precision, flags);
+		return string(buf, end, "(null)", field_width, precision,
+			      flags);
 
 #ifdef CONFIG_CMD_NET
 	switch (*fmt) {
@@ -487,15 +529,18 @@  static char *pointer(const char *fmt, char *buf, void *ptr, int field_width, int
 		flags |= SPECIAL;
 		/* Fallthrough */
 	case 'M':
-		return mac_address_string(buf, ptr, field_width, precision, flags);
+		return mac_address_string(buf, end, ptr, field_width,
+					  precision, flags);
 	case 'i':
 		flags |= SPECIAL;
 		/* Fallthrough */
 	case 'I':
 		if (fmt[1] == '6')
-			return ip6_addr_string(buf, ptr, field_width, precision, flags);
+			return ip6_addr_string(buf, end, ptr, field_width,
+					       precision, flags);
 		if (fmt[1] == '4')
-			return ip4_addr_string(buf, ptr, field_width, precision, flags);
+			return ip4_addr_string(buf, end, ptr, field_width,
+					       precision, flags);
 		flags &= ~SPECIAL;
 		break;
 	}
@@ -505,31 +550,35 @@  static char *pointer(const char *fmt, char *buf, void *ptr, int field_width, int
 		field_width = 2*sizeof(void *);
 		flags |= ZEROPAD;
 	}
-	return number(buf, (unsigned long) ptr, 16, field_width, precision, flags);
+	return number(buf, end, (unsigned long)ptr, 16, field_width,
+		      precision, flags);
 }
 
 /**
- * vsprintf - Format a string and place it in a buffer
- * @buf: The buffer to place the result into
- * @fmt: The format string to use
- * @args: Arguments for the format string
+ * Format a string and place it in a buffer (base function)
  *
- * This function follows C99 vsprintf, but has some extensions:
+ * @param buf	The buffer to place the result into
+ * @param size	The size of the buffer, including the trailing null space
+ * @param fmt	The format string to use
+ * @param args	Arguments for the format string
+ *
+ * This function follows C99 vsnprintf, but has some extensions:
  * %pS output the name of a text symbol
  * %pF output the name of a function pointer
  * %pR output the address range in a struct resource
  *
- * The function returns the number of characters written
- * into @buf.
+ * The function returns the number of characters which would be
+ * generated for the given input, excluding the trailing '\0',
+ * as per ISO C99.
  *
  * Call this function if you are already dealing with a va_list.
- * You probably want sprintf() instead.
+ * You probably want snprintf() instead.
  */
-int vsprintf(char *buf, const char *fmt, va_list args)
+int vsnprintf(char *buf, size_t size, const char *fmt, va_list args)
 {
 	unsigned NUM_TYPE num;
 	int base;
-	char *str;
+	char *str, *end;
 
 	int flags;		/* flags to number() */
 
@@ -542,10 +591,19 @@  int vsprintf(char *buf, const char *fmt, va_list args)
 				/* 't' added for ptrdiff_t */
 
 	str = buf;
+	end = str + size;
+
+	/* Make sure end is always >= buf */
+	if (end < buf) {
+		end = ((void *)-1);
+		size = end - buf;
+	}
 
 	for (; *fmt ; ++fmt) {
 		if (*fmt != '%') {
-			*str++ = *fmt;
+			if (str < end)
+				*str = *fmt;
+			++str;
 			continue;
 		}
 
@@ -606,21 +664,32 @@  int vsprintf(char *buf, const char *fmt, va_list args)
 		base = 10;
 
 		switch (*fmt) {
-		case 'c':
+		case 'c': {
+			char c;
 			if (!(flags & LEFT))
-				while (--field_width > 0)
-					*str++ = ' ';
-			*str++ = (unsigned char) va_arg(args, int);
-			while (--field_width > 0)
-				*str++ = ' ';
+				while (--field_width > 0) {
+					if (str < end)
+						*str = ' ';
+					++str;
+				}
+			c = (unsigned char) va_arg(args, int);
+			if (str < end)
+				*str = c;
+			++str;
+			while (--field_width > 0) {
+				if (str < end)
+					*str = ' ';
+				++str;
+			}
 			continue;
-
+		}
 		case 's':
-			str = string(str, va_arg(args, char *), field_width, precision, flags);
+			str = string(str, end, va_arg(args, char *),
+				     field_width, precision, flags);
 			continue;
 
 		case 'p':
-			str = pointer(fmt+1, str,
+			str = pointer(fmt+1, str, end,
 					va_arg(args, void *),
 					field_width, precision, flags);
 			/* Skip all alphanumeric pointer suffixes */
@@ -639,7 +708,9 @@  int vsprintf(char *buf, const char *fmt, va_list args)
 			continue;
 
 		case '%':
-			*str++ = '%';
+			if (str < end)
+				*str = '%';
+			++str;
 			continue;
 
 		/* integer number formats - set up the flags and "break" */
@@ -660,10 +731,14 @@  int vsprintf(char *buf, const char *fmt, va_list args)
 			break;
 
 		default:
-			*str++ = '%';
-			if (*fmt)
-				*str++ = *fmt;
-			else
+			if (str < end)
+				*str = '%';
+			++str;
+			if (*fmt) {
+				if (str < end)
+					*str = *fmt;
+				++str;
+			} else
 				--fmt;
 			continue;
 		}
@@ -686,17 +761,124 @@  int vsprintf(char *buf, const char *fmt, va_list args)
 			if (flags & SIGN)
 				num = (signed int) num;
 		}
-		str = number(str, num, base, field_width, precision, flags);
+		str = number(str, end, num, base, field_width, precision,
+			     flags);
+	}
+
+	if (size > 0) {
+		if (str < end)
+			*str = '\0';
+		else
+			end[-1] = '\0';
 	}
-	*str = '\0';
+	/* the trailing null byte doesn't count towards the total */
 	return str-buf;
 }
 
 /**
- * sprintf - Format a string and place it in a buffer
- * @buf: The buffer to place the result into
- * @fmt: The format string to use
- * @...: Arguments for the format string
+ * Format a string and place it in a buffer (va_list version)
+ *
+ * @param buf	The buffer to place the result into
+ * @param size	The size of the buffer, including the trailing null space
+ * @param fmt	The format string to use
+ * @param args	Arguments for the format string
+ * @return the number of characters which have been written into
+ * the @buf not including the trailing '\0'. If @size is == 0 the function
+ * returns 0.
+ *
+ * If you're not already dealing with a va_list consider using scnprintf().
+ *
+ * See the vsprintf() documentation for format string extensions over C99.
+ */
+int vscnprintf(char *buf, size_t size, const char *fmt, va_list args)
+{
+	int i;
+
+	i = vsnprintf(buf, size, fmt, args);
+
+	if (likely(i < size))
+		return i;
+	if (size != 0)
+		return size - 1;
+	return 0;
+}
+
+/**
+ * Format a string and place it in a buffer
+ *
+ * @param buf	The buffer to place the result into
+ * @param size	The size of the buffer, including the trailing null space
+ * @param fmt	The format string to use
+ * @param ...	Arguments for the format string
+ * @return the number of characters which would be
+ * generated for the given input, excluding the trailing null,
+ * as per ISO C99.  If the return is greater than or equal to
+ * @size, the resulting string is truncated.
+ *
+ * See the vsprintf() documentation for format string extensions over C99.
+ */
+int snprintf(char *buf, size_t size, const char *fmt, ...)
+{
+	va_list args;
+	int i;
+
+	va_start(args, fmt);
+	i = vsnprintf(buf, size, fmt, args);
+	va_end(args);
+
+	return i;
+}
+
+/**
+ * Format a string and place it in a buffer
+ *
+ * @param buf	The buffer to place the result into
+ * @param size	The size of the buffer, including the trailing null space
+ * @param fmt	The format string to use
+ * @param ...	Arguments for the format string
+ *
+ * The return value is the number of characters written into @buf not including
+ * the trailing '\0'. If @size is == 0 the function returns 0.
+ *
+ * See the vsprintf() documentation for format string extensions over C99.
+ */
+
+int scnprintf(char *buf, size_t size, const char *fmt, ...)
+{
+	va_list args;
+	int i;
+
+	va_start(args, fmt);
+	i = vscnprintf(buf, size, fmt, args);
+	va_end(args);
+
+	return i;
+}
+
+/**
+ * Format a string and place it in a buffer (va_list version)
+ *
+ * @param buf	The buffer to place the result into
+ * @param fmt	The format string to use
+ * @param args	Arguments for the format string
+ *
+ * The function returns the number of characters written
+ * into @buf. Use vsnprintf() or vscnprintf() in order to avoid
+ * buffer overflows.
+ *
+ * If you're not already dealing with a va_list consider using sprintf().
+ */
+int vsprintf(char *buf, const char *fmt, va_list args)
+{
+	return vsnprintf(buf, INT_MAX, fmt, args);
+}
+
+/**
+ * Format a string and place it in a buffer
+ *
+ * @param buf	The buffer to place the result into
+ * @param fmt	The format string to use
+ * @param ...	Arguments for the format string
  *
  * The function returns the number of characters written
  * into @buf.