diff mbox series

[net-next,06/14] net/tls: Add generic NIC offload infrastructure

Message ID 20180320024510.7408-7-saeedm@mellanox.com
State Superseded, archived
Delegated to: David Miller
Headers show
Series TLS offload, netdev & MLX5 support | expand

Commit Message

Saeed Mahameed March 20, 2018, 2:45 a.m. UTC
From: Ilya Lesokhin <ilyal@mellanox.com>

This patch adds a generic infrastructure to offload TLS crypto to a
network devices. It enables the kernel TLS socket to skip encryption
and authentication operations on the transmit side of the data path.
Leaving those computationally expensive operations to the NIC.

The NIC offload infrastructure builds TLS records and pushes them to
the TCP layer just like the SW KTLS implementation and using the same API.
TCP segmentation is mostly unaffected. Currently the only exception is
that we prevent mixed SKBs where only part of the payload requires
offload. In the future we are likely to add a similar restriction
following a change cipher spec record.

The notable differences between SW KTLS and NIC offloaded TLS
implementations are as follows:
1. The offloaded implementation builds "plaintext TLS record", those
records contain plaintext instead of ciphertext and place holder bytes
instead of authentication tags.
2. The offloaded implementation maintains a mapping from TCP sequence
number to TLS records. Thus given a TCP SKB sent from a NIC offloaded
TLS socket, we can use the tls NIC offload infrastructure to obtain
enough context to encrypt the payload of the SKB.
A TLS record is released when the last byte of the record is ack'ed,
this is done through the new icsk_clean_acked callback.

The infrastructure should be extendable to support various NIC offload
implementations.  However it is currently written with the
implementation below in mind:
The NIC assumes that packets from each offloaded stream are sent as
plaintext and in-order. It keeps track of the TLS records in the TCP
stream. When a packet marked for offload is transmitted, the NIC
encrypts the payload in-place and puts authentication tags in the
relevant place holders.

The responsibility for handling out-of-order packets (i.e. TCP
retransmission, qdisc drops) falls on the netdev driver.

The netdev driver keeps track of the expected TCP SN from the NIC's
perspective.  If the next packet to transmit matches the expected TCP
SN, the driver advances the expected TCP SN, and transmits the packet
with TLS offload indication.

If the next packet to transmit does not match the expected TCP SN. The
driver calls the TLS layer to obtain the TLS record that includes the
TCP of the packet for transmission. Using this TLS record, the driver
posts a work entry on the transmit queue to reconstruct the NIC TLS
state required for the offload of the out-of-order packet. It updates
the expected TCP SN accordingly and transmit the now in-order packet.
The same queue is used for packet transmission and TLS context
reconstruction to avoid the need for flushing the transmit queue before
issuing the context reconstruction request.

Signed-off-by: Ilya Lesokhin <ilyal@mellanox.com>
Signed-off-by: Boris Pismenny <borisp@mellanox.com>
Signed-off-by: Aviad Yehezkel <aviadye@mellanox.com>
Signed-off-by: Saeed Mahameed <saeedm@mellanox.com>
---
 include/net/tls.h             |  70 +++-
 net/tls/Kconfig               |  10 +
 net/tls/Makefile              |   2 +
 net/tls/tls_device.c          | 804 ++++++++++++++++++++++++++++++++++++++++++
 net/tls/tls_device_fallback.c | 419 ++++++++++++++++++++++
 net/tls/tls_main.c            |  33 +-
 6 files changed, 1331 insertions(+), 7 deletions(-)
 create mode 100644 net/tls/tls_device.c
 create mode 100644 net/tls/tls_device_fallback.c

Comments

Kirill Tkhai March 21, 2018, 11:15 a.m. UTC | #1
On 20.03.2018 05:45, Saeed Mahameed wrote:
> From: Ilya Lesokhin <ilyal@mellanox.com>
> 
> This patch adds a generic infrastructure to offload TLS crypto to a
> network devices. It enables the kernel TLS socket to skip encryption
> and authentication operations on the transmit side of the data path.
> Leaving those computationally expensive operations to the NIC.
> 
> The NIC offload infrastructure builds TLS records and pushes them to
> the TCP layer just like the SW KTLS implementation and using the same API.
> TCP segmentation is mostly unaffected. Currently the only exception is
> that we prevent mixed SKBs where only part of the payload requires
> offload. In the future we are likely to add a similar restriction
> following a change cipher spec record.
> 
> The notable differences between SW KTLS and NIC offloaded TLS
> implementations are as follows:
> 1. The offloaded implementation builds "plaintext TLS record", those
> records contain plaintext instead of ciphertext and place holder bytes
> instead of authentication tags.
> 2. The offloaded implementation maintains a mapping from TCP sequence
> number to TLS records. Thus given a TCP SKB sent from a NIC offloaded
> TLS socket, we can use the tls NIC offload infrastructure to obtain
> enough context to encrypt the payload of the SKB.
> A TLS record is released when the last byte of the record is ack'ed,
> this is done through the new icsk_clean_acked callback.
> 
> The infrastructure should be extendable to support various NIC offload
> implementations.  However it is currently written with the
> implementation below in mind:
> The NIC assumes that packets from each offloaded stream are sent as
> plaintext and in-order. It keeps track of the TLS records in the TCP
> stream. When a packet marked for offload is transmitted, the NIC
> encrypts the payload in-place and puts authentication tags in the
> relevant place holders.
> 
> The responsibility for handling out-of-order packets (i.e. TCP
> retransmission, qdisc drops) falls on the netdev driver.
> 
> The netdev driver keeps track of the expected TCP SN from the NIC's
> perspective.  If the next packet to transmit matches the expected TCP
> SN, the driver advances the expected TCP SN, and transmits the packet
> with TLS offload indication.
> 
> If the next packet to transmit does not match the expected TCP SN. The
> driver calls the TLS layer to obtain the TLS record that includes the
> TCP of the packet for transmission. Using this TLS record, the driver
> posts a work entry on the transmit queue to reconstruct the NIC TLS
> state required for the offload of the out-of-order packet. It updates
> the expected TCP SN accordingly and transmit the now in-order packet.
> The same queue is used for packet transmission and TLS context
> reconstruction to avoid the need for flushing the transmit queue before
> issuing the context reconstruction request.
> 
> Signed-off-by: Ilya Lesokhin <ilyal@mellanox.com>
> Signed-off-by: Boris Pismenny <borisp@mellanox.com>
> Signed-off-by: Aviad Yehezkel <aviadye@mellanox.com>
> Signed-off-by: Saeed Mahameed <saeedm@mellanox.com>
> ---
>  include/net/tls.h             |  70 +++-
>  net/tls/Kconfig               |  10 +
>  net/tls/Makefile              |   2 +
>  net/tls/tls_device.c          | 804 ++++++++++++++++++++++++++++++++++++++++++
>  net/tls/tls_device_fallback.c | 419 ++++++++++++++++++++++
>  net/tls/tls_main.c            |  33 +-
>  6 files changed, 1331 insertions(+), 7 deletions(-)
>  create mode 100644 net/tls/tls_device.c
>  create mode 100644 net/tls/tls_device_fallback.c
> 
> diff --git a/include/net/tls.h b/include/net/tls.h
> index 4913430ab807..ab98a6dc4929 100644
> --- a/include/net/tls.h
> +++ b/include/net/tls.h
> @@ -77,6 +77,37 @@ struct tls_sw_context {
>  	struct scatterlist sg_aead_out[2];
>  };
>  
> +struct tls_record_info {
> +	struct list_head list;
> +	u32 end_seq;
> +	int len;
> +	int num_frags;
> +	skb_frag_t frags[MAX_SKB_FRAGS];
> +};
> +
> +struct tls_offload_context {
> +	struct crypto_aead *aead_send;
> +	spinlock_t lock;	/* protects records list */
> +	struct list_head records_list;
> +	struct tls_record_info *open_record;
> +	struct tls_record_info *retransmit_hint;
> +	u64 hint_record_sn;
> +	u64 unacked_record_sn;
> +
> +	struct scatterlist sg_tx_data[MAX_SKB_FRAGS];
> +	void (*sk_destruct)(struct sock *sk);
> +	u8 driver_state[];
> +	/* The TLS layer reserves room for driver specific state
> +	 * Currently the belief is that there is not enough
> +	 * driver specific state to justify another layer of indirection
> +	 */
> +#define TLS_DRIVER_STATE_SIZE (max_t(size_t, 8, sizeof(void *)))
> +};
> +
> +#define TLS_OFFLOAD_CONTEXT_SIZE                                               \
> +	(ALIGN(sizeof(struct tls_offload_context), sizeof(void *)) +           \
> +	 TLS_DRIVER_STATE_SIZE)
> +
>  enum {
>  	TLS_PENDING_CLOSED_RECORD
>  };
> @@ -87,6 +118,10 @@ struct tls_context {
>  		struct tls12_crypto_info_aes_gcm_128 crypto_send_aes_gcm_128;
>  	};
>  
> +	struct list_head list;
> +	struct net_device *netdev;
> +	refcount_t refcount;
> +
>  	void *priv_ctx;
>  
>  	u8 tx_conf:2;
> @@ -131,9 +166,29 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
>  void tls_sw_close(struct sock *sk, long timeout);
>  void tls_sw_free_tx_resources(struct sock *sk);
>  
> -void tls_sk_destruct(struct sock *sk, struct tls_context *ctx);
> -void tls_icsk_clean_acked(struct sock *sk);
> +void tls_clear_device_offload(struct sock *sk, struct tls_context *ctx);
> +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx);
> +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
> +int tls_device_sendpage(struct sock *sk, struct page *page,
> +			int offset, size_t size, int flags);
> +void tls_device_sk_destruct(struct sock *sk);
> +void tls_device_init(void);
> +void tls_device_cleanup(void);
>  
> +struct tls_record_info *tls_get_record(struct tls_offload_context *context,
> +				       u32 seq, u64 *p_record_sn);
> +
> +static inline bool tls_record_is_start_marker(struct tls_record_info *rec)
> +{
> +	return rec->len == 0;
> +}
> +
> +static inline u32 tls_record_start_seq(struct tls_record_info *rec)
> +{
> +	return rec->end_seq - rec->len;
> +}
> +
> +void tls_sk_destruct(struct sock *sk, struct tls_context *ctx);
>  int tls_push_sg(struct sock *sk, struct tls_context *ctx,
>  		struct scatterlist *sg, u16 first_offset,
>  		int flags);
> @@ -170,6 +225,13 @@ static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx)
>  	return tls_ctx->pending_open_record_frags;
>  }
>  
> +static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk)
> +{
> +	return sk_fullsock(sk) &&
> +	       /* matches smp_store_release in tls_set_device_offload */
> +	       smp_load_acquire(&sk->sk_destruct) == &tls_device_sk_destruct;
> +}
> +
>  static inline void tls_err_abort(struct sock *sk)
>  {
>  	sk->sk_err = EBADMSG;
> @@ -257,4 +319,8 @@ static inline struct tls_offload_context *tls_offload_ctx(
>  int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
>  		      unsigned char *record_type);
>  
> +int tls_sw_fallback_init(struct sock *sk,
> +			 struct tls_offload_context *offload_ctx,
> +			 struct tls_crypto_info *crypto_info);
> +
>  #endif /* _TLS_OFFLOAD_H */
> diff --git a/net/tls/Kconfig b/net/tls/Kconfig
> index eb583038c67e..9d3ef820bb16 100644
> --- a/net/tls/Kconfig
> +++ b/net/tls/Kconfig
> @@ -13,3 +13,13 @@ config TLS
>  	encryption handling of the TLS protocol to be done in-kernel.
>  
>  	If unsure, say N.
> +
> +config TLS_DEVICE
> +	bool "Transport Layer Security HW offload"
> +	depends on TLS
> +	select SOCK_VALIDATE_XMIT
> +	default n
> +	---help---
> +	Enable kernel support for HW offload of the TLS protocol.
> +
> +	If unsure, say N.
> diff --git a/net/tls/Makefile b/net/tls/Makefile
> index a930fd1c4f7b..4d6b728a67d0 100644
> --- a/net/tls/Makefile
> +++ b/net/tls/Makefile
> @@ -5,3 +5,5 @@
>  obj-$(CONFIG_TLS) += tls.o
>  
>  tls-y := tls_main.o tls_sw.o
> +
> +tls-$(CONFIG_TLS_DEVICE) += tls_device.o tls_device_fallback.o
> diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
> new file mode 100644
> index 000000000000..c0d4e11a4286
> --- /dev/null
> +++ b/net/tls/tls_device.c
> @@ -0,0 +1,804 @@
> +/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
> + *
> + *     Redistribution and use in source and binary forms, with or
> + *     without modification, are permitted provided that the following
> + *     conditions are met:
> + *
> + *      - Redistributions of source code must retain the above
> + *        copyright notice, this list of conditions and the following
> + *        disclaimer.
> + *
> + *      - Redistributions in binary form must reproduce the above
> + *        copyright notice, this list of conditions and the following
> + *        disclaimer in the documentation and/or other materials
> + *        provided with the distribution.
> + *
> + *      - Neither the name of the Mellanox Technologies nor the
> + *        names of its contributors may be used to endorse or promote
> + *        products derived from this software without specific prior written
> + *        permission.
> + *
> + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
> + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
> + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
> + * A PARTICULAR PURPOSE ARE DISCLAIMED.
> + * IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
> + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
> + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
> + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
> + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
> + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
> + * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
> + * POSSIBILITY OF SUCH DAMAGE
> + */

Other patches have two licenses in header. Can I distribute this file under GPL license terms?

> +#include <linux/module.h>
> +#include <net/tcp.h>
> +#include <net/inet_common.h>
> +#include <linux/highmem.h>
> +#include <linux/netdevice.h>
> +
> +#include <net/tls.h>
> +#include <crypto/aead.h>
> +
> +/* device_offload_lock is used to synchronize tls_dev_add
> + * against NETDEV_DOWN notifications.
> + */
> +DEFINE_STATIC_PERCPU_RWSEM(device_offload_lock);
> +
> +static void tls_device_gc_task(struct work_struct *work);
> +
> +static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
> +static LIST_HEAD(tls_device_gc_list);
> +static LIST_HEAD(tls_device_list);
> +static DEFINE_SPINLOCK(tls_device_lock);
> +
> +static void tls_device_free_ctx(struct tls_context *ctx)
> +{
> +	struct tls_offload_context *offlad_ctx = tls_offload_ctx(ctx);
> +
> +	kfree(offlad_ctx);
> +	kfree(ctx);
> +}
> +
> +static void tls_device_gc_task(struct work_struct *work)
> +{
> +	struct tls_context *ctx, *tmp;
> +	struct list_head gc_list;
> +	unsigned long flags;
> +
> +	spin_lock_irqsave(&tls_device_lock, flags);
> +	INIT_LIST_HEAD(&gc_list);

This is stack variable, and it should be initialized outside of global spinlock.
There is LIST_HEAD() primitive for that in kernel.
There is one more similar place below.

> +	list_splice_init(&tls_device_gc_list, &gc_list);
> +	spin_unlock_irqrestore(&tls_device_lock, flags);
> +
> +	list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
> +		struct net_device *netdev = ctx->netdev;
> +
> +		if (netdev) {
> +			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
> +							TLS_OFFLOAD_CTX_DIR_TX);
> +			dev_put(netdev);
> +		}

How is possible the situation we meet NULL netdev here?

> +
> +		list_del(&ctx->list);
> +		tls_device_free_ctx(ctx);
> +	}
> +}
> +
> +static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
> +{
> +	unsigned long flags;
> +
> +	spin_lock_irqsave(&tls_device_lock, flags);
> +	list_move_tail(&ctx->list, &tls_device_gc_list);
> +
> +	/* schedule_work inside the spinlock
> +	 * to make sure tls_device_down waits for that work.
> +	 */
> +	schedule_work(&tls_device_gc_work);
> +
> +	spin_unlock_irqrestore(&tls_device_lock, flags);
> +}
> +
> +/* We assume that the socket is already connected */
> +static struct net_device *get_netdev_for_sock(struct sock *sk)
> +{
> +	struct inet_sock *inet = inet_sk(sk);
> +	struct net_device *netdev = NULL;
> +
> +	netdev = dev_get_by_index(sock_net(sk), inet->cork.fl.flowi_oif);
> +
> +	return netdev;
> +}
> +
> +static int attach_sock_to_netdev(struct sock *sk, struct net_device *netdev,
> +				 struct tls_context *ctx)
> +{
> +	int rc;
> +
> +	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
> +					     &ctx->crypto_send,
> +					     tcp_sk(sk)->write_seq);
> +	if (rc) {
> +		pr_err_ratelimited("The netdev has refused to offload this socket\n");
> +		goto out;
> +	}
> +
> +	rc = 0;
> +out:
> +	return rc;
> +}
> +
> +static void destroy_record(struct tls_record_info *record)
> +{
> +	skb_frag_t *frag;
> +	int nr_frags = record->num_frags;
> +
> +	while (nr_frags > 0) {
> +		frag = &record->frags[nr_frags - 1];
> +		__skb_frag_unref(frag);
> +		--nr_frags;
> +	}
> +	kfree(record);
> +}
> +
> +static void delete_all_records(struct tls_offload_context *offload_ctx)
> +{
> +	struct tls_record_info *info, *temp;
> +
> +	list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
> +		list_del(&info->list);
> +		destroy_record(info);
> +	}
> +
> +	offload_ctx->retransmit_hint = NULL;
> +}
> +
> +static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
> +{
> +	struct tls_context *tls_ctx = tls_get_ctx(sk);
> +	struct tls_offload_context *ctx;
> +	struct tls_record_info *info, *temp;
> +	unsigned long flags;
> +	u64 deleted_records = 0;
> +
> +	if (!tls_ctx)
> +		return;
> +
> +	ctx = tls_offload_ctx(tls_ctx);
> +
> +	spin_lock_irqsave(&ctx->lock, flags);
> +	info = ctx->retransmit_hint;
> +	if (info && !before(acked_seq, info->end_seq)) {
> +		ctx->retransmit_hint = NULL;
> +		list_del(&info->list);
> +		destroy_record(info);
> +		deleted_records++;
> +	}
> +
> +	list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
> +		if (before(acked_seq, info->end_seq))
> +			break;
> +		list_del(&info->list);
> +
> +		destroy_record(info);
> +		deleted_records++;
> +	}
> +
> +	ctx->unacked_record_sn += deleted_records;
> +	spin_unlock_irqrestore(&ctx->lock, flags);
> +}
> +
> +/* At this point, there should be no references on this
> + * socket and no in-flight SKBs associated with this
> + * socket, so it is safe to free all the resources.
> + */
> +void tls_device_sk_destruct(struct sock *sk)
> +{
> +	struct tls_context *tls_ctx = tls_get_ctx(sk);
> +	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
> +
> +	if (ctx->open_record)
> +		destroy_record(ctx->open_record);
> +
> +	delete_all_records(ctx);
> +	crypto_free_aead(ctx->aead_send);
> +	ctx->sk_destruct(sk);
> +
> +	if (refcount_dec_and_test(&tls_ctx->refcount))
> +		tls_device_queue_ctx_destruction(tls_ctx);
> +}
> +EXPORT_SYMBOL(tls_device_sk_destruct);
> +
> +static inline void tls_append_frag(struct tls_record_info *record,
> +				   struct page_frag *pfrag,
> +				   int size)
> +{
> +	skb_frag_t *frag;
> +
> +	frag = &record->frags[record->num_frags - 1];
> +	if (frag->page.p == pfrag->page &&
> +	    frag->page_offset + frag->size == pfrag->offset) {
> +		frag->size += size;
> +	} else {
> +		++frag;
> +		frag->page.p = pfrag->page;
> +		frag->page_offset = pfrag->offset;
> +		frag->size = size;
> +		++record->num_frags;
> +		get_page(pfrag->page);
> +	}
> +
> +	pfrag->offset += size;
> +	record->len += size;
> +}
> +
> +static inline int tls_push_record(struct sock *sk,
> +				  struct tls_context *ctx,
> +				  struct tls_offload_context *offload_ctx,
> +				  struct tls_record_info *record,
> +				  struct page_frag *pfrag,
> +				  int flags,
> +				  unsigned char record_type)
> +{
> +	skb_frag_t *frag;
> +	struct tcp_sock *tp = tcp_sk(sk);
> +	struct page_frag fallback_frag;
> +	struct page_frag  *tag_pfrag = pfrag;
> +	int i;
> +
> +	/* fill prepand */
> +	frag = &record->frags[0];
> +	tls_fill_prepend(ctx,
> +			 skb_frag_address(frag),
> +			 record->len - ctx->prepend_size,
> +			 record_type);
> +
> +	if (unlikely(!skb_page_frag_refill(ctx->tag_size, pfrag, GFP_KERNEL))) {
> +		/* HW doesn't care about the data in the tag
> +		 * so in case pfrag has no room
> +		 * for a tag and we can't allocate a new pfrag
> +		 * just use the page in the first frag
> +		 * rather then write a complicated fall back code.
> +		 */
> +		tag_pfrag = &fallback_frag;
> +		tag_pfrag->page = skb_frag_page(frag);
> +		tag_pfrag->offset = 0;
> +	}
> +
> +	tls_append_frag(record, tag_pfrag, ctx->tag_size);
> +	record->end_seq = tp->write_seq + record->len;
> +	spin_lock_irq(&offload_ctx->lock);
> +	list_add_tail(&record->list, &offload_ctx->records_list);
> +	spin_unlock_irq(&offload_ctx->lock);
> +	offload_ctx->open_record = NULL;
> +	set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
> +	tls_advance_record_sn(sk, ctx);
> +
> +	for (i = 0; i < record->num_frags; i++) {
> +		frag = &record->frags[i];
> +		sg_unmark_end(&offload_ctx->sg_tx_data[i]);
> +		sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
> +			    frag->size, frag->page_offset);
> +		sk_mem_charge(sk, frag->size);
> +		get_page(skb_frag_page(frag));
> +	}
> +	sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
> +
> +	/* all ready, send */
> +	return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
> +}
> +
> +static inline int tls_create_new_record(struct tls_offload_context *offload_ctx,
> +					struct page_frag *pfrag,
> +					size_t prepend_size)
> +{
> +	skb_frag_t *frag;
> +	struct tls_record_info *record;
> +
> +	record = kmalloc(sizeof(*record), GFP_KERNEL);
> +	if (!record)
> +		return -ENOMEM;
> +
> +	frag = &record->frags[0];
> +	__skb_frag_set_page(frag, pfrag->page);
> +	frag->page_offset = pfrag->offset;
> +	skb_frag_size_set(frag, prepend_size);
> +
> +	get_page(pfrag->page);
> +	pfrag->offset += prepend_size;
> +
> +	record->num_frags = 1;
> +	record->len = prepend_size;
> +	offload_ctx->open_record = record;
> +	return 0;
> +}
> +
> +static inline int tls_do_allocation(struct sock *sk,
> +				    struct tls_offload_context *offload_ctx,
> +				    struct page_frag *pfrag,
> +				    size_t prepend_size)
> +{
> +	int ret;
> +
> +	if (!offload_ctx->open_record) {
> +		if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
> +						   sk->sk_allocation))) {
> +			sk->sk_prot->enter_memory_pressure(sk);
> +			sk_stream_moderate_sndbuf(sk);
> +			return -ENOMEM;
> +		}
> +
> +		ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
> +		if (ret)
> +			return ret;
> +
> +		if (pfrag->size > pfrag->offset)
> +			return 0;
> +	}
> +
> +	if (!sk_page_frag_refill(sk, pfrag))
> +		return -ENOMEM;
> +
> +	return 0;
> +}
> +
> +static int tls_push_data(struct sock *sk,
> +			 struct iov_iter *msg_iter,
> +			 size_t size, int flags,
> +			 unsigned char record_type)
> +{
> +	struct tls_context *tls_ctx = tls_get_ctx(sk);
> +	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
> +	struct tls_record_info *record = ctx->open_record;
> +	struct page_frag *pfrag;
> +	int copy, rc = 0;
> +	size_t orig_size = size;
> +	u32 max_open_record_len;
> +	long timeo;
> +	int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
> +	int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
> +	bool done = false;
> +
> +	if (flags &
> +	    ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
> +		return -ENOTSUPP;
> +
> +	if (sk->sk_err)
> +		return -sk->sk_err;
> +
> +	timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
> +	rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
> +	if (rc < 0)
> +		return rc;
> +
> +	pfrag = sk_page_frag(sk);
> +
> +	/* KTLS_TLS_HEADER_SIZE is not counted as part of the TLS record, and
> +	 * we need to leave room for an authentication tag.
> +	 */
> +	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
> +			      tls_ctx->prepend_size;
> +	do {
> +		if (tls_do_allocation(sk, ctx, pfrag,
> +				      tls_ctx->prepend_size)) {
> +			rc = sk_stream_wait_memory(sk, &timeo);
> +			if (!rc)
> +				continue;
> +
> +			record = ctx->open_record;
> +			if (!record)
> +				break;
> +handle_error:
> +			if (record_type != TLS_RECORD_TYPE_DATA) {
> +				/* avoid sending partial
> +				 * record with type !=
> +				 * application_data
> +				 */
> +				size = orig_size;
> +				destroy_record(record);
> +				ctx->open_record = NULL;
> +			} else if (record->len > tls_ctx->prepend_size) {
> +				goto last_record;
> +			}
> +
> +			break;
> +		}
> +
> +		record = ctx->open_record;
> +		copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
> +		copy = min_t(size_t, copy, (max_open_record_len - record->len));
> +
> +		if (copy_from_iter_nocache(page_address(pfrag->page) +
> +					       pfrag->offset,
> +					   copy, msg_iter) != copy) {
> +			rc = -EFAULT;
> +			goto handle_error;
> +		}
> +		tls_append_frag(record, pfrag, copy);
> +
> +		size -= copy;
> +		if (!size) {
> +last_record:
> +			tls_push_record_flags = flags;
> +			if (more) {
> +				tls_ctx->pending_open_record_frags =
> +						record->num_frags;
> +				break;
> +			}
> +
> +			done = true;
> +		}
> +
> +		if ((done) || record->len >= max_open_record_len ||
> +		    (record->num_frags >= MAX_SKB_FRAGS - 1)) {
> +			rc = tls_push_record(sk,
> +					     tls_ctx,
> +					     ctx,
> +					     record,
> +					     pfrag,
> +					     tls_push_record_flags,
> +					     record_type);
> +			if (rc < 0)
> +				break;
> +		}
> +	} while (!done);
> +
> +	if (orig_size - size > 0)
> +		rc = orig_size - size;
> +
> +	return rc;
> +}
> +
> +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
> +{
> +	unsigned char record_type = TLS_RECORD_TYPE_DATA;
> +	int rc = 0;
> +
> +	lock_sock(sk);
> +
> +	if (unlikely(msg->msg_controllen)) {
> +		rc = tls_proccess_cmsg(sk, msg, &record_type);
> +		if (rc)
> +			goto out;
> +	}
> +
> +	rc = tls_push_data(sk, &msg->msg_iter, size,
> +			   msg->msg_flags, record_type);
> +
> +out:
> +	release_sock(sk);
> +	return rc;
> +}
> +
> +int tls_device_sendpage(struct sock *sk, struct page *page,
> +			int offset, size_t size, int flags)
> +{
> +	struct iov_iter	msg_iter;
> +	struct kvec iov;
> +	char *kaddr = kmap(page);
> +	int rc = 0;
> +
> +	if (flags & MSG_SENDPAGE_NOTLAST)
> +		flags |= MSG_MORE;
> +
> +	lock_sock(sk);
> +
> +	if (flags & MSG_OOB) {
> +		rc = -ENOTSUPP;
> +		goto out;
> +	}
> +
> +	iov.iov_base = kaddr + offset;
> +	iov.iov_len = size;
> +	iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size);
> +	rc = tls_push_data(sk, &msg_iter, size,
> +			   flags, TLS_RECORD_TYPE_DATA);
> +	kunmap(page);
> +
> +out:
> +	release_sock(sk);
> +	return rc;
> +}
> +
> +struct tls_record_info *tls_get_record(struct tls_offload_context *context,
> +				       u32 seq, u64 *p_record_sn)
> +{
> +	struct tls_record_info *info;
> +	u64 record_sn = context->hint_record_sn;
> +
> +	info = context->retransmit_hint;
> +	if (!info ||
> +	    before(seq, info->end_seq - info->len)) {
> +		/* if retransmit_hint is irrelevant start
> +		 * from the begging of the list
> +		 */
> +		info = list_first_entry(&context->records_list,
> +					struct tls_record_info, list);
> +		record_sn = context->unacked_record_sn;
> +	}
> +
> +	list_for_each_entry_from(info, &context->records_list, list) {
> +		if (before(seq, info->end_seq)) {
> +			if (!context->retransmit_hint ||
> +			    after(info->end_seq,
> +				  context->retransmit_hint->end_seq)) {
> +				context->hint_record_sn = record_sn;
> +				context->retransmit_hint = info;
> +			}
> +			*p_record_sn = record_sn;
> +			return info;
> +		}
> +		record_sn++;
> +	}
> +
> +	return NULL;
> +}
> +EXPORT_SYMBOL(tls_get_record);
> +
> +static int tls_device_push_pending_record(struct sock *sk, int flags)
> +{
> +	struct iov_iter	msg_iter;
> +
> +	iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
> +	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
> +}
> +
> +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
> +{
> +	u16 nonece_size, tag_size, iv_size, rec_seq_size;
> +	struct tls_record_info *start_marker_record;
> +	struct tls_offload_context *offload_ctx;
> +	struct tls_crypto_info *crypto_info;
> +	struct net_device *netdev;
> +	char *iv, *rec_seq;
> +	struct sk_buff *skb;
> +	int rc = -EINVAL;
> +	__be64 rcd_sn;
> +
> +	if (!ctx)
> +		goto out;
> +
> +	if (ctx->priv_ctx) {
> +		rc = -EEXIST;
> +		goto out;
> +	}
> +
> +	/* We support starting offload on multiple sockets
> +	 * concurrently, So we only need a read lock here.
> +	 */
> +	percpu_down_read(&device_offload_lock);
> +	netdev = get_netdev_for_sock(sk);
> +	if (!netdev) {
> +		pr_err_ratelimited("%s: netdev not found\n", __func__);
> +		rc = -EINVAL;
> +		goto release_lock;
> +	}
> +
> +	if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
> +		rc = -ENOTSUPP;
> +		goto release_netdev;
> +	}
> +
> +	/* Avoid offloading if the device is down
> +	 * We don't want to offload new flows after
> +	 * the NETDEV_DOWN event
> +	 */
> +	if (!(netdev->flags & IFF_UP)) {
> +		rc = -EINVAL;
> +		goto release_lock;
> +	}
> +
> +	crypto_info = &ctx->crypto_send;
> +	switch (crypto_info->cipher_type) {
> +	case TLS_CIPHER_AES_GCM_128: {
> +		nonece_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
> +		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
> +		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
> +		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
> +		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
> +		rec_seq =
> +		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
> +		break;
> +	}
> +	default:
> +		rc = -EINVAL;
> +		goto release_netdev;
> +	}
> +
> +	start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);

Can we memory allocations and simple memory initializations ouside the global rwsem?

> +	if (!start_marker_record) {
> +		rc = -ENOMEM;
> +		goto release_netdev;
> +	}
> +
> +	offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL);
> +	if (!offload_ctx)
> +		goto free_marker_record;
> +
> +	ctx->priv_ctx = offload_ctx;
> +	rc = attach_sock_to_netdev(sk, netdev, ctx);
> +	if (rc)
> +		goto free_offload_context;
> +
> +	ctx->netdev = netdev;
> +	ctx->prepend_size = TLS_HEADER_SIZE + nonece_size;
> +	ctx->tag_size = tag_size;
> +	ctx->iv_size = iv_size;
> +	ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
> +			  GFP_KERNEL);
> +	if (!ctx->iv) {
> +		rc = -ENOMEM;
> +		goto detach_sock;
> +	}
> +
> +	memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
> +
> +	ctx->rec_seq_size = rec_seq_size;
> +	ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
> +	if (!ctx->rec_seq) {
> +		rc = -ENOMEM;
> +		goto free_iv;
> +	}
> +	memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
> +
> +	/* start at rec_seq - 1 to account for the start marker record */
> +	memcpy(&rcd_sn, ctx->rec_seq, sizeof(rcd_sn));
> +	offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
> +
> +	rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
> +	if (rc)
> +		goto free_rec_seq;
> +
> +	start_marker_record->end_seq = tcp_sk(sk)->write_seq;
> +	start_marker_record->len = 0;
> +	start_marker_record->num_frags = 0;
> +
> +	INIT_LIST_HEAD(&offload_ctx->records_list);
> +	list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
> +	spin_lock_init(&offload_ctx->lock);
> +
> +	inet_csk(sk)->icsk_clean_acked = &tls_icsk_clean_acked;
> +	ctx->push_pending_record = tls_device_push_pending_record;
> +	offload_ctx->sk_destruct = sk->sk_destruct;
> +
> +	/* TLS offload is greatly simplified if we don't send
> +	 * SKBs where only part of the payload needs to be encrypted.
> +	 * So mark the last skb in the write queue as end of record.
> +	 */
> +	skb = tcp_write_queue_tail(sk);
> +	if (skb)
> +		TCP_SKB_CB(skb)->eor = 1;
> +
> +	refcount_set(&ctx->refcount, 1);
> +	spin_lock_irq(&tls_device_lock);
> +	list_add_tail(&ctx->list, &tls_device_list);
> +	spin_unlock_irq(&tls_device_lock);
> +
> +	/* following this assignment tls_is_sk_tx_device_offloaded
> +	 * will return true and the context might be accessed
> +	 * by the netdev's xmit function.
> +	 */
> +	smp_store_release(&sk->sk_destruct,
> +			  &tls_device_sk_destruct);
> +	goto release_lock;
> +
> +free_rec_seq:
> +	kfree(ctx->rec_seq);
> +free_iv:
> +	kfree(ctx->iv);
> +detach_sock:
> +	netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX);
> +free_offload_context:
> +	kfree(offload_ctx);
> +	ctx->priv_ctx = NULL;
> +free_marker_record:
> +	kfree(start_marker_record);
> +release_netdev:
> +	dev_put(netdev);
> +release_lock:
> +	percpu_up_read(&device_offload_lock);
> +out:
> +	return rc;
> +}
> +
> +static int tls_device_register(struct net_device *dev)
> +{
> +	if ((dev->features & NETIF_F_HW_TLS_TX) && !dev->tlsdev_ops)
> +		return NOTIFY_BAD;
> +
> +	return NOTIFY_DONE;
> +}

