diff mbox series

[7/6] cifs: Make the cifs RDMA code use iterators

Message ID 1663522.1652457567@warthog.procyon.org.uk
State New
Headers show
Series cifs: Use iov_iters down to the network transport | expand

Commit Message

David Howells May 13, 2022, 3:59 p.m. UTC
Convert the cifs RDMA code to use iterators rather than page lists and
transcribe the iterators into scatterlists.

NOTE!  This compiles, but is untested (I don't know how to set up RDMA in
Samba and cifs).  It also wants merging into a previous patch to avoid
build errors there.

Signed-off-by: David Howells <dhowells@redhat.com>
---
 fs/cifs/smb2pdu.c   |   23 +---
 fs/cifs/smbdirect.c |  284 +++++++++++++++++-----------------------------------
 fs/cifs/smbdirect.h |    4 
 3 files changed, 106 insertions(+), 205 deletions(-)
diff mbox series

Patch

diff --git a/fs/cifs/smb2pdu.c b/fs/cifs/smb2pdu.c
index 6bb9a90b018f..d9a06704daa8 100644
--- a/fs/cifs/smb2pdu.c
+++ b/fs/cifs/smb2pdu.c
@@ -4054,10 +4054,8 @@  smb2_new_read_req(void **buf, unsigned int *total_len,
 		struct smbd_buffer_descriptor_v1 *v1;
 		bool need_invalidate = server->dialect == SMB30_PROT_ID;
 
-		rdata->mr = smbd_register_mr(
-				server->smbd_conn, rdata->pages,
-				rdata->nr_pages, rdata->page_offset,
-				rdata->tailsz, true, need_invalidate);
+		rdata->mr = smbd_register_mr(server->smbd_conn, &rdata->iter,
+					     true, need_invalidate);
 		if (!rdata->mr)
 			return -EAGAIN;
 
@@ -4477,24 +4475,15 @@  smb2_async_writev(struct cifs_writedata *wdata,
 		struct smbd_buffer_descriptor_v1 *v1;
 		bool need_invalidate = server->dialect == SMB30_PROT_ID;
 
-		wdata->mr = smbd_register_mr(
-				server->smbd_conn, wdata->pages,
-				wdata->nr_pages, wdata->page_offset,
-				wdata->tailsz, false, need_invalidate);
+		wdata->mr = smbd_register_mr(server->smbd_conn, &wdata->iter,
+					     false, need_invalidate);
 		if (!wdata->mr) {
 			rc = -EAGAIN;
 			goto async_writev_out;
 		}
 		req->Length = 0;
 		req->DataOffset = 0;
-		if (wdata->nr_pages > 1)
-			req->RemainingBytes =
-				cpu_to_le32(
-					(wdata->nr_pages - 1) * wdata->pagesz -
-					wdata->page_offset + wdata->tailsz
-				);
-		else
-			req->RemainingBytes = cpu_to_le32(wdata->tailsz);
+		req->RemainingBytes = cpu_to_le32(iov_iter_count(&wdata->iter));
 		req->Channel = SMB2_CHANNEL_RDMA_V1_INVALIDATE;
 		if (need_invalidate)
 			req->Channel = SMB2_CHANNEL_RDMA_V1;
@@ -4517,7 +4506,7 @@  smb2_async_writev(struct cifs_writedata *wdata,
 #ifdef CONFIG_CIFS_SMB_DIRECT
 	if (wdata->mr) {
 		iov[0].iov_len += sizeof(struct smbd_buffer_descriptor_v1);
-		rqst.rq_npages = 0;
+		iov_iter_advance(&wdata->iter, iov_iter_count(&wdata->iter));
 	}
 #endif
 	cifs_dbg(FYI, "async write at %llu %u bytes\n",
diff --git a/fs/cifs/smbdirect.c b/fs/cifs/smbdirect.c
index 31ef64eb7fbb..5c311de5c9ac 100644
--- a/fs/cifs/smbdirect.c
+++ b/fs/cifs/smbdirect.c
@@ -34,12 +34,6 @@  static int smbd_post_recv(
 		struct smbd_response *response);
 
 static int smbd_post_send_empty(struct smbd_connection *info);
-static int smbd_post_send_data(
-		struct smbd_connection *info,
-		struct kvec *iov, int n_vec, int remaining_data_length);
-static int smbd_post_send_page(struct smbd_connection *info,
-		struct page *page, unsigned long offset,
-		size_t size, int remaining_data_length);
 
 static void destroy_mr_list(struct smbd_connection *info);
 static int allocate_mr_list(struct smbd_connection *info);
@@ -975,24 +969,6 @@  static int smbd_post_send_sgl(struct smbd_connection *info,
 	return rc;
 }
 
-/*
- * Send a page
- * page: the page to send
- * offset: offset in the page to send
- * size: length in the page to send
- * remaining_data_length: remaining data to send in this payload
- */
-static int smbd_post_send_page(struct smbd_connection *info, struct page *page,
-		unsigned long offset, size_t size, int remaining_data_length)
-{
-	struct scatterlist sgl;
-
-	sg_init_table(&sgl, 1);
-	sg_set_page(&sgl, page, size, offset);
-
-	return smbd_post_send_sgl(info, &sgl, size, remaining_data_length);
-}
-
 /*
  * Send an empty message
  * Empty message is used to extend credits to peer to for keep live
@@ -1004,35 +980,6 @@  static int smbd_post_send_empty(struct smbd_connection *info)
 	return smbd_post_send_sgl(info, NULL, 0, 0);
 }
 
-/*
- * Send a data buffer
- * iov: the iov array describing the data buffers
- * n_vec: number of iov array
- * remaining_data_length: remaining data to send following this packet
- * in segmented SMBD packet
- */
-static int smbd_post_send_data(
-	struct smbd_connection *info, struct kvec *iov, int n_vec,
-	int remaining_data_length)
-{
-	int i;
-	u32 data_length = 0;
-	struct scatterlist sgl[SMBDIRECT_MAX_SGE];
-
-	if (n_vec > SMBDIRECT_MAX_SGE) {
-		cifs_dbg(VFS, "Can't fit data to SGL, n_vec=%d\n", n_vec);
-		return -EINVAL;
-	}
-
-	sg_init_table(sgl, n_vec);
-	for (i = 0; i < n_vec; i++) {
-		data_length += iov[i].iov_len;
-		sg_set_buf(&sgl[i], iov[i].iov_base, iov[i].iov_len);
-	}
-
-	return smbd_post_send_sgl(info, sgl, data_length, remaining_data_length);
-}
-
 /*
  * Post a receive request to the transport
  * The remote peer can only send data when a receive request is posted
@@ -1976,6 +1923,42 @@  int smbd_recv(struct smbd_connection *info, struct msghdr *msg)
 	return rc;
 }
 
+/*
+ * Send the contents of an iterator
+ * @iter: The iterator to send
+ * @_remaining_data_length: remaining data to send in this payload
+ */
+static int smbd_post_send_iter(struct smbd_connection *info,
+			       struct iov_iter *iter,
+			       int *_remaining_data_length)
+{
+	struct scatterlist sgl;
+	struct page *page;
+	ssize_t len;
+	size_t offset, maxlen;
+	int i = 0, rc;
+
+	do {
+		maxlen = min_t(size_t, *_remaining_data_length, PAGE_SIZE);
+		len = iov_iter_get_pages(iter, &page, maxlen, 1, &offset);
+		if (len <= 0)
+			return len;
+
+		sg_init_table(&sgl, 1);
+		sg_set_page(&sgl, page, len, offset);
+
+		iov_iter_advance(iter, len);
+		*_remaining_data_length -= len;
+
+		log_write(INFO, "sending page i=%d offset=%zu size=%zu remaining_data_length=%d\n",
+			  i, offset, len, *_remaining_data_length);
+		rc = smbd_post_send_sgl(info, &sgl, len, *_remaining_data_length);
+		put_page(page);
+	} while (rc == 0);
+
+	return rc;
+}
+
 /*
  * Send data to transport
  * Each rqst is transported as a SMBDirect payload
@@ -1986,17 +1969,10 @@  int smbd_send(struct TCP_Server_Info *server,
 	int num_rqst, struct smb_rqst *rqst_array)
 {
 	struct smbd_connection *info = server->smbd_conn;
-	struct kvec vec;
-	int nvecs;
-	int size;
-	unsigned int buflen, remaining_data_length;
-	int start, i, j;
-	int max_iov_size =
-		info->max_send_size - sizeof(struct smbd_data_transfer);
-	struct kvec *iov;
-	int rc;
 	struct smb_rqst *rqst;
-	int rqst_idx;
+	struct iov_iter iter;
+	unsigned int remaining_data_length;
+	int rc, i, rqst_idx;
 
 	if (info->transport_status != SMBD_CONNECTED) {
 		rc = -EAGAIN;
@@ -2025,108 +2001,30 @@  int smbd_send(struct TCP_Server_Info *server,
 	rqst_idx = 0;
 next_rqst:
 	rqst = &rqst_array[rqst_idx];
-	iov = rqst->rq_iov;
 
 	cifs_dbg(FYI, "Sending smb (RDMA): idx=%d smb_len=%lu\n",
 		rqst_idx, smb_rqst_len(server, rqst));
 	for (i = 0; i < rqst->rq_nvec; i++)
-		dump_smb(iov[i].iov_base, iov[i].iov_len);
-
-
-	log_write(INFO, "rqst_idx=%d nvec=%d rqst->rq_npages=%d rq_pagesz=%d rq_tailsz=%d buflen=%lu\n",
-		  rqst_idx, rqst->rq_nvec, rqst->rq_npages, rqst->rq_pagesz,
-		  rqst->rq_tailsz, smb_rqst_len(server, rqst));
-
-	start = i = 0;
-	buflen = 0;
-	while (true) {
-		buflen += iov[i].iov_len;
-		if (buflen > max_iov_size) {
-			if (i > start) {
-				remaining_data_length -=
-					(buflen-iov[i].iov_len);
-				log_write(INFO, "sending iov[] from start=%d i=%d nvecs=%d remaining_data_length=%d\n",
-					  start, i, i - start,
-					  remaining_data_length);
-				rc = smbd_post_send_data(
-					info, &iov[start], i-start,
-					remaining_data_length);
-				if (rc)
-					goto done;
-			} else {
-				/* iov[start] is too big, break it */
-				nvecs = (buflen+max_iov_size-1)/max_iov_size;
-				log_write(INFO, "iov[%d] iov_base=%p buflen=%d break to %d vectors\n",
-					  start, iov[start].iov_base,
-					  buflen, nvecs);
-				for (j = 0; j < nvecs; j++) {
-					vec.iov_base =
-						(char *)iov[start].iov_base +
-						j*max_iov_size;
-					vec.iov_len = max_iov_size;
-					if (j == nvecs-1)
-						vec.iov_len =
-							buflen -
-							max_iov_size*(nvecs-1);
-					remaining_data_length -= vec.iov_len;
-					log_write(INFO,
-						"sending vec j=%d iov_base=%p iov_len=%zu remaining_data_length=%d\n",
-						  j, vec.iov_base, vec.iov_len,
-						  remaining_data_length);
-					rc = smbd_post_send_data(
-						info, &vec, 1,
-						remaining_data_length);
-					if (rc)
-						goto done;
-				}
-				i++;
-				if (i == rqst->rq_nvec)
-					break;
-			}
-			start = i;
-			buflen = 0;
-		} else {
-			i++;
-			if (i == rqst->rq_nvec) {
-				/* send out all remaining vecs */
-				remaining_data_length -= buflen;
-				log_write(INFO, "sending iov[] from start=%d i=%d nvecs=%d remaining_data_length=%d\n",
-					  start, i, i - start,
-					  remaining_data_length);
-				rc = smbd_post_send_data(info, &iov[start],
-					i-start, remaining_data_length);
-				if (rc)
-					goto done;
-				break;
-			}
-		}
-		log_write(INFO, "looping i=%d buflen=%d\n", i, buflen);
-	}
-
-	/* now sending pages if there are any */
-	for (i = 0; i < rqst->rq_npages; i++) {
-		unsigned int offset;
-
-		rqst_page_get_length(rqst, i, &buflen, &offset);
-		nvecs = (buflen + max_iov_size - 1) / max_iov_size;
-		log_write(INFO, "sending pages buflen=%d nvecs=%d\n",
-			buflen, nvecs);
-		for (j = 0; j < nvecs; j++) {
-			size = max_iov_size;
-			if (j == nvecs-1)
-				size = buflen - j*max_iov_size;
-			remaining_data_length -= size;
-			log_write(INFO, "sending pages i=%d offset=%d size=%d remaining_data_length=%d\n",
-				  i, j * max_iov_size + offset, size,
-				  remaining_data_length);
-			rc = smbd_post_send_page(
-				info, rqst->rq_pages[i],
-				j*max_iov_size + offset,
-				size, remaining_data_length);
-			if (rc)
-				goto done;
-		}
-	}
+		dump_smb(rqst->rq_iov[i].iov_base, rqst->rq_iov[i].iov_len);
+
+
+	log_write(INFO, "rqst_idx=%d nvec=%d rqst->rq_iter=%zd buflen=%lu\n",
+		  rqst_idx, rqst->rq_nvec, iov_iter_count(&rqst->rq_iter),
+		  smb_rqst_len(server, rqst));
+
+	/* Send the metadata pages. */
+	iov_iter_kvec(&iter, WRITE, rqst->rq_iov, rqst->rq_nvec,
+		      rqst->rq_iov[0].iov_len +
+		      (rqst->rq_nvec > 1 ? rqst->rq_iov[1].iov_len : 0));
+
+	rc = smbd_post_send_iter(info, &iter, &remaining_data_length);
+	if (rc < 0)
+		goto done;
+
+	/* And then the data pages if there are any */
+	rc = smbd_post_send_iter(info, &rqst->rq_iter, &remaining_data_length);
+	if (rc < 0)
+		goto done;
 
 	rqst_idx++;
 	if (rqst_idx < num_rqst)
@@ -2336,6 +2234,35 @@  static struct smbd_mr *get_mr(struct smbd_connection *info)
 	goto again;
 }
 
