diff mbox series

[2/7] iov_iter: Add a general purpose iteration function

Message ID 166126394098.708021.10931745751914856461.stgit@warthog.procyon.org.uk
State New
Headers show
Series smb3: Add iter helpers and use iov_iters down to the network transport | expand

Commit Message

David Howells Aug. 23, 2022, 2:12 p.m. UTC
Add a function, iov_iter_scan(), to iterate over the buffers described by
an I/O iterator, kmapping and passing each contiguous chunk the the
supplied scanner function in turn, up to the requested amount of data or
until the scanner function returns an error.

This can be used, for example, to hash all the data in an iterator by
having the scanner function call the appropriate crypto update function.

Signed-off-by: David Howells <dhowells@redhat.com>
---

 include/linux/uio.h |    4 +++
 lib/iov_iter.c      |   66 +++++++++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 67 insertions(+), 3 deletions(-)

Comments

Al Viro Aug. 26, 2022, 9:35 p.m. UTC | #1
On Tue, Aug 23, 2022 at 03:12:21PM +0100, David Howells wrote:

>  	size_t __maybe_unused off = 0;				\
>  	len = n;						\
>  	base = __p + i->iov_offset;				\
> -	len -= (STEP);						\
> -	i->iov_offset += len;					\
> -	n = len;						\
> +	do {							\
> +		len -= (STEP);					\
> +		i->iov_offset += len;				\
> +		n = len;					\
> +	} while (0);						\
>  }

*blink*

What is that supposed to change?

>  /* covers iovec and kvec alike */
> @@ -1611,6 +1613,64 @@ ssize_t extract_iter_to_iter(struct iov_iter *orig,
>  }
>  EXPORT_SYMBOL(extract_iter_to_iter);
>  
> +/**
> + * iov_iter_scan - Scan a source iter
> + * @i: The iterator to scan
> + * @bytes: The amount of buffer/data to scan
> + * @scanner: The function to call to process each segment
> + * @priv: Private data to pass to the scanner function
> + *
> + * Scan an iterator, passing each segment to the scanner function.  If the
> + * scanner returns an error at any time, scanning stops and the error is
> + * returned, otherwise the sum of the scanner results is returned.
> + */
> +ssize_t iov_iter_scan(struct iov_iter *i, size_t bytes,
> +		      ssize_t (*scanner)(struct iov_iter *i, const void *p,
> +					 size_t len, size_t off, void *priv),
> +		      void *priv)
> +{
> +	unsigned int gup_flags = 0;
> +	ssize_t ret = 0, scanned = 0;
> +
> +	if (!bytes)
> +		return 0;
> +	if (WARN_ON(iov_iter_is_discard(i)))
> +		return 0;
> +	if (iter_is_iovec(i))
> +		might_fault();
> +
> +	if (iov_iter_rw(i) != WRITE)
> +		gup_flags |= FOLL_WRITE;
> +	if (i->nofault)
> +		gup_flags |= FOLL_NOFAULT;
> +
> +	iterate_and_advance(
> +		i, bytes, base, len, off, ({
> +				struct page *page;
> +				void *q;
> +
> +				ret = get_user_pages_fast((unsigned long)base, 1,
> +							  gup_flags, &page);
> +				if (ret < 0)
> +					break;
> +				q = kmap_local_page(page);
> +				ret = scanner(i, q, len, off, priv);
> +				kunmap_local(q);
> +				put_page(page);
> +				if (ret < 0)
> +					break;
> +				scanned += ret;
> +			}), ({

Huh?  You do realize that the first ("userland") callback of
iterate_and_advance() is expected to have the amount not processed
as value?  That's what this
	len -= (STEP);
is about.  And anything non-zero means "fucking stop already".

How the hell does that thing manage to work?  And what makes you
think that it'll keep boinking an iovec segment again and again
on short operations?  Is that what that mystery do-while had
been supposed to do?

This makes no sense.  Again, I'm not even talking about the
potential usefulness of the primitive in question - it won't work
as posted, period.
diff mbox series

Patch

diff --git a/include/linux/uio.h b/include/linux/uio.h
index 88fd93508710..76a3aeca8703 100644
--- a/include/linux/uio.h
+++ b/include/linux/uio.h
@@ -259,6 +259,10 @@  int iov_iter_npages(const struct iov_iter *i, int maxpages);
 void iov_iter_restore(struct iov_iter *i, struct iov_iter_state *state);
 
 const void *dup_iter(struct iov_iter *new, struct iov_iter *old, gfp_t flags);
+ssize_t iov_iter_scan(struct iov_iter *i, size_t bytes,
+		      ssize_t (*scanner)(struct iov_iter *i, const void *p,
+					 size_t len, size_t off, void *priv),
+		      void *priv);
 
 static inline size_t iov_iter_count(const struct iov_iter *i)
 {
diff --git a/lib/iov_iter.c b/lib/iov_iter.c
index c07bf978b935..3f22822a946c 100644
--- a/lib/iov_iter.c
+++ b/lib/iov_iter.c
@@ -21,9 +21,11 @@ 
 	size_t __maybe_unused off = 0;				\
 	len = n;						\
 	base = __p + i->iov_offset;				\
-	len -= (STEP);						\
-	i->iov_offset += len;					\
-	n = len;						\
+	do {							\
+		len -= (STEP);					\
+		i->iov_offset += len;				\
+		n = len;					\
+	} while (0);						\
 }
 
 /* covers iovec and kvec alike */
@@ -1611,6 +1613,64 @@  ssize_t extract_iter_to_iter(struct iov_iter *orig,
 }
 EXPORT_SYMBOL(extract_iter_to_iter);
 
+/**
+ * iov_iter_scan - Scan a source iter
+ * @i: The iterator to scan
+ * @bytes: The amount of buffer/data to scan
+ * @scanner: The function to call to process each segment
+ * @priv: Private data to pass to the scanner function
+ *
+ * Scan an iterator, passing each segment to the scanner function.  If the
+ * scanner returns an error at any time, scanning stops and the error is
+ * returned, otherwise the sum of the scanner results is returned.
+ */
+ssize_t iov_iter_scan(struct iov_iter *i, size_t bytes,
+		      ssize_t (*scanner)(struct iov_iter *i, const void *p,
+					 size_t len, size_t off, void *priv),
+		      void *priv)
+{
+	unsigned int gup_flags = 0;
+	ssize_t ret = 0, scanned = 0;
+
+	if (!bytes)
+		return 0;
+	if (WARN_ON(iov_iter_is_discard(i)))
+		return 0;
+	if (iter_is_iovec(i))
+		might_fault();
+
+	if (iov_iter_rw(i) != WRITE)
+		gup_flags |= FOLL_WRITE;
+	if (i->nofault)
+		gup_flags |= FOLL_NOFAULT;
+
+	iterate_and_advance(
+		i, bytes, base, len, off, ({
+				struct page *page;
+				void *q;
+
+				ret = get_user_pages_fast((unsigned long)base, 1,
+							  gup_flags, &page);
+				if (ret < 0)
+					break;
+				q = kmap_local_page(page);
+				ret = scanner(i, q, len, off, priv);
+				kunmap_local(q);
+				put_page(page);
+				if (ret < 0)
+					break;
+				scanned += ret;
+			}), ({
+				ret = scanner(i, base, len, off, priv);
+				if (ret < 0)
+					break;
+				scanned += ret;
+			})
+	);
+	return ret < 0 ? ret : scanned;
+}
+EXPORT_SYMBOL(iov_iter_scan);
+
 size_t csum_and_copy_from_iter(void *addr, size_t bytes, __wsum *csum,
 			       struct iov_iter *i)
 {