This function is the same as tls_device_feat_change(). Can't we merge
them together and avoid duplicating of code?

> +static int tls_device_unregister(struct net_device *dev)
> +{
> +	return NOTIFY_DONE;
> +}

This function does nothing, and next patches do not change it.
Can't we remove it since so?

> +static int tls_device_feat_change(struct net_device *dev)
> +{
> +	if ((dev->features & NETIF_F_HW_TLS_TX) && !dev->tlsdev_ops)
> +		return NOTIFY_BAD;
> +
> +	return NOTIFY_DONE;
> +}
> +
> +static int tls_device_down(struct net_device *netdev)
> +{
> +	struct tls_context *ctx, *tmp;
> +	struct list_head list;
> +	unsigned long flags;
> +
> +	if (!(netdev->features & NETIF_F_HW_TLS_TX))
> +		return NOTIFY_DONE;

Can't we move this check in tls_dev_event() and use it for all types of events?
Then we avoid duplicate code.

> +
> +	/* Request a write lock to block new offload attempts
> +	 */
> +	percpu_down_write(&device_offload_lock);

What is the reason percpu_rwsem is chosen here? It looks like this primitive
gives more advantages readers, then plain rwsem does. But it also gives
disadvantages to writers. It would be good, unless tls_device_down() is called
with rtnl_lock() held from netdevice notifier. But since netdevice notifier
are called with rtnl_lock() held, percpu_rwsem will increase the time rtnl_lock()
is locked.

Can't we use plain rwsem here instead?

> +
> +	spin_lock_irqsave(&tls_device_lock, flags);
> +	INIT_LIST_HEAD(&list);

This may go outside the global spinlock.

> +	list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
> +		if (ctx->netdev != netdev ||
> +		    !refcount_inc_not_zero(&ctx->refcount))
> +			continue;
> +
> +		list_move(&ctx->list, &list);
> +	}
> +	spin_unlock_irqrestore(&tls_device_lock, flags);
> +
> +	list_for_each_entry_safe(ctx, tmp, &list, list)	{
> +		netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
> +						TLS_OFFLOAD_CTX_DIR_TX);
> +		ctx->netdev = NULL;
> +		dev_put(netdev);
> +		list_del_init(&ctx->list);
> +
> +		if (refcount_dec_and_test(&ctx->refcount))
> +			tls_device_free_ctx(ctx);
> +	}
> +
> +	percpu_up_write(&device_offload_lock);
> +
> +	flush_work(&tls_device_gc_work);
> +
> +	return NOTIFY_DONE;
> +}
> +
> +static int tls_dev_event(struct notifier_block *this, unsigned long event,
> +			 void *ptr)
> +{
> +	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
> +
> +	switch (event) {
> +	case NETDEV_REGISTER:
> +		return tls_device_register(dev);
> +
> +	case NETDEV_UNREGISTER:
> +		return tls_device_unregister(dev);
> +
> +	case NETDEV_FEAT_CHANGE:
> +		return tls_device_feat_change(dev);
> +
> +	case NETDEV_DOWN:
> +		return tls_device_down(dev);
> +	}
> +	return NOTIFY_DONE;
> +}
> +
> +static struct notifier_block tls_dev_notifier = {
> +	.notifier_call	= tls_dev_event,
> +};
> +
> +void __init tls_device_init(void)
> +{
> +	register_netdevice_notifier(&tls_dev_notifier);
> +}
> +
> +void __exit tls_device_cleanup(void)
> +{
> +	unregister_netdevice_notifier(&tls_dev_notifier);
> +	flush_work(&tls_device_gc_work);
> +}
> diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
> new file mode 100644
> index 000000000000..14d31a36885c
> --- /dev/null
> +++ b/net/tls/tls_device_fallback.c
> @@ -0,0 +1,419 @@
> +/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
> + *
> + *     Redistribution and use in source and binary forms, with or
> + *     without modification, are permitted provided that the following
> + *     conditions are met:
> + *
> + *      - Redistributions of source code must retain the above
> + *        copyright notice, this list of conditions and the following
> + *        disclaimer.
> + *
> + *      - Redistributions in binary form must reproduce the above
> + *        copyright notice, this list of conditions and the following
> + *        disclaimer in the documentation and/or other materials
> + *        provided with the distribution.
> + *
> + *      - Neither the name of the Mellanox Technologies nor the
> + *        names of its contributors may be used to endorse or promote
> + *        products derived from this software without specific prior written
> + *        permission.
> + *
> + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
> + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
> + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
> + * A PARTICULAR PURPOSE ARE DISCLAIMED.
> + * IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
> + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
> + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
> + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
> + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
> + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
> + * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
> + * POSSIBILITY OF SUCH DAMAGE
> + */
> +
> +#include <net/tls.h>
> +#include <crypto/aead.h>
> +#include <crypto/scatterwalk.h>
> +#include <net/ip6_checksum.h>
> +
> +static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
> +{
> +	struct scatterlist *src = walk->sg;
> +	int diff = walk->offset - src->offset;
> +
> +	sg_set_page(sg, sg_page(src),
> +		    src->length - diff, walk->offset);
> +
> +	scatterwalk_crypto_chain(sg, sg_next(src), 0, 2);
> +}
> +
> +static int tls_enc_record(struct aead_request *aead_req,
> +			  struct crypto_aead *aead, char *aad, char *iv,
> +			  __be64 rcd_sn, struct scatter_walk *in,
> +			  struct scatter_walk *out, int *in_len)
> +{
> +	struct scatterlist sg_in[3];
> +	struct scatterlist sg_out[3];
> +	unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
> +	u16 len;
> +	int rc;
> +
> +	len = min_t(int, *in_len, ARRAY_SIZE(buf));
> +
> +	scatterwalk_copychunks(buf, in, len, 0);
> +	scatterwalk_copychunks(buf, out, len, 1);
> +
> +	*in_len -= len;
> +	if (!*in_len)
> +		return 0;
> +
> +	scatterwalk_pagedone(in, 0, 1);
> +	scatterwalk_pagedone(out, 1, 1);
> +
> +	len = buf[4] | (buf[3] << 8);
> +	len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
> +
> +	tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
> +		     (char *)&rcd_sn, sizeof(rcd_sn), buf[0]);
> +
> +	memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
> +	       TLS_CIPHER_AES_GCM_128_IV_SIZE);
> +
> +	sg_init_table(sg_in, ARRAY_SIZE(sg_in));
> +	sg_init_table(sg_out, ARRAY_SIZE(sg_out));
> +	sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
> +	sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
> +	chain_to_walk(sg_in + 1, in);
> +	chain_to_walk(sg_out + 1, out);
> +
> +	*in_len -= len;
> +	if (*in_len < 0) {
> +		*in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
> +		if (*in_len < 0)
> +		/* the input buffer doesn't contain the entire record.
> +		 * trim len accordingly. The resulting authentication tag
> +		 * will contain garbage. but we don't care as we won't
> +		 * include any of it in the output skb
> +		 * Note that we assume the output buffer length
> +		 * is larger then input buffer length + tag size
> +		 */
> +			len += *in_len;
> +
> +		*in_len = 0;
> +	}
> +
> +	if (*in_len) {
> +		scatterwalk_copychunks(NULL, in, len, 2);
> +		scatterwalk_pagedone(in, 0, 1);
> +		scatterwalk_copychunks(NULL, out, len, 2);
> +		scatterwalk_pagedone(out, 1, 1);
> +	}
> +
> +	len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
> +	aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
> +
> +	rc = crypto_aead_encrypt(aead_req);
> +
> +	return rc;
> +}
> +
> +static void tls_init_aead_request(struct aead_request *aead_req,
> +				  struct crypto_aead *aead)
> +{
> +	aead_request_set_tfm(aead_req, aead);
> +	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
> +}
> +
> +static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead,
> +						   gfp_t flags)
> +{
> +	unsigned int req_size = sizeof(struct aead_request) +
> +		crypto_aead_reqsize(aead);
> +	struct aead_request *aead_req;
> +
> +	aead_req = kzalloc(req_size, flags);
> +	if (!aead_req)
> +		return NULL;
> +
> +	tls_init_aead_request(aead_req, aead);
> +	return aead_req;
> +}
> +
> +static int tls_enc_records(struct aead_request *aead_req,
> +			   struct crypto_aead *aead, struct scatterlist *sg_in,
> +			   struct scatterlist *sg_out, char *aad, char *iv,
> +			   u64 rcd_sn, int len)
> +{
> +	struct scatter_walk in;
> +	struct scatter_walk out;
> +	int rc;
> +
> +	scatterwalk_start(&in, sg_in);
> +	scatterwalk_start(&out, sg_out);
> +
> +	do {
> +		rc = tls_enc_record(aead_req, aead, aad, iv,
> +				    cpu_to_be64(rcd_sn), &in, &out, &len);
> +		rcd_sn++;
> +
> +	} while (rc == 0 && len);
> +
> +	scatterwalk_done(&in, 0, 0);
> +	scatterwalk_done(&out, 1, 0);
> +
> +	return rc;
> +}
> +
> +static inline void update_chksum(struct sk_buff *skb, int headln)
> +{
> +	/* Can't use icsk->icsk_af_ops->send_check here because the ip addresses
> +	 * might have been changed by NAT.
> +	 */
> +
> +	const struct ipv6hdr *ipv6h;
> +	const struct iphdr *iph;
> +	struct tcphdr *th = tcp_hdr(skb);
> +	int datalen = skb->len - headln;
> +
> +	/* We only changed the payload so if we are using partial we don't
> +	 * need to update anything.
> +	 */
> +	if (likely(skb->ip_summed == CHECKSUM_PARTIAL))
> +		return;
> +
> +	skb->ip_summed = CHECKSUM_PARTIAL;
> +	skb->csum_start = skb_transport_header(skb) - skb->head;
> +	skb->csum_offset = offsetof(struct tcphdr, check);
> +
> +	if (skb->sk->sk_family == AF_INET6) {
> +		ipv6h = ipv6_hdr(skb);
> +		th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
> +					     datalen, IPPROTO_TCP, 0);
> +	} else {
> +		iph = ip_hdr(skb);
> +		th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen,
> +					       IPPROTO_TCP, 0);
> +	}
> +}
> +
> +static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
> +{
> +	skb_copy_header(nskb, skb);
> +
> +	skb_put(nskb, skb->len);
> +	memcpy(nskb->data, skb->data, headln);
> +	update_chksum(nskb, headln);
> +
> +	nskb->destructor = skb->destructor;
> +	nskb->sk = skb->sk;
> +	skb->destructor = NULL;
> +	skb->sk = NULL;
> +	refcount_add(nskb->truesize - skb->truesize,
> +		     &nskb->sk->sk_wmem_alloc);
> +}
> +
> +/* This function may be called after the user socket is already
> + * closed so make sure we don't use anything freed during
> + * tls_sk_proto_close here
> + */
> +static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
> +{
> +	int tcp_header_size = tcp_hdrlen(skb);
> +	int tcp_payload_offset = skb_transport_offset(skb) + tcp_header_size;
> +	int payload_len = skb->len - tcp_payload_offset;
> +	struct tls_context *tls_ctx = tls_get_ctx(sk);
> +	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
> +	int remaining, buf_len, resync_sgs, rc, i = 0;
> +	void *buf, *dummy_buf, *iv, *aad;
> +	struct scatterlist *sg_in;
> +	struct scatterlist sg_out[3];
> +	u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
> +	struct aead_request *aead_req;
> +	struct sk_buff *nskb = NULL;
> +	struct tls_record_info *record;
> +	unsigned long flags;
> +	s32 sync_size;
> +	u64 rcd_sn;
> +
> +	/* worst case is:
> +	 * MAX_SKB_FRAGS in tls_record_info
> +	 * MAX_SKB_FRAGS + 1 in SKB head an frags.
> +	 */
> +	int sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1;
> +
> +	if (!payload_len)
> +		return skb;
> +
> +	sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC);
> +	if (!sg_in)
> +		goto free_orig;
> +
> +	sg_init_table(sg_in, sg_in_max_elements);
> +	sg_init_table(sg_out, ARRAY_SIZE(sg_out));
> +
> +	spin_lock_irqsave(&ctx->lock, flags);
> +	record = tls_get_record(ctx, tcp_seq, &rcd_sn);
> +	if (!record) {
> +		spin_unlock_irqrestore(&ctx->lock, flags);
> +		WARN(1, "Record not found for seq %u\n", tcp_seq);
> +		goto free_sg;
> +	}
> +
> +	sync_size = tcp_seq - tls_record_start_seq(record);
> +	if (sync_size < 0) {
> +		int is_start_marker = tls_record_is_start_marker(record);
> +
> +		spin_unlock_irqrestore(&ctx->lock, flags);
> +		if (!is_start_marker)
> +		/* This should only occur if the relevant record was
> +		 * already acked. In that case it should be ok
> +		 * to drop the packet and avoid retransmission.
> +		 *
> +		 * There is a corner case where the packet contains
> +		 * both an acked and a non-acked record.
> +		 * We currently don't handle that case and rely
> +		 * on TCP to retranmit a packet that doesn't contain
> +		 * already acked payload.
> +		 */
> +			goto free_orig;
> +
> +		if (payload_len > -sync_size) {
> +			WARN(1, "Fallback of partially offloaded packets is not supported\n");
> +			goto free_sg;
> +		} else {
> +			return skb;
> +		}
> +	}
> +
> +	remaining = sync_size;
> +	while (remaining > 0) {
> +		skb_frag_t *frag = &record->frags[i];
> +
> +		__skb_frag_ref(frag);
> +		sg_set_page(sg_in + i, skb_frag_page(frag),
> +			    skb_frag_size(frag), frag->page_offset);
> +
> +		remaining -= skb_frag_size(frag);
> +
> +		if (remaining < 0)
> +			sg_in[i].length += remaining;
> +
> +		i++;
> +	}
> +	spin_unlock_irqrestore(&ctx->lock, flags);
> +	resync_sgs = i;
> +
> +	aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC);
> +	if (!aead_req)
> +		goto put_sg;
> +
> +	buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
> +		  TLS_CIPHER_AES_GCM_128_IV_SIZE +
> +		  TLS_AAD_SPACE_SIZE +
> +		  sync_size +
> +		  tls_ctx->tag_size;
> +	buf = kmalloc(buf_len, GFP_ATOMIC);
> +	if (!buf)
> +		goto free_req;
> +
> +	nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
> +	if (!nskb)
> +		goto free_buf;
> +
> +	skb_reserve(nskb, skb_headroom(skb));
> +
> +	iv = buf;
> +
> +	memcpy(iv, tls_ctx->crypto_send_aes_gcm_128.salt,
> +	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
> +	aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
> +	      TLS_CIPHER_AES_GCM_128_IV_SIZE;
> +	dummy_buf = aad + TLS_AAD_SPACE_SIZE;
> +
> +	sg_set_buf(&sg_out[0], dummy_buf, sync_size);
> +	sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset,
> +		   payload_len);
> +	/* Add room for authentication tag produced by crypto */
> +	dummy_buf += sync_size;
> +	sg_set_buf(&sg_out[2], dummy_buf, tls_ctx->tag_size);
> +	rc = skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset,
> +			  payload_len);
> +	if (rc < 0)
> +		goto free_nskb;
> +
> +	rc = tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv,
> +			     rcd_sn, sync_size + payload_len);
> +	if (rc < 0)
> +		goto free_nskb;
> +
> +	complete_skb(nskb, skb, tcp_payload_offset);
> +
> +	/* validate_xmit_skb_list assumes that if the skb wasn't segmented
> +	 * nskb->prev will point to the skb itself
> +	 */
> +	nskb->prev = nskb;
> +free_buf:
> +	kfree(buf);
> +free_req:
> +	kfree(aead_req);
> +put_sg:
> +	for (i = 0; i < resync_sgs; i++)
> +		put_page(sg_page(&sg_in[i]));
> +free_sg:
> +	kfree(sg_in);
> +free_orig:
> +	kfree_skb(skb);
> +	return nskb;
> +
> +free_nskb:
> +	kfree_skb(nskb);
> +	nskb = NULL;
> +	goto free_buf;
> +}
> +
> +static struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
> +					     struct net_device *dev,
> +					     struct sk_buff *skb)
> +{
> +	if (dev == tls_get_ctx(sk)->netdev)
> +		return skb;
> +
> +	return tls_sw_fallback(sk, skb);
> +}
> +
> +int tls_sw_fallback_init(struct sock *sk,
> +			 struct tls_offload_context *offload_ctx,
> +			 struct tls_crypto_info *crypto_info)
> +{
> +	int rc;
> +	const u8 *key;
> +
> +	offload_ctx->aead_send =
> +	    crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
> +	if (IS_ERR(offload_ctx->aead_send)) {
> +		rc = PTR_ERR(offload_ctx->aead_send);
> +		pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc);
> +		offload_ctx->aead_send = NULL;
> +		goto err_out;
> +	}
> +
> +	key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
> +
> +	rc = crypto_aead_setkey(offload_ctx->aead_send, key,
> +				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
> +	if (rc)
> +		goto free_aead;
> +
> +	rc = crypto_aead_setauthsize(offload_ctx->aead_send,
> +				     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
> +	if (rc)
> +		goto free_aead;
> +
> +	sk->sk_validate_xmit_skb = tls_validate_xmit_skb;
> +	return 0;
> +free_aead:
> +	crypto_free_aead(offload_ctx->aead_send);
> +err_out:
> +	return rc;
> +}
> diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
> index d824d548447e..e0dface33017 100644
> --- a/net/tls/tls_main.c
> +++ b/net/tls/tls_main.c
> @@ -54,6 +54,9 @@ enum {
>  enum {
>  	TLS_BASE_TX,
>  	TLS_SW_TX,
> +#ifdef CONFIG_TLS_DEVICE
> +	TLS_HW_TX,
> +#endif
>  	TLS_NUM_CONFIG,
>  };
>  
> @@ -416,11 +419,19 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
>  		goto err_crypto_info;
>  	}
>  
> -	/* currently SW is default, we will have ethtool in future */
> -	rc = tls_set_sw_offload(sk, ctx);
> -	tx_conf = TLS_SW_TX;
> -	if (rc)
> -		goto err_crypto_info;
> +#ifdef CONFIG_TLS_DEVICE
> +	rc = tls_set_device_offload(sk, ctx);
> +	tx_conf = TLS_HW_TX;
> +	if (rc) {
> +#else
> +	{
> +#endif
> +		/* if HW offload fails fallback to SW */
> +		rc = tls_set_sw_offload(sk, ctx);
> +		tx_conf = TLS_SW_TX;
> +		if (rc)
> +			goto err_crypto_info;
> +	}
>  
>  	ctx->tx_conf = tx_conf;
>  	update_sk_prot(sk, ctx);
> @@ -473,6 +484,12 @@ static void build_protos(struct proto *prot, struct proto *base)
>  	prot[TLS_SW_TX] = prot[TLS_BASE_TX];
>  	prot[TLS_SW_TX].sendmsg		= tls_sw_sendmsg;
>  	prot[TLS_SW_TX].sendpage	= tls_sw_sendpage;
> +
> +#ifdef CONFIG_TLS_DEVICE
> +	prot[TLS_HW_TX] = prot[TLS_SW_TX];
> +	prot[TLS_HW_TX].sendmsg		= tls_device_sendmsg;
> +	prot[TLS_HW_TX].sendpage	= tls_device_sendpage;
> +#endif
>  }
>  
>  static int tls_init(struct sock *sk)
> @@ -531,6 +548,9 @@ static int __init tls_register(void)
>  {
>  	build_protos(tls_prots[TLSV4], &tcp_prot);
>  
> +#ifdef CONFIG_TLS_DEVICE
> +	tls_device_init();
> +#endif
>  	tcp_register_ulp(&tcp_tls_ulp_ops);
>  
>  	return 0;
> @@ -539,6 +559,9 @@ static int __init tls_register(void)
>  static void __exit tls_unregister(void)
>  {
>  	tcp_unregister_ulp(&tcp_tls_ulp_ops);
> +#ifdef CONFIG_TLS_DEVICE
> +	tls_device_cleanup();
> +#endif
>  }
>  
>  module_init(tls_register);

Thanks,
Kirill
Dave Watson March 21, 2018, 3:08 p.m. UTC | #2
On 03/19/18 07:45 PM, Saeed Mahameed wrote:
> +#define TLS_OFFLOAD_CONTEXT_SIZE                                               \
> +	(ALIGN(sizeof(struct tls_offload_context), sizeof(void *)) +           \
> +	 TLS_DRIVER_STATE_SIZE)
> +
> +	pfrag = sk_page_frag(sk);
> +
> +	/* KTLS_TLS_HEADER_SIZE is not counted as part of the TLS record, and

I think the define is actually TLS_HEADER_SIZE, no KTLS_ prefix

> +	memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
> +
> +	ctx->rec_seq_size = rec_seq_size;
> +	/* worst case is:
> +	 * MAX_SKB_FRAGS in tls_record_info
> +	 * MAX_SKB_FRAGS + 1 in SKB head an frags.

spelling

> +int tls_sw_fallback_init(struct sock *sk,
> +			 struct tls_offload_context *offload_ctx,
> +			 struct tls_crypto_info *crypto_info)
> +{
> +	int rc;
> +	const u8 *key;
> +
> +	offload_ctx->aead_send =
> +	    crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);

in tls_sw we went with async + crypto_wait_req, any reason to not do
that here?  Otherwise I think you still get the software gcm on x86
instead of aesni without additional changes.

> diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
> index d824d548447e..e0dface33017 100644
> --- a/net/tls/tls_main.c
> +++ b/net/tls/tls_main.c
> @@ -54,6 +54,9 @@ enum {
>  enum {
>  	TLS_BASE_TX,
>  	TLS_SW_TX,
> +#ifdef CONFIG_TLS_DEVICE
> +	TLS_HW_TX,
> +#endif
>  	TLS_NUM_CONFIG,
>  };

I have posted SW_RX patches, do you forsee any issues with SW_RX + HW_TX?

Thanks
Boris Pismenny March 21, 2018, 3:38 p.m. UTC | #3
On 3/21/2018 5:08 PM, Dave Watson wrote:
> On 03/19/18 07:45 PM, Saeed Mahameed wrote:
>> +#define TLS_OFFLOAD_CONTEXT_SIZE                                               \
>> +	(ALIGN(sizeof(struct tls_offload_context), sizeof(void *)) +           \
>> +	 TLS_DRIVER_STATE_SIZE)
>> +
>> +	pfrag = sk_page_frag(sk);
>> +
>> +	/* KTLS_TLS_HEADER_SIZE is not counted as part of the TLS record, and
> 
> I think the define is actually TLS_HEADER_SIZE, no KTLS_ prefix
> 

Fixed. Thanks.

>> +	memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
>> +
>> +	ctx->rec_seq_size = rec_seq_size;
>> +	/* worst case is:
>> +	 * MAX_SKB_FRAGS in tls_record_info
>> +	 * MAX_SKB_FRAGS + 1 in SKB head an frags.
> 
> spelling
> 

Fixed. Thanks.

>> +int tls_sw_fallback_init(struct sock *sk,
>> +			 struct tls_offload_context *offload_ctx,
>> +			 struct tls_crypto_info *crypto_info)
>> +{
>> +	int rc;
>> +	const u8 *key;
>> +
>> +	offload_ctx->aead_send =
>> +	    crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
> 
> in tls_sw we went with async + crypto_wait_req, any reason to not do
> that here?  Otherwise I think you still get the software gcm on x86
> instead of aesni without additional changes.
> 

Yes, synchronous crypto code runs to handle a software fallback in 
validate_xmit_skb, where waiting is not possible. I know Steffen 
recently added support for calling async crypto from validate_xmit_skb, 
but it wasn't available when we were writing these patches.

I think we could implemented async support in the future based on the 
infrastructure introduced by Steffen.

>> diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
>> index d824d548447e..e0dface33017 100644
>> --- a/net/tls/tls_main.c
>> +++ b/net/tls/tls_main.c
>> @@ -54,6 +54,9 @@ enum {
>>   enum {
>>   	TLS_BASE_TX,
>>   	TLS_SW_TX,
>> +#ifdef CONFIG_TLS_DEVICE
>> +	TLS_HW_TX,
>> +#endif
>>   	TLS_NUM_CONFIG,
>>   };
> 
> I have posted SW_RX patches, do you forsee any issues with SW_RX + HW_TX?
> 

No, but I haven't tested these patches with the SW_RX patches.
I'll try to rebase your V2 SW_RX patches over this series tomorrow and 
run some tests.

> Thanks
>
Boris Pismenny March 21, 2018, 3:53 p.m. UTC | #4
...
> 
> Other patches have two licenses in header. Can I distribute this file under GPL license terms?
> 

Sure, I'll update the license to match other files under net/tls.

>> +#include <linux/module.h>
>> +#include <net/tcp.h>
>> +#include <net/inet_common.h>
>> +#include <linux/highmem.h>
>> +#include <linux/netdevice.h>
>> +
>> +#include <net/tls.h>
>> +#include <crypto/aead.h>
>> +
>> +/* device_offload_lock is used to synchronize tls_dev_add
>> + * against NETDEV_DOWN notifications.
>> + */
>> +DEFINE_STATIC_PERCPU_RWSEM(device_offload_lock);
>> +
>> +static void tls_device_gc_task(struct work_struct *work);
>> +
>> +static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
>> +static LIST_HEAD(tls_device_gc_list);
>> +static LIST_HEAD(tls_device_list);
>> +static DEFINE_SPINLOCK(tls_device_lock);
>> +
>> +static void tls_device_free_ctx(struct tls_context *ctx)
>> +{
>> +	struct tls_offload_context *offlad_ctx = tls_offload_ctx(ctx);
>> +
>> +	kfree(offlad_ctx);
>> +	kfree(ctx);
>> +}
>> +
>> +static void tls_device_gc_task(struct work_struct *work)
>> +{
>> +	struct tls_context *ctx, *tmp;
>> +	struct list_head gc_list;
>> +	unsigned long flags;
>> +
>> +	spin_lock_irqsave(&tls_device_lock, flags);
>> +	INIT_LIST_HEAD(&gc_list);
> 
> This is stack variable, and it should be initialized outside of global spinlock.
> There is LIST_HEAD() primitive for that in kernel.
> There is one more similar place below.
> 

Sure.

>> +	list_splice_init(&tls_device_gc_list, &gc_list);
>> +	spin_unlock_irqrestore(&tls_device_lock, flags);
>> +
>> +	list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
>> +		struct net_device *netdev = ctx->netdev;
>> +
>> +		if (netdev) {
>> +			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
>> +							TLS_OFFLOAD_CTX_DIR_TX);
>> +			dev_put(netdev);
>> +		}
> 
> How is possible the situation we meet NULL netdev here >

This can happen in tls_device_down. tls_deviec_down is called whenever a 
netdev that is used for TLS inline crypto offload goes down. It gets 
called via the NETDEV_DOWN event of the netdevice notifier.

This flow is somewhat similar to the xfrm_device netdev notifier. 
However, we do not destroy the socket (as in destroying the xfrm_state 
in xfrm_device). Instead, we cleanup the netdev state and allow software 
fallback to handle the rest of the traffic.

>> +
>> +		list_del(&ctx->list);
>> +		tls_device_free_ctx(ctx);
>> +	}
>> +}
>> +
>> +static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
>> +{
>> +	unsigned long flags;
>> +
>> +	spin_lock_irqsave(&tls_device_lock, flags);
>> +	list_move_tail(&ctx->list, &tls_device_gc_list);
>> +
>> +	/* schedule_work inside the spinlock
>> +	 * to make sure tls_device_down waits for that work.
>> +	 */
>> +	schedule_work(&tls_device_gc_work);
>> +
>> +	spin_unlock_irqrestore(&tls_device_lock, flags);
>> +}
>> +
>> +/* We assume that the socket is already connected */
>> +static struct net_device *get_netdev_for_sock(struct sock *sk)
>> +{
>> +	struct inet_sock *inet = inet_sk(sk);
>> +	struct net_device *netdev = NULL;
>> +
>> +	netdev = dev_get_by_index(sock_net(sk), inet->cork.fl.flowi_oif);
>> +
>> +	return netdev;
>> +}
>> +
>> +static int attach_sock_to_netdev(struct sock *sk, struct net_device *netdev,
>> +				 struct tls_context *ctx)
>> +{
>> +	int rc;
>> +
>> +	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
>> +					     &ctx->crypto_send,
>> +					     tcp_sk(sk)->write_seq);
>> +	if (rc) {
>> +		pr_err_ratelimited("The netdev has refused to offload this socket\n");
>> +		goto out;
>> +	}
>> +
>> +	rc = 0;
>> +out:
>> +	return rc;
>> +}
>> +
>> +static void destroy_record(struct tls_record_info *record)
>> +{
>> +	skb_frag_t *frag;
>> +	int nr_frags = record->num_frags;
>> +
>> +	while (nr_frags > 0) {
>> +		frag = &record->frags[nr_frags - 1];
>> +		__skb_frag_unref(frag);
>> +		--nr_frags;
>> +	}
>> +	kfree(record);
>> +}
>> +
>> +static void delete_all_records(struct tls_offload_context *offload_ctx)
>> +{
>> +	struct tls_record_info *info, *temp;
>> +
>> +	list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
>> +		list_del(&info->list);
>> +		destroy_record(info);
>> +	}
>> +
>> +	offload_ctx->retransmit_hint = NULL;
>> +}
>> +
>> +static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
>> +{
>> +	struct tls_context *tls_ctx = tls_get_ctx(sk);
>> +	struct tls_offload_context *ctx;
>> +	struct tls_record_info *info, *temp;
>> +	unsigned long flags;
>> +	u64 deleted_records = 0;
>> +
>> +	if (!tls_ctx)
>> +		return;
>> +
>> +	ctx = tls_offload_ctx(tls_ctx);
>> +
>> +	spin_lock_irqsave(&ctx->lock, flags);
>> +	info = ctx->retransmit_hint;
>> +	if (info && !before(acked_seq, info->end_seq)) {
>> +		ctx->retransmit_hint = NULL;
>> +		list_del(&info->list);
>> +		destroy_record(info);
>> +		deleted_records++;
>> +	}
>> +
>> +	list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
>> +		if (before(acked_seq, info->end_seq))
>> +			break;
>> +		list_del(&info->list);
>> +
>> +		destroy_record(info);
>> +		deleted_records++;
>> +	}
>> +
>> +	ctx->unacked_record_sn += deleted_records;
>> +	spin_unlock_irqrestore(&ctx->lock, flags);
>> +}
>> +
>> +/* At this point, there should be no references on this
>> + * socket and no in-flight SKBs associated with this
>> + * socket, so it is safe to free all the resources.
>> + */
>> +void tls_device_sk_destruct(struct sock *sk)
>> +{
>> +	struct tls_context *tls_ctx = tls_get_ctx(sk);
>> +	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
>> +
>> +	if (ctx->open_record)
>> +		destroy_record(ctx->open_record);
>> +
>> +	delete_all_records(ctx);
>> +	crypto_free_aead(ctx->aead_send);
>> +	ctx->sk_destruct(sk);
>> +
>> +	if (refcount_dec_and_test(&tls_ctx->refcount))
>> +		tls_device_queue_ctx_destruction(tls_ctx);
>> +}
>> +EXPORT_SYMBOL(tls_device_sk_destruct);
>> +
>> +static inline void tls_append_frag(struct tls_record_info *record,
>> +				   struct page_frag *pfrag,
>> +				   int size)
>> +{
>> +	skb_frag_t *frag;
>> +
>> +	frag = &record->frags[record->num_frags - 1];
>> +	if (frag->page.p == pfrag->page &&
>> +	    frag->page_offset + frag->size == pfrag->offset) {
>> +		frag->size += size;
>> +	} else {
>> +		++frag;
>> +		frag->page.p = pfrag->page;
>> +		frag->page_offset = pfrag->offset;
>> +		frag->size = size;
>> +		++record->num_frags;
>> +		get_page(pfrag->page);
>> +	}
>> +
>> +	pfrag->offset += size;
>> +	record->len += size;
>> +}
>> +
>> +static inline int tls_push_record(struct sock *sk,
>> +				  struct tls_context *ctx,
>> +				  struct tls_offload_context *offload_ctx,
>> +				  struct tls_record_info *record,
>> +				  struct page_frag *pfrag,
>> +				  int flags,
>> +				  unsigned char record_type)
>> +{
>> +	skb_frag_t *frag;
>> +	struct tcp_sock *tp = tcp_sk(sk);
>> +	struct page_frag fallback_frag;
>> +	struct page_frag  *tag_pfrag = pfrag;
>> +	int i;
>> +
>> +	/* fill prepand */
>> +	frag = &record->frags[0];
>> +	tls_fill_prepend(ctx,
>> +			 skb_frag_address(frag),
>> +			 record->len - ctx->prepend_size,
>> +			 record_type);
>> +
>> +	if (unlikely(!skb_page_frag_refill(ctx->tag_size, pfrag, GFP_KERNEL))) {
>> +		/* HW doesn't care about the data in the tag
>> +		 * so in case pfrag has no room
>> +		 * for a tag and we can't allocate a new pfrag
>> +		 * just use the page in the first frag
>> +		 * rather then write a complicated fall back code.
>> +		 */
>> +		tag_pfrag = &fallback_frag;
>> +		tag_pfrag->page = skb_frag_page(frag);
>> +		tag_pfrag->offset = 0;
>> +	}
>> +
>> +	tls_append_frag(record, tag_pfrag, ctx->tag_size);
>> +	record->end_seq = tp->write_seq + record->len;
>> +	spin_lock_irq(&offload_ctx->lock);
>> +	list_add_tail(&record->list, &offload_ctx->records_list);
>> +	spin_unlock_irq(&offload_ctx->lock);
>> +	offload_ctx->open_record = NULL;
>> +	set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
>> +	tls_advance_record_sn(sk, ctx);
>> +
>> +	for (i = 0; i < record->num_frags; i++) {
>> +		frag = &record->frags[i];
>> +		sg_unmark_end(&offload_ctx->sg_tx_data[i]);
>> +		sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
>> +			    frag->size, frag->page_offset);
>> +		sk_mem_charge(sk, frag->size);
>> +		get_page(skb_frag_page(frag));
>> +	}
>> +	sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
>> +
>> +	/* all ready, send */
>> +	return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
>> +}
>> +
>> +static inline int tls_create_new_record(struct tls_offload_context *offload_ctx,
>> +					struct page_frag *pfrag,
>> +					size_t prepend_size)
>> +{
>> +	skb_frag_t *frag;
>> +	struct tls_record_info *record;
>> +
>> +	record = kmalloc(sizeof(*record), GFP_KERNEL);
>> +	if (!record)
>> +		return -ENOMEM;
>> +
>> +	frag = &record->frags[0];
>> +	__skb_frag_set_page(frag, pfrag->page);
>> +	frag->page_offset = pfrag->offset;
>> +	skb_frag_size_set(frag, prepend_size);
>> +
>> +	get_page(pfrag->page);
>> +	pfrag->offset += prepend_size;
>> +
>> +	record->num_frags = 1;
>> +	record->len = prepend_size;
>> +	offload_ctx->open_record = record;
>> +	return 0;
>> +}
>> +
>> +static inline int tls_do_allocation(struct sock *sk,
>> +				    struct tls_offload_context *offload_ctx,
>> +				    struct page_frag *pfrag,
>> +				    size_t prepend_size)
>> +{
>> +	int ret;
>> +
>> +	if (!offload_ctx->open_record) {
>> +		if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
>> +						   sk->sk_allocation))) {
>> +			sk->sk_prot->enter_memory_pressure(sk);
>> +			sk_stream_moderate_sndbuf(sk);
>> +			return -ENOMEM;
>> +		}
>> +
>> +		ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
>> +		if (ret)
>> +			return ret;
>> +
>> +		if (pfrag->size > pfrag->offset)
>> +			return 0;
>> +	}
>> +
>> +	if (!sk_page_frag_refill(sk, pfrag))
>> +		return -ENOMEM;
>> +
>> +	return 0;
>> +}
>> +
>> +static int tls_push_data(struct sock *sk,
>> +			 struct iov_iter *msg_iter,
>> +			 size_t size, int flags,
>> +			 unsigned char record_type)
>> +{
>> +	struct tls_context *tls_ctx = tls_get_ctx(sk);
>> +	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
>> +	struct tls_record_info *record = ctx->open_record;
>> +	struct page_frag *pfrag;
>> +	int copy, rc = 0;
>> +	size_t orig_size = size;
>> +	u32 max_open_record_len;
>> +	long timeo;
>> +	int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
>> +	int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
>> +	bool done = false;
>> +
>> +	if (flags &
>> +	    ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
>> +		return -ENOTSUPP;
>> +
>> +	if (sk->sk_err)
>> +		return -sk->sk_err;
>> +
>> +	timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
>> +	rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
>> +	if (rc < 0)
>> +		return rc;
>> +
>> +	pfrag = sk_page_frag(sk);
>> +
>> +	/* KTLS_TLS_HEADER_SIZE is not counted as part of the TLS record, and
>> +	 * we need to leave room for an authentication tag.
>> +	 */
>> +	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
>> +			      tls_ctx->prepend_size;
>> +	do {
>> +		if (tls_do_allocation(sk, ctx, pfrag,
>> +				      tls_ctx->prepend_size)) {
>> +			rc = sk_stream_wait_memory(sk, &timeo);
>> +			if (!rc)
>> +				continue;
>> +
>> +			record = ctx->open_record;
>> +			if (!record)
>> +				break;
>> +handle_error:
>> +			if (record_type != TLS_RECORD_TYPE_DATA) {
>> +				/* avoid sending partial
>> +				 * record with type !=
>> +				 * application_data
>> +				 */
>> +				size = orig_size;
>> +				destroy_record(record);
>> +				ctx->open_record = NULL;
>> +			} else if (record->len > tls_ctx->prepend_size) {
>> +				goto last_record;
>> +			}
>> +
>> +			break;
>> +		}
>> +
>> +		record = ctx->open_record;
>> +		copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
>> +		copy = min_t(size_t, copy, (max_open_record_len - record->len));
>> +
>> +		if (copy_from_iter_nocache(page_address(pfrag->page) +
>> +					       pfrag->offset,
>> +					   copy, msg_iter) != copy) {
>> +			rc = -EFAULT;
>> +			goto handle_error;
>> +		}
>> +		tls_append_frag(record, pfrag, copy);
>> +
>> +		size -= copy;
>> +		if (!size) {
>> +last_record:
>> +			tls_push_record_flags = flags;
>> +			if (more) {
>> +				tls_ctx->pending_open_record_frags =
>> +						record->num_frags;
>> +				break;
>> +			}
>> +
>> +			done = true;
>> +		}
>> +
>> +		if ((done) || record->len >= max_open_record_len ||
>> +		    (record->num_frags >= MAX_SKB_FRAGS - 1)) {
>> +			rc = tls_push_record(sk,
>> +					     tls_ctx,
>> +					     ctx,
>> +					     record,
>> +					     pfrag,
>> +					     tls_push_record_flags,
>> +					     record_type);
>> +			if (rc < 0)
>> +				break;
>> +		}
>> +	} while (!done);
>> +
>> +	if (orig_size - size > 0)
>> +		rc = orig_size - size;
>> +
>> +	return rc;
>> +}
>> +
>> +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
>> +{
>> +	unsigned char record_type = TLS_RECORD_TYPE_DATA;
>> +	int rc = 0;
>> +
>> +	lock_sock(sk);
>> +
>> +	if (unlikely(msg->msg_controllen)) {
>> +		rc = tls_proccess_cmsg(sk, msg, &record_type);
>> +		if (rc)
>> +			goto out;
>> +	}
>> +
>> +	rc = tls_push_data(sk, &msg->msg_iter, size,
>> +			   msg->msg_flags, record_type);
>> +
>> +out:
>> +	release_sock(sk);
>> +	return rc;
>> +}
>> +
>> +int tls_device_sendpage(struct sock *sk, struct page *page,
>> +			int offset, size_t size, int flags)
>> +{
>> +	struct iov_iter	msg_iter;
>> +	struct kvec iov;
>> +	char *kaddr = kmap(page);
>> +	int rc = 0;
>> +
>> +	if (flags & MSG_SENDPAGE_NOTLAST)
>> +		flags |= MSG_MORE;
>> +
>> +	lock_sock(sk);
>> +
>> +	if (flags & MSG_OOB) {
>> +		rc = -ENOTSUPP;
>> +		goto out;
>> +	}
>> +
>> +	iov.iov_base = kaddr + offset;
>> +	iov.iov_len = size;
>> +	iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size);
>> +	rc = tls_push_data(sk, &msg_iter, size,
>> +			   flags, TLS_RECORD_TYPE_DATA);
>> +	kunmap(page);
>> +
>> +out:
>> +	release_sock(sk);
>> +	return rc;
>> +}
>> +
>> +struct tls_record_info *tls_get_record(struct tls_offload_context *context,
>> +				       u32 seq, u64 *p_record_sn)
>> +{
>> +	struct tls_record_info *info;
>> +	u64 record_sn = context->hint_record_sn;
>> +
>> +	info = context->retransmit_hint;
>> +	if (!info ||
>> +	    before(seq, info->end_seq - info->len)) {
>> +		/* if retransmit_hint is irrelevant start
>> +		 * from the begging of the list
>> +		 */
>> +		info = list_first_entry(&context->records_list,
>> +					struct tls_record_info, list);
>> +		record_sn = context->unacked_record_sn;
>> +	}
>> +
>> +	list_for_each_entry_from(info, &context->records_list, list) {
>> +		if (before(seq, info->end_seq)) {
>> +			if (!context->retransmit_hint ||
>> +			    after(info->end_seq,
>> +				  context->retransmit_hint->end_seq)) {
>> +				context->hint_record_sn = record_sn;
>> +				context->retransmit_hint = info;
>> +			}
>> +			*p_record_sn = record_sn;
>> +			return info;
>> +		}
>> +		record_sn++;
>> +	}
>> +
>> +	return NULL;
>> +}
>> +EXPORT_SYMBOL(tls_get_record);
>> +
>> +static int tls_device_push_pending_record(struct sock *sk, int flags)
>> +{
>> +	struct iov_iter	msg_iter;
>> +
>> +	iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
>> +	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
>> +}
>> +
>> +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
>> +{
>> +	u16 nonece_size, tag_size, iv_size, rec_seq_size;
>> +	struct tls_record_info *start_marker_record;
>> +	struct tls_offload_context *offload_ctx;
>> +	struct tls_crypto_info *crypto_info;
>> +	struct net_device *netdev;
>> +	char *iv, *rec_seq;
>> +	struct sk_buff *skb;
>> +	int rc = -EINVAL;
>> +	__be64 rcd_sn;
>> +
>> +	if (!ctx)
>> +		goto out;
>> +
>> +	if (ctx->priv_ctx) {
>> +		rc = -EEXIST;
>> +		goto out;
>> +	}
>> +
>> +	/* We support starting offload on multiple sockets
>> +	 * concurrently, So we only need a read lock here.
>> +	 */
>> +	percpu_down_read(&device_offload_lock);
>> +	netdev = get_netdev_for_sock(sk);
>> +	if (!netdev) {
>> +		pr_err_ratelimited("%s: netdev not found\n", __func__);
>> +		rc = -EINVAL;
>> +		goto release_lock;
>> +	}
>> +
>> +	if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
>> +		rc = -ENOTSUPP;
>> +		goto release_netdev;
>> +	}
>> +
>> +	/* Avoid offloading if the device is down
>> +	 * We don't want to offload new flows after
>> +	 * the NETDEV_DOWN event
>> +	 */
>> +	if (!(netdev->flags & IFF_UP)) {
>> +		rc = -EINVAL;
>> +		goto release_lock;
>> +	}
>> +
>> +	crypto_info = &ctx->crypto_send;
>> +	switch (crypto_info->cipher_type) {
>> +	case TLS_CIPHER_AES_GCM_128: {
>> +		nonece_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
>> +		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
>> +		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
>> +		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
>> +		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
>> +		rec_seq =
>> +		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
>> +		break;
>> +	}
>> +	default:
>> +		rc = -EINVAL;
>> +		goto release_netdev;
>> +	}
>> +
>> +	start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
> 
> Can we memory allocations and simple memory initializations ouside the global rwsem?
> 

