diff mbox series

[net-next,1/1] tls: Fix recvmsg() to be able to peek across multiple records

Message ID 20190116064537.9460-1-vakul.garg@nxp.com
State Superseded
Delegated to: David Miller
Headers show
Series [net-next,1/1] tls: Fix recvmsg() to be able to peek across multiple records | expand

Commit Message

Vakul Garg Jan. 16, 2019, 6:48 a.m. UTC
This fixes recvmsg() to be able to peek across multiple tls records.
Without this patch, the tls's selftests test case
'recv_peek_large_buf_mult_recs' fails. Each tls receive context now
maintains a 'rx_list' to retain incoming skb carrying tls records. If a
tls record needs to be retained e.g. for peek case or for the case when
the buffer passed to recvmsg() has a length smaller than decrypted
record length, then it is added to 'rx_list'. Additionally, records are
added in 'rx_list' if the crypto operation runs in async mode. The
records are dequeued from 'rx_list' after the decrypted data is consumed
by copying into the buffer passed to recvmsg(). In case, the MSG_PEEK
flag is used in recvmsg(), then records are not consumed or removed
from the 'rx_list'.

Signed-off-by: Vakul Garg <vakul.garg@nxp.com>
---
 include/net/tls.h |   3 +-
 net/tls/tls_sw.c  | 261 +++++++++++++++++++++++++++++++++++++++---------------
 2 files changed, 193 insertions(+), 71 deletions(-)
diff mbox series

Patch

diff --git a/include/net/tls.h b/include/net/tls.h
index 2a6ac8d642af..90bf52db573e 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -145,12 +145,13 @@  struct tls_sw_context_tx {
 struct tls_sw_context_rx {
 	struct crypto_aead *aead_recv;
 	struct crypto_wait async_wait;
-
 	struct strparser strp;
+	struct sk_buff_head rx_list;	/* list of decrypted 'data' records */
 	void (*saved_data_ready)(struct sock *sk);
 
 	struct sk_buff *recv_pkt;
 	u8 control;
+	int async_capable;
 	bool decrypted;
 	atomic_t decrypt_pending;
 	bool async_notify;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 11cdc8f7db63..6a6b3e6e2797 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -124,9 +124,11 @@  static void tls_decrypt_done(struct crypto_async_request *req, int err)
 {
 	struct aead_request *aead_req = (struct aead_request *)req;
 	struct scatterlist *sgout = aead_req->dst;
+	struct scatterlist *sgin = aead_req->src;
 	struct tls_sw_context_rx *ctx;
 	struct tls_context *tls_ctx;
 	struct scatterlist *sg;
+	struct strp_msg *rxm;
 	struct sk_buff *skb;
 	unsigned int pages;
 	int pending;
@@ -134,7 +136,6 @@  static void tls_decrypt_done(struct crypto_async_request *req, int err)
 	skb = (struct sk_buff *)req->data;
 	tls_ctx = tls_get_ctx(skb->sk);
 	ctx = tls_sw_ctx_rx(tls_ctx);
-	pending = atomic_dec_return(&ctx->decrypt_pending);
 
 	/* Propagate if there was an err */
 	if (err) {
@@ -142,23 +143,30 @@  static void tls_decrypt_done(struct crypto_async_request *req, int err)
 		tls_err_abort(skb->sk, err);
 	}
 
+	rxm = strp_msg(skb);
+	rxm->offset += tls_ctx->rx.prepend_size;
+	rxm->full_len -= tls_ctx->rx.overhead_size;
+
 	/* After using skb->sk to propagate sk through crypto async callback
 	 * we need to NULL it again.
 	 */
 	skb->sk = NULL;
 
-	/* Release the skb, pages and memory allocated for crypto req */
-	kfree_skb(skb);
 
-	/* Skip the first S/G entry as it points to AAD */
-	for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
-		if (!sg)
-			break;
-		put_page(sg_page(sg));
+	/* Free the destination pages if skb was not decrypted inplace */
+	if (sgout != sgin) {
+		/* Skip the first S/G entry as it points to AAD */
+		for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
+			if (!sg)
+				break;
+			put_page(sg_page(sg));
+		}
 	}
 
 	kfree(aead_req);
 
+	pending = atomic_dec_return(&ctx->decrypt_pending);
+
 	if (!pending && READ_ONCE(ctx->async_notify))
 		complete(&ctx->async_wait.completion);
 }
@@ -1281,7 +1289,7 @@  static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 			    struct iov_iter *out_iov,
 			    struct scatterlist *out_sg,
-			    int *chunk, bool *zc)
+			    int *chunk, bool *zc, bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1381,13 +1389,13 @@  static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 fallback_to_reg_recv:
 		sgout = sgin;
 		pages = 0;
