[v3,net-next,09/19] tls: Add rx inline crypto offload

Message ID 1531338873-18466-10-git-send-email-borisp@mellanox.com
State Changes Requested
Delegated to: David Miller
Headers show
Series
  • TLS offload rx, netdev & mlx5
Related show

Commit Message

Boris Pismenny July 11, 2018, 7:54 p.m.
This patch completes the generic infrastructure to offload TLS crypto to a
network device. It enables the kernel to skip decryption and
authentication of some skbs marked as decrypted by the NIC. In the fast
path, all packets received are decrypted by the NIC and the performance
is comparable to plain TCP.

This infrastructure doesn't require a TCP offload engine. Instead, the
NIC only decrypts packets that contain the expected TCP sequence number.
Out-Of-Order TCP packets are provided unmodified. As a result, at the
worst case a received TLS record consists of both plaintext and ciphertext
packets. These partially decrypted records must be reencrypted,
only to be decrypted.

The notable differences between SW KTLS Rx and this offload are as
follows:
1. Partial decryption - Software must handle the case of a TLS record
that was only partially decrypted by HW. This can happen due to packet
reordering.
2. Resynchronization - tls_read_size calls the device driver to
resynchronize HW after HW lost track of TLS record framing in
the TCP stream.

Signed-off-by: Boris Pismenny <borisp@mellanox.com>
---
 include/net/tls.h             |  63 +++++++++-
 net/tls/tls_device.c          | 278 ++++++++++++++++++++++++++++++++++++++----
 net/tls/tls_device_fallback.c |   1 +
 net/tls/tls_main.c            |  32 +++--
 net/tls/tls_sw.c              |  24 +++-
 5 files changed, 355 insertions(+), 43 deletions(-)

Patch

diff --git a/include/net/tls.h b/include/net/tls.h
index 7a485de..d8b3b65 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -83,6 +83,16 @@  struct tls_device {
 	void (*unhash)(struct tls_device *device, struct sock *sk);
 };
 
+enum {
+	TLS_BASE,
+	TLS_SW,
+#ifdef CONFIG_TLS_DEVICE
+	TLS_HW,
+#endif
+	TLS_HW_RECORD,
+	TLS_NUM_CONFIG,
+};
+
 struct tls_sw_context_tx {
 	struct crypto_aead *aead_send;
 	struct crypto_wait async_wait;
@@ -197,6 +207,7 @@  struct tls_context {
 	int (*push_pending_record)(struct sock *sk, int flags);
 
 	void (*sk_write_space)(struct sock *sk);
+	void (*sk_destruct)(struct sock *sk);
 	void (*sk_proto_close)(struct sock *sk, long timeout);
 
 	int  (*setsockopt)(struct sock *sk, int level,
@@ -209,13 +220,27 @@  struct tls_context {
 	void (*unhash)(struct sock *sk);
 };
 
+struct tls_offload_context_rx {
+	/* sw must be the first member of tls_offload_context_rx */
+	struct tls_sw_context_rx sw;
+	atomic64_t resync_req;
+	u8 driver_state[];
+	/* The TLS layer reserves room for driver specific state
+	 * Currently the belief is that there is not enough
+	 * driver specific state to justify another layer of indirection
+	 */
+};
+
+#define TLS_OFFLOAD_CONTEXT_SIZE_RX					\
+	(ALIGN(sizeof(struct tls_offload_context_rx), sizeof(void *)) + \
+	 TLS_DRIVER_STATE_SIZE)
+
 int wait_on_pending_writer(struct sock *sk, long *timeo);
 int tls_sk_query(struct sock *sk, int optname, char __user *optval,
 		int __user *optlen);
 int tls_sk_attach(struct sock *sk, int optname, char __user *optval,
 		  unsigned int optlen);
 
-
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx);
 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
 int tls_sw_sendpage(struct sock *sk, struct page *page,
@@ -290,11 +315,19 @@  static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx)
 	return tls_ctx->pending_open_record_frags;
 }
 
+struct sk_buff *
+tls_validate_xmit_skb(struct sock *sk, struct net_device *dev,
+		      struct sk_buff *skb);
+
 static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk)
 {
-	return sk_fullsock(sk) &&
-	       /* matches smp_store_release in tls_set_device_offload */
-	       smp_load_acquire(&sk->sk_destruct) == &tls_device_sk_destruct;
+#ifdef CONFIG_SOCK_VALIDATE_XMIT
+	return sk_fullsock(sk) &
+	       (smp_load_acquire(&sk->sk_validate_xmit_skb) ==
+	       &tls_validate_xmit_skb);
+#else
+	return false;
+#endif
 }
 
 static inline void tls_err_abort(struct sock *sk, int err)