Sure, we can move all memory allocations outside the lock.

>> +	if (!start_marker_record) {
>> +		rc = -ENOMEM;
>> +		goto release_netdev;
>> +	}
>> +
>> +	offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL);
>> +	if (!offload_ctx)
>> +		goto free_marker_record;
>> +
>> +	ctx->priv_ctx = offload_ctx;
>> +	rc = attach_sock_to_netdev(sk, netdev, ctx);
>> +	if (rc)
>> +		goto free_offload_context;
>> +
>> +	ctx->netdev = netdev;
>> +	ctx->prepend_size = TLS_HEADER_SIZE + nonece_size;
>> +	ctx->tag_size = tag_size;
>> +	ctx->iv_size = iv_size;
>> +	ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
>> +			  GFP_KERNEL);
>> +	if (!ctx->iv) {
>> +		rc = -ENOMEM;
>> +		goto detach_sock;
>> +	}
>> +
>> +	memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
>> +
>> +	ctx->rec_seq_size = rec_seq_size;
>> +	ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
>> +	if (!ctx->rec_seq) {
>> +		rc = -ENOMEM;
>> +		goto free_iv;
>> +	}
>> +	memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
>> +
>> +	/* start at rec_seq - 1 to account for the start marker record */
>> +	memcpy(&rcd_sn, ctx->rec_seq, sizeof(rcd_sn));
>> +	offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
>> +
>> +	rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
>> +	if (rc)
>> +		goto free_rec_seq;
>> +
>> +	start_marker_record->end_seq = tcp_sk(sk)->write_seq;
>> +	start_marker_record->len = 0;
>> +	start_marker_record->num_frags = 0;
>> +
>> +	INIT_LIST_HEAD(&offload_ctx->records_list);
>> +	list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
>> +	spin_lock_init(&offload_ctx->lock);
>> +
>> +	inet_csk(sk)->icsk_clean_acked = &tls_icsk_clean_acked;
>> +	ctx->push_pending_record = tls_device_push_pending_record;
>> +	offload_ctx->sk_destruct = sk->sk_destruct;
>> +
>> +	/* TLS offload is greatly simplified if we don't send
>> +	 * SKBs where only part of the payload needs to be encrypted.
>> +	 * So mark the last skb in the write queue as end of record.
>> +	 */
>> +	skb = tcp_write_queue_tail(sk);
>> +	if (skb)
>> +		TCP_SKB_CB(skb)->eor = 1;
>> +
>> +	refcount_set(&ctx->refcount, 1);
>> +	spin_lock_irq(&tls_device_lock);
>> +	list_add_tail(&ctx->list, &tls_device_list);
>> +	spin_unlock_irq(&tls_device_lock);
>> +
>> +	/* following this assignment tls_is_sk_tx_device_offloaded
>> +	 * will return true and the context might be accessed
>> +	 * by the netdev's xmit function.
>> +	 */
>> +	smp_store_release(&sk->sk_destruct,
>> +			  &tls_device_sk_destruct);
>> +	goto release_lock;
>> +
>> +free_rec_seq:
>> +	kfree(ctx->rec_seq);
>> +free_iv:
>> +	kfree(ctx->iv);
>> +detach_sock:
>> +	netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX);
>> +free_offload_context:
>> +	kfree(offload_ctx);
>> +	ctx->priv_ctx = NULL;
>> +free_marker_record:
>> +	kfree(start_marker_record);
>> +release_netdev:
>> +	dev_put(netdev);
>> +release_lock:
>> +	percpu_up_read(&device_offload_lock);
>> +out:
>> +	return rc;
>> +}
>> +
>> +static int tls_device_register(struct net_device *dev)
>> +{
>> +	if ((dev->features & NETIF_F_HW_TLS_TX) && !dev->tlsdev_ops)
>> +		return NOTIFY_BAD;
>> +
>> +	return NOTIFY_DONE;
>> +}
> 
> This function is the same as tls_device_feat_change(). Can't we merge
> them together and avoid duplicating of code?
> 

Sure.

>> +static int tls_device_unregister(struct net_device *dev)
>> +{
>> +	return NOTIFY_DONE;
>> +}
> 
> This function does nothing, and next patches do not change it.
> Can't we remove it since so?
> 

Sure.