-		*chunk = 0;
+		*chunk = data_len;
 		*zc = false;
 	}
 
 	/* Prepare and submit AEAD request */
 	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
-				data_len, aead_req, *zc);
+				data_len, aead_req, async);
 	if (err == -EINPROGRESS)
 		return err;
 
@@ -1400,7 +1408,8 @@  static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 }
 
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-			      struct iov_iter *dest, int *chunk, bool *zc)
+			      struct iov_iter *dest, int *chunk, bool *zc,
+			      bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1413,7 +1422,7 @@  static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
 		return err;
 #endif
 	if (!ctx->decrypted) {
-		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
+		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
 		if (err < 0) {
 			if (err == -EINPROGRESS)
 				tls_advance_record_sn(sk, &tls_ctx->rx);
@@ -1439,7 +1448,7 @@  int decrypt_skb(struct sock *sk, struct sk_buff *skb,
 	bool zc = true;
 	int chunk;
 
-	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
+	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
 }
 
 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -1466,6 +1475,72 @@  static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
 	return true;
 }
 
+static int process_rx_list(struct tls_sw_context_rx *ctx,
+			   struct msghdr *msg,
+			   size_t skip,
+			   size_t len,
+			   bool zc,
+			   bool is_peek)
+{
+	struct sk_buff *skb = skb_peek(&ctx->rx_list);
+	ssize_t copied = 0;
+
+	while (skip && skb) {
+		struct strp_msg *rxm = strp_msg(skb);
+
+		if (skip < rxm->full_len)
+			break;
+
+		skip = skip - rxm->full_len;
+		skb = skb_peek_next(skb, &ctx->rx_list);
+	}
+
+	while (len && skb) {
+		struct sk_buff *next_skb;
+		struct strp_msg *rxm = strp_msg(skb);
+		int chunk = min_t(unsigned int, rxm->full_len - skip, len);
+
+		if (!zc || (rxm->full_len - skip) > len) {
+			int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
+						    msg, chunk);
+			if (err < 0)
+				return err;
+		}
+
+		len = len - chunk;
+		copied = copied + chunk;
+
+		/* Consume the data from record if it is non-peek case*/
+		if (!is_peek) {
+			rxm->offset = rxm->offset + chunk;
+			rxm->full_len = rxm->full_len - chunk;
+
+			/* Return if there is unconsumed data in the record */
+			if (rxm->full_len - skip)
+				break;
+		}
+
+		/* The remaining skip-bytes must lie in 1st record in rx_list.
+		 * So from the 2nd record, 'skip' should be 0.
+		 */
+		skip = 0;
+
+		if (msg)
+			msg->msg_flags |= MSG_EOR;
+
+		next_skb = skb_peek_next(skb, &ctx->rx_list);
+
+		if (!is_peek) {
+			skb_unlink(skb, &ctx->rx_list);
+			kfree_skb(skb);
+		}
+
+		skb = next_skb;
+	}
+
+	return copied;
+}
+
 int tls_sw_recvmsg(struct sock *sk,
 		   struct msghdr *msg,
 		   size_t len,
@@ -1476,7 +1551,8 @@  int tls_sw_recvmsg(struct sock *sk,
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 	struct sk_psock *psock;
-	unsigned char control;
+	unsigned char control = 0;
+	ssize_t decrypted = 0;
 	struct strp_msg *rxm;
 	struct sk_buff *skb;
 	ssize_t copied = 0;
@@ -1484,6 +1560,7 @@  int tls_sw_recvmsg(struct sock *sk,
 	int target, err = 0;
 	long timeo;
 	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
+	bool is_peek = flags & MSG_PEEK;
 	int num_async = 0;
 
 	flags |= nonblock;
@@ -1494,11 +1571,28 @@  int tls_sw_recvmsg(struct sock *sk,
 	psock = sk_psock_get(sk);
 	lock_sock(sk);
 
-	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
-	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+	/* Process pending decrypted records. It must be non-zero-copy */
+	err = process_rx_list(ctx, msg, 0, len, false, is_peek);
+	if (err < 0) {
+		tls_err_abort(sk, err);
+		goto end;
+	} else {
+		copied = err;
+	}
+
+	len = len - copied;
+	if (len) {
+		target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
+		timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+	} else {
+		goto recv_end;
+	}
+
 	do {
-		bool zc = false;
+		bool retain_skb = false;
 		bool async = false;
+		bool zc = false;
+		int to_decrypt;
 		int chunk = 0;
 
 		skb = tls_wait_data(sk, psock, flags, timeo, &err);
@@ -1508,7 +1602,7 @@  int tls_sw_recvmsg(struct sock *sk,
 							    msg, len, flags);
 
 				if (ret > 0) {
-					copied += ret;
+					decrypted += ret;
 					len -= ret;
 					continue;
 				}
@@ -1535,70 +1629,70 @@  int tls_sw_recvmsg(struct sock *sk,
 			goto recv_end;
 		}
 
-		if (!ctx->decrypted) {
-			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
-
-			if (!is_kvec && to_copy <= len &&
-			    likely(!(flags & MSG_PEEK)))
-				zc = true;
-
-			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
-						 &chunk, &zc);
-			if (err < 0 && err != -EINPROGRESS) {
-				tls_err_abort(sk, EBADMSG);
-				goto recv_end;
-			}
+		to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size;
 
-			if (err == -EINPROGRESS) {
-				async = true;
-				num_async++;
-				goto pick_next_record;
-			}
+		if (to_decrypt <= len && !is_kvec && !is_peek)
+			zc = true;
 
-			ctx->decrypted = true;
+		err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+					 &chunk, &zc, ctx->async_capable);
+		if (err < 0 && err != -EINPROGRESS) {
+			tls_err_abort(sk, EBADMSG);
+			goto recv_end;
 		}
 
-		if (!zc) {
-			chunk = min_t(unsigned int, rxm->full_len, len);
+		if (err == -EINPROGRESS) {
+			async = true;
+			num_async++;
+			goto pick_next_record;
+		} else {
+			if (!zc) {
+				if (rxm->full_len > len) {
+					retain_skb = true;
+					chunk = len;
+				} else {
+					chunk = rxm->full_len;
+				}
 
-			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
-						    chunk);
-			if (err < 0)
-				goto recv_end;
+				err = skb_copy_datagram_msg(skb, rxm->offset,
+							    msg, chunk);
+				if (err < 0)
+					goto recv_end;
+
+				if (!is_peek) {
+					rxm->offset = rxm->offset + chunk;
+					rxm->full_len = rxm->full_len - chunk;
+				}
+			}
 		}
 
 pick_next_record:
-		copied += chunk;
+		if (chunk > len)
+			chunk = len;
+
+		decrypted += chunk;
 		len -= chunk;
-		if (likely(!(flags & MSG_PEEK))) {
-			u8 control = ctx->control;
-
-			/* For async, drop current skb reference */
-			if (async)
-				skb = NULL;
-
-			if (tls_sw_advance_skb(sk, skb, chunk)) {
-				/* Return full control message to
-				 * userspace before trying to parse
-				 * another message type
-				 */
-				msg->msg_flags |= MSG_EOR;
-				if (control != TLS_RECORD_TYPE_DATA)
-					goto recv_end;
-			} else {
-				break;
-			}
-		} else {
-			/* MSG_PEEK right now cannot look beyond current skb
-			 * from strparser, meaning we cannot advance skb here
-			 * and thus unpause strparser since we'd loose original
-			 * one.
+
+		/* For async or peek case, queue the current skb */
+		if (async || is_peek || retain_skb) {
+			skb_queue_tail(&ctx->rx_list, skb);
+			skb = NULL;
+		}
+
+		if (tls_sw_advance_skb(sk, skb, chunk)) {
+			/* Return full control message to
+			 * userspace before trying to parse
+			 * another message type
 			 */
+			msg->msg_flags |= MSG_EOR;
+			if (ctx->control != TLS_RECORD_TYPE_DATA)
+				goto recv_end;
+		} else {
 			break;
 		}
 
 		/* If we have a new message from strparser, continue now. */
-		if (copied >= target && !ctx->recv_pkt)
+		if (decrypted >= target && !ctx->recv_pkt)
 			break;
 	} while (len);
 
@@ -1612,13 +1706,33 @@  int tls_sw_recvmsg(struct sock *sk,
 				/* one of async decrypt failed */
 				tls_err_abort(sk, err);
 				copied = 0;
+				decrypted = 0;
+				goto end;
 			}
 		} else {
 			reinit_completion(&ctx->async_wait.completion);
 		}
 		WRITE_ONCE(ctx->async_notify, false);
+
+		/* Drain records from the rx_list & copy if required */
+		if (is_peek || is_kvec)
+			err = process_rx_list(ctx, msg, copied,
+					      decrypted, false, is_peek);
+		else
+			err = process_rx_list(ctx, msg, 0,
+					      decrypted, true, is_peek);
+		if (err < 0) {
+			tls_err_abort(sk, err);
+			copied = 0;
+			goto end;
+		}
+
+		WARN_ON(decrypted != err);
 	}
 
+	copied += decrypted;
+
+end:
 	release_sock(sk);
 	if (psock)
 		sk_psock_put(sk, psock);
@@ -1655,7 +1769,7 @@  ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 	}
 
 	if (!ctx->decrypted) {
-		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
+		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
 
 		if (err < 0) {
 			tls_err_abort(sk, EBADMSG);
@@ -1842,6 +1956,7 @@  void tls_sw_release_resources_rx(struct sock *sk)
 	if (ctx->aead_recv) {
 		kfree_skb(ctx->recv_pkt);
 		ctx->recv_pkt = NULL;
+		skb_queue_purge(&ctx->rx_list);
 		crypto_free_aead(ctx->aead_recv);
 		strp_stop(&ctx->strp);
 		write_lock_bh(&sk->sk_callback_lock);
@@ -1891,6 +2006,7 @@  int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 	struct crypto_aead **aead;
 	struct strp_callbacks cb;
 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
+	struct crypto_tfm *tfm;
 	char *iv, *rec_seq;
 	int rc = 0;
 
@@ -1937,6 +2053,7 @@  int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 		crypto_init_wait(&sw_ctx_rx->async_wait);
 		crypto_info = &ctx->crypto_recv.info;
 		cctx = &ctx->rx;
+		skb_queue_head_init(&sw_ctx_rx->rx_list);
 		aead = &sw_ctx_rx->aead_recv;
 	}
 
@@ -2004,6 +2121,10 @@  int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 		goto free_aead;
 
 	if (sw_ctx_rx) {
+		tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
+		sw_ctx_rx->async_capable =
+			tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
+
 		/* Set up strparser */
 		memset(&cb, 0, sizeof(cb));
 		cb.rcv_msg = tls_queue;