@@ -387,10 +420,27 @@  static inline struct tls_sw_context_tx *tls_sw_ctx_tx(
 	return (struct tls_offload_context_tx *)tls_ctx->priv_ctx_tx;
 }
 
+static inline struct tls_offload_context_rx *
+tls_offload_ctx_rx(const struct tls_context *tls_ctx)
+{
+	return (struct tls_offload_context_rx *)tls_ctx->priv_ctx_rx;
+}
+
+/* The TLS context is valid until sk_destruct is called */
+static inline void tls_offload_rx_resync_request(struct sock *sk, __be32 seq)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx);
+
+	atomic64_set(&rx_ctx->resync_req, ((((uint64_t)seq) << 32) | 1));
+}
+
+
 int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
 		      unsigned char *record_type);
 void tls_register_device(struct tls_device *device);
 void tls_unregister_device(struct tls_device *device);
+int tls_device_decrypted(struct sock *sk, struct sk_buff *skb);
 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
 		struct scatterlist *sgout);
 
@@ -402,4 +452,9 @@  int tls_sw_fallback_init(struct sock *sk,
 			 struct tls_offload_context_tx *offload_ctx,
 			 struct tls_crypto_info *crypto_info);
 
+int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx);
+
+void tls_device_offload_cleanup_rx(struct sock *sk);
+void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn);
+
 #endif /* _TLS_OFFLOAD_H */
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index 332a5d1..4995d84 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -52,7 +52,11 @@ 
 
 static void tls_device_free_ctx(struct tls_context *ctx)
 {
-	kfree(tls_offload_ctx_tx(ctx));
+	if (ctx->tx_conf == TLS_HW)
+		kfree(tls_offload_ctx_tx(ctx));
+
+	if (ctx->rx_conf == TLS_HW)
+		kfree(tls_offload_ctx_rx(ctx));
 
 	kfree(ctx);
 }
@@ -70,10 +74,11 @@  static void tls_device_gc_task(struct work_struct *work)
 	list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
 		struct net_device *netdev = ctx->netdev;
 
-		if (netdev) {
+		if (netdev && ctx->tx_conf == TLS_HW) {
 			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
 							TLS_OFFLOAD_CTX_DIR_TX);
 			dev_put(netdev);
+			ctx->netdev = NULL;
 		}
 
 		list_del(&ctx->list);
@@ -81,6 +86,22 @@  static void tls_device_gc_task(struct work_struct *work)
 	}
 }
 
+static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
+			      struct net_device *netdev)
+{
+	if (sk->sk_destruct != tls_device_sk_destruct) {
+		refcount_set(&ctx->refcount, 1);
+		dev_hold(netdev);
+		ctx->netdev = netdev;
+		spin_lock_irq(&tls_device_lock);
+		list_add_tail(&ctx->list, &tls_device_list);
+		spin_unlock_irq(&tls_device_lock);
+
+		ctx->sk_destruct = sk->sk_destruct;
+		sk->sk_destruct = tls_device_sk_destruct;
+	}
+}
+
 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
 {
 	unsigned long flags;
@@ -180,13 +201,15 @@  void tls_device_sk_destruct(struct sock *sk)
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
 
-	if (ctx->open_record)
-		destroy_record(ctx->open_record);
+	tls_ctx->sk_destruct(sk);
 
-	delete_all_records(ctx);
-	crypto_free_aead(ctx->aead_send);
-	ctx->sk_destruct(sk);
-	clean_acked_data_disable(inet_csk(sk));
+	if (tls_ctx->tx_conf == TLS_HW) {
+		if (ctx->open_record)
+			destroy_record(ctx->open_record);
+		delete_all_records(ctx);
+		crypto_free_aead(ctx->aead_send);
+		clean_acked_data_disable(inet_csk(sk));
+	}
 
 	if (refcount_dec_and_test(&tls_ctx->refcount))
 		tls_device_queue_ctx_destruction(tls_ctx);
@@ -519,6 +542,118 @@  static int tls_device_push_pending_record(struct sock *sk, int flags)
 	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
 }
 