>> +static int tls_device_feat_change(struct net_device *dev)
>> +{
>> +	if ((dev->features & NETIF_F_HW_TLS_TX) && !dev->tlsdev_ops)
>> +		return NOTIFY_BAD;
>> +
>> +	return NOTIFY_DONE;
>> +}
>> +
>> +static int tls_device_down(struct net_device *netdev)
>> +{
>> +	struct tls_context *ctx, *tmp;
>> +	struct list_head list;
>> +	unsigned long flags;
>> +
>> +	if (!(netdev->features & NETIF_F_HW_TLS_TX))
>> +		return NOTIFY_DONE;
> 
> Can't we move this check in tls_dev_event() and use it for all types of events?
> Then we avoid duplicate code.
> 

No. Not all events require this check. Also, the result is different for 
different events.

>> +
>> +	/* Request a write lock to block new offload attempts
>> +	 */
>> +	percpu_down_write(&device_offload_lock);
> 
> What is the reason percpu_rwsem is chosen here? It looks like this primitive
> gives more advantages readers, then plain rwsem does. But it also gives
> disadvantages to writers. It would be good, unless tls_device_down() is called
> with rtnl_lock() held from netdevice notifier. But since netdevice notifier
> are called with rtnl_lock() held, percpu_rwsem will increase the time rtnl_lock()
> is locked.
We use the a rwsem to allow multiple (readers) invocations of 
tls_set_device_offload, which is triggered by the user (persumably) 
during the TLS handshake. This might be considered a fast-path.

However, we must block all calls to tls_set_device_offload while we are 
processing NETDEV_DOWN events (writer).

As you've mentioned, the percpu rwsem is more efficient for readers, 
especially on NUMA systems, where cache-line bouncing occurs during 
reader acquire and reduces performance.

> 
> Can't we use plain rwsem here instead?
> 

Its a performance tradeoff. I'm not certain that the percpu rwsem write 
side acquire is significantly worse than using the global rwsem.

For now, while all of this is experimental, can we agree to focus on the 
performance of readers? We can change it later if it becomes a problem.

>> +
>> +	spin_lock_irqsave(&tls_device_lock, flags);
>> +	INIT_LIST_HEAD(&list);
> 
> This may go outside the global spinlock.
> 

Sure.

>> +	list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
>> +		if (ctx->netdev != netdev ||
>> +		    !refcount_inc_not_zero(&ctx->refcount))
>> +			continue;
>> +
>> +		list_move(&ctx->list, &list);
>> +	}
>> +	spin_unlock_irqrestore(&tls_device_lock, flags);
>> +
>> +	list_for_each_entry_safe(ctx, tmp, &list, list)	{
>> +		netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
>> +						TLS_OFFLOAD_CTX_DIR_TX);
>> +		ctx->netdev = NULL;
>> +		dev_put(netdev);
>> +		list_del_init(&ctx->list);
>> +
>> +		if (refcount_dec_and_test(&ctx->refcount))
>> +			tls_device_free_ctx(ctx);
>> +	}
>> +
>> +	percpu_up_write(&device_offload_lock);
>> +
>> +	flush_work(&tls_device_gc_work);
>> +
>> +	return NOTIFY_DONE;
>> +}
>> +
>> +static int tls_dev_event(struct notifier_block *this, unsigned long event,
>> +			 void *ptr)
>> +{
>> +	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
>> +
>> +	switch (event) {
>> +	case NETDEV_REGISTER:
>> +		return tls_device_register(dev);
>> +
>> +	case NETDEV_UNREGISTER:
>> +		return tls_device_unregister(dev);
>> +
>> +	case NETDEV_FEAT_CHANGE:
>> +		return tls_device_feat_change(dev);
>> +
>> +	case NETDEV_DOWN:
>> +		return tls_device_down(dev);
>> +	}
>> +	return NOTIFY_DONE;
>> +}
>> +
>> +static struct notifier_block tls_dev_notifier = {
>> +	.notifier_call	= tls_dev_event,
>> +};
>> +
>> +void __init tls_device_init(void)
>> +{
>> +	register_netdevice_notifier(&tls_dev_notifier);
>> +}
>> +
>> +void __exit tls_device_cleanup(void)
>> +{
>> +	unregister_netdevice_notifier(&tls_dev_notifier);
>> +	flush_work(&tls_device_gc_work);
>> +}
>> diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
>> new file mode 100644
>> index 000000000000..14d31a36885c
>> --- /dev/null
>> +++ b/net/tls/tls_device_fallback.c
>> @@ -0,0 +1,419 @@
>> +/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
>> + *
>> + *     Redistribution and use in source and binary forms, with or
>> + *     without modification, are permitted provided that the following
>> + *     conditions are met:
>> + *
>> + *      - Redistributions of source code must retain the above
>> + *        copyright notice, this list of conditions and the following
>> + *        disclaimer.
>> + *
>> + *      - Redistributions in binary form must reproduce the above
>> + *        copyright notice, this list of conditions and the following
>> + *        disclaimer in the documentation and/or other materials
>> + *        provided with the distribution.
>> + *
>> + *      - Neither the name of the Mellanox Technologies nor the
>> + *        names of its contributors may be used to endorse or promote
>> + *        products derived from this software without specific prior written
>> + *        permission.
>> + *
>> + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
>> + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
>> + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
>> + * A PARTICULAR PURPOSE ARE DISCLAIMED.
>> + * IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
>> + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
>> + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
>> + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
>> + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
>> + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
>> + * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
>> + * POSSIBILITY OF SUCH DAMAGE
>> + */
>> +
>> +#include <net/tls.h>
>> +#include <crypto/aead.h>
>> +#include <crypto/scatterwalk.h>
>> +#include <net/ip6_checksum.h>
>> +
>> +static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
>> +{
>> +	struct scatterlist *src = walk->sg;
>> +	int diff = walk->offset - src->offset;
>> +
>> +	sg_set_page(sg, sg_page(src),
>> +		    src->length - diff, walk->offset);
>> +
>> +	scatterwalk_crypto_chain(sg, sg_next(src), 0, 2);
>> +}
>> +
>> +static int tls_enc_record(struct aead_request *aead_req,
>> +			  struct crypto_aead *aead, char *aad, char *iv,
>> +			  __be64 rcd_sn, struct scatter_walk *in,
>> +			  struct scatter_walk *out, int *in_len)
>> +{
>> +	struct scatterlist sg_in[3];
>> +	struct scatterlist sg_out[3];
>> +	unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
>> +	u16 len;
>> +	int rc;
>> +
>> +	len = min_t(int, *in_len, ARRAY_SIZE(buf));
>> +
>> +	scatterwalk_copychunks(buf, in, len, 0);
>> +	scatterwalk_copychunks(buf, out, len, 1);
>> +
>> +	*in_len -= len;
>> +	if (!*in_len)
>> +		return 0;
>> +
>> +	scatterwalk_pagedone(in, 0, 1);
>> +	scatterwalk_pagedone(out, 1, 1);
>> +
>> +	len = buf[4] | (buf[3] << 8);
>> +	len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
>> +
>> +	tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
>> +		     (char *)&rcd_sn, sizeof(rcd_sn), buf[0]);
>> +
>> +	memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
>> +	       TLS_CIPHER_AES_GCM_128_IV_SIZE);
>> +
>> +	sg_init_table(sg_in, ARRAY_SIZE(sg_in));
>> +	sg_init_table(sg_out, ARRAY_SIZE(sg_out));
>> +	sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
>> +	sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
>> +	chain_to_walk(sg_in + 1, in);
>> +	chain_to_walk(sg_out + 1, out);
>> +
>> +	*in_len -= len;
>> +	if (*in_len < 0) {
>> +		*in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
>> +		if (*in_len < 0)
>> +		/* the input buffer doesn't contain the entire record.
>> +		 * trim len accordingly. The resulting authentication tag
>> +		 * will contain garbage. but we don't care as we won't
>> +		 * include any of it in the output skb
>> +		 * Note that we assume the output buffer length
>> +		 * is larger then input buffer length + tag size
>> +		 */
>> +			len += *in_len;
>> +
>> +		*in_len = 0;
>> +	}
>> +
>> +	if (*in_len) {
>> +		scatterwalk_copychunks(NULL, in, len, 2);
>> +		scatterwalk_pagedone(in, 0, 1);
>> +		scatterwalk_copychunks(NULL, out, len, 2);
>> +		scatterwalk_pagedone(out, 1, 1);
>> +	}
>> +
>> +	len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
>> +	aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
>> +
>> +	rc = crypto_aead_encrypt(aead_req);
>> +
>> +	return rc;
>> +}
>> +
>> +static void tls_init_aead_request(struct aead_request *aead_req,
>> +				  struct crypto_aead *aead)
>> +{
>> +	aead_request_set_tfm(aead_req, aead);
>> +	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
>> +}
>> +
>> +static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead,
>> +						   gfp_t flags)
>> +{
>> +	unsigned int req_size = sizeof(struct aead_request) +
>> +		crypto_aead_reqsize(aead);
>> +	struct aead_request *aead_req;
>> +
>> +	aead_req = kzalloc(req_size, flags);
>> +	if (!aead_req)
>> +		return NULL;
>> +
>> +	tls_init_aead_request(aead_req, aead);
>> +	return aead_req;
>> +}
>> +
>> +static int tls_enc_records(struct aead_request *aead_req,
>> +			   struct crypto_aead *aead, struct scatterlist *sg_in,
>> +			   struct scatterlist *sg_out, char *aad, char *iv,
>> +			   u64 rcd_sn, int len)
>> +{
>> +	struct scatter_walk in;
>> +	struct scatter_walk out;
>> +	int rc;
>> +
>> +	scatterwalk_start(&in, sg_in);
>> +	scatterwalk_start(&out, sg_out);
>> +
>> +	do {
>> +		rc = tls_enc_record(aead_req, aead, aad, iv,
>> +				    cpu_to_be64(rcd_sn), &in, &out, &len);
>> +		rcd_sn++;
>> +
>> +	} while (rc == 0 && len);
>> +
>> +	scatterwalk_done(&in, 0, 0);
>> +	scatterwalk_done(&out, 1, 0);
>> +
>> +	return rc;
>> +}
>> +
>> +static inline void update_chksum(struct sk_buff *skb, int headln)
>> +{
>> +	/* Can't use icsk->icsk_af_ops->send_check here because the ip addresses
>> +	 * might have been changed by NAT.
>> +	 */
>> +
>> +	const struct ipv6hdr *ipv6h;
>> +	const struct iphdr *iph;
>> +	struct tcphdr *th = tcp_hdr(skb);
>> +	int datalen = skb->len - headln;
>> +
>> +	/* We only changed the payload so if we are using partial we don't
>> +	 * need to update anything.
>> +	 */
>> +	if (likely(skb->ip_summed == CHECKSUM_PARTIAL))
>> +		return;
>> +
>> +	skb->ip_summed = CHECKSUM_PARTIAL;
>> +	skb->csum_start = skb_transport_header(skb) - skb->head;
>> +	skb->csum_offset = offsetof(struct tcphdr, check);
>> +
>> +	if (skb->sk->sk_family == AF_INET6) {
>> +		ipv6h = ipv6_hdr(skb);
>> +		th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
>> +					     datalen, IPPROTO_TCP, 0);
>> +	} else {
>> +		iph = ip_hdr(skb);
>> +		th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen,
>> +					       IPPROTO_TCP, 0);
>> +	}
>> +}
>> +
>> +static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
>> +{
>> +	skb_copy_header(nskb, skb);
>> +
>> +	skb_put(nskb, skb->len);
>> +	memcpy(nskb->data, skb->data, headln);
>> +	update_chksum(nskb, headln);
>> +
>> +	nskb->destructor = skb->destructor;
>> +	nskb->sk = skb->sk;
>> +	skb->destructor = NULL;
>> +	skb->sk = NULL;
>> +	refcount_add(nskb->truesize - skb->truesize,
>> +		     &nskb->sk->sk_wmem_alloc);
>> +}
>> +
>> +/* This function may be called after the user socket is already
>> + * closed so make sure we don't use anything freed during
>> + * tls_sk_proto_close here
>> + */
>> +static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
>> +{
>> +	int tcp_header_size = tcp_hdrlen(skb);
>> +	int tcp_payload_offset = skb_transport_offset(skb) + tcp_header_size;
>> +	int payload_len = skb->len - tcp_payload_offset;
>> +	struct tls_context *tls_ctx = tls_get_ctx(sk);
>> +	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
>> +	int remaining, buf_len, resync_sgs, rc, i = 0;
>> +	void *buf, *dummy_buf, *iv, *aad;
>> +	struct scatterlist *sg_in;
>> +	struct scatterlist sg_out[3];
>> +	u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
>> +	struct aead_request *aead_req;
>> +	struct sk_buff *nskb = NULL;
>> +	struct tls_record_info *record;
>> +	unsigned long flags;
>> +	s32 sync_size;
>> +	u64 rcd_sn;
>> +
>> +	/* worst case is:
>> +	 * MAX_SKB_FRAGS in tls_record_info
>> +	 * MAX_SKB_FRAGS + 1 in SKB head an frags.
>> +	 */
>> +	int sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1;
>> +
>> +	if (!payload_len)
>> +		return skb;
>> +
>> +	sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC);
>> +	if (!sg_in)
>> +		goto free_orig;
>> +
>> +	sg_init_table(sg_in, sg_in_max_elements);
>> +	sg_init_table(sg_out, ARRAY_SIZE(sg_out));
>> +
>> +	spin_lock_irqsave(&ctx->lock, flags);
>> +	record = tls_get_record(ctx, tcp_seq, &rcd_sn);
>> +	if (!record) {
>> +		spin_unlock_irqrestore(&ctx->lock, flags);
>> +		WARN(1, "Record not found for seq %u\n", tcp_seq);
>> +		goto free_sg;
>> +	}
>> +
>> +	sync_size = tcp_seq - tls_record_start_seq(record);
>> +	if (sync_size < 0) {
>> +		int is_start_marker = tls_record_is_start_marker(record);
>> +
>> +		spin_unlock_irqrestore(&ctx->lock, flags);
>> +		if (!is_start_marker)
>> +		/* This should only occur if the relevant record was
>> +		 * already acked. In that case it should be ok
>> +		 * to drop the packet and avoid retransmission.
>> +		 *
>> +		 * There is a corner case where the packet contains
>> +		 * both an acked and a non-acked record.
>> +		 * We currently don't handle that case and rely
>> +		 * on TCP to retranmit a packet that doesn't contain
>> +		 * already acked payload.
>> +		 */
>> +			goto free_orig;
>> +
>> +		if (payload_len > -sync_size) {
>> +			WARN(1, "Fallback of partially offloaded packets is not supported\n");
>> +			goto free_sg;
>> +		} else {
>> +			return skb;
>> +		}
>> +	}
>> +
>> +	remaining = sync_size;
>> +	while (remaining > 0) {
>> +		skb_frag_t *frag = &record->frags[i];
>> +
>> +		__skb_frag_ref(frag);
>> +		sg_set_page(sg_in + i, skb_frag_page(frag),
>> +			    skb_frag_size(frag), frag->page_offset);
>> +
>> +		remaining -= skb_frag_size(frag);
>> +
>> +		if (remaining < 0)
>> +			sg_in[i].length += remaining;
>> +
>> +		i++;
>> +	}
>> +	spin_unlock_irqrestore(&ctx->lock, flags);
>> +	resync_sgs = i;
>> +
>> +	aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC);
>> +	if (!aead_req)
>> +		goto put_sg;
>> +
>> +	buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
>> +		  TLS_CIPHER_AES_GCM_128_IV_SIZE +
>> +		  TLS_AAD_SPACE_SIZE +
>> +		  sync_size +
>> +		  tls_ctx->tag_size;
>> +	buf = kmalloc(buf_len, GFP_ATOMIC);
>> +	if (!buf)
>> +		goto free_req;
>> +
>> +	nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
>> +	if (!nskb)
>> +		goto free_buf;
>> +
>> +	skb_reserve(nskb, skb_headroom(skb));
>> +
>> +	iv = buf;
>> +
>> +	memcpy(iv, tls_ctx->crypto_send_aes_gcm_128.salt,
>> +	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
>> +	aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
>> +	      TLS_CIPHER_AES_GCM_128_IV_SIZE;
>> +	dummy_buf = aad + TLS_AAD_SPACE_SIZE;
>> +
>> +	sg_set_buf(&sg_out[0], dummy_buf, sync_size);
>> +	sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset,
>> +		   payload_len);
>> +	/* Add room for authentication tag produced by crypto */
>> +	dummy_buf += sync_size;
>> +	sg_set_buf(&sg_out[2], dummy_buf, tls_ctx->tag_size);
>> +	rc = skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset,
>> +			  payload_len);
>> +	if (rc < 0)
>> +		goto free_nskb;
>> +
>> +	rc = tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv,
>> +			     rcd_sn, sync_size + payload_len);
>> +	if (rc < 0)
>> +		goto free_nskb;
>> +
>> +	complete_skb(nskb, skb, tcp_payload_offset);
>> +
>> +	/* validate_xmit_skb_list assumes that if the skb wasn't segmented
>> +	 * nskb->prev will point to the skb itself
>> +	 */
>> +	nskb->prev = nskb;
>> +free_buf:
>> +	kfree(buf);
>> +free_req:
>> +	kfree(aead_req);
>> +put_sg:
>> +	for (i = 0; i < resync_sgs; i++)
>> +		put_page(sg_page(&sg_in[i]));
>> +free_sg:
>> +	kfree(sg_in);
>> +free_orig:
>> +	kfree_skb(skb);
>> +	return nskb;
>> +
>> +free_nskb:
>> +	kfree_skb(nskb);
>> +	nskb = NULL;
>> +	goto free_buf;
>> +}
>> +
>> +static struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
>> +					     struct net_device *dev,
>> +					     struct sk_buff *skb)
>> +{
>> +	if (dev == tls_get_ctx(sk)->netdev)
>> +		return skb;
>> +
>> +	return tls_sw_fallback(sk, skb);
>> +}
>> +
>> +int tls_sw_fallback_init(struct sock *sk,
>> +			 struct tls_offload_context *offload_ctx,
>> +			 struct tls_crypto_info *crypto_info)
>> +{
>> +	int rc;
>> +	const u8 *key;
>> +
>> +	offload_ctx->aead_send =
>> +	    crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
>> +	if (IS_ERR(offload_ctx->aead_send)) {
>> +		rc = PTR_ERR(offload_ctx->aead_send);
>> +		pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc);
>> +		offload_ctx->aead_send = NULL;
>> +		goto err_out;
>> +	}
>> +
>> +	key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
>> +
>> +	rc = crypto_aead_setkey(offload_ctx->aead_send, key,
>> +				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
>> +	if (rc)
>> +		goto free_aead;
>> +
>> +	rc = crypto_aead_setauthsize(offload_ctx->aead_send,
>> +				     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
>> +	if (rc)
>> +		goto free_aead;
>> +
>> +	sk->sk_validate_xmit_skb = tls_validate_xmit_skb;
>> +	return 0;
>> +free_aead:
>> +	crypto_free_aead(offload_ctx->aead_send);
>> +err_out:
>> +	return rc;
>> +}
>> diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
>> index d824d548447e..e0dface33017 100644
>> --- a/net/tls/tls_main.c
>> +++ b/net/tls/tls_main.c
>> @@ -54,6 +54,9 @@ enum {
>>   enum {
>>   	TLS_BASE_TX,
>>   	TLS_SW_TX,
>> +#ifdef CONFIG_TLS_DEVICE
>> +	TLS_HW_TX,
>> +#endif
>>   	TLS_NUM_CONFIG,
>>   };
>>   
>> @@ -416,11 +419,19 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
>>   		goto err_crypto_info;
>>   	}
>>   
>> -	/* currently SW is default, we will have ethtool in future */
>> -	rc = tls_set_sw_offload(sk, ctx);
>> -	tx_conf = TLS_SW_TX;
>> -	if (rc)
>> -		goto err_crypto_info;
>> +#ifdef CONFIG_TLS_DEVICE
>> +	rc = tls_set_device_offload(sk, ctx);
>> +	tx_conf = TLS_HW_TX;
>> +	if (rc) {
>> +#else
>> +	{
>> +#endif
>> +		/* if HW offload fails fallback to SW */
>> +		rc = tls_set_sw_offload(sk, ctx);
>> +		tx_conf = TLS_SW_TX;
>> +		if (rc)
>> +			goto err_crypto_info;
>> +	}
>>   
>>   	ctx->tx_conf = tx_conf;
>>   	update_sk_prot(sk, ctx);
>> @@ -473,6 +484,12 @@ static void build_protos(struct proto *prot, struct proto *base)
>>   	prot[TLS_SW_TX] = prot[TLS_BASE_TX];
>>   	prot[TLS_SW_TX].sendmsg		= tls_sw_sendmsg;
>>   	prot[TLS_SW_TX].sendpage	= tls_sw_sendpage;
>> +
>> +#ifdef CONFIG_TLS_DEVICE
>> +	prot[TLS_HW_TX] = prot[TLS_SW_TX];
>> +	prot[TLS_HW_TX].sendmsg		= tls_device_sendmsg;
>> +	prot[TLS_HW_TX].sendpage	= tls_device_sendpage;
>> +#endif
>>   }
>>   
>>   static int tls_init(struct sock *sk)
>> @@ -531,6 +548,9 @@ static int __init tls_register(void)
>>   {
>>   	build_protos(tls_prots[TLSV4], &tcp_prot);
>>   
>> +#ifdef CONFIG_TLS_DEVICE
>> +	tls_device_init();
>> +#endif
>>   	tcp_register_ulp(&tcp_tls_ulp_ops);
>>   
>>   	return 0;
>> @@ -539,6 +559,9 @@ static int __init tls_register(void)
>>   static void __exit tls_unregister(void)
>>   {
>>   	tcp_unregister_ulp(&tcp_tls_ulp_ops);
>> +#ifdef CONFIG_TLS_DEVICE
>> +	tls_device_cleanup();
>> +#endif
>>   }
>>   
>>   module_init(tls_register);
> 
> Thanks,
> Kirill
> 

Best,
Boris.
Kirill Tkhai March 21, 2018, 4:31 p.m. UTC | #5
On 21.03.2018 18:53, Boris Pismenny wrote:
> ...
>>
>> Other patches have two licenses in header. Can I distribute this file under GPL license terms?
>>
> 
> Sure, I'll update the license to match other files under net/tls.
> 
>>> +#include <linux/module.h>
>>> +#include <net/tcp.h>
>>> +#include <net/inet_common.h>
>>> +#include <linux/highmem.h>
>>> +#include <linux/netdevice.h>
>>> +
>>> +#include <net/tls.h>
>>> +#include <crypto/aead.h>
>>> +
>>> +/* device_offload_lock is used to synchronize tls_dev_add
>>> + * against NETDEV_DOWN notifications.
>>> + */
>>> +DEFINE_STATIC_PERCPU_RWSEM(device_offload_lock);
>>> +
>>> +static void tls_device_gc_task(struct work_struct *work);
>>> +
>>> +static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
>>> +static LIST_HEAD(tls_device_gc_list);
>>> +static LIST_HEAD(tls_device_list);
>>> +static DEFINE_SPINLOCK(tls_device_lock);
>>> +
>>> +static void tls_device_free_ctx(struct tls_context *ctx)
>>> +{
>>> +    struct tls_offload_context *offlad_ctx = tls_offload_ctx(ctx);
>>> +
>>> +    kfree(offlad_ctx);
>>> +    kfree(ctx);
>>> +}
>>> +
>>> +static void tls_device_gc_task(struct work_struct *work)
>>> +{
>>> +    struct tls_context *ctx, *tmp;
>>> +    struct list_head gc_list;
>>> +    unsigned long flags;
>>> +
>>> +    spin_lock_irqsave(&tls_device_lock, flags);
>>> +    INIT_LIST_HEAD(&gc_list);
>>
>> This is stack variable, and it should be initialized outside of global spinlock.
>> There is LIST_HEAD() primitive for that in kernel.
>> There is one more similar place below.
>>
> 
> Sure.
> 
>>> +    list_splice_init(&tls_device_gc_list, &gc_list);
>>> +    spin_unlock_irqrestore(&tls_device_lock, flags);
>>> +
>>> +    list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
>>> +        struct net_device *netdev = ctx->netdev;
>>> +
>>> +        if (netdev) {
>>> +            netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
>>> +                            TLS_OFFLOAD_CTX_DIR_TX);
>>> +            dev_put(netdev);
>>> +        }
>>
>> How is possible the situation we meet NULL netdev here >
> 
> This can happen in tls_device_down. tls_deviec_down is called whenever a netdev that is used for TLS inline crypto offload goes down. It gets called via the NETDEV_DOWN event of the netdevice notifier.
> 
> This flow is somewhat similar to the xfrm_device netdev notifier. However, we do not destroy the socket (as in destroying the xfrm_state in xfrm_device). Instead, we cleanup the netdev state and allow software fallback to handle the rest of the traffic.
> 
>>> +
>>> +        list_del(&ctx->list);
>>> +        tls_device_free_ctx(ctx);
>>> +    }
>>> +}
>>> +
>>> +static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
>>> +{
>>> +    unsigned long flags;
>>> +
>>> +    spin_lock_irqsave(&tls_device_lock, flags);
>>> +    list_move_tail(&ctx->list, &tls_device_gc_list);
>>> +
>>> +    /* schedule_work inside the spinlock
>>> +     * to make sure tls_device_down waits for that work.
>>> +     */
>>> +    schedule_work(&tls_device_gc_work);
>>> +
>>> +    spin_unlock_irqrestore(&tls_device_lock, flags);
>>> +}
>>> +
>>> +/* We assume that the socket is already connected */
>>> +static struct net_device *get_netdev_for_sock(struct sock *sk)
>>> +{
>>> +    struct inet_sock *inet = inet_sk(sk);
>>> +    struct net_device *netdev = NULL;
>>> +
>>> +    netdev = dev_get_by_index(sock_net(sk), inet->cork.fl.flowi_oif);
>>> +
>>> +    return netdev;
>>> +}
>>> +
>>> +static int attach_sock_to_netdev(struct sock *sk, struct net_device *netdev,
>>> +                 struct tls_context *ctx)
>>> +{
>>> +    int rc;
>>> +
>>> +    rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
>>> +                         &ctx->crypto_send,
>>> +                         tcp_sk(sk)->write_seq);
>>> +    if (rc) {
>>> +        pr_err_ratelimited("The netdev has refused to offload this socket\n");
>>> +        goto out;
>>> +    }
>>> +
>>> +    rc = 0;
>>> +out:
>>> +    return rc;
>>> +}
>>> +
>>> +static void destroy_record(struct tls_record_info *record)
>>> +{
>>> +    skb_frag_t *frag;
>>> +    int nr_frags = record->num_frags;
>>> +
>>> +    while (nr_frags > 0) {
>>> +        frag = &record->frags[nr_frags - 1];
>>> +        __skb_frag_unref(frag);
>>> +        --nr_frags;
>>> +    }
>>> +    kfree(record);
>>> +}
>>> +
>>> +static void delete_all_records(struct tls_offload_context *offload_ctx)
>>> +{
>>> +    struct tls_record_info *info, *temp;
>>> +
>>> +    list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
>>> +        list_del(&info->list);
>>> +        destroy_record(info);
>>> +    }
>>> +
>>> +    offload_ctx->retransmit_hint = NULL;
>>> +}
>>> +
>>> +static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
>>> +{
>>> +    struct tls_context *tls_ctx = tls_get_ctx(sk);
>>> +    struct tls_offload_context *ctx;
>>> +    struct tls_record_info *info, *temp;
>>> +    unsigned long flags;
>>> +    u64 deleted_records = 0;
>>> +
>>> +    if (!tls_ctx)
>>> +        return;
>>> +
>>> +    ctx = tls_offload_ctx(tls_ctx);
>>> +
>>> +    spin_lock_irqsave(&ctx->lock, flags);
>>> +    info = ctx->retransmit_hint;
>>> +    if (info && !before(acked_seq, info->end_seq)) {
>>> +        ctx->retransmit_hint = NULL;
>>> +        list_del(&info->list);
>>> +        destroy_record(info);
>>> +        deleted_records++;
>>> +    }
>>> +
>>> +    list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
>>> +        if (before(acked_seq, info->end_seq))
>>> +            break;
>>> +        list_del(&info->list);
>>> +
>>> +        destroy_record(info);
>>> +        deleted_records++;
>>> +    }
>>> +
>>> +    ctx->unacked_record_sn += deleted_records;
>>> +    spin_unlock_irqrestore(&ctx->lock, flags);
>>> +}
>>> +
>>> +/* At this point, there should be no references on this
>>> + * socket and no in-flight SKBs associated with this
>>> + * socket, so it is safe to free all the resources.
>>> + */
>>> +void tls_device_sk_destruct(struct sock *sk)
>>> +{
>>> +    struct tls_context *tls_ctx = tls_get_ctx(sk);
>>> +    struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
>>> +
>>> +    if (ctx->open_record)
>>> +        destroy_record(ctx->open_record);
>>> +
>>> +    delete_all_records(ctx);
>>> +    crypto_free_aead(ctx->aead_send);
>>> +    ctx->sk_destruct(sk);
>>> +
>>> +    if (refcount_dec_and_test(&tls_ctx->refcount))
>>> +        tls_device_queue_ctx_destruction(tls_ctx);
>>> +}
>>> +EXPORT_SYMBOL(tls_device_sk_destruct);
>>> +
>>> +static inline void tls_append_frag(struct tls_record_info *record,
>>> +                   struct page_frag *pfrag,
>>> +                   int size)
>>> +{
>>> +    skb_frag_t *frag;
>>> +
>>> +    frag = &record->frags[record->num_frags - 1];
>>> +    if (frag->page.p == pfrag->page &&
>>> +        frag->page_offset + frag->size == pfrag->offset) {
>>> +        frag->size += size;
>>> +    } else {
>>> +        ++frag;
>>> +        frag->page.p = pfrag->page;
>>> +        frag->page_offset = pfrag->offset;
>>> +        frag->size = size;
>>> +        ++record->num_frags;
>>> +        get_page(pfrag->page);
>>> +    }
>>> +
>>> +    pfrag->offset += size;
>>> +    record->len += size;
>>> +}
>>> +
>>> +static inline int tls_push_record(struct sock *sk,
>>> +                  struct tls_context *ctx,
>>> +                  struct tls_offload_context *offload_ctx,
>>> +                  struct tls_record_info *record,
>>> +                  struct page_frag *pfrag,
>>> +                  int flags,
>>> +                  unsigned char record_type)
>>> +{
>>> +    skb_frag_t *frag;
>>> +    struct tcp_sock *tp = tcp_sk(sk);
>>> +    struct page_frag fallback_frag;
>>> +    struct page_frag  *tag_pfrag = pfrag;
>>> +    int i;
>>> +
>>> +    /* fill prepand */
>>> +    frag = &record->frags[0];
>>> +    tls_fill_prepend(ctx,
>>> +             skb_frag_address(frag),
>>> +             record->len - ctx->prepend_size,
>>> +             record_type);
>>> +
>>> +    if (unlikely(!skb_page_frag_refill(ctx->tag_size, pfrag, GFP_KERNEL))) {
>>> +        /* HW doesn't care about the data in the tag
>>> +         * so in case pfrag has no room
>>> +         * for a tag and we can't allocate a new pfrag
>>> +         * just use the page in the first frag
>>> +         * rather then write a complicated fall back code.
>>> +         */
>>> +        tag_pfrag = &fallback_frag;
>>> +        tag_pfrag->page = skb_frag_page(frag);
>>> +        tag_pfrag->offset = 0;
>>> +    }
>>> +
>>> +    tls_append_frag(record, tag_pfrag, ctx->tag_size);
>>> +    record->end_seq = tp->write_seq + record->len;
>>> +    spin_lock_irq(&offload_ctx->lock);
>>> +    list_add_tail(&record->list, &offload_ctx->records_list);
>>> +    spin_unlock_irq(&offload_ctx->lock);
>>> +    offload_ctx->open_record = NULL;
>>> +    set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
>>> +    tls_advance_record_sn(sk, ctx);
>>> +
>>> +    for (i = 0; i < record->num_frags; i++) {
>>> +        frag = &record->frags[i];
>>> +        sg_unmark_end(&offload_ctx->sg_tx_data[i]);
>>> +        sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
>>> +                frag->size, frag->page_offset);
>>> +        sk_mem_charge(sk, frag->size);
>>> +        get_page(skb_frag_page(frag));
>>> +    }
>>> +    sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
>>> +
>>> +    /* all ready, send */
>>> +    return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
>>> +}
>>> +
>>> +static inline int tls_create_new_record(struct tls_offload_context *offload_ctx,
>>> +                    struct page_frag *pfrag,
>>> +                    size_t prepend_size)
>>> +{
>>> +    skb_frag_t *frag;
>>> +    struct tls_record_info *record;
>>> +
>>> +    record = kmalloc(sizeof(*record), GFP_KERNEL);
>>> +    if (!record)
>>> +        return -ENOMEM;
>>> +
>>> +    frag = &record->frags[0];
>>> +    __skb_frag_set_page(frag, pfrag->page);
>>> +    frag->page_offset = pfrag->offset;
>>> +    skb_frag_size_set(frag, prepend_size);
>>> +
>>> +    get_page(pfrag->page);
>>> +    pfrag->offset += prepend_size;
>>> +
>>> +    record->num_frags = 1;
>>> +    record->len = prepend_size;
>>> +    offload_ctx->open_record = record;
>>> +    return 0;
>>> +}
>>> +
>>> +static inline int tls_do_allocation(struct sock *sk,
>>> +                    struct tls_offload_context *offload_ctx,
>>> +                    struct page_frag *pfrag,
>>> +                    size_t prepend_size)
>>> +{
>>> +    int ret;
>>> +
>>> +    if (!offload_ctx->open_record) {
>>> +        if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
>>> +                           sk->sk_allocation))) {
>>> +            sk->sk_prot->enter_memory_pressure(sk);
>>> +            sk_stream_moderate_sndbuf(sk);
>>> +            return -ENOMEM;
>>> +        }
>>> +
>>> +        ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
>>> +        if (ret)
>>> +            return ret;
>>> +
>>> +        if (pfrag->size > pfrag->offset)
>>> +            return 0;
>>> +    }
>>> +
>>> +    if (!sk_page_frag_refill(sk, pfrag))
>>> +        return -ENOMEM;
>>> +
>>> +    return 0;
>>> +}
>>> +
>>> +static int tls_push_data(struct sock *sk,
>>> +             struct iov_iter *msg_iter,
>>> +             size_t size, int flags,
>>> +             unsigned char record_type)
>>> +{
>>> +    struct tls_context *tls_ctx = tls_get_ctx(sk);
>>> +    struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
>>> +    struct tls_record_info *record = ctx->open_record;
>>> +    struct page_frag *pfrag;
>>> +    int copy, rc = 0;
>>> +    size_t orig_size = size;
>>> +    u32 max_open_record_len;
>>> +    long timeo;
>>> +    int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
>>> +    int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
>>> +    bool done = false;
>>> +
>>> +    if (flags &
>>> +        ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
>>> +        return -ENOTSUPP;
>>> +
>>> +    if (sk->sk_err)
>>> +        return -sk->sk_err;
>>> +
>>> +    timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
>>> +    rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
>>> +    if (rc < 0)
>>> +        return rc;
>>> +
>>> +    pfrag = sk_page_frag(sk);
>>> +
>>> +    /* KTLS_TLS_HEADER_SIZE is not counted as part of the TLS record, and
>>> +     * we need to leave room for an authentication tag.
>>> +     */
>>> +    max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
>>> +                  tls_ctx->prepend_size;
>>> +    do {
>>> +        if (tls_do_allocation(sk, ctx, pfrag,
>>> +                      tls_ctx->prepend_size)) {
>>> +            rc = sk_stream_wait_memory(sk, &timeo);
>>> +            if (!rc)
>>> +                continue;
>>> +
>>> +            record = ctx->open_record;
>>> +            if (!record)
>>> +                break;
>>> +handle_error:
>>> +            if (record_type != TLS_RECORD_TYPE_DATA) {
>>> +                /* avoid sending partial
>>> +                 * record with type !=
>>> +                 * application_data
>>> +                 */
>>> +                size = orig_size;
>>> +                destroy_record(record);
>>> +                ctx->open_record = NULL;
>>> +            } else if (record->len > tls_ctx->prepend_size) {
>>> +                goto last_record;
>>> +            }
>>> +
>>> +            break;
>>> +        }
>>> +
>>> +        record = ctx->open_record;
>>> +        copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
>>> +        copy = min_t(size_t, copy, (max_open_record_len - record->len));
>>> +
>>> +        if (copy_from_iter_nocache(page_address(pfrag->page) +
>>> +                           pfrag->offset,
>>> +                       copy, msg_iter) != copy) {
>>> +            rc = -EFAULT;
>>> +            goto handle_error;
>>> +        }
>>> +        tls_append_frag(record, pfrag, copy);
>>> +
>>> +        size -= copy;
>>> +        if (!size) {
>>> +last_record:
>>> +            tls_push_record_flags = flags;
>>> +            if (more) {
>>> +                tls_ctx->pending_open_record_frags =
>>> +                        record->num_frags;
>>> +                break;
>>> +            }
>>> +
>>> +            done = true;
>>> +        }
>>> +
>>> +        if ((done) || record->len >= max_open_record_len ||
>>> +            (record->num_frags >= MAX_SKB_FRAGS - 1)) {
>>> +            rc = tls_push_record(sk,
>>> +                         tls_ctx,
>>> +                         ctx,
>>> +                         record,
>>> +                         pfrag,
>>> +                         tls_push_record_flags,
>>> +                         record_type);
>>> +            if (rc < 0)
>>> +                break;
>>> +        }
>>> +    } while (!done);
>>> +
>>> +    if (orig_size - size > 0)
>>> +        rc = orig_size - size;
>>> +
>>> +    return rc;
>>> +}
>>> +
>>> +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
>>> +{
>>> +    unsigned char record_type = TLS_RECORD_TYPE_DATA;
>>> +    int rc = 0;
>>> +
>>> +    lock_sock(sk);
>>> +
>>> +    if (unlikely(msg->msg_controllen)) {
>>> +        rc = tls_proccess_cmsg(sk, msg, &record_type);
>>> +        if (rc)
>>> +            goto out;
>>> +    }
>>> +
>>> +    rc = tls_push_data(sk, &msg->msg_iter, size,
>>> +               msg->msg_flags, record_type);
>>> +
>>> +out:
>>> +    release_sock(sk);
>>> +    return rc;
>>> +}
>>> +
>>> +int tls_device_sendpage(struct sock *sk, struct page *page,
>>> +            int offset, size_t size, int flags)
>>> +{
>>> +    struct iov_iter    msg_iter;
>>> +    struct kvec iov;
>>> +    char *kaddr = kmap(page);
>>> +    int rc = 0;
>>> +
>>> +    if (flags & MSG_SENDPAGE_NOTLAST)
>>> +        flags |= MSG_MORE;
>>> +
>>> +    lock_sock(sk);
>>> +
>>> +    if (flags & MSG_OOB) {
>>> +        rc = -ENOTSUPP;
>>> +        goto out;
>>> +    }
>>> +
>>> +    iov.iov_base = kaddr + offset;
>>> +    iov.iov_len = size;
>>> +    iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size);
>>> +    rc = tls_push_data(sk, &msg_iter, size,
>>> +               flags, TLS_RECORD_TYPE_DATA);
>>> +    kunmap(page);
>>> +
>>> +out:
>>> +    release_sock(sk);
>>> +    return rc;
>>> +}
>>> +
>>> +struct tls_record_info *tls_get_record(struct tls_offload_context *context,
>>> +                       u32 seq, u64 *p_record_sn)
>>> +{
>>> +    struct tls_record_info *info;
>>> +    u64 record_sn = context->hint_record_sn;
>>> +
>>> +    info = context->retransmit_hint;
>>> +    if (!info ||
>>> +        before(seq, info->end_seq - info->len)) {
>>> +        /* if retransmit_hint is irrelevant start
>>> +         * from the begging of the list
>>> +         */
>>> +        info = list_first_entry(&context->records_list,
>>> +                    struct tls_record_info, list);
>>> +        record_sn = context->unacked_record_sn;
>>> +    }
>>> +
>>> +    list_for_each_entry_from(info, &context->records_list, list) {
>>> +        if (before(seq, info->end_seq)) {
>>> +            if (!context->retransmit_hint ||
>>> +                after(info->end_seq,
>>> +                  context->retransmit_hint->end_seq)) {
>>> +                context->hint_record_sn = record_sn;
>>> +                context->retransmit_hint = info;
>>> +            }
>>> +            *p_record_sn = record_sn;
>>> +            return info;
>>> +        }
>>> +        record_sn++;
>>> +    }
>>> +
>>> +    return NULL;
>>> +}
>>> +EXPORT_SYMBOL(tls_get_record);
>>> +
>>> +static int tls_device_push_pending_record(struct sock *sk, int flags)
>>> +{
>>> +    struct iov_iter    msg_iter;
>>> +
>>> +    iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
>>> +    return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
>>> +}
>>> +
>>> +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
>>> +{
>>> +    u16 nonece_size, tag_size, iv_size, rec_seq_size;
>>> +    struct tls_record_info *start_marker_record;
>>> +    struct tls_offload_context *offload_ctx;
>>> +    struct tls_crypto_info *crypto_info;
>>> +    struct net_device *netdev;
>>> +    char *iv, *rec_seq;
>>> +    struct sk_buff *skb;
>>> +    int rc = -EINVAL;
>>> +    __be64 rcd_sn;
>>> +
>>> +    if (!ctx)
>>> +        goto out;
>>> +
>>> +    if (ctx->priv_ctx) {
>>> +        rc = -EEXIST;
>>> +        goto out;
>>> +    }
>>> +
>>> +    /* We support starting offload on multiple sockets
>>> +     * concurrently, So we only need a read lock here.
>>> +     */
>>> +    percpu_down_read(&device_offload_lock);
>>> +    netdev = get_netdev_for_sock(sk);
>>> +    if (!netdev) {
>>> +        pr_err_ratelimited("%s: netdev not found\n", __func__);
>>> +        rc = -EINVAL;
>>> +        goto release_lock;
>>> +    }
>>> +
>>> +    if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
>>> +        rc = -ENOTSUPP;
>>> +        goto release_netdev;
>>> +    }
>>> +
>>> +    /* Avoid offloading if the device is down
>>> +     * We don't want to offload new flows after
>>> +     * the NETDEV_DOWN event
>>> +     */
>>> +    if (!(netdev->flags & IFF_UP)) {
>>> +        rc = -EINVAL;
>>> +        goto release_lock;
>>> +    }
>>> +
>>> +    crypto_info = &ctx->crypto_send;
>>> +    switch (crypto_info->cipher_type) {
>>> +    case TLS_CIPHER_AES_GCM_128: {
>>> +        nonece_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
>>> +        tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
>>> +        iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
>>> +        iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
>>> +        rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
>>> +        rec_seq =
>>> +         ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
>>> +        break;
>>> +    }
>>> +    default:
>>> +        rc = -EINVAL;
>>> +        goto release_netdev;
>>> +    }
>>> +
>>> +    start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
>>
>> Can we memory allocations and simple memory initializations ouside the global rwsem?
>>
> 
> Sure, we can move all memory allocations outside the lock.
> 
>>> +    if (!start_marker_record) {
>>> +        rc = -ENOMEM;
>>> +        goto release_netdev;
>>> +    }
>>> +
>>> +    offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL);
>>> +    if (!offload_ctx)
>>> +        goto free_marker_record;
>>> +
>>> +    ctx->priv_ctx = offload_ctx;
>>> +    rc = attach_sock_to_netdev(sk, netdev, ctx);
>>> +    if (rc)
>>> +        goto free_offload_context;
>>> +
>>> +    ctx->netdev = netdev;
>>> +    ctx->prepend_size = TLS_HEADER_SIZE + nonece_size;
>>> +    ctx->tag_size = tag_size;
>>> +    ctx->iv_size = iv_size;
>>> +    ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
>>> +              GFP_KERNEL);
>>> +    if (!ctx->iv) {
>>> +        rc = -ENOMEM;
>>> +        goto detach_sock;
>>> +    }
>>> +
>>> +    memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
>>> +
>>> +    ctx->rec_seq_size = rec_seq_size;
>>> +    ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
>>> +    if (!ctx->rec_seq) {
>>> +        rc = -ENOMEM;
>>> +        goto free_iv;
>>> +    }
>>> +    memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
>>> +
>>> +    /* start at rec_seq - 1 to account for the start marker record */
>>> +    memcpy(&rcd_sn, ctx->rec_seq, sizeof(rcd_sn));
>>> +    offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
>>> +
>>> +    rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
>>> +    if (rc)
>>> +        goto free_rec_seq;
>>> +
>>> +    start_marker_record->end_seq = tcp_sk(sk)->write_seq;
>>> +    start_marker_record->len = 0;
>>> +    start_marker_record->num_frags = 0;
>>> +
>>> +    INIT_LIST_HEAD(&offload_ctx->records_list);
>>> +    list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
>>> +    spin_lock_init(&offload_ctx->lock);
>>> +
>>> +    inet_csk(sk)->icsk_clean_acked = &tls_icsk_clean_acked;
>>> +    ctx->push_pending_record = tls_device_push_pending_record;
>>> +    offload_ctx->sk_destruct = sk->sk_destruct;
>>> +
>>> +    /* TLS offload is greatly simplified if we don't send
>>> +     * SKBs where only part of the payload needs to be encrypted.
>>> +     * So mark the last skb in the write queue as end of record.
>>> +     */
>>> +    skb = tcp_write_queue_tail(sk);
>>> +    if (skb)
>>> +        TCP_SKB_CB(skb)->eor = 1;
>>> +
>>> +    refcount_set(&ctx->refcount, 1);
>>> +    spin_lock_irq(&tls_device_lock);
>>> +    list_add_tail(&ctx->list, &tls_device_list);
>>> +    spin_unlock_irq(&tls_device_lock);
>>> +
>>> +    /* following this assignment tls_is_sk_tx_device_offloaded
>>> +     * will return true and the context might be accessed
>>> +     * by the netdev's xmit function.
>>> +     */
>>> +    smp_store_release(&sk->sk_destruct,
>>> +              &tls_device_sk_destruct);
>>> +    goto release_lock;
>>> +
>>> +free_rec_seq:
>>> +    kfree(ctx->rec_seq);
>>> +free_iv:
>>> +    kfree(ctx->iv);
>>> +detach_sock:
>>> +    netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX);
>>> +free_offload_context:
>>> +    kfree(offload_ctx);
>>> +    ctx->priv_ctx = NULL;
>>> +free_marker_record:
>>> +    kfree(start_marker_record);
>>> +release_netdev:
>>> +    dev_put(netdev);
>>> +release_lock:
>>> +    percpu_up_read(&device_offload_lock);
>>> +out:
>>> +    return rc;
>>> +}
>>> +
>>> +static int tls_device_register(struct net_device *dev)
>>> +{
>>> +    if ((dev->features & NETIF_F_HW_TLS_TX) && !dev->tlsdev_ops)
>>> +        return NOTIFY_BAD;
>>> +
>>> +    return NOTIFY_DONE;
>>> +}
>>
>> This function is the same as tls_device_feat_change(). Can't we merge
>> them together and avoid duplicating of code?
>>
> 
> Sure.
> 
>>> +static int tls_device_unregister(struct net_device *dev)
>>> +{
>>> +    return NOTIFY_DONE;
>>> +}
>>
>> This function does nothing, and next patches do not change it.
>> Can't we remove it since so?
>>
> 
> Sure.
> 
>>> +static int tls_device_feat_change(struct net_device *dev)
>>> +{
>>> +    if ((dev->features & NETIF_F_HW_TLS_TX) && !dev->tlsdev_ops)
>>> +        return NOTIFY_BAD;
>>> +
>>> +    return NOTIFY_DONE;
>>> +}
>>> +
>>> +static int tls_device_down(struct net_device *netdev)
>>> +{
>>> +    struct tls_context *ctx, *tmp;
>>> +    struct list_head list;
>>> +    unsigned long flags;
>>> +
>>> +    if (!(netdev->features & NETIF_F_HW_TLS_TX))
>>> +        return NOTIFY_DONE;
>>
>> Can't we move this check in tls_dev_event() and use it for all types of events?
>> Then we avoid duplicate code.
>>
> 
> No. Not all events require this check. Also, the result is different for different events.

No. You always return NOTIFY_DONE, in case of !(netdev->features & NETIF_F_HW_TLS_TX).
See below:

static int tls_check_dev_ops(struct net_device *dev) 
{
	if (!dev->tlsdev_ops)
		return NOTIFY_BAD; 

	return NOTIFY_DONE; 
}

static int tls_device_down(struct net_device *netdev) 
{
	struct tls_context *ctx, *tmp; 
	struct list_head list; 
	unsigned long flags; 

	...
	return NOTIFY_DONE;
}

static int tls_dev_event(struct notifier_block *this, unsigned long event, 
        		 void *ptr) 
{ 
        struct net_device *dev = netdev_notifier_info_to_dev(ptr); 

	if (!(netdev->features & NETIF_F_HW_TLS_TX)) 
		return NOTIFY_DONE; 
 
        switch (event) { 
        case NETDEV_REGISTER:
        case NETDEV_FEAT_CHANGE: 
        	return tls_check_dev_ops(dev); 
 
        case NETDEV_DOWN: 
        	return tls_device_down(dev); 
        } 
        return NOTIFY_DONE; 
} 
 
>>> +
>>> +    /* Request a write lock to block new offload attempts
>>> +     */
>>> +    percpu_down_write(&device_offload_lock);
>>
>> What is the reason percpu_rwsem is chosen here? It looks like this primitive
>> gives more advantages readers, then plain rwsem does. But it also gives
>> disadvantages to writers. It would be good, unless tls_device_down() is called
>> with rtnl_lock() held from netdevice notifier. But since netdevice notifier
>> are called with rtnl_lock() held, percpu_rwsem will increase the time rtnl_lock()
>> is locked.
> We use the a rwsem to allow multiple (readers) invocations of tls_set_device_offload, which is triggered by the user (persumably) during the TLS handshake. This might be considered a fast-path.
> 
> However, we must block all calls to tls_set_device_offload while we are processing NETDEV_DOWN events (writer).
> 
> As you've mentioned, the percpu rwsem is more efficient for readers, especially on NUMA systems, where cache-line bouncing occurs during reader acquire and reduces performance.

Hm, and who are the readers? It's used from do_tls_setsockopt_tx(), but it doesn't
seem to be performance critical. Who else?

>>
>> Can't we use plain rwsem here instead?
>>
> 
> Its a performance tradeoff. I'm not certain that the percpu rwsem write side acquire is significantly worse than using the global rwsem.
> 
> For now, while all of this is experimental, can we agree to focus on the performance of readers? We can change it later if it becomes a problem.

Same as above.
 
>>> +
>>> +    spin_lock_irqsave(&tls_device_lock, flags);
>>> +    INIT_LIST_HEAD(&list);
>>
>> This may go outside the global spinlock.
>>
> 
> Sure.
> 
>>> +    list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
>>> +        if (ctx->netdev != netdev ||
>>> +            !refcount_inc_not_zero(&ctx->refcount))
>>> +            continue;
>>> +
>>> +        list_move(&ctx->list, &list);
>>> +    }
>>> +    spin_unlock_irqrestore(&tls_device_lock, flags);
>>> +
>>> +    list_for_each_entry_safe(ctx, tmp, &list, list)    {
>>> +        netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
>>> +                        TLS_OFFLOAD_CTX_DIR_TX);
>>> +        ctx->netdev = NULL;
>>> +        dev_put(netdev);
>>> +        list_del_init(&ctx->list);
>>> +
>>> +        if (refcount_dec_and_test(&ctx->refcount))
>>> +            tls_device_free_ctx(ctx);
>>> +    }
>>> +
>>> +    percpu_up_write(&device_offload_lock);
>>> +
>>> +    flush_work(&tls_device_gc_work);
>>> +
>>> +    return NOTIFY_DONE;
>>> +}
>>> +
>>> +static int tls_dev_event(struct notifier_block *this, unsigned long event,
>>> +             void *ptr)
>>> +{
>>> +    struct net_device *dev = netdev_notifier_info_to_dev(ptr);
>>> +
>>> +    switch (event) {
>>> +    case NETDEV_REGISTER:
>>> +        return tls_device_register(dev);
>>> +
>>> +    case NETDEV_UNREGISTER:
>>> +        return tls_device_unregister(dev);
>>> +
>>> +    case NETDEV_FEAT_CHANGE:
>>> +        return tls_device_feat_change(dev);
>>> +
>>> +    case NETDEV_DOWN:
>>> +        return tls_device_down(dev);
>>> +    }
>>> +    return NOTIFY_DONE;
>>> +}
>>> +
>>> +static struct notifier_block tls_dev_notifier = {
>>> +    .notifier_call    = tls_dev_event,
>>> +};
>>> +
>>> +void __init tls_device_init(void)
>>> +{
>>> +    register_netdevice_notifier(&tls_dev_notifier);
>>> +}
>>> +
>>> +void __exit tls_device_cleanup(void)
>>> +{
>>> +    unregister_netdevice_notifier(&tls_dev_notifier);
>>> +    flush_work(&tls_device_gc_work);
>>> +}
>>> diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
>>> new file mode 100644
>>> index 000000000000..14d31a36885c
>>> --- /dev/null
>>> +++ b/net/tls/tls_device_fallback.c
>>> @@ -0,0 +1,419 @@
>>> +/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
>>> + *
>>> + *     Redistribution and use in source and binary forms, with or
>>> + *     without modification, are permitted provided that the following
>>> + *     conditions are met:
>>> + *
>>> + *      - Redistributions of source code must retain the above
>>> + *        copyright notice, this list of conditions and the following
>>> + *        disclaimer.
>>> + *
>>> + *      - Redistributions in binary form must reproduce the above
>>> + *        copyright notice, this list of conditions and the following
>>> + *        disclaimer in the documentation and/or other materials
>>> + *        provided with the distribution.
>>> + *
>>> + *      - Neither the name of the Mellanox Technologies nor the
>>> + *        names of its contributors may be used to endorse or promote
>>> + *        products derived from this software without specific prior written
>>> + *        permission.
>>> + *
>>> + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
>>> + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
>>> + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
>>> + * A PARTICULAR PURPOSE ARE DISCLAIMED.
>>> + * IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
>>> + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
>>> + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
>>> + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
>>> + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
>>> + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
>>> + * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
>>> + * POSSIBILITY OF SUCH DAMAGE
>>> + */
>>> +
>>> +#include <net/tls.h>
>>> +#include <crypto/aead.h>
>>> +#include <crypto/scatterwalk.h>
>>> +#include <net/ip6_checksum.h>
>>> +
>>> +static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
>>> +{
>>> +    struct scatterlist *src = walk->sg;
>>> +    int diff = walk->offset - src->offset;
>>> +
>>> +    sg_set_page(sg, sg_page(src),
>>> +            src->length - diff, walk->offset);
>>> +
>>> +    scatterwalk_crypto_chain(sg, sg_next(src), 0, 2);
>>> +}
>>> +
>>> +static int tls_enc_record(struct aead_request *aead_req,
>>> +              struct crypto_aead *aead, char *aad, char *iv,
>>> +              __be64 rcd_sn, struct scatter_walk *in,
>>> +              struct scatter_walk *out, int *in_len)
>>> +{
>>> +    struct scatterlist sg_in[3];
>>> +    struct scatterlist sg_out[3];
>>> +    unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
>>> +    u16 len;
>>> +    int rc;
>>> +
>>> +    len = min_t(int, *in_len, ARRAY_SIZE(buf));
>>> +
>>> +    scatterwalk_copychunks(buf, in, len, 0);
>>> +    scatterwalk_copychunks(buf, out, len, 1);
>>> +
>>> +    *in_len -= len;
>>> +    if (!*in_len)
>>> +        return 0;
>>> +
>>> +    scatterwalk_pagedone(in, 0, 1);
>>> +    scatterwalk_pagedone(out, 1, 1);
>>> +
>>> +    len = buf[4] | (buf[3] << 8);
>>> +    len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
>>> +
>>> +    tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
>>> +             (char *)&rcd_sn, sizeof(rcd_sn), buf[0]);
>>> +
>>> +    memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
>>> +           TLS_CIPHER_AES_GCM_128_IV_SIZE);
>>> +
>>> +    sg_init_table(sg_in, ARRAY_SIZE(sg_in));
>>> +    sg_init_table(sg_out, ARRAY_SIZE(sg_out));
>>> +    sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
>>> +    sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
>>> +    chain_to_walk(sg_in + 1, in);
>>> +    chain_to_walk(sg_out + 1, out);
>>> +
>>> +    *in_len -= len;
>>> +    if (*in_len < 0) {
>>> +        *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
>>> +        if (*in_len < 0)
>>> +        /* the input buffer doesn't contain the entire record.
>>> +         * trim len accordingly. The resulting authentication tag
>>> +         * will contain garbage. but we don't care as we won't
>>> +         * include any of it in the output skb
>>> +         * Note that we assume the output buffer length
>>> +         * is larger then input buffer length + tag size
>>> +         */
>>> +            len += *in_len;
>>> +
>>> +        *in_len = 0;
>>> +    }
>>> +
>>> +    if (*in_len) {
>>> +        scatterwalk_copychunks(NULL, in, len, 2);
>>> +        scatterwalk_pagedone(in, 0, 1);
>>> +        scatterwalk_copychunks(NULL, out, len, 2);
>>> +        scatterwalk_pagedone(out, 1, 1);
>>> +    }
>>> +
>>> +    len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
>>> +    aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
>>> +
>>> +    rc = crypto_aead_encrypt(aead_req);
>>> +
>>> +    return rc;
>>> +}
>>> +
>>> +static void tls_init_aead_request(struct aead_request *aead_req,
>>> +                  struct crypto_aead *aead)
>>> +{
>>> +    aead_request_set_tfm(aead_req, aead);
>>> +    aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
>>> +}
>>> +
>>> +static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead,
>>> +                           gfp_t flags)
>>> +{
>>> +    unsigned int req_size = sizeof(struct aead_request) +
>>> +        crypto_aead_reqsize(aead);
>>> +    struct aead_request *aead_req;
>>> +
>>> +    aead_req = kzalloc(req_size, flags);
>>> +    if (!aead_req)
>>> +        return NULL;
>>> +
>>> +    tls_init_aead_request(aead_req, aead);
>>> +    return aead_req;
>>> +}
>>> +
>>> +static int tls_enc_records(struct aead_request *aead_req,
>>> +               struct crypto_aead *aead, struct scatterlist *sg_in,
>>> +               struct scatterlist *sg_out, char *aad, char *iv,
>>> +               u64 rcd_sn, int len)
>>> +{
>>> +    struct scatter_walk in;
>>> +    struct scatter_walk out;
>>> +    int rc;
>>> +
>>> +    scatterwalk_start(&in, sg_in);
>>> +    scatterwalk_start(&out, sg_out);
>>> +
>>> +    do {
>>> +        rc = tls_enc_record(aead_req, aead, aad, iv,
>>> +                    cpu_to_be64(rcd_sn), &in, &out, &len);
>>> +        rcd_sn++;
>>> +
>>> +    } while (rc == 0 && len);
>>> +
>>> +    scatterwalk_done(&in, 0, 0);
>>> +    scatterwalk_done(&out, 1, 0);
>>> +
>>> +    return rc;
>>> +}
>>> +
>>> +static inline void update_chksum(struct sk_buff *skb, int headln)
>>> +{
>>> +    /* Can't use icsk->icsk_af_ops->send_check here because the ip addresses
>>> +     * might have been changed by NAT.
>>> +     */
>>> +
>>> +    const struct ipv6hdr *ipv6h;
>>> +    const struct iphdr *iph;
>>> +    struct tcphdr *th = tcp_hdr(skb);
>>> +    int datalen = skb->len - headln;
>>> +
>>> +    /* We only changed the payload so if we are using partial we don't
>>> +     * need to update anything.
>>> +     */
>>> +    if (likely(skb->ip_summed == CHECKSUM_PARTIAL))
>>> +        return;
>>> +
>>> +    skb->ip_summed = CHECKSUM_PARTIAL;
>>> +    skb->csum_start = skb_transport_header(skb) - skb->head;
>>> +    skb->csum_offset = offsetof(struct tcphdr, check);
>>> +
>>> +    if (skb->sk->sk_family == AF_INET6) {
>>> +        ipv6h = ipv6_hdr(skb);
>>> +        th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
>>> +                         datalen, IPPROTO_TCP, 0);
>>> +    } else {
>>> +        iph = ip_hdr(skb);
>>> +        th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen,
>>> +                           IPPROTO_TCP, 0);
>>> +    }
>>> +}
>>> +
>>> +static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
>>> +{
>>> +    skb_copy_header(nskb, skb);
>>> +
>>> +    skb_put(nskb, skb->len);
>>> +    memcpy(nskb->data, skb->data, headln);
>>> +    update_chksum(nskb, headln);
>>> +
>>> +    nskb->destructor = skb->destructor;
>>> +    nskb->sk = skb->sk;
>>> +    skb->destructor = NULL;
>>> +    skb->sk = NULL;
>>> +    refcount_add(nskb->truesize - skb->truesize,
>>> +             &nskb->sk->sk_wmem_alloc);
>>> +}
>>> +
>>> +/* This function may be called after the user socket is already
>>> + * closed so make sure we don't use anything freed during
>>> + * tls_sk_proto_close here
>>> + */
>>> +static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
>>> +{
>>> +    int tcp_header_size = tcp_hdrlen(skb);
>>> +    int tcp_payload_offset = skb_transport_offset(skb) + tcp_header_size;
>>> +    int payload_len = skb->len - tcp_payload_offset;
>>> +    struct tls_context *tls_ctx = tls_get_ctx(sk);
>>> +    struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
>>> +    int remaining, buf_len, resync_sgs, rc, i = 0;
>>> +    void *buf, *dummy_buf, *iv, *aad;
>>> +    struct scatterlist *sg_in;
>>> +    struct scatterlist sg_out[3];
>>> +    u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
>>> +    struct aead_request *aead_req;
>>> +    struct sk_buff *nskb = NULL;
>>> +    struct tls_record_info *record;
>>> +    unsigned long flags;
>>> +    s32 sync_size;
>>> +    u64 rcd_sn;
>>> +
>>> +    /* worst case is:
>>> +     * MAX_SKB_FRAGS in tls_record_info
>>> +     * MAX_SKB_FRAGS + 1 in SKB head an frags.
>>> +     */
>>> +    int sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1;
>>> +
>>> +    if (!payload_len)
>>> +        return skb;
>>> +
>>> +    sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC);
>>> +    if (!sg_in)
>>> +        goto free_orig;
>>> +
>>> +    sg_init_table(sg_in, sg_in_max_elements);
>>> +    sg_init_table(sg_out, ARRAY_SIZE(sg_out));
>>> +
>>> +    spin_lock_irqsave(&ctx->lock, flags);
>>> +    record = tls_get_record(ctx, tcp_seq, &rcd_sn);
>>> +    if (!record) {
>>> +        spin_unlock_irqrestore(&ctx->lock, flags);
>>> +        WARN(1, "Record not found for seq %u\n", tcp_seq);
>>> +        goto free_sg;
>>> +    }
>>> +
>>> +    sync_size = tcp_seq - tls_record_start_seq(record);
>>> +    if (sync_size < 0) {
>>> +        int is_start_marker = tls_record_is_start_marker(record);
>>> +
>>> +        spin_unlock_irqrestore(&ctx->lock, flags);
>>> +        if (!is_start_marker)
>>> +        /* This should only occur if the relevant record was
>>> +         * already acked. In that case it should be ok
>>> +         * to drop the packet and avoid retransmission.
>>> +         *
>>> +         * There is a corner case where the packet contains
>>> +         * both an acked and a non-acked record.
>>> +         * We currently don't handle that case and rely
>>> +         * on TCP to retranmit a packet that doesn't contain
>>> +         * already acked payload.
>>> +         */
>>> +            goto free_orig;
>>> +
>>> +        if (payload_len > -sync_size) {
>>> +            WARN(1, "Fallback of partially offloaded packets is not supported\n");
>>> +            goto free_sg;
>>> +        } else {
>>> +            return skb;
>>> +        }
>>> +    }
>>> +
>>> +    remaining = sync_size;
>>> +    while (remaining > 0) {
>>> +        skb_frag_t *frag = &record->frags[i];
>>> +
>>> +        __skb_frag_ref(frag);
>>> +        sg_set_page(sg_in + i, skb_frag_page(frag),
>>> +                skb_frag_size(frag), frag->page_offset);
>>> +
>>> +        remaining -= skb_frag_size(frag);
>>> +
>>> +        if (remaining < 0)
>>> +            sg_in[i].length += remaining;
>>> +
>>> +        i++;
>>> +    }
>>> +    spin_unlock_irqrestore(&ctx->lock, flags);
>>> +    resync_sgs = i;
>>> +
>>> +    aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC);
>>> +    if (!aead_req)
>>> +        goto put_sg;
>>> +
>>> +    buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
>>> +          TLS_CIPHER_AES_GCM_128_IV_SIZE +
>>> +          TLS_AAD_SPACE_SIZE +
>>> +          sync_size +
>>> +          tls_ctx->tag_size;
>>> +    buf = kmalloc(buf_len, GFP_ATOMIC);
>>> +    if (!buf)
>>> +        goto free_req;
>>> +
>>> +    nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
>>> +    if (!nskb)
>>> +        goto free_buf;
>>> +
>>> +    skb_reserve(nskb, skb_headroom(skb));
>>> +
>>> +    iv = buf;
>>> +
>>> +    memcpy(iv, tls_ctx->crypto_send_aes_gcm_128.salt,
>>> +           TLS_CIPHER_AES_GCM_128_SALT_SIZE);
>>> +    aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
>>> +          TLS_CIPHER_AES_GCM_128_IV_SIZE;
>>> +    dummy_buf = aad + TLS_AAD_SPACE_SIZE;
>>> +
>>> +    sg_set_buf(&sg_out[0], dummy_buf, sync_size);
>>> +    sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset,
>>> +           payload_len);
>>> +    /* Add room for authentication tag produced by crypto */
>>> +    dummy_buf += sync_size;
>>> +    sg_set_buf(&sg_out[2], dummy_buf, tls_ctx->tag_size);
>>> +    rc = skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset,
>>> +              payload_len);
>>> +    if (rc < 0)
>>> +        goto free_nskb;
>>> +
>>> +    rc = tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv,
>>> +                 rcd_sn, sync_size + payload_len);
>>> +    if (rc < 0)
>>> +        goto free_nskb;
>>> +
>>> +    complete_skb(nskb, skb, tcp_payload_offset);
>>> +
>>> +    /* validate_xmit_skb_list assumes that if the skb wasn't segmented
>>> +     * nskb->prev will point to the skb itself
>>> +     */
>>> +    nskb->prev = nskb;
>>> +free_buf:
>>> +    kfree(buf);
>>> +free_req:
>>> +    kfree(aead_req);
>>> +put_sg:
>>> +    for (i = 0; i < resync_sgs; i++)
>>> +        put_page(sg_page(&sg_in[i]));
>>> +free_sg:
>>> +    kfree(sg_in);
>>> +free_orig:
>>> +    kfree_skb(skb);
>>> +    return nskb;
>>> +
>>> +free_nskb:
>>> +    kfree_skb(nskb);
>>> +    nskb = NULL;
>>> +    goto free_buf;
>>> +}
>>> +
>>> +static struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
>>> +                         struct net_device *dev,
>>> +                         struct sk_buff *skb)
>>> +{
>>> +    if (dev == tls_get_ctx(sk)->netdev)
>>> +        return skb;
>>> +
>>> +    return tls_sw_fallback(sk, skb);
>>> +}
>>> +
>>> +int tls_sw_fallback_init(struct sock *sk,
>>> +             struct tls_offload_context *offload_ctx,
>>> +             struct tls_crypto_info *crypto_info)
>>> +{
>>> +    int rc;
>>> +    const u8 *key;
>>> +
>>> +    offload_ctx->aead_send =
>>> +        crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
>>> +    if (IS_ERR(offload_ctx->aead_send)) {
>>> +        rc = PTR_ERR(offload_ctx->aead_send);
>>> +        pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc);
>>> +        offload_ctx->aead_send = NULL;
>>> +        goto err_out;
>>> +    }
>>> +
>>> +    key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
>>> +
>>> +    rc = crypto_aead_setkey(offload_ctx->aead_send, key,
>>> +                TLS_CIPHER_AES_GCM_128_KEY_SIZE);
>>> +    if (rc)
>>> +        goto free_aead;
>>> +
>>> +    rc = crypto_aead_setauthsize(offload_ctx->aead_send,
>>> +                     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
>>> +    if (rc)
>>> +        goto free_aead;
>>> +
>>> +    sk->sk_validate_xmit_skb = tls_validate_xmit_skb;
>>> +    return 0;
>>> +free_aead:
>>> +    crypto_free_aead(offload_ctx->aead_send);
>>> +err_out:
>>> +    return rc;
>>> +}
>>> diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
>>> index d824d548447e..e0dface33017 100644
>>> --- a/net/tls/tls_main.c
>>> +++ b/net/tls/tls_main.c
>>> @@ -54,6 +54,9 @@ enum {
>>>   enum {
>>>       TLS_BASE_TX,
>>>       TLS_SW_TX,
>>> +#ifdef CONFIG_TLS_DEVICE
>>> +    TLS_HW_TX,
>>> +#endif
>>>       TLS_NUM_CONFIG,
>>>   };
>>>   @@ -416,11 +419,19 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
>>>           goto err_crypto_info;
>>>       }
>>>   -    /* currently SW is default, we will have ethtool in future */
>>> -    rc = tls_set_sw_offload(sk, ctx);
>>> -    tx_conf = TLS_SW_TX;
>>> -    if (rc)
>>> -        goto err_crypto_info;
>>> +#ifdef CONFIG_TLS_DEVICE
>>> +    rc = tls_set_device_offload(sk, ctx);
>>> +    tx_conf = TLS_HW_TX;
>>> +    if (rc) {
>>> +#else
>>> +    {
>>> +#endif
>>> +        /* if HW offload fails fallback to SW */
>>> +        rc = tls_set_sw_offload(sk, ctx);
>>> +        tx_conf = TLS_SW_TX;
>>> +        if (rc)
>>> +            goto err_crypto_info;
>>> +    }
>>>         ctx->tx_conf = tx_conf;
>>>       update_sk_prot(sk, ctx);
>>> @@ -473,6 +484,12 @@ static void build_protos(struct proto *prot, struct proto *base)
>>>       prot[TLS_SW_TX] = prot[TLS_BASE_TX];
>>>       prot[TLS_SW_TX].sendmsg        = tls_sw_sendmsg;
>>>       prot[TLS_SW_TX].sendpage    = tls_sw_sendpage;
>>> +
>>> +#ifdef CONFIG_TLS_DEVICE
>>> +    prot[TLS_HW_TX] = prot[TLS_SW_TX];
>>> +    prot[TLS_HW_TX].sendmsg        = tls_device_sendmsg;
>>> +    prot[TLS_HW_TX].sendpage    = tls_device_sendpage;
>>> +#endif
>>>   }
>>>     static int tls_init(struct sock *sk)
>>> @@ -531,6 +548,9 @@ static int __init tls_register(void)
>>>   {
>>>       build_protos(tls_prots[TLSV4], &tcp_prot);
>>>   +#ifdef CONFIG_TLS_DEVICE
>>> +    tls_device_init();
>>> +#endif
>>>       tcp_register_ulp(&tcp_tls_ulp_ops);
>>>         return 0;
>>> @@ -539,6 +559,9 @@ static int __init tls_register(void)
>>>   static void __exit tls_unregister(void)
>>>   {
>>>       tcp_unregister_ulp(&tcp_tls_ulp_ops);
>>> +#ifdef CONFIG_TLS_DEVICE
>>> +    tls_device_cleanup();
>>> +#endif
>>>   }
>>>     module_init(tls_register);

Thanks,
Kirill
Saeed Mahameed March 21, 2018, 8:50 p.m. UTC | #6
On Wed, 2018-03-21 at 19:31 +0300, Kirill Tkhai wrote:
> On 21.03.2018 18:53, Boris Pismenny wrote:
> > ...
> > > 
> > > Other patches have two licenses in header. Can I distribute this
> > > file under GPL license terms?
> > > 
> > 
> > Sure, I'll update the license to match other files under net/tls.
> > 
> > > > +#include <linux/module.h>
> > > > +#include <net/tcp.h>
> > > > +#include <net/inet_common.h>
> > > > +#include <linux/highmem.h>
> > > > +#include <linux/netdevice.h>
> > > > +
> > > > +#include <net/tls.h>
> > > > +#include <crypto/aead.h>
> > > > +
> > > > +/* device_offload_lock is used to synchronize tls_dev_add
> > > > + * against NETDEV_DOWN notifications.
> > > > + */
> > > > +DEFINE_STATIC_PERCPU_RWSEM(device_offload_lock);
> > > > +
> > > > +static void tls_device_gc_task(struct work_struct *work);
> > > > +
> > > > +static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
> > > > +static LIST_HEAD(tls_device_gc_list);
> > > > +static LIST_HEAD(tls_device_list);
> > > > +static DEFINE_SPINLOCK(tls_device_lock);
> > > > +
> > > > +static void tls_device_free_ctx(struct tls_context *ctx)
> > > > +{
> > > > +    struct tls_offload_context *offlad_ctx =
> > > > tls_offload_ctx(ctx);
> > > > +
> > > > +    kfree(offlad_ctx);
> > > > +    kfree(ctx);
> > > > +}
> > > > +
> > > > +static void tls_device_gc_task(struct work_struct *work)
> > > > +{
> > > > +    struct tls_context *ctx, *tmp;
> > > > +    struct list_head gc_list;
> > > > +    unsigned long flags;
> > > > +
> > > > +    spin_lock_irqsave(&tls_device_lock, flags);
> > > > +    INIT_LIST_HEAD(&gc_list);
> > > 
> > > This is stack variable, and it should be initialized outside of
> > > global spinlock.
> > > There is LIST_HEAD() primitive for that in kernel.
> > > There is one more similar place below.
> > > 
> > 
> > Sure.
> > 
> > > > +    list_splice_init(&tls_device_gc_list, &gc_list);
> > > > +    spin_unlock_irqrestore(&tls_device_lock, flags);
> > > > +
> > > > +    list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
> > > > +        struct net_device *netdev = ctx->netdev;
> > > > +
> > > > +        if (netdev) {
> > > > +            netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
> > > > +                            TLS_OFFLOAD_CTX_DIR_TX);
> > > > +            dev_put(netdev);
> > > > +        }
> > > 
> > > How is possible the situation we meet NULL netdev here >
> > 
> > This can happen in tls_device_down. tls_deviec_down is called
> > whenever a netdev that is used for TLS inline crypto offload goes
> > down. It gets called via the NETDEV_DOWN event of the netdevice
> > notifier.
> > 
> > This flow is somewhat similar to the xfrm_device netdev notifier.
> > However, we do not destroy the socket (as in destroying the
> > xfrm_state in xfrm_device). Instead, we cleanup the netdev state
> > and allow software fallback to handle the rest of the traffic.
> > 
> > > > +
> > > > +        list_del(&ctx->list);
> > > > +        tls_device_free_ctx(ctx);
> > > > +    }
> > > > +}
> > > > +
> > > > +static void tls_device_queue_ctx_destruction(struct
> > > > tls_context *ctx)
> > > > +{
> > > > +    unsigned long flags;
> > > > +
> > > > +    spin_lock_irqsave(&tls_device_lock, flags);
> > > > +    list_move_tail(&ctx->list, &tls_device_gc_list);
> > > > +
> > > > +    /* schedule_work inside the spinlock
> > > > +     * to make sure tls_device_down waits for that work.
> > > > +     */
> > > > +    schedule_work(&tls_device_gc_work);
> > > > +
> > > > +    spin_unlock_irqrestore(&tls_device_lock, flags);
> > > > +}
> > > > +
> > > > +/* We assume that the socket is already connected */
> > > > +static struct net_device *get_netdev_for_sock(struct sock *sk)
> > > > +{
> > > > +    struct inet_sock *inet = inet_sk(sk);
> > > > +    struct net_device *netdev = NULL;
> > > > +
> > > > +    netdev = dev_get_by_index(sock_net(sk), inet-
> > > > >cork.fl.flowi_oif);
> > > > +
> > > > +    return netdev;
> > > > +}
> > > > +
> > > > +static int attach_sock_to_netdev(struct sock *sk, struct
> > > > net_device *netdev,
> > > > +                 struct tls_context *ctx)
> > > > +{
> > > > +    int rc;
> > > > +
> > > > +    rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk,
> > > > TLS_OFFLOAD_CTX_DIR_TX,
> > > > +                         &ctx->crypto_send,
> > > > +                         tcp_sk(sk)->write_seq);
> > > > +    if (rc) {
> > > > +        pr_err_ratelimited("The netdev has refused to offload
> > > > this socket\n");
> > > > +        goto out;
> > > > +    }
> > > > +
> > > > +    rc = 0;
> > > > +out:
> > > > +    return rc;
> > > > +}
> > > > +
> > > > +static void destroy_record(struct tls_record_info *record)
> > > > +{
> > > > +    skb_frag_t *frag;
> > > > +    int nr_frags = record->num_frags;
> > > > +
> > > > +    while (nr_frags > 0) {
> > > > +        frag = &record->frags[nr_frags - 1];
> > > > +        __skb_frag_unref(frag);
> > > > +        --nr_frags;
> > > > +    }
> > > > +    kfree(record);
> > > > +}
> > > > +
> > > > +static void delete_all_records(struct tls_offload_context
> > > > *offload_ctx)
> > > > +{
> > > > +    struct tls_record_info *info, *temp;
> > > > +
> > > > +    list_for_each_entry_safe(info, temp, &offload_ctx-
> > > > >records_list, list) {
> > > > +        list_del(&info->list);
> > > > +        destroy_record(info);
> > > > +    }
> > > > +
> > > > +    offload_ctx->retransmit_hint = NULL;
> > > > +}
> > > > +
> > > > +static void tls_icsk_clean_acked(struct sock *sk, u32
> > > > acked_seq)
> > > > +{
> > > > +    struct tls_context *tls_ctx = tls_get_ctx(sk);
> > > > +    struct tls_offload_context *ctx;
> > > > +    struct tls_record_info *info, *temp;
> > > > +    unsigned long flags;
> > > > +    u64 deleted_records = 0;
> > > > +
> > > > +    if (!tls_ctx)
> > > > +        return;
> > > > +
> > > > +    ctx = tls_offload_ctx(tls_ctx);
> > > > +
> > > > +    spin_lock_irqsave(&ctx->lock, flags);
> > > > +    info = ctx->retransmit_hint;
> > > > +    if (info && !before(acked_seq, info->end_seq)) {
> > > > +        ctx->retransmit_hint = NULL;
> > > > +        list_del(&info->list);
> > > > +        destroy_record(info);
> > > > +        deleted_records++;
> > > > +    }
> > > > +
> > > > +    list_for_each_entry_safe(info, temp, &ctx->records_list,
> > > > list) {
> > > > +        if (before(acked_seq, info->end_seq))
> > > > +            break;
> > > > +        list_del(&info->list);
> > > > +
> > > > +        destroy_record(info);
> > > > +        deleted_records++;
> > > > +    }
> > > > +
> > > > +    ctx->unacked_record_sn += deleted_records;
> > > > +    spin_unlock_irqrestore(&ctx->lock, flags);
> > > > +}
> > > > +
> > > > +/* At this point, there should be no references on this
> > > > + * socket and no in-flight SKBs associated with this
> > > > + * socket, so it is safe to free all the resources.
> > > > + */
> > > > +void tls_device_sk_destruct(struct sock *sk)
> > > > +{
> > > > +    struct tls_context *tls_ctx = tls_get_ctx(sk);
> > > > +    struct tls_offload_context *ctx =
> > > > tls_offload_ctx(tls_ctx);
> > > > +
> > > > +    if (ctx->open_record)
> > > > +        destroy_record(ctx->open_record);
> > > > +
> > > > +    delete_all_records(ctx);
> > > > +    crypto_free_aead(ctx->aead_send);
> > > > +    ctx->sk_destruct(sk);
> > > > +
> > > > +    if (refcount_dec_and_test(&tls_ctx->refcount))
> > > > +        tls_device_queue_ctx_destruction(tls_ctx);
> > > > +}
> > > > +EXPORT_SYMBOL(tls_device_sk_destruct);
> > > > +
> > > > +static inline void tls_append_frag(struct tls_record_info
> > > > *record,
> > > > +                   struct page_frag *pfrag,
> > > > +                   int size)
> > > > +{
> > > > +    skb_frag_t *frag;
> > > > +
> > > > +    frag = &record->frags[record->num_frags - 1];
> > > > +    if (frag->page.p == pfrag->page &&
> > > > +        frag->page_offset + frag->size == pfrag->offset) {
> > > > +        frag->size += size;
> > > > +    } else {
> > > > +        ++frag;
> > > > +        frag->page.p = pfrag->page;
> > > > +        frag->page_offset = pfrag->offset;
> > > > +        frag->size = size;
> > > > +        ++record->num_frags;
> > > > +        get_page(pfrag->page);
> > > > +    }
> > > > +
> > > > +    pfrag->offset += size;
> > > > +    record->len += size;
> > > > +}
> > > > +
> > > > +static inline int tls_push_record(struct sock *sk,
> > > > +                  struct tls_context *ctx,
> > > > +                  struct tls_offload_context *offload_ctx,
> > > > +                  struct tls_record_info *record,
> > > > +                  struct page_frag *pfrag,
> > > > +                  int flags,
> > > > +                  unsigned char record_type)
> > > > +{
> > > > +    skb_frag_t *frag;
> > > > +    struct tcp_sock *tp = tcp_sk(sk);
> > > > +    struct page_frag fallback_frag;
> > > > +    struct page_frag  *tag_pfrag = pfrag;
> > > > +    int i;
> > > > +
> > > > +    /* fill prepand */
> > > > +    frag = &record->frags[0];
> > > > +    tls_fill_prepend(ctx,
> > > > +             skb_frag_address(frag),
> > > > +             record->len - ctx->prepend_size,
> > > > +             record_type);
> > > > +
> > > > +    if (unlikely(!skb_page_frag_refill(ctx->tag_size, pfrag,
> > > > GFP_KERNEL))) {
> > > > +        /* HW doesn't care about the data in the tag
> > > > +         * so in case pfrag has no room
> > > > +         * for a tag and we can't allocate a new pfrag
> > > > +         * just use the page in the first frag
> > > > +         * rather then write a complicated fall back code.
> > > > +         */
> > > > +        tag_pfrag = &fallback_frag;
> > > > +        tag_pfrag->page = skb_frag_page(frag);
> > > > +        tag_pfrag->offset = 0;
> > > > +    }
> > > > +
> > > > +    tls_append_frag(record, tag_pfrag, ctx->tag_size);
> > > > +    record->end_seq = tp->write_seq + record->len;
> > > > +    spin_lock_irq(&offload_ctx->lock);
> > > > +    list_add_tail(&record->list, &offload_ctx->records_list);
> > > > +    spin_unlock_irq(&offload_ctx->lock);
> > > > +    offload_ctx->open_record = NULL;
> > > > +    set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
> > > > +    tls_advance_record_sn(sk, ctx);
> > > > +
> > > > +    for (i = 0; i < record->num_frags; i++) {
> > > > +        frag = &record->frags[i];
> > > > +        sg_unmark_end(&offload_ctx->sg_tx_data[i]);
> > > > +        sg_set_page(&offload_ctx->sg_tx_data[i],
> > > > skb_frag_page(frag),
> > > > +                frag->size, frag->page_offset);
> > > > +        sk_mem_charge(sk, frag->size);
> > > > +        get_page(skb_frag_page(frag));
> > > > +    }
> > > > +    sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags -
> > > > 1]);
> > > > +
> > > > +    /* all ready, send */
> > > > +    return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0,
> > > > flags);
> > > > +}
> > > > +
> > > > +static inline int tls_create_new_record(struct
> > > > tls_offload_context *offload_ctx,
> > > > +                    struct page_frag *pfrag,
> > > > +                    size_t prepend_size)
> > > > +{
> > > > +    skb_frag_t *frag;
> > > > +    struct tls_record_info *record;
> > > > +
> > > > +    record = kmalloc(sizeof(*record), GFP_KERNEL);
> > > > +    if (!record)
> > > > +        return -ENOMEM;
> > > > +
> > > > +    frag = &record->frags[0];
> > > > +    __skb_frag_set_page(frag, pfrag->page);
> > > > +    frag->page_offset = pfrag->offset;
> > > > +    skb_frag_size_set(frag, prepend_size);
> > > > +
> > > > +    get_page(pfrag->page);
> > > > +    pfrag->offset += prepend_size;
> > > > +
> > > > +    record->num_frags = 1;
> > > > +    record->len = prepend_size;
> > > > +    offload_ctx->open_record = record;
> > > > +    return 0;
> > > > +}
> > > > +
> > > > +static inline int tls_do_allocation(struct sock *sk,
> > > > +                    struct tls_offload_context *offload_ctx,
> > > > +                    struct page_frag *pfrag,
> > > > +                    size_t prepend_size)
> > > > +{
> > > > +    int ret;
> > > > +
> > > > +    if (!offload_ctx->open_record) {
> > > > +        if (unlikely(!skb_page_frag_refill(prepend_size,
> > > > pfrag,
> > > > +                           sk->sk_allocation))) {
> > > > +            sk->sk_prot->enter_memory_pressure(sk);
> > > > +            sk_stream_moderate_sndbuf(sk);
> > > > +            return -ENOMEM;
> > > > +        }
> > > > +
> > > > +        ret = tls_create_new_record(offload_ctx, pfrag,
> > > > prepend_size);
> > > > +        if (ret)
> > > > +            return ret;
> > > > +
> > > > +        if (pfrag->size > pfrag->offset)
> > > > +            return 0;
> > > > +    }
> > > > +
> > > > +    if (!sk_page_frag_refill(sk, pfrag))
> > > > +        return -ENOMEM;
> > > > +
> > > > +    return 0;
> > > > +}
> > > > +
> > > > +static int tls_push_data(struct sock *sk,
> > > > +             struct iov_iter *msg_iter,
> > > > +             size_t size, int flags,
> > > > +             unsigned char record_type)
> > > > +{
> > > > +    struct tls_context *tls_ctx = tls_get_ctx(sk);
> > > > +    struct tls_offload_context *ctx =
> > > > tls_offload_ctx(tls_ctx);
> > > > +    struct tls_record_info *record = ctx->open_record;
> > > > +    struct page_frag *pfrag;
> > > > +    int copy, rc = 0;
> > > > +    size_t orig_size = size;
> > > > +    u32 max_open_record_len;
> > > > +    long timeo;
> > > > +    int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
> > > > +    int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
> > > > +    bool done = false;
> > > > +
> > > > +    if (flags &
> > > > +        ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
> > > > MSG_SENDPAGE_NOTLAST))
> > > > +        return -ENOTSUPP;
> > > > +
> > > > +    if (sk->sk_err)
> > > > +        return -sk->sk_err;
> > > > +
> > > > +    timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
> > > > +    rc = tls_complete_pending_work(sk, tls_ctx, flags,
> > > > &timeo);
> > > > +    if (rc < 0)
> > > > +        return rc;
> > > > +
> > > > +    pfrag = sk_page_frag(sk);
> > > > +
> > > > +    /* KTLS_TLS_HEADER_SIZE is not counted as part of the TLS
> > > > record, and
> > > > +     * we need to leave room for an authentication tag.
> > > > +     */
> > > > +    max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
> > > > +                  tls_ctx->prepend_size;
> > > > +    do {
> > > > +        if (tls_do_allocation(sk, ctx, pfrag,
> > > > +                      tls_ctx->prepend_size)) {
> > > > +            rc = sk_stream_wait_memory(sk, &timeo);
> > > > +            if (!rc)
> > > > +                continue;
> > > > +
> > > > +            record = ctx->open_record;
> > > > +            if (!record)
> > > > +                break;
> > > > +handle_error:
> > > > +            if (record_type != TLS_RECORD_TYPE_DATA) {
> > > > +                /* avoid sending partial
> > > > +                 * record with type !=
> > > > +                 * application_data
> > > > +                 */
> > > > +                size = orig_size;
> > > > +                destroy_record(record);
> > > > +                ctx->open_record = NULL;
> > > > +            } else if (record->len > tls_ctx->prepend_size) {
> > > > +                goto last_record;
> > > > +            }
> > > > +
> > > > +            break;
> > > > +        }
> > > > +
> > > > +        record = ctx->open_record;
> > > > +        copy = min_t(size_t, size, (pfrag->size - pfrag-
> > > > >offset));
> > > > +        copy = min_t(size_t, copy, (max_open_record_len -
> > > > record->len));
> > > > +
> > > > +        if (copy_from_iter_nocache(page_address(pfrag->page) +
> > > > +                           pfrag->offset,
> > > > +                       copy, msg_iter) != copy) {
> > > > +            rc = -EFAULT;
> > > > +            goto handle_error;
> > > > +        }
> > > > +        tls_append_frag(record, pfrag, copy);
> > > > +
> > > > +        size -= copy;
> > > > +        if (!size) {
> > > > +last_record:
> > > > +            tls_push_record_flags = flags;
> > > > +            if (more) {
> > > > +                tls_ctx->pending_open_record_frags =
> > > > +                        record->num_frags;
> > > > +                break;
> > > > +            }
> > > > +
> > > > +            done = true;
> > > > +        }
> > > > +
> > > > +        if ((done) || record->len >= max_open_record_len ||
> > > > +            (record->num_frags >= MAX_SKB_FRAGS - 1)) {
> > > > +            rc = tls_push_record(sk,
> > > > +                         tls_ctx,
> > > > +                         ctx,
> > > > +                         record,
> > > > +                         pfrag,
> > > > +                         tls_push_record_flags,
> > > > +                         record_type);
> > > > +            if (rc < 0)
> > > > +                break;
> > > > +        }
> > > > +    } while (!done);
> > > > +
> > > > +    if (orig_size - size > 0)
> > > > +        rc = orig_size - size;
> > > > +
> > > > +    return rc;
> > > > +}
> > > > +
> > > > +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg,
> > > > size_t size)
> > > > +{
> > > > +    unsigned char record_type = TLS_RECORD_TYPE_DATA;
> > > > +    int rc = 0;
> > > > +
> > > > +    lock_sock(sk);
> > > > +
> > > > +    if (unlikely(msg->msg_controllen)) {
> > > > +        rc = tls_proccess_cmsg(sk, msg, &record_type);
> > > > +        if (rc)
> > > > +            goto out;
> > > > +    }
> > > > +
> > > > +    rc = tls_push_data(sk, &msg->msg_iter, size,
> > > > +               msg->msg_flags, record_type);
> > > > +
> > > > +out:
> > > > +    release_sock(sk);
> > > > +    return rc;
> > > > +}
> > > > +
> > > > +int tls_device_sendpage(struct sock *sk, struct page *page,
> > > > +            int offset, size_t size, int flags)
> > > > +{
> > > > +    struct iov_iter    msg_iter;
> > > > +    struct kvec iov;
> > > > +    char *kaddr = kmap(page);
> > > > +    int rc = 0;
> > > > +
> > > > +    if (flags & MSG_SENDPAGE_NOTLAST)
> > > > +        flags |= MSG_MORE;
> > > > +
> > > > +    lock_sock(sk);
> > > > +
> > > > +    if (flags & MSG_OOB) {
> > > > +        rc = -ENOTSUPP;
> > > > +        goto out;
> > > > +    }
> > > > +
> > > > +    iov.iov_base = kaddr + offset;
> > > > +    iov.iov_len = size;
> > > > +    iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1,
> > > > size);
> > > > +    rc = tls_push_data(sk, &msg_iter, size,
> > > > +               flags, TLS_RECORD_TYPE_DATA);
> > > > +    kunmap(page);
> > > > +
> > > > +out:
> > > > +    release_sock(sk);
> > > > +    return rc;
> > > > +}
> > > > +
> > > > +struct tls_record_info *tls_get_record(struct
> > > > tls_offload_context *context,
> > > > +                       u32 seq, u64 *p_record_sn)
> > > > +{
> > > > +    struct tls_record_info *info;
> > > > +    u64 record_sn = context->hint_record_sn;
> > > > +
> > > > +    info = context->retransmit_hint;
> > > > +    if (!info ||
> > > > +        before(seq, info->end_seq - info->len)) {
> > > > +        /* if retransmit_hint is irrelevant start
> > > > +         * from the begging of the list
> > > > +         */
> > > > +        info = list_first_entry(&context->records_list,
> > > > +                    struct tls_record_info, list);
> > > > +        record_sn = context->unacked_record_sn;
> > > > +    }
> > > > +
> > > > +    list_for_each_entry_from(info, &context->records_list,
> > > > list) {
> > > > +        if (before(seq, info->end_seq)) {
> > > > +            if (!context->retransmit_hint ||
> > > > +                after(info->end_seq,
> > > > +                  context->retransmit_hint->end_seq)) {
> > > > +                context->hint_record_sn = record_sn;
> > > > +                context->retransmit_hint = info;
> > > > +            }
> > > > +            *p_record_sn = record_sn;
> > > > +            return info;
> > > > +        }
> > > > +        record_sn++;
> > > > +    }
> > > > +
> > > > +    return NULL;
> > > > +}
> > > > +EXPORT_SYMBOL(tls_get_record);
> > > > +
> > > > +static int tls_device_push_pending_record(struct sock *sk, int
> > > > flags)
> > > > +{
> > > > +    struct iov_iter    msg_iter;
> > > > +
> > > > +    iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
> > > > +    return tls_push_data(sk, &msg_iter, 0, flags,
> > > > TLS_RECORD_TYPE_DATA);
> > > > +}
> > > > +
> > > > +int tls_set_device_offload(struct sock *sk, struct tls_context
> > > > *ctx)
> > > > +{
> > > > +    u16 nonece_size, tag_size, iv_size, rec_seq_size;
> > > > +    struct tls_record_info *start_marker_record;
> > > > +    struct tls_offload_context *offload_ctx;
> > > > +    struct tls_crypto_info *crypto_info;
> > > > +    struct net_device *netdev;
> > > > +    char *iv, *rec_seq;
> > > > +    struct sk_buff *skb;
> > > > +    int rc = -EINVAL;
> > > > +    __be64 rcd_sn;
> > > > +
> > > > +    if (!ctx)
> > > > +        goto out;
> > > > +
> > > > +    if (ctx->priv_ctx) {
> > > > +        rc = -EEXIST;
> > > > +        goto out;
> > > > +    }
> > > > +
> > > > +    /* We support starting offload on multiple sockets
> > > > +     * concurrently, So we only need a read lock here.
> > > > +     */
> > > > +    percpu_down_read(&device_offload_lock);
> > > > +    netdev = get_netdev_for_sock(sk);
> > > > +    if (!netdev) {
> > > > +        pr_err_ratelimited("%s: netdev not found\n",
> > > > __func__);
> > > > +        rc = -EINVAL;
> > > > +        goto release_lock;
> > > > +    }
> > > > +
> > > > +    if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
> > > > +        rc = -ENOTSUPP;
> > > > +        goto release_netdev;
> > > > +    }
> > > > +
> > > > +    /* Avoid offloading if the device is down
> > > > +     * We don't want to offload new flows after
> > > > +     * the NETDEV_DOWN event
> > > > +     */
> > > > +    if (!(netdev->flags & IFF_UP)) {
> > > > +        rc = -EINVAL;
> > > > +        goto release_lock;
> > > > +    }
> > > > +
> > > > +    crypto_info = &ctx->crypto_send;
> > > > +    switch (crypto_info->cipher_type) {
> > > > +    case TLS_CIPHER_AES_GCM_128: {
> > > > +        nonece_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
> > > > +        tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
> > > > +        iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
> > > > +        iv = ((struct tls12_crypto_info_aes_gcm_128
> > > > *)crypto_info)->iv;
> > > > +        rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
> > > > +        rec_seq =
> > > > +         ((struct tls12_crypto_info_aes_gcm_128
> > > > *)crypto_info)->rec_seq;
> > > > +        break;
> > > > +    }
> > > > +    default:
> > > > +        rc = -EINVAL;
> > > > +        goto release_netdev;
> > > > +    }
> > > > +
> > > > +    start_marker_record =
> > > > kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
> > > 
> > > Can we memory allocations and simple memory initializations
> > > ouside the global rwsem?
> > > 
> > 
> > Sure, we can move all memory allocations outside the lock.
> > 
> > > > +    if (!start_marker_record) {
> > > > +        rc = -ENOMEM;
> > > > +        goto release_netdev;
> > > > +    }
> > > > +
> > > > +    offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE,
> > > > GFP_KERNEL);
> > > > +    if (!offload_ctx)
> > > > +        goto free_marker_record;
> > > > +
> > > > +    ctx->priv_ctx = offload_ctx;
> > > > +    rc = attach_sock_to_netdev(sk, netdev, ctx);
> > > > +    if (rc)
> > > > +        goto free_offload_context;
> > > > +
> > > > +    ctx->netdev = netdev;
> > > > +    ctx->prepend_size = TLS_HEADER_SIZE + nonece_size;
> > > > +    ctx->tag_size = tag_size;
> > > > +    ctx->iv_size = iv_size;
> > > > +    ctx->iv = kmalloc(iv_size +
> > > > TLS_CIPHER_AES_GCM_128_SALT_SIZE,
> > > > +              GFP_KERNEL);
> > > > +    if (!ctx->iv) {
> > > > +        rc = -ENOMEM;
> > > > +        goto detach_sock;
> > > > +    }
> > > > +
> > > > +    memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv,
> > > > iv_size);
> > > > +
> > > > +    ctx->rec_seq_size = rec_seq_size;
> > > > +    ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
> > > > +    if (!ctx->rec_seq) {
> > > > +        rc = -ENOMEM;
> > > > +        goto free_iv;
> > > > +    }
> > > > +    memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
> > > > +
> > > > +    /* start at rec_seq - 1 to account for the start marker
> > > > record */
> > > > +    memcpy(&rcd_sn, ctx->rec_seq, sizeof(rcd_sn));
> > > > +    offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
> > > > +
> > > > +    rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
> > > > +    if (rc)
> > > > +        goto free_rec_seq;
> > > > +
> > > > +    start_marker_record->end_seq = tcp_sk(sk)->write_seq;
> > > > +    start_marker_record->len = 0;
> > > > +    start_marker_record->num_frags = 0;
> > > > +
> > > > +    INIT_LIST_HEAD(&offload_ctx->records_list);
> > > > +    list_add_tail(&start_marker_record->list, &offload_ctx-
> > > > >records_list);
> > > > +    spin_lock_init(&offload_ctx->lock);
> > > > +
> > > > +    inet_csk(sk)->icsk_clean_acked = &tls_icsk_clean_acked;
> > > > +    ctx->push_pending_record = tls_device_push_pending_record;
> > > > +    offload_ctx->sk_destruct = sk->sk_destruct;
> > > > +
> > > > +    /* TLS offload is greatly simplified if we don't send
> > > > +     * SKBs where only part of the payload needs to be
> > > > encrypted.
> > > > +     * So mark the last skb in the write queue as end of
> > > > record.
> > > > +     */
> > > > +    skb = tcp_write_queue_tail(sk);
> > > > +    if (skb)
> > > > +        TCP_SKB_CB(skb)->eor = 1;
> > > > +
> > > > +    refcount_set(&ctx->refcount, 1);
> > > > +    spin_lock_irq(&tls_device_lock);
> > > > +    list_add_tail(&ctx->list, &tls_device_list);
> > > > +    spin_unlock_irq(&tls_device_lock);
> > > > +
> > > > +    /* following this assignment tls_is_sk_tx_device_offloaded
> > > > +     * will return true and the context might be accessed
> > > > +     * by the netdev's xmit function.
> > > > +     */
> > > > +    smp_store_release(&sk->sk_destruct,
> > > > +              &tls_device_sk_destruct);
> > > > +    goto release_lock;
> > > > +
> > > > +free_rec_seq:
> > > > +    kfree(ctx->rec_seq);
> > > > +free_iv:
> > > > +    kfree(ctx->iv);
> > > > +detach_sock:
> > > > +    netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
> > > > TLS_OFFLOAD_CTX_DIR_TX);
> > > > +free_offload_context:
> > > > +    kfree(offload_ctx);
> > > > +    ctx->priv_ctx = NULL;
> > > > +free_marker_record:
> > > > +    kfree(start_marker_record);
> > > > +release_netdev:
> > > > +    dev_put(netdev);
> > > > +release_lock:
> > > > +    percpu_up_read(&device_offload_lock);
> > > > +out:
> > > > +    return rc;
> > > > +}
> > > > +
> > > > +static int tls_device_register(struct net_device *dev)
> > > > +{
> > > > +    if ((dev->features & NETIF_F_HW_TLS_TX) && !dev-
> > > > >tlsdev_ops)
> > > > +        return NOTIFY_BAD;
> > > > +
> > > > +    return NOTIFY_DONE;
> > > > +}
> > > 
> > > This function is the same as tls_device_feat_change(). Can't we
> > > merge
> > > them together and avoid duplicating of code?
> > > 
> > 
> > Sure.
> > 
> > > > +static int tls_device_unregister(struct net_device *dev)
> > > > +{
> > > > +    return NOTIFY_DONE;
> > > > +}
> > > 
> > > This function does nothing, and next patches do not change it.
> > > Can't we remove it since so?
> > > 
> > 
> > Sure.
> > 
> > > > +static int tls_device_feat_change(struct net_device *dev)
> > > > +{
> > > > +    if ((dev->features & NETIF_F_HW_TLS_TX) && !dev-
> > > > >tlsdev_ops)
> > > > +        return NOTIFY_BAD;
> > > > +
> > > > +    return NOTIFY_DONE;
> > > > +}
> > > > +
> > > > +static int tls_device_down(struct net_device *netdev)
> > > > +{
> > > > +    struct tls_context *ctx, *tmp;
> > > > +    struct list_head list;
> > > > +    unsigned long flags;
> > > > +
> > > > +    if (!(netdev->features & NETIF_F_HW_TLS_TX))
> > > > +        return NOTIFY_DONE;
> > > 
> > > Can't we move this check in tls_dev_event() and use it for all
> > > types of events?
> > > Then we avoid duplicate code.
> > > 
> > 
> > No. Not all events require this check. Also, the result is
> > different for different events.
> 
> No. You always return NOTIFY_DONE, in case of !(netdev->features &
> NETIF_F_HW_TLS_TX).
> See below:
> 
> static int tls_check_dev_ops(struct net_device *dev) 
> {
> 	if (!dev->tlsdev_ops)
> 		return NOTIFY_BAD; 
> 
> 	return NOTIFY_DONE; 
> }
> 
> static int tls_device_down(struct net_device *netdev) 
> {
> 	struct tls_context *ctx, *tmp; 
> 	struct list_head list; 
> 	unsigned long flags; 
> 
> 	...
> 	return NOTIFY_DONE;
> }
> 
> static int tls_dev_event(struct notifier_block *this, unsigned long
> event, 
>         		 void *ptr) 
> { 
>         struct net_device *dev = netdev_notifier_info_to_dev(ptr); 
> 
> 	if (!(netdev->features & NETIF_F_HW_TLS_TX)) 
> 		return NOTIFY_DONE; 
>  
>         switch (event) { 
>         case NETDEV_REGISTER:
>         case NETDEV_FEAT_CHANGE: 
>         	return tls_check_dev_ops(dev); 
>  
>         case NETDEV_DOWN: 
>         	return tls_device_down(dev); 
>         } 
>         return NOTIFY_DONE; 
> } 
>  

Will fix in V2.

> > > > +
> > > > +    /* Request a write lock to block new offload attempts
> > > > +     */
> > > > +    percpu_down_write(&device_offload_lock);
> > > 
> > > What is the reason percpu_rwsem is chosen here? It looks like
> > > this primitive
> > > gives more advantages readers, then plain rwsem does. But it also
> > > gives
> > > disadvantages to writers. It would be good, unless
> > > tls_device_down() is called
> > > with rtnl_lock() held from netdevice notifier. But since
> > > netdevice notifier
> > > are called with rtnl_lock() held, percpu_rwsem will increase the
> > > time rtnl_lock()
> > > is locked.
> > 
> > We use the a rwsem to allow multiple (readers) invocations of
> > tls_set_device_offload, which is triggered by the user (persumably)
> > during the TLS handshake. This might be considered a fast-path.
> > 
> > However, we must block all calls to tls_set_device_offload while we
> > are processing NETDEV_DOWN events (writer).
> > 
> > As you've mentioned, the percpu rwsem is more efficient for
> > readers, especially on NUMA systems, where cache-line bouncing
> > occurs during reader acquire and reduces performance.
> 
> Hm, and who are the readers? It's used from do_tls_setsockopt_tx(),
> but it doesn't
> seem to be performance critical. Who else?
> 

it is performance critical since it is done in the socket handshake
phase, anyway I tend to agree with you that per cpu rwsem is an
overkill, will change it to regular rwsem in V2.

> > > 
> > > Can't we use plain rwsem here instead?
> > > 
> > 
> > Its a performance tradeoff. I'm not certain that the percpu rwsem
> > write side acquire is significantly worse than using the global
> > rwsem.
> > 
> > For now, while all of this is experimental, can we agree to focus
> > on the performance of readers? We can change it later if it becomes
> > a problem.
> 
> Same as above.
>  
> > > > +
> > > > +    spin_lock_irqsave(&tls_device_lock, flags);
> > > > +    INIT_LIST_HEAD(&list);
> > > 
> > > This may go outside the global spinlock.
> > > 
> > 
> > Sure.
> > 
> > > > +    list_for_each_entry_safe(ctx, tmp, &tls_device_list, list)
> > > > {
> > > > +        if (ctx->netdev != netdev ||
> > > > +            !refcount_inc_not_zero(&ctx->refcount))
> > > > +            continue;
> > > > +
> > > > +        list_move(&ctx->list, &list);
> > > > +    }
> > > > +    spin_unlock_irqrestore(&tls_device_lock, flags);
> > > > +
> > > > +    list_for_each_entry_safe(ctx, tmp, &list, list)    {
> > > > +        netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
> > > > +                        TLS_OFFLOAD_CTX_DIR_TX);
> > > > +        ctx->netdev = NULL;
> > > > +        dev_put(netdev);
> > > > +        list_del_init(&ctx->list);
> > > > +
> > > > +        if (refcount_dec_and_test(&ctx->refcount))
> > > > +            tls_device_free_ctx(ctx);
> > > > +    }
> > > > +
> > > > +    percpu_up_write(&device_offload_lock);
> > > > +
> > > > +    flush_work(&tls_device_gc_work);
> > > > +
> > > > +    return NOTIFY_DONE;
> > > > +}
> > > > +
> > > > +static int tls_dev_event(struct notifier_block *this, unsigned
> > > > long event,
> > > > +             void *ptr)
> > > > +{
> > > > +    struct net_device *dev = netdev_notifier_info_to_dev(ptr);
> > > > +
> > > > +    switch (event) {
> > > > +    case NETDEV_REGISTER:
> > > > +        return tls_device_register(dev);
> > > > +
> > > > +    case NETDEV_UNREGISTER:
> > > > +        return tls_device_unregister(dev);
> > > > +
> > > > +    case NETDEV_FEAT_CHANGE:
> > > > +        return tls_device_feat_change(dev);
> > > > +
> > > > +    case NETDEV_DOWN:
> > > > +        return tls_device_down(dev);
> > > > +    }
> > > > +    return NOTIFY_DONE;
> > > > +}
> > > > +
> > > > +static struct notifier_block tls_dev_notifier = {
> > > > +    .notifier_call    = tls_dev_event,
> > > > +};
> > > > +
> > > > +void __init tls_device_init(void)
> > > > +{
> > > > +    register_netdevice_notifier(&tls_dev_notifier);
> > > > +}
> > > > +
> > > > +void __exit tls_device_cleanup(void)
> > > > +{
> > > > +    unregister_netdevice_notifier(&tls_dev_notifier);
> > > > +    flush_work(&tls_device_gc_work);
> > > > +}
> > > > diff --git a/net/tls/tls_device_fallback.c
> > > > b/net/tls/tls_device_fallback.c
> > > > new file mode 100644
> > > > index 000000000000..14d31a36885c
> > > > --- /dev/null
> > > > +++ b/net/tls/tls_device_fallback.c
> > > > @@ -0,0 +1,419 @@
> > > > +/* Copyright (c) 2018, Mellanox Technologies All rights
> > > > reserved.
> > > > + *
> > > > + *     Redistribution and use in source and binary forms, with
> > > > or
> > > > + *     without modification, are permitted provided that the
> > > > following
> > > > + *     conditions are met:
> > > > + *
> > > > + *      - Redistributions of source code must retain the above
> > > > + *        copyright notice, this list of conditions and the
> > > > following
> > > > + *        disclaimer.
> > > > + *
> > > > + *      - Redistributions in binary form must reproduce the
> > > > above
> > > > + *        copyright notice, this list of conditions and the
> > > > following
> > > > + *        disclaimer in the documentation and/or other
> > > > materials
> > > > + *        provided with the distribution.
> > > > + *
> > > > + *      - Neither the name of the Mellanox Technologies nor
> > > > the
> > > > + *        names of its contributors may be used to endorse or
> > > > promote
> > > > + *        products derived from this software without specific
> > > > prior written
> > > > + *        permission.
> > > > + *
> > > > + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
> > > > CONTRIBUTORS "AS IS"
> > > > + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
> > > > LIMITED TO,
> > > > + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
> > > > + * A PARTICULAR PURPOSE ARE DISCLAIMED.
> > > > + * IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
> > > > LIABLE FOR
> > > > + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
> > > > CONSEQUENTIAL
> > > > + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
> > > > SUBSTITUTE GOODS OR
> > > > + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
> > > > INTERRUPTION)
> > > > + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
> > > > CONTRACT,
> > > > + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
> > > > OTHERWISE) ARISING
> > > > + * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
> > > > OF THE
> > > > + * POSSIBILITY OF SUCH DAMAGE
> > > > + */
> > > > +
> > > > +#include <net/tls.h>
> > > > +#include <crypto/aead.h>
> > > > +#include <crypto/scatterwalk.h>
> > > > +#include <net/ip6_checksum.h>
> > > > +
> > > > +static void chain_to_walk(struct scatterlist *sg, struct
> > > > scatter_walk *walk)
> > > > +{
> > > > +    struct scatterlist *src = walk->sg;
> > > > +    int diff = walk->offset - src->offset;
> > > > +
> > > > +    sg_set_page(sg, sg_page(src),
> > > > +            src->length - diff, walk->offset);
> > > > +
> > > > +    scatterwalk_crypto_chain(sg, sg_next(src), 0, 2);
> > > > +}
> > > > +
> > > > +static int tls_enc_record(struct aead_request *aead_req,
> > > > +              struct crypto_aead *aead, char *aad, char *iv,
> > > > +              __be64 rcd_sn, struct scatter_walk *in,
> > > > +              struct scatter_walk *out, int *in_len)
> > > > +{
> > > > +    struct scatterlist sg_in[3];
> > > > +    struct scatterlist sg_out[3];
> > > > +    unsigned char buf[TLS_HEADER_SIZE +
> > > > TLS_CIPHER_AES_GCM_128_IV_SIZE];
> > > > +    u16 len;
> > > > +    int rc;
> > > > +
> > > > +    len = min_t(int, *in_len, ARRAY_SIZE(buf));
> > > > +
> > > > +    scatterwalk_copychunks(buf, in, len, 0);
> > > > +    scatterwalk_copychunks(buf, out, len, 1);
> > > > +
> > > > +    *in_len -= len;
> > > > +    if (!*in_len)
> > > > +        return 0;
> > > > +
> > > > +    scatterwalk_pagedone(in, 0, 1);
> > > > +    scatterwalk_pagedone(out, 1, 1);
> > > > +
> > > > +    len = buf[4] | (buf[3] << 8);
> > > > +    len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
> > > > +
> > > > +    tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
> > > > +             (char *)&rcd_sn, sizeof(rcd_sn), buf[0]);
> > > > +
> > > > +    memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf +
> > > > TLS_HEADER_SIZE,
> > > > +           TLS_CIPHER_AES_GCM_128_IV_SIZE);
> > > > +
> > > > +    sg_init_table(sg_in, ARRAY_SIZE(sg_in));
> > > > +    sg_init_table(sg_out, ARRAY_SIZE(sg_out));
> > > > +    sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
> > > > +    sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
> > > > +    chain_to_walk(sg_in + 1, in);
> > > > +    chain_to_walk(sg_out + 1, out);
> > > > +
> > > > +    *in_len -= len;
> > > > +    if (*in_len < 0) {
> > > > +        *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
> > > > +        if (*in_len < 0)
> > > > +        /* the input buffer doesn't contain the entire record.
> > > > +         * trim len accordingly. The resulting authentication
> > > > tag
> > > > +         * will contain garbage. but we don't care as we won't
> > > > +         * include any of it in the output skb
> > > > +         * Note that we assume the output buffer length
> > > > +         * is larger then input buffer length + tag size
> > > > +         */
> > > > +            len += *in_len;
> > > > +
> > > > +        *in_len = 0;
> > > > +    }
> > > > +
> > > > +    if (*in_len) {
> > > > +        scatterwalk_copychunks(NULL, in, len, 2);
> > > > +        scatterwalk_pagedone(in, 0, 1);
> > > > +        scatterwalk_copychunks(NULL, out, len, 2);
> > > > +        scatterwalk_pagedone(out, 1, 1);
> > > > +    }
> > > > +
> > > > +    len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
> > > > +    aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
> > > > +
> > > > +    rc = crypto_aead_encrypt(aead_req);
> > > > +
> > > > +    return rc;
> > > > +}
> > > > +
> > > > +static void tls_init_aead_request(struct aead_request
> > > > *aead_req,
> > > > +                  struct crypto_aead *aead)
> > > > +{
> > > > +    aead_request_set_tfm(aead_req, aead);
> > > > +    aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
> > > > +}
> > > > +
> > > > +static struct aead_request *tls_alloc_aead_request(struct
> > > > crypto_aead *aead,
> > > > +                           gfp_t flags)
> > > > +{
> > > > +    unsigned int req_size = sizeof(struct aead_request) +
> > > > +        crypto_aead_reqsize(aead);
> > > > +    struct aead_request *aead_req;
> > > > +
> > > > +    aead_req = kzalloc(req_size, flags);
> > > > +    if (!aead_req)
> > > > +        return NULL;
> > > > +
> > > > +    tls_init_aead_request(aead_req, aead);
> > > > +    return aead_req;
> > > > +}
> > > > +
> > > > +static int tls_enc_records(struct aead_request *aead_req,
> > > > +               struct crypto_aead *aead, struct scatterlist
> > > > *sg_in,
> > > > +               struct scatterlist *sg_out, char *aad, char
> > > > *iv,
> > > > +               u64 rcd_sn, int len)
> > > > +{
> > > > +    struct scatter_walk in;
> > > > +    struct scatter_walk out;
> > > > +    int rc;
> > > > +
> > > > +    scatterwalk_start(&in, sg_in);
> > > > +    scatterwalk_start(&out, sg_out);
> > > > +
> > > > +    do {
> > > > +        rc = tls_enc_record(aead_req, aead, aad, iv,
> > > > +                    cpu_to_be64(rcd_sn), &in, &out, &len);
> > > > +        rcd_sn++;
> > > > +
> > > > +    } while (rc == 0 && len);
> > > > +
> > > > +    scatterwalk_done(&in, 0, 0);
> > > > +    scatterwalk_done(&out, 1, 0);
> > > > +
> > > > +    return rc;
> > > > +}
> > > > +
> > > > +static inline void update_chksum(struct sk_buff *skb, int
> > > > headln)
> > > > +{
> > > > +    /* Can't use icsk->icsk_af_ops->send_check here because
> > > > the ip addresses
> > > > +     * might have been changed by NAT.
> > > > +     */
> > > > +
> > > > +    const struct ipv6hdr *ipv6h;
> > > > +    const struct iphdr *iph;
> > > > +    struct tcphdr *th = tcp_hdr(skb);
> > > > +    int datalen = skb->len - headln;
> > > > +
> > > > +    /* We only changed the payload so if we are using partial
> > > > we don't
> > > > +     * need to update anything.
> > > > +     */
> > > > +    if (likely(skb->ip_summed == CHECKSUM_PARTIAL))
> > > > +        return;
> > > > +
> > > > +    skb->ip_summed = CHECKSUM_PARTIAL;
> > > > +    skb->csum_start = skb_transport_header(skb) - skb->head;
> > > > +    skb->csum_offset = offsetof(struct tcphdr, check);
> > > > +
> > > > +    if (skb->sk->sk_family == AF_INET6) {
> > > > +        ipv6h = ipv6_hdr(skb);
> > > > +        th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h-
> > > > >daddr,
> > > > +                         datalen, IPPROTO_TCP, 0);
> > > > +    } else {
> > > > +        iph = ip_hdr(skb);
> > > > +        th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, 
> > > > datalen,
> > > > +                           IPPROTO_TCP, 0);
> > > > +    }
> > > > +}
> > > > +
> > > > +static void complete_skb(struct sk_buff *nskb, struct sk_buff
> > > > *skb, int headln)
> > > > +{
> > > > +    skb_copy_header(nskb, skb);
> > > > +
> > > > +    skb_put(nskb, skb->len);
> > > > +    memcpy(nskb->data, skb->data, headln);
> > > > +    update_chksum(nskb, headln);
> > > > +
> > > > +    nskb->destructor = skb->destructor;
> > > > +    nskb->sk = skb->sk;
> > > > +    skb->destructor = NULL;
> > > > +    skb->sk = NULL;
> > > > +    refcount_add(nskb->truesize - skb->truesize,
> > > > +             &nskb->sk->sk_wmem_alloc);
> > > > +}
> > > > +
> > > > +/* This function may be called after the user socket is
> > > > already
> > > > + * closed so make sure we don't use anything freed during
> > > > + * tls_sk_proto_close here
> > > > + */
> > > > +static struct sk_buff *tls_sw_fallback(struct sock *sk, struct
> > > > sk_buff *skb)
> > > > +{
> > > > +    int tcp_header_size = tcp_hdrlen(skb);
> > > > +    int tcp_payload_offset = skb_transport_offset(skb) +
> > > > tcp_header_size;
> > > > +    int payload_len = skb->len - tcp_payload_offset;
> > > > +    struct tls_context *tls_ctx = tls_get_ctx(sk);
> > > > +    struct tls_offload_context *ctx =
> > > > tls_offload_ctx(tls_ctx);
> > > > +    int remaining, buf_len, resync_sgs, rc, i = 0;
> > > > +    void *buf, *dummy_buf, *iv, *aad;
> > > > +    struct scatterlist *sg_in;
> > > > +    struct scatterlist sg_out[3];
> > > > +    u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
> > > > +    struct aead_request *aead_req;
> > > > +    struct sk_buff *nskb = NULL;
> > > > +    struct tls_record_info *record;
> > > > +    unsigned long flags;
> > > > +    s32 sync_size;
> > > > +    u64 rcd_sn;
> > > > +
> > > > +    /* worst case is:
> > > > +     * MAX_SKB_FRAGS in tls_record_info
> > > > +     * MAX_SKB_FRAGS + 1 in SKB head an frags.
> > > > +     */
> > > > +    int sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1;
> > > > +
> > > > +    if (!payload_len)
> > > > +        return skb;
> > > > +
> > > > +    sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in),
> > > > GFP_ATOMIC);
> > > > +    if (!sg_in)
> > > > +        goto free_orig;
> > > > +
> > > > +    sg_init_table(sg_in, sg_in_max_elements);
> > > > +    sg_init_table(sg_out, ARRAY_SIZE(sg_out));
> > > > +
> > > > +    spin_lock_irqsave(&ctx->lock, flags);
> > > > +    record = tls_get_record(ctx, tcp_seq, &rcd_sn);
> > > > +    if (!record) {
> > > > +        spin_unlock_irqrestore(&ctx->lock, flags);
> > > > +        WARN(1, "Record not found for seq %u\n", tcp_seq);
> > > > +        goto free_sg;
> > > > +    }
> > > > +
> > > > +    sync_size = tcp_seq - tls_record_start_seq(record);
> > > > +    if (sync_size < 0) {
> > > > +        int is_start_marker =
> > > > tls_record_is_start_marker(record);
> > > > +
> > > > +        spin_unlock_irqrestore(&ctx->lock, flags);
> > > > +        if (!is_start_marker)
> > > > +        /* This should only occur if the relevant record was
> > > > +         * already acked. In that case it should be ok
> > > > +         * to drop the packet and avoid retransmission.
> > > > +         *
> > > > +         * There is a corner case where the packet contains
> > > > +         * both an acked and a non-acked record.
> > > > +         * We currently don't handle that case and rely
> > > > +         * on TCP to retranmit a packet that doesn't contain
> > > > +         * already acked payload.
> > > > +         */
> > > > +            goto free_orig;
> > > > +
> > > > +        if (payload_len > -sync_size) {
> > > > +            WARN(1, "Fallback of partially offloaded packets
> > > > is not supported\n");
> > > > +            goto free_sg;
> > > > +        } else {
> > > > +            return skb;
> > > > +        }
> > > > +    }
> > > > +
> > > > +    remaining = sync_size;
> > > > +    while (remaining > 0) {
> > > > +        skb_frag_t *frag = &record->frags[i];
> > > > +
> > > > +        __skb_frag_ref(frag);
> > > > +        sg_set_page(sg_in + i, skb_frag_page(frag),
> > > > +                skb_frag_size(frag), frag->page_offset);
> > > > +
> > > > +        remaining -= skb_frag_size(frag);
> > > > +
> > > > +        if (remaining < 0)
> > > > +            sg_in[i].length += remaining;
> > > > +
> > > > +        i++;
> > > > +    }
> > > > +    spin_unlock_irqrestore(&ctx->lock, flags);
> > > > +    resync_sgs = i;
> > > > +
> > > > +    aead_req = tls_alloc_aead_request(ctx->aead_send,
> > > > GFP_ATOMIC);
> > > > +    if (!aead_req)
> > > > +        goto put_sg;
> > > > +
> > > > +    buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
> > > > +          TLS_CIPHER_AES_GCM_128_IV_SIZE +
> > > > +          TLS_AAD_SPACE_SIZE +
> > > > +          sync_size +
> > > > +          tls_ctx->tag_size;
> > > > +    buf = kmalloc(buf_len, GFP_ATOMIC);
> > > > +    if (!buf)
> > > > +        goto free_req;
> > > > +
> > > > +    nskb = alloc_skb(skb_headroom(skb) + skb->len,
> > > > GFP_ATOMIC);
> > > > +    if (!nskb)
> > > > +        goto free_buf;
> > > > +
> > > > +    skb_reserve(nskb, skb_headroom(skb));
> > > > +
> > > > +    iv = buf;
> > > > +
> > > > +    memcpy(iv, tls_ctx->crypto_send_aes_gcm_128.salt,
> > > > +           TLS_CIPHER_AES_GCM_128_SALT_SIZE);
> > > > +    aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
> > > > +          TLS_CIPHER_AES_GCM_128_IV_SIZE;
> > > > +    dummy_buf = aad + TLS_AAD_SPACE_SIZE;
> > > > +
> > > > +    sg_set_buf(&sg_out[0], dummy_buf, sync_size);
> > > > +    sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset,
> > > > +           payload_len);
> > > > +    /* Add room for authentication tag produced by crypto */
> > > > +    dummy_buf += sync_size;
> > > > +    sg_set_buf(&sg_out[2], dummy_buf, tls_ctx->tag_size);
> > > > +    rc = skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset,
> > > > +              payload_len);
> > > > +    if (rc < 0)
> > > > +        goto free_nskb;
> > > > +
> > > > +    rc = tls_enc_records(aead_req, ctx->aead_send, sg_in,
> > > > sg_out, aad, iv,
> > > > +                 rcd_sn, sync_size + payload_len);
> > > > +    if (rc < 0)
> > > > +        goto free_nskb;
> > > > +
> > > > +    complete_skb(nskb, skb, tcp_payload_offset);
> > > > +
> > > > +    /* validate_xmit_skb_list assumes that if the skb wasn't
> > > > segmented
> > > > +     * nskb->prev will point to the skb itself
> > > > +     */
> > > > +    nskb->prev = nskb;
> > > > +free_buf:
> > > > +    kfree(buf);
> > > > +free_req:
> > > > +    kfree(aead_req);
> > > > +put_sg:
> > > > +    for (i = 0; i < resync_sgs; i++)
> > > > +        put_page(sg_page(&sg_in[i]));
> > > > +free_sg:
> > > > +    kfree(sg_in);
> > > > +free_orig:
> > > > +    kfree_skb(skb);
> > > > +    return nskb;
> > > > +
> > > > +free_nskb:
> > > > +    kfree_skb(nskb);
> > > > +    nskb = NULL;
> > > > +    goto free_buf;
> > > > +}
> > > > +
> > > > +static struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
> > > > +                         struct net_device *dev,
> > > > +                         struct sk_buff *skb)
> > > > +{
> > > > +    if (dev == tls_get_ctx(sk)->netdev)
> > > > +        return skb;
> > > > +
> > > > +    return tls_sw_fallback(sk, skb);
> > > > +}
> > > > +
> > > > +int tls_sw_fallback_init(struct sock *sk,
> > > > +             struct tls_offload_context *offload_ctx,
> > > > +             struct tls_crypto_info *crypto_info)
> > > > +{
> > > > +    int rc;
> > > > +    const u8 *key;
> > > > +
> > > > +    offload_ctx->aead_send =
> > > > +        crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
> > > > +    if (IS_ERR(offload_ctx->aead_send)) {
> > > > +        rc = PTR_ERR(offload_ctx->aead_send);
> > > > +        pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n",
> > > > rc);
> > > > +        offload_ctx->aead_send = NULL;
> > > > +        goto err_out;
> > > > +    }
> > > > +
> > > > +    key = ((struct tls12_crypto_info_aes_gcm_128
> > > > *)crypto_info)->key;
> > > > +
> > > > +    rc = crypto_aead_setkey(offload_ctx->aead_send, key,
> > > > +                TLS_CIPHER_AES_GCM_128_KEY_SIZE);
> > > > +    if (rc)
> > > > +        goto free_aead;
> > > > +
> > > > +    rc = crypto_aead_setauthsize(offload_ctx->aead_send,
> > > > +                     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
> > > > +    if (rc)
> > > > +        goto free_aead;
> > > > +
> > > > +    sk->sk_validate_xmit_skb = tls_validate_xmit_skb;
> > > > +    return 0;
> > > > +free_aead:
> > > > +    crypto_free_aead(offload_ctx->aead_send);
> > > > +err_out:
> > > > +    return rc;
> > > > +}
> > > > diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
> > > > index d824d548447e..e0dface33017 100644
> > > > --- a/net/tls/tls_main.c
> > > > +++ b/net/tls/tls_main.c
> > > > @@ -54,6 +54,9 @@ enum {
> > > >   enum {
> > > >       TLS_BASE_TX,
> > > >       TLS_SW_TX,
> > > > +#ifdef CONFIG_TLS_DEVICE
> > > > +    TLS_HW_TX,
> > > > +#endif
> > > >       TLS_NUM_CONFIG,
> > > >   };
> > > >   @@ -416,11 +419,19 @@ static int do_tls_setsockopt_tx(struct
> > > > sock *sk, char __user *optval,
> > > >           goto err_crypto_info;
> > > >       }
> > > >   -    /* currently SW is default, we will have ethtool in
> > > > future */
> > > > -    rc = tls_set_sw_offload(sk, ctx);
> > > > -    tx_conf = TLS_SW_TX;
> > > > -    if (rc)
> > > > -        goto err_crypto_info;
> > > > +#ifdef CONFIG_TLS_DEVICE
> > > > +    rc = tls_set_device_offload(sk, ctx);
> > > > +    tx_conf = TLS_HW_TX;
> > > > +    if (rc) {
> > > > +#else
> > > > +    {
> > > > +#endif
> > > > +        /* if HW offload fails fallback to SW */
> > > > +        rc = tls_set_sw_offload(sk, ctx);
> > > > +        tx_conf = TLS_SW_TX;
> > > > +        if (rc)
> > > > +            goto err_crypto_info;
> > > > +    }
> > > >         ctx->tx_conf = tx_conf;
> > > >       update_sk_prot(sk, ctx);
> > > > @@ -473,6 +484,12 @@ static void build_protos(struct proto
> > > > *prot, struct proto *base)
> > > >       prot[TLS_SW_TX] = prot[TLS_BASE_TX];
> > > >       prot[TLS_SW_TX].sendmsg        = tls_sw_sendmsg;
> > > >       prot[TLS_SW_TX].sendpage    = tls_sw_sendpage;
> > > > +
> > > > +#ifdef CONFIG_TLS_DEVICE
> > > > +    prot[TLS_HW_TX] = prot[TLS_SW_TX];
> > > > +    prot[TLS_HW_TX].sendmsg        = tls_device_sendmsg;
> > > > +    prot[TLS_HW_TX].sendpage    = tls_device_sendpage;
> > > > +#endif
> > > >   }
> > > >     static int tls_init(struct sock *sk)
> > > > @@ -531,6 +548,9 @@ static int __init tls_register(void)
> > > >   {
> > > >       build_protos(tls_prots[TLSV4], &tcp_prot);
> > > >   +#ifdef CONFIG_TLS_DEVICE
> > > > +    tls_device_init();
> > > > +#endif
> > > >       tcp_register_ulp(&tcp_tls_ulp_ops);
> > > >         return 0;
> > > > @@ -539,6 +559,9 @@ static int __init tls_register(void)
> > > >   static void __exit tls_unregister(void)
> > > >   {
> > > >       tcp_unregister_ulp(&tcp_tls_ulp_ops);
> > > > +#ifdef CONFIG_TLS_DEVICE
> > > > +    tls_device_cleanup();
> > > > +#endif
> > > >   }
> > > >     module_init(tls_register);
> 
> Thanks,
> Kirill
Boris Pismenny March 22, 2018, 12:38 p.m. UTC | #7
...
>>>
>>> Can't we move this check in tls_dev_event() and use it for all types of events?
>>> Then we avoid duplicate code.
>>>
>>
>> No. Not all events require this check. Also, the result is different for different events.
> 
> No. You always return NOTIFY_DONE, in case of !(netdev->features & NETIF_F_HW_TLS_TX).
> See below:
> 
> static int tls_check_dev_ops(struct net_device *dev)
> {
> 	if (!dev->tlsdev_ops)
> 		return NOTIFY_BAD;
> 
> 	return NOTIFY_DONE;
> }
> 
> static int tls_device_down(struct net_device *netdev)
> {
> 	struct tls_context *ctx, *tmp;
> 	struct list_head list;
> 	unsigned long flags;
> 
> 	...
> 	return NOTIFY_DONE;
> }
> 
> static int tls_dev_event(struct notifier_block *this, unsigned long event,
>          		 void *ptr)
> {
>          struct net_device *dev = netdev_notifier_info_to_dev(ptr);
> 
> 	if (!(netdev->features & NETIF_F_HW_TLS_TX))
> 		return NOTIFY_DONE;
>   
>          switch (event) {
>          case NETDEV_REGISTER:
>          case NETDEV_FEAT_CHANGE:
>          	return tls_check_dev_ops(dev);
>   
>          case NETDEV_DOWN:
>          	return tls_device_down(dev);
>          }
>          return NOTIFY_DONE;
> }
>  

Sure, will fix in V3.

>>>> +
>>>> +    /* Request a write lock to block new offload attempts
>>>> +     */
>>>> +    percpu_down_write(&device_offload_lock);
>>>
>>> What is the reason percpu_rwsem is chosen here? It looks like this primitive
>>> gives more advantages readers, then plain rwsem does. But it also gives
>>> disadvantages to writers. It would be good, unless tls_device_down() is called
>>> with rtnl_lock() held from netdevice notifier. But since netdevice notifier
>>> are called with rtnl_lock() held, percpu_rwsem will increase the time rtnl_lock()
>>> is locked.
>> We use the a rwsem to allow multiple (readers) invocations of tls_set_device_offload, which is triggered by the user (persumably) during the TLS handshake. This might be considered a fast-path.
>>
>> However, we must block all calls to tls_set_device_offload while we are processing NETDEV_DOWN events (writer).
>>
>> As you've mentioned, the percpu rwsem is more efficient for readers, especially on NUMA systems, where cache-line bouncing occurs during reader acquire and reduces performance.
> 
> Hm, and who are the readers? It's used from do_tls_setsockopt_tx(), but it doesn't
> seem to be performance critical. Who else?
> 

It depends on whether you consider the TLS handshake code as critical.
The readers are TCP connections processing the CCS message of the TLS 
handshake. They are providing key material to the kernel to start using 
Kernel TLS.


>>>
>>> Can't we use plain rwsem here instead?
>>>
>>
>> Its a performance tradeoff. I'm not certain that the percpu rwsem write side acquire is significantly worse than using the global rwsem.
>>
>> For now, while all of this is experimental, can we agree to focus on the performance of readers? We can change it later if it becomes a problem.
> 
> Same as above.
>   

Replaced with rwsem from V2.
Kirill Tkhai March 22, 2018, 1:03 p.m. UTC | #8
On 22.03.2018 15:38, Boris Pismenny wrote:
> ...
>>>>
>>>> Can't we move this check in tls_dev_event() and use it for all types of events?
>>>> Then we avoid duplicate code.
>>>>
>>>
>>> No. Not all events require this check. Also, the result is different for different events.
>>
>> No. You always return NOTIFY_DONE, in case of !(netdev->features & NETIF_F_HW_TLS_TX).
>> See below:
>>
>> static int tls_check_dev_ops(struct net_device *dev)
>> {
>>     if (!dev->tlsdev_ops)
>>         return NOTIFY_BAD;
>>
>>     return NOTIFY_DONE;
>> }
>>
>> static int tls_device_down(struct net_device *netdev)
>> {
>>     struct tls_context *ctx, *tmp;
>>     struct list_head list;
>>     unsigned long flags;
>>
>>     ...
>>     return NOTIFY_DONE;
>> }
>>
>> static int tls_dev_event(struct notifier_block *this, unsigned long event,
>>                   void *ptr)
>> {
>>          struct net_device *dev = netdev_notifier_info_to_dev(ptr);
>>
>>     if (!(netdev->features & NETIF_F_HW_TLS_TX))
>>         return NOTIFY_DONE;
>>            switch (event) {
>>          case NETDEV_REGISTER:
>>          case NETDEV_FEAT_CHANGE:
>>              return tls_check_dev_ops(dev);
>>            case NETDEV_DOWN:
>>              return tls_device_down(dev);
>>          }
>>          return NOTIFY_DONE;
>> }
>>  
> 
> Sure, will fix in V3.
> 
>>>>> +
>>>>> +    /* Request a write lock to block new offload attempts
>>>>> +     */
>>>>> +    percpu_down_write(&device_offload_lock);
>>>>
>>>> What is the reason percpu_rwsem is chosen here? It looks like this primitive
>>>> gives more advantages readers, then plain rwsem does. But it also gives
>>>> disadvantages to writers. It would be good, unless tls_device_down() is called
>>>> with rtnl_lock() held from netdevice notifier. But since netdevice notifier
>>>> are called with rtnl_lock() held, percpu_rwsem will increase the time rtnl_lock()
>>>> is locked.
>>> We use the a rwsem to allow multiple (readers) invocations of tls_set_device_offload, which is triggered by the user (persumably) during the TLS handshake. This might be considered a fast-path.
>>>
>>> However, we must block all calls to tls_set_device_offload while we are processing NETDEV_DOWN events (writer).
>>>
>>> As you've mentioned, the percpu rwsem is more efficient for readers, especially on NUMA systems, where cache-line bouncing occurs during reader acquire and reduces performance.
>>
>> Hm, and who are the readers? It's used from do_tls_setsockopt_tx(), but it doesn't
>> seem to be performance critical. Who else?
>>
> 
> It depends on whether you consider the TLS handshake code as critical.
> The readers are TCP connections processing the CCS message of the TLS handshake. They are providing key material to the kernel to start using Kernel TLS.

The thing is rtnl_lock() is critical for the rest of the system,
while TLS handshake is small subset of actions the system makes.

rtnl_lock() is used just almost everywhere, from netlink messages
to netdev ioctls.

Currently, you even just can't close raw socket without rtnl lock.
So, all of this is big reason to avoid doing rcu waitings under it.

Kirill

>>>>
>>>> Can't we use plain rwsem here instead?
>>>>
>>>
>>> Its a performance tradeoff. I'm not certain that the percpu rwsem write side acquire is significantly worse than using the global rwsem.
>>>
>>> For now, while all of this is experimental, can we agree to focus on the performance of readers? We can change it later if it becomes a problem.
>>
>> Same as above.
>>   
> 
> Replaced with rwsem from V2.
diff mbox series

Patch

diff --git a/include/net/tls.h b/include/net/tls.h
index 4913430ab807..ab98a6dc4929 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -77,6 +77,37 @@  struct tls_sw_context {
 	struct scatterlist sg_aead_out[2];
 };
 
+struct tls_record_info {
+	struct list_head list;
+	u32 end_seq;
+	int len;
+	int num_frags;
+	skb_frag_t frags[MAX_SKB_FRAGS];
+};
+
+struct tls_offload_context {
+	struct crypto_aead *aead_send;
+	spinlock_t lock;	/* protects records list */
+	struct list_head records_list;
+	struct tls_record_info *open_record;
+	struct tls_record_info *retransmit_hint;
+	u64 hint_record_sn;
+	u64 unacked_record_sn;
+
+	struct scatterlist sg_tx_data[MAX_SKB_FRAGS];
+	void (*sk_destruct)(struct sock *sk);
+	u8 driver_state[];
+	/* The TLS layer reserves room for driver specific state
+	 * Currently the belief is that there is not enough
+	 * driver specific state to justify another layer of indirection
+	 */
+#define TLS_DRIVER_STATE_SIZE (max_t(size_t, 8, sizeof(void *)))
+};
+
+#define TLS_OFFLOAD_CONTEXT_SIZE                                               \
+	(ALIGN(sizeof(struct tls_offload_context), sizeof(void *)) +           \
+	 TLS_DRIVER_STATE_SIZE)
+
 enum {
 	TLS_PENDING_CLOSED_RECORD
 };
@@ -87,6 +118,10 @@  struct tls_context {
 		struct tls12_crypto_info_aes_gcm_128 crypto_send_aes_gcm_128;
 	};
 
+	struct list_head list;
+	struct net_device *netdev;
+	refcount_t refcount;
+
 	void *priv_ctx;
 
 	u8 tx_conf:2;
@@ -131,9 +166,29 @@  int tls_sw_sendpage(struct sock *sk, struct page *page,
 void tls_sw_close(struct sock *sk, long timeout);
 void tls_sw_free_tx_resources(struct sock *sk);
 
-void tls_sk_destruct(struct sock *sk, struct tls_context *ctx);
-void tls_icsk_clean_acked(struct sock *sk);
+void tls_clear_device_offload(struct sock *sk, struct tls_context *ctx);
+int tls_set_device_offload(struct sock *sk, struct tls_context *ctx);
+int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
+int tls_device_sendpage(struct sock *sk, struct page *page,
+			int offset, size_t size, int flags);
+void tls_device_sk_destruct(struct sock *sk);
+void tls_device_init(void);
+void tls_device_cleanup(void);
 
+struct tls_record_info *tls_get_record(struct tls_offload_context *context,
+				       u32 seq, u64 *p_record_sn);
+
+static inline bool tls_record_is_start_marker(struct tls_record_info *rec)
+{
+	return rec->len == 0;
+}
+
+static inline u32 tls_record_start_seq(struct tls_record_info *rec)
+{
+	return rec->end_seq - rec->len;
+}
+
+void tls_sk_destruct(struct sock *sk, struct tls_context *ctx);
 int tls_push_sg(struct sock *sk, struct tls_context *ctx,
 		struct scatterlist *sg, u16 first_offset,
 		int flags);
@@ -170,6 +225,13 @@  static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx)
 	return tls_ctx->pending_open_record_frags;
 }
 
+static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk)
+{
+	return sk_fullsock(sk) &&
+	       /* matches smp_store_release in tls_set_device_offload */
+	       smp_load_acquire(&sk->sk_destruct) == &tls_device_sk_destruct;
+}
+
 static inline void tls_err_abort(struct sock *sk)
 {
 	sk->sk_err = EBADMSG;
@@ -257,4 +319,8 @@  static inline struct tls_offload_context *tls_offload_ctx(
 int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
 		      unsigned char *record_type);
 
+int tls_sw_fallback_init(struct sock *sk,
+			 struct tls_offload_context *offload_ctx,
+			 struct tls_crypto_info *crypto_info);
+
 #endif /* _TLS_OFFLOAD_H */
diff --git a/net/tls/Kconfig b/net/tls/Kconfig
index eb583038c67e..9d3ef820bb16 100644
--- a/net/tls/Kconfig
+++ b/net/tls/Kconfig
@@ -13,3 +13,13 @@  config TLS
 	encryption handling of the TLS protocol to be done in-kernel.
 
 	If unsure, say N.
+
+config TLS_DEVICE
+	bool "Transport Layer Security HW offload"
+	depends on TLS
+	select SOCK_VALIDATE_XMIT
+	default n
+	---help---
+	Enable kernel support for HW offload of the TLS protocol.
+
+	If unsure, say N.
diff --git a/net/tls/Makefile b/net/tls/Makefile
index a930fd1c4f7b..4d6b728a67d0 100644
--- a/net/tls/Makefile
+++ b/net/tls/Makefile
@@ -5,3 +5,5 @@ 
 obj-$(CONFIG_TLS) += tls.o
 
 tls-y := tls_main.o tls_sw.o
+
+tls-$(CONFIG_TLS_DEVICE) += tls_device.o tls_device_fallback.o
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
new file mode 100644
index 000000000000..c0d4e11a4286
--- /dev/null
+++ b/net/tls/tls_device.c
@@ -0,0 +1,804 @@ 
+/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
+ *
+ *     Redistribution and use in source and binary forms, with or
+ *     without modification, are permitted provided that the following
+ *     conditions are met:
+ *
+ *      - Redistributions of source code must retain the above
+ *        copyright notice, this list of conditions and the following
+ *        disclaimer.
+ *
+ *      - Redistributions in binary form must reproduce the above
+ *        copyright notice, this list of conditions and the following
+ *        disclaimer in the documentation and/or other materials
+ *        provided with the distribution.
+ *
+ *      - Neither the name of the Mellanox Technologies nor the
+ *        names of its contributors may be used to endorse or promote
+ *        products derived from this software without specific prior written
+ *        permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
+ * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED.
+ * IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
+ * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE
+ */
+
+#include <linux/module.h>
+#include <net/tcp.h>
+#include <net/inet_common.h>
+#include <linux/highmem.h>
+#include <linux/netdevice.h>
+
+#include <net/tls.h>
+#include <crypto/aead.h>
+
+/* device_offload_lock is used to synchronize tls_dev_add
+ * against NETDEV_DOWN notifications.
+ */
+DEFINE_STATIC_PERCPU_RWSEM(device_offload_lock);
+
+static void tls_device_gc_task(struct work_struct *work);
+
+static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
+static LIST_HEAD(tls_device_gc_list);
+static LIST_HEAD(tls_device_list);
+static DEFINE_SPINLOCK(tls_device_lock);
+
+static void tls_device_free_ctx(struct tls_context *ctx)
+{
+	struct tls_offload_context *offlad_ctx = tls_offload_ctx(ctx);
+
+	kfree(offlad_ctx);
+	kfree(ctx);
+}
+
+static void tls_device_gc_task(struct work_struct *work)
+{
+	struct tls_context *ctx, *tmp;
+	struct list_head gc_list;
+	unsigned long flags;
+
+	spin_lock_irqsave(&tls_device_lock, flags);
+	INIT_LIST_HEAD(&gc_list);
+	list_splice_init(&tls_device_gc_list, &gc_list);
+	spin_unlock_irqrestore(&tls_device_lock, flags);
+
+	list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
+		struct net_device *netdev = ctx->netdev;
+
+		if (netdev) {
+			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
+							TLS_OFFLOAD_CTX_DIR_TX);
+			dev_put(netdev);
+		}
+
+		list_del(&ctx->list);
+		tls_device_free_ctx(ctx);
+	}
+}
+
+static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
+{
+	unsigned long flags;
+
+	spin_lock_irqsave(&tls_device_lock, flags);
+	list_move_tail(&ctx->list, &tls_device_gc_list);
+
+	/* schedule_work inside the spinlock
+	 * to make sure tls_device_down waits for that work.
+	 */
+	schedule_work(&tls_device_gc_work);
+
+	spin_unlock_irqrestore(&tls_device_lock, flags);
+}
+
+/* We assume that the socket is already connected */
+static struct net_device *get_netdev_for_sock(struct sock *sk)
+{
+	struct inet_sock *inet = inet_sk(sk);
+	struct net_device *netdev = NULL;
+
+	netdev = dev_get_by_index(sock_net(sk), inet->cork.fl.flowi_oif);
+
+	return netdev;
+}
+
+static int attach_sock_to_netdev(struct sock *sk, struct net_device *netdev,
+				 struct tls_context *ctx)
+{
+	int rc;
+
+	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
+					     &ctx->crypto_send,
+					     tcp_sk(sk)->write_seq);
+	if (rc) {
+		pr_err_ratelimited("The netdev has refused to offload this socket\n");
+		goto out;
+	}
+
+	rc = 0;
+out:
+	return rc;
+}
+
+static void destroy_record(struct tls_record_info *record)
+{
+	skb_frag_t *frag;
+	int nr_frags = record->num_frags;
+
+	while (nr_frags > 0) {
+		frag = &record->frags[nr_frags - 1];
+		__skb_frag_unref(frag);
+		--nr_frags;
+	}
+	kfree(record);
+}
+
+static void delete_all_records(struct tls_offload_context *offload_ctx)
+{
+	struct tls_record_info *info, *temp;
+
+	list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
+		list_del(&info->list);
+		destroy_record(info);
+	}
+
+	offload_ctx->retransmit_hint = NULL;
+}
+
+static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context *ctx;
+	struct tls_record_info *info, *temp;
+	unsigned long flags;
+	u64 deleted_records = 0;
+
+	if (!tls_ctx)
+		return;
+
+	ctx = tls_offload_ctx(tls_ctx);
+
+	spin_lock_irqsave(&ctx->lock, flags);
+	info = ctx->retransmit_hint;
+	if (info && !before(acked_seq, info->end_seq)) {
+		ctx->retransmit_hint = NULL;
+		list_del(&info->list);
+		destroy_record(info);
+		deleted_records++;
+	}
+
+	list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
+		if (before(acked_seq, info->end_seq))
+			break;
+		list_del(&info->list);
+
+		destroy_record(info);
+		deleted_records++;
+	}
+
+	ctx->unacked_record_sn += deleted_records;
+	spin_unlock_irqrestore(&ctx->lock, flags);
+}
+
+/* At this point, there should be no references on this
+ * socket and no in-flight SKBs associated with this
+ * socket, so it is safe to free all the resources.
+ */
+void tls_device_sk_destruct(struct sock *sk)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+
+	if (ctx->open_record)
+		destroy_record(ctx->open_record);
+
+	delete_all_records(ctx);
+	crypto_free_aead(ctx->aead_send);
+	ctx->sk_destruct(sk);
+
+	if (refcount_dec_and_test(&tls_ctx->refcount))
+		tls_device_queue_ctx_destruction(tls_ctx);
+}
+EXPORT_SYMBOL(tls_device_sk_destruct);
+
+static inline void tls_append_frag(struct tls_record_info *record,
+				   struct page_frag *pfrag,
+				   int size)
+{
+	skb_frag_t *frag;
+
+	frag = &record->frags[record->num_frags - 1];
+	if (frag->page.p == pfrag->page &&
+	    frag->page_offset + frag->size == pfrag->offset) {
+		frag->size += size;
+	} else {
+		++frag;
+		frag->page.p = pfrag->page;
+		frag->page_offset = pfrag->offset;
+		frag->size = size;
+		++record->num_frags;
+		get_page(pfrag->page);
+	}
+
+	pfrag->offset += size;
+	record->len += size;
+}
+
+static inline int tls_push_record(struct sock *sk,
+				  struct tls_context *ctx,
+				  struct tls_offload_context *offload_ctx,
+				  struct tls_record_info *record,
+				  struct page_frag *pfrag,
+				  int flags,
+				  unsigned char record_type)
+{
+	skb_frag_t *frag;
+	struct tcp_sock *tp = tcp_sk(sk);
+	struct page_frag fallback_frag;
+	struct page_frag  *tag_pfrag = pfrag;
+	int i;
+
+	/* fill prepand */
+	frag = &record->frags[0];
+	tls_fill_prepend(ctx,
+			 skb_frag_address(frag),
+			 record->len - ctx->prepend_size,
+			 record_type);
+
+	if (unlikely(!skb_page_frag_refill(ctx->tag_size, pfrag, GFP_KERNEL))) {
+		/* HW doesn't care about the data in the tag
+		 * so in case pfrag has no room
+		 * for a tag and we can't allocate a new pfrag
+		 * just use the page in the first frag
+		 * rather then write a complicated fall back code.
+		 */
+		tag_pfrag = &fallback_frag;
+		tag_pfrag->page = skb_frag_page(frag);
+		tag_pfrag->offset = 0;
+	}
+
+	tls_append_frag(record, tag_pfrag, ctx->tag_size);
+	record->end_seq = tp->write_seq + record->len;
+	spin_lock_irq(&offload_ctx->lock);
+	list_add_tail(&record->list, &offload_ctx->records_list);
+	spin_unlock_irq(&offload_ctx->lock);
+	offload_ctx->open_record = NULL;
+	set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
+	tls_advance_record_sn(sk, ctx);
+
+	for (i = 0; i < record->num_frags; i++) {
+		frag = &record->frags[i];
+		sg_unmark_end(&offload_ctx->sg_tx_data[i]);
+		sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
+			    frag->size, frag->page_offset);
+		sk_mem_charge(sk, frag->size);
+		get_page(skb_frag_page(frag));
+	}
+	sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
+
+	/* all ready, send */
+	return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
+}
+
+static inline int tls_create_new_record(struct tls_offload_context *offload_ctx,
+					struct page_frag *pfrag,
+					size_t prepend_size)
+{
+	skb_frag_t *frag;
+	struct tls_record_info *record;
+
+	record = kmalloc(sizeof(*record), GFP_KERNEL);
+	if (!record)
+		return -ENOMEM;
+
+	frag = &record->frags[0];
+	__skb_frag_set_page(frag, pfrag->page);
+	frag->page_offset = pfrag->offset;
+	skb_frag_size_set(frag, prepend_size);
+
+	get_page(pfrag->page);
+	pfrag->offset += prepend_size;
+
+	record->num_frags = 1;
+	record->len = prepend_size;
+	offload_ctx->open_record = record;
+	return 0;
+}
+
+static inline int tls_do_allocation(struct sock *sk,
+				    struct tls_offload_context *offload_ctx,
+				    struct page_frag *pfrag,
+				    size_t prepend_size)
+{
+	int ret;
+
+	if (!offload_ctx->open_record) {
+		if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
+						   sk->sk_allocation))) {
+			sk->sk_prot->enter_memory_pressure(sk);
+			sk_stream_moderate_sndbuf(sk);
+			return -ENOMEM;
+		}
+
+		ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
+		if (ret)
+			return ret;
+
+		if (pfrag->size > pfrag->offset)
+			return 0;
+	}
+
+	if (!sk_page_frag_refill(sk, pfrag))
+		return -ENOMEM;
+
+	return 0;
+}
+
+static int tls_push_data(struct sock *sk,
+			 struct iov_iter *msg_iter,
+			 size_t size, int flags,
+			 unsigned char record_type)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+	struct tls_record_info *record = ctx->open_record;
+	struct page_frag *pfrag;
+	int copy, rc = 0;
+	size_t orig_size = size;
+	u32 max_open_record_len;
+	long timeo;
+	int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
+	int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
+	bool done = false;
+
+	if (flags &
+	    ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
+		return -ENOTSUPP;
+
+	if (sk->sk_err)
+		return -sk->sk_err;
+
+	timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
+	rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
+	if (rc < 0)
+		return rc;
+
+	pfrag = sk_page_frag(sk);
+
+	/* KTLS_TLS_HEADER_SIZE is not counted as part of the TLS record, and
+	 * we need to leave room for an authentication tag.
+	 */
+	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
+			      tls_ctx->prepend_size;
+	do {
+		if (tls_do_allocation(sk, ctx, pfrag,
+				      tls_ctx->prepend_size)) {
+			rc = sk_stream_wait_memory(sk, &timeo);
+			if (!rc)
+				continue;
+
+			record = ctx->open_record;
+			if (!record)
+				break;
+handle_error:
+			if (record_type != TLS_RECORD_TYPE_DATA) {
+				/* avoid sending partial
+				 * record with type !=
+				 * application_data
+				 */
+				size = orig_size;
+				destroy_record(record);
+				ctx->open_record = NULL;
+			} else if (record->len > tls_ctx->prepend_size) {
+				goto last_record;
+			}
+
+			break;
+		}
+
+		record = ctx->open_record;
+		copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
+		copy = min_t(size_t, copy, (max_open_record_len - record->len));
+
+		if (copy_from_iter_nocache(page_address(pfrag->page) +
+					       pfrag->offset,
+					   copy, msg_iter) != copy) {
+			rc = -EFAULT;
+			goto handle_error;
+		}
+		tls_append_frag(record, pfrag, copy);
+
+		size -= copy;
+		if (!size) {
+last_record:
+			tls_push_record_flags = flags;
+			if (more) {
+				tls_ctx->pending_open_record_frags =
+						record->num_frags;
+				break;
+			}
+
+			done = true;
+		}
+
+		if ((done) || record->len >= max_open_record_len ||
+		    (record->num_frags >= MAX_SKB_FRAGS - 1)) {
+			rc = tls_push_record(sk,
+					     tls_ctx,
+					     ctx,
+					     record,
+					     pfrag,
+					     tls_push_record_flags,
+					     record_type);
+			if (rc < 0)
+				break;
+		}
+	} while (!done);
+
+	if (orig_size - size > 0)
+		rc = orig_size - size;
+
+	return rc;
+}
+
+int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+{
+	unsigned char record_type = TLS_RECORD_TYPE_DATA;
+	int rc = 0;
+
+	lock_sock(sk);
+
+	if (unlikely(msg->msg_controllen)) {
+		rc = tls_proccess_cmsg(sk, msg, &record_type);
+		if (rc)
+			goto out;
+	}
+
+	rc = tls_push_data(sk, &msg->msg_iter, size,
+			   msg->msg_flags, record_type);
+
+out:
+	release_sock(sk);
+	return rc;
+}
+
+int tls_device_sendpage(struct sock *sk, struct page *page,
+			int offset, size_t size, int flags)
+{
+	struct iov_iter	msg_iter;
+	struct kvec iov;
+	char *kaddr = kmap(page);
+	int rc = 0;
+
+	if (flags & MSG_SENDPAGE_NOTLAST)
+		flags |= MSG_MORE;
+
+	lock_sock(sk);
+
+	if (flags & MSG_OOB) {
+		rc = -ENOTSUPP;
+		goto out;
+	}
+
+	iov.iov_base = kaddr + offset;
+	iov.iov_len = size;
+	iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size);
+	rc = tls_push_data(sk, &msg_iter, size,
+			   flags, TLS_RECORD_TYPE_DATA);
+	kunmap(page);
+
+out:
+	release_sock(sk);
+	return rc;
+}
+
+struct tls_record_info *tls_get_record(struct tls_offload_context *context,
+				       u32 seq, u64 *p_record_sn)
+{
+	struct tls_record_info *info;
+	u64 record_sn = context->hint_record_sn;
+
+	info = context->retransmit_hint;
+	if (!info ||
+	    before(seq, info->end_seq - info->len)) {
+		/* if retransmit_hint is irrelevant start
+		 * from the begging of the list
+		 */
+		info = list_first_entry(&context->records_list,
+					struct tls_record_info, list);
+		record_sn = context->unacked_record_sn;
+	}
+
+	list_for_each_entry_from(info, &context->records_list, list) {
+		if (before(seq, info->end_seq)) {
+			if (!context->retransmit_hint ||
+			    after(info->end_seq,
+				  context->retransmit_hint->end_seq)) {
+				context->hint_record_sn = record_sn;
+				context->retransmit_hint = info;
+			}
+			*p_record_sn = record_sn;
+			return info;
+		}
+		record_sn++;
+	}
+
+	return NULL;
+}
+EXPORT_SYMBOL(tls_get_record);
+
+static int tls_device_push_pending_record(struct sock *sk, int flags)
+{
+	struct iov_iter	msg_iter;
+
+	iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
+	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
+}
+
+int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
+{
+	u16 nonece_size, tag_size, iv_size, rec_seq_size;
+	struct tls_record_info *start_marker_record;
+	struct tls_offload_context *offload_ctx;
+	struct tls_crypto_info *crypto_info;
+	struct net_device *netdev;
+	char *iv, *rec_seq;
+	struct sk_buff *skb;
+	int rc = -EINVAL;
+	__be64 rcd_sn;
+
+	if (!ctx)
+		goto out;
+
+	if (ctx->priv_ctx) {
+		rc = -EEXIST;
+		goto out;
+	}
+
+	/* We support starting offload on multiple sockets
+	 * concurrently, So we only need a read lock here.
+	 */
+	percpu_down_read(&device_offload_lock);
+	netdev = get_netdev_for_sock(sk);
+	if (!netdev) {
+		pr_err_ratelimited("%s: netdev not found\n", __func__);
+		rc = -EINVAL;
+		goto release_lock;
+	}
+
+	if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
+		rc = -ENOTSUPP;
+		goto release_netdev;
+	}
+
+	/* Avoid offloading if the device is down
+	 * We don't want to offload new flows after
+	 * the NETDEV_DOWN event
+	 */
+	if (!(netdev->flags & IFF_UP)) {
+		rc = -EINVAL;
+		goto release_lock;
+	}
+
+	crypto_info = &ctx->crypto_send;
+	switch (crypto_info->cipher_type) {
+	case TLS_CIPHER_AES_GCM_128: {
+		nonece_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
+		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
+		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
+		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
+		rec_seq =
+		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
+		break;
+	}
+	default:
+		rc = -EINVAL;
+		goto release_netdev;
+	}
+
+	start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
+	if (!start_marker_record) {
+		rc = -ENOMEM;
+		goto release_netdev;
+	}
+
+	offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL);
+	if (!offload_ctx)
+		goto free_marker_record;
+
+	ctx->priv_ctx = offload_ctx;
+	rc = attach_sock_to_netdev(sk, netdev, ctx);
+	if (rc)
+		goto free_offload_context;
+
+	ctx->netdev = netdev;
+	ctx->prepend_size = TLS_HEADER_SIZE + nonece_size;
+	ctx->tag_size = tag_size;
+	ctx->iv_size = iv_size;
+	ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+			  GFP_KERNEL);
+	if (!ctx->iv) {
+		rc = -ENOMEM;
+		goto detach_sock;
+	}
+
+	memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
+
+	ctx->rec_seq_size = rec_seq_size;
+	ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
+	if (!ctx->rec_seq) {
+		rc = -ENOMEM;
+		goto free_iv;
+	}
+	memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
+
+	/* start at rec_seq - 1 to account for the start marker record */
+	memcpy(&rcd_sn, ctx->rec_seq, sizeof(rcd_sn));
+	offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
+
+	rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
+	if (rc)
+		goto free_rec_seq;
+
+	start_marker_record->end_seq = tcp_sk(sk)->write_seq;
+	start_marker_record->len = 0;
+	start_marker_record->num_frags = 0;
+
+	INIT_LIST_HEAD(&offload_ctx->records_list);
+	list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
+	spin_lock_init(&offload_ctx->lock);
+
+	inet_csk(sk)->icsk_clean_acked = &tls_icsk_clean_acked;
+	ctx->push_pending_record = tls_device_push_pending_record;
+	offload_ctx->sk_destruct = sk->sk_destruct;
+
+	/* TLS offload is greatly simplified if we don't send
+	 * SKBs where only part of the payload needs to be encrypted.
+	 * So mark the last skb in the write queue as end of record.
+	 */
+	skb = tcp_write_queue_tail(sk);
+	if (skb)
+		TCP_SKB_CB(skb)->eor = 1;
+
+	refcount_set(&ctx->refcount, 1);
+	spin_lock_irq(&tls_device_lock);
+	list_add_tail(&ctx->list, &tls_device_list);
+	spin_unlock_irq(&tls_device_lock);
+
+	/* following this assignment tls_is_sk_tx_device_offloaded
+	 * will return true and the context might be accessed
+	 * by the netdev's xmit function.
+	 */
+	smp_store_release(&sk->sk_destruct,
+			  &tls_device_sk_destruct);
+	goto release_lock;
+
+free_rec_seq:
+	kfree(ctx->rec_seq);
+free_iv:
+	kfree(ctx->iv);
+detach_sock:
+	netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX);
+free_offload_context:
+	kfree(offload_ctx);
+	ctx->priv_ctx = NULL;
+free_marker_record:
+	kfree(start_marker_record);
+release_netdev:
+	dev_put(netdev);
+release_lock:
+	percpu_up_read(&device_offload_lock);
+out:
+	return rc;
+}
+
+static int tls_device_register(struct net_device *dev)
+{
+	if ((dev->features & NETIF_F_HW_TLS_TX) && !dev->tlsdev_ops)
+		return NOTIFY_BAD;
+
+	return NOTIFY_DONE;
+}
+
+static int tls_device_unregister(struct net_device *dev)
+{
+	return NOTIFY_DONE;
+}
+
+static int tls_device_feat_change(struct net_device *dev)
+{
+	if ((dev->features & NETIF_F_HW_TLS_TX) && !dev->tlsdev_ops)
+		return NOTIFY_BAD;
+
+	return NOTIFY_DONE;
+}
+
+static int tls_device_down(struct net_device *netdev)
+{
+	struct tls_context *ctx, *tmp;
+	struct list_head list;
+	unsigned long flags;
+
+	if (!(netdev->features & NETIF_F_HW_TLS_TX))
+		return NOTIFY_DONE;
+
+	/* Request a write lock to block new offload attempts
+	 */
+	percpu_down_write(&device_offload_lock);
+
+	spin_lock_irqsave(&tls_device_lock, flags);
+	INIT_LIST_HEAD(&list);
+
+	list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
+		if (ctx->netdev != netdev ||
+		    !refcount_inc_not_zero(&ctx->refcount))
+			continue;
+
+		list_move(&ctx->list, &list);
+	}
+	spin_unlock_irqrestore(&tls_device_lock, flags);
+
+	list_for_each_entry_safe(ctx, tmp, &list, list)	{
+		netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
+						TLS_OFFLOAD_CTX_DIR_TX);
+		ctx->netdev = NULL;
+		dev_put(netdev);
+		list_del_init(&ctx->list);
+
+		if (refcount_dec_and_test(&ctx->refcount))
+			tls_device_free_ctx(ctx);
+	}
+
+	percpu_up_write(&device_offload_lock);
+
+	flush_work(&tls_device_gc_work);
+
+	return NOTIFY_DONE;
+}
+
+static int tls_dev_event(struct notifier_block *this, unsigned long event,
+			 void *ptr)
+{
+	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
+
+	switch (event) {
+	case NETDEV_REGISTER:
+		return tls_device_register(dev);
+
+	case NETDEV_UNREGISTER:
+		return tls_device_unregister(dev);
+
+	case NETDEV_FEAT_CHANGE:
+		return tls_device_feat_change(dev);
+
+	case NETDEV_DOWN:
+		return tls_device_down(dev);
+	}
+	return NOTIFY_DONE;
+}
+
+static struct notifier_block tls_dev_notifier = {
+	.notifier_call	= tls_dev_event,
+};
+
+void __init tls_device_init(void)
+{
+	register_netdevice_notifier(&tls_dev_notifier);
+}
+
+void __exit tls_device_cleanup(void)
+{
+	unregister_netdevice_notifier(&tls_dev_notifier);
+	flush_work(&tls_device_gc_work);
+}
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
new file mode 100644
index 000000000000..14d31a36885c
--- /dev/null
+++ b/net/tls/tls_device_fallback.c
@@ -0,0 +1,419 @@ 
+/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
+ *
+ *     Redistribution and use in source and binary forms, with or
+ *     without modification, are permitted provided that the following
+ *     conditions are met:
+ *
+ *      - Redistributions of source code must retain the above
+ *        copyright notice, this list of conditions and the following
+ *        disclaimer.
+ *
+ *      - Redistributions in binary form must reproduce the above
+ *        copyright notice, this list of conditions and the following
+ *        disclaimer in the documentation and/or other materials
+ *        provided with the distribution.
+ *
+ *      - Neither the name of the Mellanox Technologies nor the
+ *        names of its contributors may be used to endorse or promote
+ *        products derived from this software without specific prior written
+ *        permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
+ * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED.
+ * IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
+ * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE
+ */
+
+#include <net/tls.h>
+#include <crypto/aead.h>
+#include <crypto/scatterwalk.h>
+#include <net/ip6_checksum.h>
+
+static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
+{
+	struct scatterlist *src = walk->sg;
+	int diff = walk->offset - src->offset;
+
+	sg_set_page(sg, sg_page(src),
+		    src->length - diff, walk->offset);
+
+	scatterwalk_crypto_chain(sg, sg_next(src), 0, 2);
+}
+
+static int tls_enc_record(struct aead_request *aead_req,
+			  struct crypto_aead *aead, char *aad, char *iv,
+			  __be64 rcd_sn, struct scatter_walk *in,
+			  struct scatter_walk *out, int *in_len)
+{
+	struct scatterlist sg_in[3];
+	struct scatterlist sg_out[3];
+	unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
+	u16 len;
+	int rc;
+
+	len = min_t(int, *in_len, ARRAY_SIZE(buf));
+
+	scatterwalk_copychunks(buf, in, len, 0);
+	scatterwalk_copychunks(buf, out, len, 1);
+
+	*in_len -= len;
+	if (!*in_len)
+		return 0;
+
+	scatterwalk_pagedone(in, 0, 1);
+	scatterwalk_pagedone(out, 1, 1);
+
+	len = buf[4] | (buf[3] << 8);
+	len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
+
+	tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
+		     (char *)&rcd_sn, sizeof(rcd_sn), buf[0]);
+
+	memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
+	       TLS_CIPHER_AES_GCM_128_IV_SIZE);
+
+	sg_init_table(sg_in, ARRAY_SIZE(sg_in));
+	sg_init_table(sg_out, ARRAY_SIZE(sg_out));
+	sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
+	sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
+	chain_to_walk(sg_in + 1, in);
+	chain_to_walk(sg_out + 1, out);
+
+	*in_len -= len;
+	if (*in_len < 0) {
+		*in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+		if (*in_len < 0)
+		/* the input buffer doesn't contain the entire record.
+		 * trim len accordingly. The resulting authentication tag
+		 * will contain garbage. but we don't care as we won't
+		 * include any of it in the output skb
+		 * Note that we assume the output buffer length
+		 * is larger then input buffer length + tag size
+		 */
+			len += *in_len;
+
+		*in_len = 0;
+	}
+
+	if (*in_len) {
+		scatterwalk_copychunks(NULL, in, len, 2);
+		scatterwalk_pagedone(in, 0, 1);
+		scatterwalk_copychunks(NULL, out, len, 2);
+		scatterwalk_pagedone(out, 1, 1);
+	}
+
+	len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+	aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
+
+	rc = crypto_aead_encrypt(aead_req);
+
+	return rc;
+}
+
+static void tls_init_aead_request(struct aead_request *aead_req,
+				  struct crypto_aead *aead)
+{
+	aead_request_set_tfm(aead_req, aead);
+	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+}
+
+static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead,
+						   gfp_t flags)
+{
+	unsigned int req_size = sizeof(struct aead_request) +
+		crypto_aead_reqsize(aead);
+	struct aead_request *aead_req;
+
+	aead_req = kzalloc(req_size, flags);
+	if (!aead_req)
+		return NULL;
+
+	tls_init_aead_request(aead_req, aead);
+	return aead_req;
+}
+
+static int tls_enc_records(struct aead_request *aead_req,
+			   struct crypto_aead *aead, struct scatterlist *sg_in,
+			   struct scatterlist *sg_out, char *aad, char *iv,
+			   u64 rcd_sn, int len)
+{
+	struct scatter_walk in;
+	struct scatter_walk out;
+	int rc;
+
+	scatterwalk_start(&in, sg_in);
+	scatterwalk_start(&out, sg_out);
+
+	do {
+		rc = tls_enc_record(aead_req, aead, aad, iv,
+				    cpu_to_be64(rcd_sn), &in, &out, &len);
+		rcd_sn++;
+
+	} while (rc == 0 && len);
+
+	scatterwalk_done(&in, 0, 0);
+	scatterwalk_done(&out, 1, 0);
+
+	return rc;
+}
+
+static inline void update_chksum(struct sk_buff *skb, int headln)
+{
+	/* Can't use icsk->icsk_af_ops->send_check here because the ip addresses
+	 * might have been changed by NAT.
+	 */
+
+	const struct ipv6hdr *ipv6h;
+	const struct iphdr *iph;
+	struct tcphdr *th = tcp_hdr(skb);
+	int datalen = skb->len - headln;
+
+	/* We only changed the payload so if we are using partial we don't
+	 * need to update anything.
+	 */
+	if (likely(skb->ip_summed == CHECKSUM_PARTIAL))
+		return;
+
+	skb->ip_summed = CHECKSUM_PARTIAL;
+	skb->csum_start = skb_transport_header(skb) - skb->head;
+	skb->csum_offset = offsetof(struct tcphdr, check);
+
+	if (skb->sk->sk_family == AF_INET6) {
+		ipv6h = ipv6_hdr(skb);
+		th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
+					     datalen, IPPROTO_TCP, 0);
+	} else {
+		iph = ip_hdr(skb);
+		th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen,
+					       IPPROTO_TCP, 0);
+	}
+}
+
+static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
+{
+	skb_copy_header(nskb, skb);
+
+	skb_put(nskb, skb->len);
+	memcpy(nskb->data, skb->data, headln);
+	update_chksum(nskb, headln);
+
+	nskb->destructor = skb->destructor;
+	nskb->sk = skb->sk;
+	skb->destructor = NULL;
+	skb->sk = NULL;
+	refcount_add(nskb->truesize - skb->truesize,
+		     &nskb->sk->sk_wmem_alloc);
+}
+
+/* This function may be called after the user socket is already
+ * closed so make sure we don't use anything freed during
+ * tls_sk_proto_close here
+ */
+static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
+{
+	int tcp_header_size = tcp_hdrlen(skb);
+	int tcp_payload_offset = skb_transport_offset(skb) + tcp_header_size;
+	int payload_len = skb->len - tcp_payload_offset;
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+	int remaining, buf_len, resync_sgs, rc, i = 0;
+	void *buf, *dummy_buf, *iv, *aad;
+	struct scatterlist *sg_in;
+	struct scatterlist sg_out[3];
+	u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
+	struct aead_request *aead_req;
+	struct sk_buff *nskb = NULL;
+	struct tls_record_info *record;
+	unsigned long flags;
+	s32 sync_size;
+	u64 rcd_sn;
+
+	/* worst case is:
+	 * MAX_SKB_FRAGS in tls_record_info
+	 * MAX_SKB_FRAGS + 1 in SKB head an frags.
+	 */
+	int sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1;
+
+	if (!payload_len)
+		return skb;
+
+	sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC);
+	if (!sg_in)
+		goto free_orig;
+
+	sg_init_table(sg_in, sg_in_max_elements);
+	sg_init_table(sg_out, ARRAY_SIZE(sg_out));
+
+	spin_lock_irqsave(&ctx->lock, flags);
+	record = tls_get_record(ctx, tcp_seq, &rcd_sn);
+	if (!record) {
+		spin_unlock_irqrestore(&ctx->lock, flags);
+		WARN(1, "Record not found for seq %u\n", tcp_seq);
+		goto free_sg;
+	}
+
+	sync_size = tcp_seq - tls_record_start_seq(record);
+	if (sync_size < 0) {
+		int is_start_marker = tls_record_is_start_marker(record);
+
+		spin_unlock_irqrestore(&ctx->lock, flags);
+		if (!is_start_marker)
+		/* This should only occur if the relevant record was
+		 * already acked. In that case it should be ok
+		 * to drop the packet and avoid retransmission.
+		 *
+		 * There is a corner case where the packet contains
+		 * both an acked and a non-acked record.
+		 * We currently don't handle that case and rely
+		 * on TCP to retranmit a packet that doesn't contain
+		 * already acked payload.
+		 */
+			goto free_orig;
+
+		if (payload_len > -sync_size) {
+			WARN(1, "Fallback of partially offloaded packets is not supported\n");
+			goto free_sg;
+		} else {
+			return skb;
+		}
+	}
+
+	remaining = sync_size;
+	while (remaining > 0) {
+		skb_frag_t *frag = &record->frags[i];
+
+		__skb_frag_ref(frag);
+		sg_set_page(sg_in + i, skb_frag_page(frag),
+			    skb_frag_size(frag), frag->page_offset);
+
+		remaining -= skb_frag_size(frag);
+
+		if (remaining < 0)
+			sg_in[i].length += remaining;
+
+		i++;
+	}
+	spin_unlock_irqrestore(&ctx->lock, flags);
+	resync_sgs = i;
+
+	aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC);
+	if (!aead_req)
+		goto put_sg;
+
+	buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
+		  TLS_CIPHER_AES_GCM_128_IV_SIZE +
+		  TLS_AAD_SPACE_SIZE +
+		  sync_size +
+		  tls_ctx->tag_size;
+	buf = kmalloc(buf_len, GFP_ATOMIC);
+	if (!buf)
+		goto free_req;
+
+	nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
+	if (!nskb)
+		goto free_buf;
+
+	skb_reserve(nskb, skb_headroom(skb));
+
+	iv = buf;
+
+	memcpy(iv, tls_ctx->crypto_send_aes_gcm_128.salt,
+	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+	aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
+	      TLS_CIPHER_AES_GCM_128_IV_SIZE;
+	dummy_buf = aad + TLS_AAD_SPACE_SIZE;
+
+	sg_set_buf(&sg_out[0], dummy_buf, sync_size);
+	sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset,
+		   payload_len);
+	/* Add room for authentication tag produced by crypto */
+	dummy_buf += sync_size;
+	sg_set_buf(&sg_out[2], dummy_buf, tls_ctx->tag_size);
+	rc = skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset,
+			  payload_len);
+	if (rc < 0)
+		goto free_nskb;
+
+	rc = tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv,
+			     rcd_sn, sync_size + payload_len);
+	if (rc < 0)
+		goto free_nskb;
+
+	complete_skb(nskb, skb, tcp_payload_offset);
+
+	/* validate_xmit_skb_list assumes that if the skb wasn't segmented
+	 * nskb->prev will point to the skb itself
+	 */
+	nskb->prev = nskb;
+free_buf:
+	kfree(buf);
+free_req:
+	kfree(aead_req);
+put_sg:
+	for (i = 0; i < resync_sgs; i++)
+		put_page(sg_page(&sg_in[i]));
+free_sg:
+	kfree(sg_in);
+free_orig:
+	kfree_skb(skb);
+	return nskb;
+
+free_nskb:
+	kfree_skb(nskb);
+	nskb = NULL;
+	goto free_buf;
+}
+
+static struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
+					     struct net_device *dev,
+					     struct sk_buff *skb)
+{
+	if (dev == tls_get_ctx(sk)->netdev)
+		return skb;
+
+	return tls_sw_fallback(sk, skb);
+}
+
+int tls_sw_fallback_init(struct sock *sk,
+			 struct tls_offload_context *offload_ctx,
+			 struct tls_crypto_info *crypto_info)
+{
+	int rc;
+	const u8 *key;
+
+	offload_ctx->aead_send =
+	    crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
+	if (IS_ERR(offload_ctx->aead_send)) {
+		rc = PTR_ERR(offload_ctx->aead_send);
+		pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc);
+		offload_ctx->aead_send = NULL;
+		goto err_out;
+	}
+
+	key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
+
+	rc = crypto_aead_setkey(offload_ctx->aead_send, key,
+				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+	if (rc)
+		goto free_aead;
+
+	rc = crypto_aead_setauthsize(offload_ctx->aead_send,
+				     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+	if (rc)
+		goto free_aead;
+
+	sk->sk_validate_xmit_skb = tls_validate_xmit_skb;
+	return 0;
+free_aead:
+	crypto_free_aead(offload_ctx->aead_send);
+err_out:
+	return rc;
+}
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index d824d548447e..e0dface33017 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -54,6 +54,9 @@  enum {
 enum {
 	TLS_BASE_TX,
 	TLS_SW_TX,
+#ifdef CONFIG_TLS_DEVICE
+	TLS_HW_TX,
+#endif
 	TLS_NUM_CONFIG,
 };
 
@@ -416,11 +419,19 @@  static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
 		goto err_crypto_info;
 	}
 
-	/* currently SW is default, we will have ethtool in future */
-	rc = tls_set_sw_offload(sk, ctx);
-	tx_conf = TLS_SW_TX;
-	if (rc)
-		goto err_crypto_info;
+#ifdef CONFIG_TLS_DEVICE
+	rc = tls_set_device_offload(sk, ctx);
+	tx_conf = TLS_HW_TX;
+	if (rc) {
+#else
+	{
+#endif
+		/* if HW offload fails fallback to SW */
+		rc = tls_set_sw_offload(sk, ctx);
+		tx_conf = TLS_SW_TX;
+		if (rc)
+			goto err_crypto_info;
+	}
 
 	ctx->tx_conf = tx_conf;
 	update_sk_prot(sk, ctx);
@@ -473,6 +484,12 @@  static void build_protos(struct proto *prot, struct proto *base)
 	prot[TLS_SW_TX] = prot[TLS_BASE_TX];
 	prot[TLS_SW_TX].sendmsg		= tls_sw_sendmsg;
 	prot[TLS_SW_TX].sendpage	= tls_sw_sendpage;
+
+#ifdef CONFIG_TLS_DEVICE
+	prot[TLS_HW_TX] = prot[TLS_SW_TX];
+	prot[TLS_HW_TX].sendmsg		= tls_device_sendmsg;
+	prot[TLS_HW_TX].sendpage	= tls_device_sendpage;
+#endif
 }
 
 static int tls_init(struct sock *sk)
@@ -531,6 +548,9 @@  static int __init tls_register(void)
 {
 	build_protos(tls_prots[TLSV4], &tcp_prot);
 
+#ifdef CONFIG_TLS_DEVICE
+	tls_device_init();
+#endif
 	tcp_register_ulp(&tcp_tls_ulp_ops);
 
 	return 0;
@@ -539,6 +559,9 @@  static int __init tls_register(void)
 static void __exit tls_unregister(void)
 {
 	tcp_unregister_ulp(&tcp_tls_ulp_ops);
+#ifdef CONFIG_TLS_DEVICE
+	tls_device_cleanup();
+#endif
 }
 
 module_init(tls_register);