+/*
+ * Transcribe the pages from an iterator into an MR scatterlist.
+ * @iter: The iterator to transcribe
+ * @_remaining_data_length: remaining data to send in this payload
+ */
+static int smbd_iter_to_mr(struct smbd_connection *info,
+			   struct iov_iter *iter,
+			   struct scatterlist *sgl,
+			   unsigned int num_pages)
+{
+	struct page *page;
+	ssize_t len;
+	size_t offset, maxlen;
+
+	sg_init_table(sgl, num_pages);
+
+	for (;;) {
+		maxlen = min_t(size_t, iov_iter_count(iter), PAGE_SIZE);
+		len = iov_iter_get_pages(iter, &page, maxlen, 1, &offset);
+		if (len <= 0)
+			return len;
+
+		sg_set_page(sgl, page, len, offset);
+		sgl++;
+		put_page(page);
+		iov_iter_advance(iter, len);
+	}
+}
+
 /*
  * Register memory for RDMA read/write
  * pages[]: the list of pages to register memory with
@@ -2346,14 +2273,15 @@  static struct smbd_mr *get_mr(struct smbd_connection *info)
  * return value: the MR registered, NULL if failed.
  */
 struct smbd_mr *smbd_register_mr(
-	struct smbd_connection *info, struct page *pages[], int num_pages,
-	int offset, int tailsz, bool writing, bool need_invalidate)
+	struct smbd_connection *info, struct iov_iter *iter,
+	bool writing, bool need_invalidate)
 {
 	struct smbd_mr *smbdirect_mr;
-	int rc, i;
+	int rc, num_pages;
 	enum dma_data_direction dir;
 	struct ib_reg_wr *reg_wr;
 
+	num_pages = iov_iter_npages(iter, info->max_frmr_depth + 1);
 	if (num_pages > info->max_frmr_depth) {
 		log_rdma_mr(ERR, "num_pages=%d max_frmr_depth=%d\n",
 			num_pages, info->max_frmr_depth);
@@ -2365,32 +2293,16 @@  struct smbd_mr *smbd_register_mr(
 		log_rdma_mr(ERR, "get_mr returning NULL\n");
 		return NULL;
 	}
+
+	dir = writing ? DMA_FROM_DEVICE : DMA_TO_DEVICE;
+	smbdirect_mr->dir = dir;
 	smbdirect_mr->need_invalidate = need_invalidate;
 	smbdirect_mr->sgl_count = num_pages;
-	sg_init_table(smbdirect_mr->sgl, num_pages);
-
-	log_rdma_mr(INFO, "num_pages=0x%x offset=0x%x tailsz=0x%x\n",
-			num_pages, offset, tailsz);
-
-	if (num_pages == 1) {
-		sg_set_page(&smbdirect_mr->sgl[0], pages[0], tailsz, offset);
-		goto skip_multiple_pages;
-	}
 
-	/* We have at least two pages to register */
-	sg_set_page(
-		&smbdirect_mr->sgl[0], pages[0], PAGE_SIZE - offset, offset);
-	i = 1;
-	while (i < num_pages - 1) {
-		sg_set_page(&smbdirect_mr->sgl[i], pages[i], PAGE_SIZE, 0);
-		i++;
-	}
-	sg_set_page(&smbdirect_mr->sgl[i], pages[i],
-		tailsz ? tailsz : PAGE_SIZE, 0);
+	log_rdma_mr(INFO, "num_pages=0x%x count=0x%zx\n",
+		    num_pages, iov_iter_count(iter));
+	smbd_iter_to_mr(info, iter, smbdirect_mr->sgl, num_pages);
 
-skip_multiple_pages:
-	dir = writing ? DMA_FROM_DEVICE : DMA_TO_DEVICE;
-	smbdirect_mr->dir = dir;
 	rc = ib_dma_map_sg(info->id->device, smbdirect_mr->sgl, num_pages, dir);
 	if (!rc) {
 		log_rdma_mr(ERR, "ib_dma_map_sg num_pages=%x dir=%x rc=%x\n",
diff --git a/fs/cifs/smbdirect.h b/fs/cifs/smbdirect.h
index a87fca82a796..3a0d39e148e8 100644
--- a/fs/cifs/smbdirect.h
+++ b/fs/cifs/smbdirect.h
@@ -298,8 +298,8 @@  struct smbd_mr {
 
 /* Interfaces to register and deregister MR for RDMA read/write */
 struct smbd_mr *smbd_register_mr(
-	struct smbd_connection *info, struct page *pages[], int num_pages,
-	int offset, int tailsz, bool writing, bool need_invalidate);
+	struct smbd_connection *info, struct iov_iter *iter,
+	bool writing, bool need_invalidate);
 int smbd_deregister_mr(struct smbd_mr *mr);
 
 #else