+void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct net_device *netdev = tls_ctx->netdev;
+	struct tls_offload_context_rx *rx_ctx;
+	u32 is_req_pending;
+	s64 resync_req;
+	u32 req_seq;
+
+	if (tls_ctx->rx_conf != TLS_HW)
+		return;
+
+	rx_ctx = tls_offload_ctx_rx(tls_ctx);
+	resync_req = atomic64_read(&rx_ctx->resync_req);
+	req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1);
+	is_req_pending = resync_req;
+
+	if (unlikely(is_req_pending) && req_seq == seq &&
+	    atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
+		netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk,
+						      seq + TLS_HEADER_SIZE - 1,
+						      rcd_sn);
+}
+
+static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
+{
+	struct strp_msg *rxm = strp_msg(skb);
+	int err = 0, offset = rxm->offset, copy, nsg;
+	struct sk_buff *skb_iter, *unused;
+	struct scatterlist sg[1];
+	char *orig_buf, *buf;
+
+	orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
+			   TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
+	if (!orig_buf)
+		return -ENOMEM;
+	buf = orig_buf;
+
+	nsg = skb_cow_data(skb, 0, &unused);
+	if (unlikely(nsg < 0)) {
+		err = nsg;
+		goto free_buf;
+	}
+
+	sg_init_table(sg, 1);
+	sg_set_buf(&sg[0], buf,
+		   rxm->full_len + TLS_HEADER_SIZE +
+		   TLS_CIPHER_AES_GCM_128_IV_SIZE);
+	skb_copy_bits(skb, offset, buf,
+		      TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+
+	/* We are interested only in the decrypted data not the auth */
+	err = decrypt_skb(sk, skb, sg);
+	if (err != -EBADMSG)
+		goto free_buf;
+	else
+		err = 0;
+
+	copy = min_t(int, skb_pagelen(skb) - offset,
+		     rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+
+	if (skb->decrypted)
+		skb_store_bits(skb, offset, buf, copy);
+
+	offset += copy;
+	buf += copy;
+
+	skb_walk_frags(skb, skb_iter) {
+		copy = min_t(int, skb_iter->len,
+			     rxm->full_len - offset + rxm->offset -
+			     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+
+		if (skb_iter->decrypted)
+			skb_store_bits(skb, offset, buf, copy);
+
+		offset += copy;
+		buf += copy;
+	}
+
+free_buf:
+	kfree(orig_buf);
+	return err;
+}
+
+int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
+	int is_decrypted = skb->decrypted;
+	int is_encrypted = !is_decrypted;
+	struct sk_buff *skb_iter;
+
+	/* Skip if it is already decrypted */
+	if (ctx->sw.decrypted)
+		return 0;
+
+	/* Check if all the data is decrypted already */
+	skb_walk_frags(skb, skb_iter) {
+		is_decrypted &= skb_iter->decrypted;
+		is_encrypted &= !skb_iter->decrypted;
+	}
+
+	ctx->sw.decrypted |= is_decrypted;
+
+	/* Return immedeatly if the record is either entirely plaintext or
+	 * entirely ciphertext. Otherwise handle reencrypt partially decrypted
+	 * record.
+	 */
+	return (is_encrypted || is_decrypted) ? 0 :
+		tls_device_reencrypt(sk, skb);
+}
+
 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 {
 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
@@ -608,7 +743,6 @@  int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 
 	clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
 	ctx->push_pending_record = tls_device_push_pending_record;
-	offload_ctx->sk_destruct = sk->sk_destruct;
 
 	/* TLS offload is greatly simplified if we don't send
 	 * SKBs where only part of the payload needs to be encrypted.
@@ -618,8 +752,6 @@  int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 	if (skb)
 		TCP_SKB_CB(skb)->eor = 1;
 
-	refcount_set(&ctx->refcount, 1);
-
 	/* We support starting offload on multiple sockets
 	 * concurrently, so we only need a read lock here.
 	 * This lock must precede get_netdev_for_sock to prevent races between
@@ -654,19 +786,14 @@  int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 	if (rc)
 		goto release_netdev;
 
-	ctx->netdev = netdev;
+	tls_device_attach(ctx, sk, netdev);
 
-	spin_lock_irq(&tls_device_lock);
-	list_add_tail(&ctx->list, &tls_device_list);
-	spin_unlock_irq(&tls_device_lock);
-
-	sk->sk_validate_xmit_skb = tls_validate_xmit_skb;
 	/* following this assignment tls_is_sk_tx_device_offloaded
 	 * will return true and the context might be accessed
 	 * by the netdev's xmit function.
 	 */
-	smp_store_release(&sk->sk_destruct,
-			  &tls_device_sk_destruct);
+	smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
+	dev_put(netdev);
 	up_read(&device_offload_lock);
 	goto out;
 
@@ -689,6 +816,105 @@  int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 	return rc;
 }
 
+int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
+{
+	struct tls_offload_context_rx *context;
+	struct net_device *netdev;
+	int rc = 0;
+
+	/* We support starting offload on multiple sockets
+	 * concurrently, so we only need a read lock here.
+	 * This lock must precede get_netdev_for_sock to prevent races between
+	 * NETDEV_DOWN and setsockopt.
+	 */
+	down_read(&device_offload_lock);
+	netdev = get_netdev_for_sock(sk);
+	if (!netdev) {
+		pr_err_ratelimited("%s: netdev not found\n", __func__);
+		rc = -EINVAL;
+		goto release_lock;
+	}
+
+	if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
+		pr_err_ratelimited("%s: netdev %s with no TLS offload\n",
+				   __func__, netdev->name);
+		rc = -ENOTSUPP;
+		goto release_netdev;
+	}
+
+	/* Avoid offloading if the device is down
+	 * We don't want to offload new flows after
+	 * the NETDEV_DOWN event
+	 */
+	if (!(netdev->flags & IFF_UP)) {
+		rc = -EINVAL;
+		goto release_netdev;
+	}
+
+	context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
+	if (!context) {
+		rc = -ENOMEM;
+		goto release_netdev;
+	}
+
+	ctx->priv_ctx_rx = context;
+	rc = tls_set_sw_offload(sk, ctx, 0);
+	if (rc)
+		goto release_ctx;
+
+	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
+					     &ctx->crypto_recv,
+					     tcp_sk(sk)->copied_seq);
+	if (rc) {
+		pr_err_ratelimited("%s: The netdev has refused to offload this socket\n",
+				   __func__);
+		goto free_sw_resources;
+	}
+
+	tls_device_attach(ctx, sk, netdev);
+	goto release_netdev;
+
+free_sw_resources:
+	tls_sw_free_resources_rx(sk);
+release_ctx:
+	ctx->priv_ctx_rx = NULL;
+release_netdev:
+	dev_put(netdev);
+release_lock:
+	up_read(&device_offload_lock);
+	return rc;
+}
+
+void tls_device_offload_cleanup_rx(struct sock *sk)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct net_device *netdev;
+
+	down_read(&device_offload_lock);
+	netdev = tls_ctx->netdev;
+	if (!netdev)
+		goto out;
+
+	if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
+		pr_err_ratelimited("%s: device is missing NETIF_F_HW_TLS_RX cap\n",
+				   __func__);
+		goto out;
+	}
+
+	netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
+					TLS_OFFLOAD_CTX_DIR_RX);
+
+	if (tls_ctx->tx_conf != TLS_HW) {
+		dev_put(netdev);
+		tls_ctx->netdev = NULL;
+	}
+out:
+	up_read(&device_offload_lock);
+	kfree(tls_ctx->rx.rec_seq);
+	kfree(tls_ctx->rx.iv);
+	tls_sw_release_resources_rx(sk);
+}
+
 static int tls_device_down(struct net_device *netdev)
 {
 	struct tls_context *ctx, *tmp;
@@ -709,8 +935,12 @@  static int tls_device_down(struct net_device *netdev)
 	spin_unlock_irqrestore(&tls_device_lock, flags);
 
 	list_for_each_entry_safe(ctx, tmp, &list, list)	{
-		netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
-						TLS_OFFLOAD_CTX_DIR_TX);
+		if (ctx->tx_conf == TLS_HW)
+			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
+							TLS_OFFLOAD_CTX_DIR_TX);
+		if (ctx->rx_conf == TLS_HW)
+			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
+							TLS_OFFLOAD_CTX_DIR_RX);
 		ctx->netdev = NULL;
 		dev_put(netdev);
 		list_del_init(&ctx->list);
@@ -731,12 +961,16 @@  static int tls_dev_event(struct notifier_block *this, unsigned long event,
 {
 	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
 
-	if (!(dev->features & NETIF_F_HW_TLS_TX))
+	if (!(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
 		return NOTIFY_DONE;
 
 	switch (event) {
 	case NETDEV_REGISTER:
 	case NETDEV_FEAT_CHANGE:
+		if ((dev->features & NETIF_F_HW_TLS_RX) &&
+		    !dev->tlsdev_ops->tls_dev_resync_rx)
+			return NOTIFY_BAD;
+
 		if  (dev->tlsdev_ops &&
 		     dev->tlsdev_ops->tls_dev_add &&
 		     dev->tlsdev_ops->tls_dev_del)
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
index d1d7dce..e3313c4 100644
--- a/net/tls/tls_device_fallback.c
+++ b/net/tls/tls_device_fallback.c
@@ -413,6 +413,7 @@  struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
 
 	return tls_sw_fallback(sk, skb);
 }
+EXPORT_SYMBOL_GPL(tls_validate_xmit_skb);
 
 int tls_sw_fallback_init(struct sock *sk,
 			 struct tls_offload_context_tx *offload_ctx,
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 301f224..b09867c 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -51,15 +51,6 @@  enum {
 	TLSV6,
 	TLS_NUM_PROTS,
 };
-enum {
-	TLS_BASE,
-	TLS_SW,
-#ifdef CONFIG_TLS_DEVICE
-	TLS_HW,
-#endif
-	TLS_HW_RECORD,
-	TLS_NUM_CONFIG,
-};
 
 static struct proto *saved_tcpv6_prot;
 static DEFINE_MUTEX(tcpv6_prot_mutex);
@@ -290,7 +281,10 @@  static void tls_sk_proto_close(struct sock *sk, long timeout)
 	}
 
 #ifdef CONFIG_TLS_DEVICE
-	if (ctx->tx_conf != TLS_HW) {
+	if (ctx->rx_conf == TLS_HW)
+		tls_device_offload_cleanup_rx(sk);
+
+	if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
 #else
 	{
 #endif
@@ -470,8 +464,16 @@  static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
 			conf = TLS_SW;
 		}
 	} else {
-		rc = tls_set_sw_offload(sk, ctx, 0);
-		conf = TLS_SW;
+#ifdef CONFIG_TLS_DEVICE
+		rc = tls_set_device_offload_rx(sk, ctx);
+		conf = TLS_HW;
+		if (rc) {
+#else
+		{
+#endif
+			rc = tls_set_sw_offload(sk, ctx, 0);
+			conf = TLS_SW;
+		}
 	}
 
 	if (rc)
@@ -629,6 +631,12 @@  static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
 	prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
 	prot[TLS_HW][TLS_SW].sendmsg		= tls_device_sendmsg;
 	prot[TLS_HW][TLS_SW].sendpage		= tls_device_sendpage;
+
+	prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
+
+	prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
+
+	prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
 #endif
 
 	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 5073676..2a6ba0f 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -658,16 +658,25 @@  static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
 }
 
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-			      struct scatterlist *sgout)
+			      struct scatterlist *sgout, bool *zc)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 	struct strp_msg *rxm = strp_msg(skb);
 	int err = 0;
 
-	err = decrypt_skb(sk, skb, sgout);
+#ifdef CONFIG_TLS_DEVICE
+	err = tls_device_decrypted(sk, skb);
 	if (err < 0)
 		return err;
+#endif
+	if (!ctx->decrypted) {
+		err = decrypt_skb(sk, skb, sgout);
+		if (err < 0)
+			return err;
+	} else {
+		*zc = false;
+	}
 
 	rxm->offset += tls_ctx->rx.prepend_size;
 	rxm->full_len -= tls_ctx->rx.overhead_size;
@@ -829,7 +838,7 @@  int tls_sw_recvmsg(struct sock *sk,
 				if (err < 0)
 					goto fallback_to_reg_recv;
 
-				err = decrypt_skb_update(sk, skb, sgin);
+				err = decrypt_skb_update(sk, skb, sgin, &zc);
 				for (; pages > 0; pages--)
 					put_page(sg_page(&sgin[pages]));
 				if (err < 0) {
@@ -838,7 +847,7 @@  int tls_sw_recvmsg(struct sock *sk,
 				}
 			} else {
 fallback_to_reg_recv:
-				err = decrypt_skb_update(sk, skb, NULL);
+				err = decrypt_skb_update(sk, skb, NULL, &zc);
 				if (err < 0) {
 					tls_err_abort(sk, EBADMSG);
 					goto recv_end;
@@ -893,6 +902,7 @@  ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 	int err = 0;
 	long timeo;
 	int chunk;
+	bool zc;
 
 	lock_sock(sk);
 
@@ -909,7 +919,7 @@  ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 	}
 
 	if (!ctx->decrypted) {
-		err = decrypt_skb_update(sk, skb, NULL);
+		err = decrypt_skb_update(sk, skb, NULL, &zc);
 
 		if (err < 0) {
 			tls_err_abort(sk, EBADMSG);
@@ -998,6 +1008,10 @@  static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
 		goto read_failure;
 	}
 
+#ifdef CONFIG_TLS_DEVICE
+	handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
+			     *(u64*)tls_ctx->rx.rec_seq);
+#endif
 	return data_len + TLS_HEADER_SIZE;
 
 read_failure: