@@ -71,6 +71,29 @@ struct tls_sw_context {
struct scatterlist sg_aead_out[2];
};
+struct tls_record_info {
+ struct list_head list;
+ u32 end_seq;
+ int len;
+ int num_frags;
+ skb_frag_t frags[MAX_SKB_FRAGS];
+};
+
+struct tls_offload_context {
+ struct crypto_aead *aead_send;
+
+ struct list_head records_list;
+ struct scatterlist sg_tx_data[MAX_SKB_FRAGS];
+ void (*sk_destruct)(struct sock *sk);
+ struct tls_record_info *open_record;
+ struct tls_record_info *retransmit_hint;
+ u64 hint_record_sn;
+ u64 unacked_record_sn;
+
+ u32 expected_seq;
+ spinlock_t lock; /* protects records list */
+};
+
enum {
TLS_PENDING_CLOSED_RECORD
};
@@ -81,6 +104,9 @@ struct tls_context {
struct tls12_crypto_info_aes_gcm_128 crypto_send_aes_gcm_128;
};
+ struct list_head gclist;
+ struct sock *sk;
+ struct net_device *netdev;
void *priv_ctx;
u8 tx_conf:2;
@@ -125,9 +151,23 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
void tls_sw_close(struct sock *sk, long timeout);
void tls_sw_free_tx_resources(struct sock *sk);
-void tls_sk_destruct(struct sock *sk, struct tls_context *ctx);
-void tls_icsk_clean_acked(struct sock *sk);
+void tls_clear_device_offload(struct sock *sk, struct tls_context *ctx);
+int tls_set_device_offload(struct sock *sk, struct tls_context *ctx);
+int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
+int tls_device_sendpage(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags);
+void tls_device_sk_destruct(struct sock *sk);
+void tls_device_cleanup(void);
+
+struct tls_record_info *tls_get_record(struct tls_offload_context *context,
+ u32 seq, u64 *p_record_sn);
+static inline bool tls_record_is_start_marker(struct tls_record_info *rec)
+{
+ return rec->len == 0;
+}
+
+void tls_sk_destruct(struct sock *sk, struct tls_context *ctx);
int tls_push_sg(struct sock *sk, struct tls_context *ctx,
struct scatterlist *sg, u16 first_offset,
int flags);
@@ -164,6 +204,13 @@ static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx)
return tls_ctx->pending_open_record_frags;
}
+static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk)
+{
+ /* matches smp_store_release in tls_set_device_offload */
+ return smp_load_acquire(&sk->sk_destruct) ==
+ &tls_device_sk_destruct;
+}
+
static inline void tls_err_abort(struct sock *sk)
{
sk->sk_err = -EBADMSG;
@@ -251,4 +298,8 @@ static inline struct tls_offload_context *tls_offload_ctx(
int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
unsigned char *record_type);
+int tls_sw_fallback_init(struct sock *sk,
+ struct tls_offload_context *offload_ctx,
+ struct tls_crypto_info *crypto_info);
+
#endif /* _TLS_OFFLOAD_H */
@@ -13,3 +13,12 @@ config TLS
encryption handling of the TLS protocol to be done in-kernel.
If unsure, say N.
+
+config TLS_DEVICE
+ bool "Transport Layer Security HW offload"
+ depends on TLS
+ default n
+ ---help---
+ Enable kernel support for HW offload of the TLS protocol.
+
+ If unsure, say N.
@@ -5,3 +5,6 @@
obj-$(CONFIG_TLS) += tls.o
tls-y := tls_main.o tls_sw.o
+
+tls-$(CONFIG_TLS_DEVICE) += tls_device.o tls_device_fallback.o
+
new file mode 100644
@@ -0,0 +1,692 @@
+/* Copyright (c) 2016-2017, Mellanox Technologies All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or
+ * without modification, are permitted provided that the following
+ * conditions are met:
+ *
+ * - Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * - Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * - Neither the name of the Mellanox Technologies nor the
+ * names of its contributors may be used to endorse or promote
+ * products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
+ * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED.
+ * IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
+ * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE
+ */
+
+#include <linux/module.h>
+#include <net/tcp.h>
+#include <net/inet_common.h>
+#include <linux/highmem.h>
+#include <linux/netdevice.h>
+
+#include <net/tls.h>
+#include <crypto/aead.h>
+
+static void tls_device_gc_task(struct work_struct *work);
+
+static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
+static LIST_HEAD(tls_device_gc_list);
+static DEFINE_SPINLOCK(tls_device_gc_lock);
+
+static void tls_device_gc_task(struct work_struct *work)
+{
+ struct tls_context *ctx, *tmp;
+ struct list_head gc_list;
+ unsigned long flags;
+
+ spin_lock_irqsave(&tls_device_gc_lock, flags);
+ INIT_LIST_HEAD(&gc_list);
+ list_splice_init(&tls_device_gc_list, &gc_list);
+ spin_unlock_irqrestore(&tls_device_gc_lock, flags);
+
+ list_for_each_entry_safe(ctx, tmp, &gc_list, gclist) {
+ struct tls_offload_context *offlad_ctx = tls_offload_ctx(ctx);
+ void (*sk_destruct)(struct sock *sk) = offlad_ctx->sk_destruct;
+ struct net_device *netdev = ctx->netdev;
+ struct sock *sk = ctx->sk;
+
+ netdev->tlsdev_ops->tls_dev_del(netdev, sk,
+ TLS_OFFLOAD_CTX_DIR_TX);
+
+ list_del(&ctx->gclist);
+ kfree(offlad_ctx);
+ kfree(ctx);
+ sk_destruct(sk);
+ }
+}
+
+static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
+{
+ unsigned long flags;
+
+ spin_lock_irqsave(&tls_device_gc_lock, flags);
+ list_add_tail(&ctx->gclist, &tls_device_gc_list);
+ spin_unlock_irqrestore(&tls_device_gc_lock, flags);
+
+ schedule_work(&tls_device_gc_work);
+}
+
+/* We assume that the socket is already connected */
+static struct net_device *get_netdev_for_sock(struct sock *sk)
+{
+ struct inet_sock *inet = inet_sk(sk);
+ struct net_device *netdev = NULL;
+
+ netdev = dev_get_by_index(sock_net(sk), inet->cork.fl.flowi_oif);
+
+ return netdev;
+}
+
+static void detach_sock_from_netdev(struct sock *sk, struct tls_context *ctx)
+{
+ struct net_device *netdev;
+
+ netdev = get_netdev_for_sock(sk);
+ if (!netdev) {
+ pr_err("got offloaded socket with no netdev\n");
+ return;
+ }
+
+ if (!netdev->tlsdev_ops) {
+ pr_err("attach_sock_to_netdev: netdev %s with no TLS offload\n",
+ netdev->name);
+ return;
+ }
+
+ netdev->tlsdev_ops->tls_dev_del(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX);
+ dev_put(netdev);
+}
+
+static int attach_sock_to_netdev(struct sock *sk, struct net_device *netdev,
+ struct tls_context *ctx)
+{
+ int rc;
+
+ rc = netdev->tlsdev_ops->tls_dev_add(
+ netdev,
+ sk,
+ TLS_OFFLOAD_CTX_DIR_TX,
+ &ctx->crypto_send);
+ if (rc) {
+ pr_err("The netdev has refused to offload this socket\n");
+ goto out;
+ }
+
+ rc = 0;
+out:
+ return rc;
+}
+
+static void destroy_record(struct tls_record_info *record)
+{
+ skb_frag_t *frag;
+ int nr_frags = record->num_frags;
+
+ while (nr_frags > 0) {
+ frag = &record->frags[nr_frags - 1];
+ __skb_frag_unref(frag);
+ --nr_frags;
+ }
+ kfree(record);
+}
+
+static void delete_all_records(struct tls_offload_context *offload_ctx)
+{
+ struct tls_record_info *info, *temp;
+
+ list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
+ list_del(&info->list);
+ destroy_record(info);
+ }
+
+ offload_ctx->retransmit_hint = NULL;
+}
+
+static void tls_icsk_clean_acked(struct sock *sk)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_offload_context *ctx;
+ struct tcp_sock *tp = tcp_sk(sk);
+ struct tls_record_info *info, *temp;
+ unsigned long flags;
+ u64 deleted_records = 0;
+
+ if (!tls_ctx)
+ return;
+
+ ctx = tls_offload_ctx(tls_ctx);
+
+ spin_lock_irqsave(&ctx->lock, flags);
+ info = ctx->retransmit_hint;
+ if (info && !before(tp->snd_una, info->end_seq)) {
+ ctx->retransmit_hint = NULL;
+ list_del(&info->list);
+ destroy_record(info);
+ deleted_records++;
+ }
+
+ list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
+ if (before(tp->snd_una, info->end_seq))
+ break;
+ list_del(&info->list);
+
+ destroy_record(info);
+ deleted_records++;
+ }
+
+ ctx->unacked_record_sn += deleted_records;
+ spin_unlock_irqrestore(&ctx->lock, flags);
+}
+
+/* At this point, there should be no references on this
+ * socket and no in-flight SKBs associated with this
+ * socket, so it is safe to free all the resources.
+ */
+void tls_device_sk_destruct(struct sock *sk)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+
+ if (ctx->open_record)
+ destroy_record(ctx->open_record);
+
+ delete_all_records(ctx);
+ crypto_free_aead(ctx->aead_send);
+
+ tls_device_queue_ctx_destruction(tls_ctx);
+}
+EXPORT_SYMBOL(tls_device_sk_destruct);
+
+static inline void tls_append_frag(struct tls_record_info *record,
+ struct page_frag *pfrag,
+ int size)
+{
+ skb_frag_t *frag;
+
+ frag = &record->frags[record->num_frags - 1];
+ if (frag->page.p == pfrag->page &&
+ frag->page_offset + frag->size == pfrag->offset) {
+ frag->size += size;
+ } else {
+ ++frag;
+ frag->page.p = pfrag->page;
+ frag->page_offset = pfrag->offset;
+ frag->size = size;
+ ++record->num_frags;
+ get_page(pfrag->page);
+ }
+
+ pfrag->offset += size;
+ record->len += size;
+}
+
+static inline int tls_push_record(struct sock *sk,
+ struct tls_context *ctx,
+ struct tls_offload_context *offload_ctx,
+ struct tls_record_info *record,
+ struct page_frag *pfrag,
+ int flags,
+ unsigned char record_type)
+{
+ skb_frag_t *frag;
+ struct tcp_sock *tp = tcp_sk(sk);
+ struct page_frag fallback_frag;
+ struct page_frag *tag_pfrag = pfrag;
+ int i;
+
+ /* fill prepand */
+ frag = &record->frags[0];
+ tls_fill_prepend(ctx,
+ skb_frag_address(frag),
+ record->len - ctx->prepend_size,
+ record_type);
+
+ if (unlikely(!skb_page_frag_refill(
+ ctx->tag_size,
+ pfrag, GFP_KERNEL))) {
+ /* HW doesn't care about the data in the tag
+ * so in case pfrag has no room
+ * for a tag and we can't allocate a new pfrag
+ * just use the page in the first frag
+ * rather then write a complicated fall back code.
+ */
+ tag_pfrag = &fallback_frag;
+ tag_pfrag->page = skb_frag_page(frag);
+ tag_pfrag->offset = 0;
+ }
+
+ tls_append_frag(record, tag_pfrag, ctx->tag_size);
+ record->end_seq = tp->write_seq + record->len;
+ spin_lock_irq(&offload_ctx->lock);
+ list_add_tail(&record->list, &offload_ctx->records_list);
+ spin_unlock_irq(&offload_ctx->lock);
+ offload_ctx->open_record = NULL;
+ set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
+ tls_advance_record_sn(sk, ctx);
+
+ for (i = 0; i < record->num_frags; i++) {
+ frag = &record->frags[i];
+ sg_unmark_end(&offload_ctx->sg_tx_data[i]);
+ sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
+ frag->size, frag->page_offset);
+ sk_mem_charge(sk, frag->size);
+ get_page(skb_frag_page(frag));
+ }
+ sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
+
+ /* all ready, send */
+ return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
+}
+
+static inline int tls_create_new_record(
+ struct tls_offload_context *offload_ctx,
+ struct page_frag *pfrag,
+ size_t prepend_size)
+{
+ skb_frag_t *frag;
+ struct tls_record_info *record;
+
+ record = kmalloc(sizeof(*record), GFP_KERNEL);
+ if (!record)
+ return -ENOMEM;
+
+ frag = &record->frags[0];
+ __skb_frag_set_page(frag, pfrag->page);
+ frag->page_offset = pfrag->offset;
+ skb_frag_size_set(frag, prepend_size);
+
+ get_page(pfrag->page);
+ pfrag->offset += prepend_size;
+
+ record->num_frags = 1;
+ record->len = prepend_size;
+ offload_ctx->open_record = record;
+ return 0;
+}
+
+static inline int tls_do_allocation(
+ struct sock *sk,
+ struct tls_offload_context *offload_ctx,
+ struct page_frag *pfrag,
+ size_t prepend_size)
+{
+ int ret;
+
+ if (!offload_ctx->open_record) {
+ if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
+ sk->sk_allocation))) {
+ sk->sk_prot->enter_memory_pressure(sk);
+ sk_stream_moderate_sndbuf(sk);
+ return -ENOMEM;
+ }
+
+ ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
+ if (ret)
+ return ret;
+
+ if (pfrag->size > pfrag->offset)
+ return 0;
+ }
+
+ if (!sk_page_frag_refill(sk, pfrag))
+ return -ENOMEM;
+
+ return 0;
+}
+
+static int tls_push_data(struct sock *sk,
+ struct iov_iter *msg_iter,
+ size_t size, int flags,
+ unsigned char record_type)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+ struct tls_record_info *record = ctx->open_record;
+ struct page_frag *pfrag;
+ int copy, rc = 0;
+ size_t orig_size = size;
+ u32 max_open_record_len;
+ long timeo;
+ int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
+ int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
+ bool done = false;
+
+ if (sk->sk_err)
+ return -sk->sk_err;
+
+ timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
+ rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
+ if (rc < 0)
+ return rc;
+
+ pfrag = sk_page_frag(sk);
+
+ /* KTLS_TLS_HEADER_SIZE is not counted as part of the TLS record, and
+ * we need to leave room for an authentication tag.
+ */
+ max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
+ tls_ctx->prepend_size;
+ do {
+ if (tls_do_allocation(sk, ctx, pfrag,
+ tls_ctx->prepend_size)) {
+ rc = sk_stream_wait_memory(sk, &timeo);
+ if (!rc)
+ continue;
+
+ record = ctx->open_record;
+ if (!record)
+ break;
+handle_error:
+ if (record_type != TLS_RECORD_TYPE_DATA) {
+ /* avoid sending partial
+ * record with type !=
+ * application_data
+ */
+ size = orig_size;
+ destroy_record(record);
+ ctx->open_record = NULL;
+ } else if (record->len > tls_ctx->prepend_size) {
+ goto last_record;
+ }
+
+ break;
+ }
+
+ record = ctx->open_record;
+ copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
+ copy = min_t(size_t, copy, (max_open_record_len - record->len));
+
+ if (copy_from_iter_nocache(
+ page_address(pfrag->page) + pfrag->offset,
+ copy, msg_iter) != copy) {
+ rc = -EFAULT;
+ goto handle_error;
+ }
+ tls_append_frag(record, pfrag, copy);
+
+ size -= copy;
+ if (!size) {
+last_record:
+ tls_push_record_flags = flags;
+ if (more) {
+ tls_ctx->pending_open_record_frags =
+ record->num_frags;
+ break;
+ }
+
+ done = true;
+ }
+
+ if ((done) ||
+ (record->len >= max_open_record_len) ||
+ (record->num_frags >= MAX_SKB_FRAGS - 1)) {
+ rc = tls_push_record(sk,
+ tls_ctx,
+ ctx,
+ record,
+ pfrag,
+ tls_push_record_flags,
+ record_type);
+ if (rc < 0)
+ break;
+ }
+ } while (!done);
+
+ if (orig_size - size > 0)
+ rc = orig_size - size;
+
+ return rc;
+}
+
+int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+{
+ unsigned char record_type = TLS_RECORD_TYPE_DATA;
+ int rc = 0;
+
+ lock_sock(sk);
+
+ if (unlikely(msg->msg_controllen)) {
+ rc = tls_proccess_cmsg(sk, msg, &record_type);
+ if (rc)
+ goto out;
+ }
+
+ rc = tls_push_data(sk, &msg->msg_iter, size,
+ msg->msg_flags, record_type);
+
+out:
+ release_sock(sk);
+ return rc;
+}
+
+int tls_device_sendpage(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags)
+{
+ struct iov_iter msg_iter;
+ struct kvec iov;
+ char *kaddr = kmap(page);
+ int rc = 0;
+
+ if (flags & MSG_SENDPAGE_NOTLAST)
+ flags |= MSG_MORE;
+
+ lock_sock(sk);
+
+ if (flags & MSG_OOB) {
+ rc = -ENOTSUPP;
+ goto out;
+ }
+
+ iov.iov_base = kaddr + offset;
+ iov.iov_len = size;
+ iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size);
+ rc = tls_push_data(sk, &msg_iter, size,
+ flags, TLS_RECORD_TYPE_DATA);
+ kunmap(page);
+
+out:
+ release_sock(sk);
+ return rc;
+}
+
+struct tls_record_info *tls_get_record(struct tls_offload_context *context,
+ u32 seq, u64 *p_record_sn)
+{
+ struct tls_record_info *info;
+ u64 record_sn = context->hint_record_sn;
+
+ info = context->retransmit_hint;
+ if (!info ||
+ before(seq, info->end_seq - info->len)) {
+ /* if retransmit_hint is irrelevant start
+ * from the begging of the list
+ */
+ info = list_first_entry(&context->records_list,
+ struct tls_record_info, list);
+ record_sn = context->unacked_record_sn;
+ }
+
+ list_for_each_entry_from(info, &context->records_list, list) {
+ if (before(seq, info->end_seq)) {
+ if (!context->retransmit_hint ||
+ after(info->end_seq,
+ context->retransmit_hint->end_seq)) {
+ context->hint_record_sn = record_sn;
+ context->retransmit_hint = info;
+ }
+ *p_record_sn = record_sn;
+ return info;
+ }
+ record_sn++;
+ }
+
+ return NULL;
+}
+EXPORT_SYMBOL(tls_get_record);
+
+static int tls_device_push_pending_record(struct sock *sk, int flags)
+{
+ struct iov_iter msg_iter;
+
+ iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
+ return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
+}
+
+int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
+{
+ struct tls_crypto_info *crypto_info;
+ struct tls_offload_context *offload_ctx;
+ struct tls_record_info *start_marker_record;
+ u16 nonece_size, tag_size, iv_size, rec_seq_size;
+ char *iv, *rec_seq;
+ int rc;
+ struct net_device *netdev;
+ struct sk_buff *skb;
+
+ if (!ctx) {
+ rc = -EINVAL;
+ goto out;
+ }
+
+ if (ctx->priv_ctx) {
+ rc = -EEXIST;
+ goto out;
+ }
+
+ netdev = get_netdev_for_sock(sk);
+ if (!netdev) {
+ pr_err("%s: netdev not found\n", __func__);
+ rc = -EINVAL;
+ goto out;
+ }
+
+ if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
+ rc = -ENOTSUPP;
+ goto release_netdev;
+ }
+
+ crypto_info = &ctx->crypto_send;
+ switch (crypto_info->cipher_type) {
+ case TLS_CIPHER_AES_GCM_128: {
+ nonece_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
+ tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
+ iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
+ rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
+ rec_seq =
+ ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
+ break;
+ }
+ default:
+ rc = -EINVAL;
+ goto release_netdev;
+ }
+
+ start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
+ if (!start_marker_record) {
+ rc = -ENOMEM;
+ goto release_netdev;
+ }
+
+ rc = attach_sock_to_netdev(sk, netdev, ctx);
+ if (rc)
+ goto free_marker_record;
+
+ ctx->netdev = netdev;
+ ctx->sk = sk;
+
+ ctx->prepend_size = TLS_HEADER_SIZE + nonece_size;
+ ctx->tag_size = tag_size;
+ ctx->iv_size = iv_size;
+ ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+ GFP_KERNEL);
+ if (!ctx->iv) {
+ rc = -ENOMEM;
+ goto detach_sock;
+ }
+
+ memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, rec_seq, iv_size);
+
+ ctx->rec_seq_size = rec_seq_size;
+ ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
+ if (!ctx->rec_seq) {
+ rc = -ENOMEM;
+ goto err_iv;
+ }
+ memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
+
+ offload_ctx = ctx->priv_ctx;
+ memcpy(&offload_ctx->unacked_record_sn, rec_seq,
+ sizeof(offload_ctx->unacked_record_sn));
+
+ /* start at rec_seq -1 to account for the start marker record */
+ offload_ctx->unacked_record_sn =
+ be64_to_cpu(offload_ctx->unacked_record_sn) - 1;
+
+ rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
+ if (rc)
+ goto err_iv;
+
+ start_marker_record->end_seq = tcp_sk(sk)->write_seq;
+ start_marker_record->len = 0;
+ start_marker_record->num_frags = 0;
+
+ INIT_LIST_HEAD(&offload_ctx->records_list);
+ list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
+ spin_lock_init(&offload_ctx->lock);
+
+ inet_csk(sk)->icsk_clean_acked = &tls_icsk_clean_acked;
+ ctx->push_pending_record = tls_device_push_pending_record;
+ offload_ctx->sk_destruct = sk->sk_destruct;
+
+ /* TLS offload is greatly simplified if we don't send
+ * SKBs where only part of the payload needs to be encrypted.
+ * So mark the last skb in the write queue as end of record.
+ */
+ skb = tcp_write_queue_tail(sk);
+ if (skb)
+ TCP_SKB_CB(skb)->eor = 1;
+
+ /* following this assignment tls_is_sk_tx_device_offloaded
+ * will return true and the context might be accessed
+ * by the netdev's xmit function.
+ */
+ smp_store_release(&sk->sk_destruct,
+ &tls_device_sk_destruct);
+ goto release_netdev;
+
+err_iv:
+ kfree(ctx->iv);
+detach_sock:
+ detach_sock_from_netdev(sk, ctx);
+free_marker_record:
+ kfree(start_marker_record);
+release_netdev:
+ dev_put(netdev);
+out:
+ return rc;
+}
+
+void __exit tls_device_cleanup(void)
+{
+ flush_work(&tls_device_gc_work);
+}
new file mode 100644
@@ -0,0 +1,382 @@
+/* Copyright (c) 2016-2017, Mellanox Technologies All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or
+ * without modification, are permitted provided that the following
+ * conditions are met:
+ *
+ * - Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * - Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * - Neither the name of the Mellanox Technologies nor the
+ * names of its contributors may be used to endorse or promote
+ * products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
+ * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED.
+ * IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
+ * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE
+ */
+
+#include <net/tls.h>
+#include <crypto/aead.h>
+#include <crypto/scatterwalk.h>
+
+static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
+{
+ struct scatterlist *src = walk->sg;
+ int diff = walk->offset - src->offset;
+
+ sg_set_page(sg, sg_page(src),
+ src->length - diff, walk->offset);
+
+ scatterwalk_crypto_chain(sg, sg_next(src), 0, 2);
+}
+
+static int tls_enc_record(struct aead_request *aead_req,
+ struct crypto_aead *aead, char *aad, char *iv,
+ __be64 rcd_sn, struct scatter_walk *in,
+ struct scatter_walk *out, int *in_len)
+{
+ struct scatterlist sg_in[3];
+ struct scatterlist sg_out[3];
+ unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
+ u16 len;
+ int rc;
+
+ len = min_t(int, *in_len, ARRAY_SIZE(buf));
+
+ scatterwalk_copychunks(buf, in, len, 0);
+ scatterwalk_copychunks(buf, out, len, 1);
+
+ *in_len -= len;
+ if (!*in_len)
+ return 0;
+
+ scatterwalk_pagedone(in, 0, 1);
+ scatterwalk_pagedone(out, 1, 1);
+
+ len = buf[4] | (buf[3] << 8);
+ len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
+
+ tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
+ (char *)&rcd_sn, sizeof(rcd_sn), buf[0]);
+
+ memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
+ TLS_CIPHER_AES_GCM_128_IV_SIZE);
+
+ sg_init_table(sg_in, ARRAY_SIZE(sg_in));
+ sg_init_table(sg_out, ARRAY_SIZE(sg_out));
+ sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
+ sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
+ chain_to_walk(sg_in + 1, in);
+ chain_to_walk(sg_out + 1, out);
+
+ *in_len -= len;
+ if (*in_len < 0) {
+ *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ if (*in_len < 0)
+ /* the input buffer doesn't contain the entire record.
+ * trim len accordingly. The resulting authentication tag
+ * will contain garbage. but we don't care as we won't
+ * include any of it in the output skb
+ * Note that we assume the output buffer length
+ * is larger then input buffer length + tag size
+ */
+ len += *in_len;
+
+ *in_len = 0;
+ }
+
+ if (*in_len) {
+ scatterwalk_copychunks(NULL, in, len, 2);
+ scatterwalk_pagedone(in, 0, 1);
+ scatterwalk_copychunks(NULL, out, len, 2);
+ scatterwalk_pagedone(out, 1, 1);
+ }
+
+ len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
+
+ rc = crypto_aead_encrypt(aead_req);
+
+ return rc;
+}
+
+static void tls_init_aead_request(struct aead_request *aead_req,
+ struct crypto_aead *aead)
+{
+ aead_request_set_tfm(aead_req, aead);
+ aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+ /* Clear the CRYPTO_TFM_REQ_MAY_SLEEP flag to avoid
+ * "sleeping function called from invalid context " warning
+ */
+ //aead_request_set_callback(aead_req, 0, NULL, NULL);
+}
+
+static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead,
+ gfp_t flags)
+{
+ unsigned int req_size = sizeof(struct aead_request) +
+ crypto_aead_reqsize(aead);
+
+ struct aead_request* aead_req;
+
+ aead_req = kzalloc(req_size, flags);
+ if (!aead)
+ return NULL;
+
+ tls_init_aead_request(aead_req, aead);
+ return aead_req;
+}
+
+static int tls_enc_records(struct aead_request *aead_req,
+ struct crypto_aead *aead, struct scatterlist *sg_in,
+ struct scatterlist *sg_out, char *aad, char *iv,
+ u64 rcd_sn, int len)
+{
+ struct scatter_walk in;
+ struct scatter_walk out;
+ int rc;
+
+ scatterwalk_start(&in, sg_in);
+ scatterwalk_start(&out, sg_out);
+
+ do {
+ rc = tls_enc_record(aead_req, aead, aad, iv,
+ cpu_to_be64(rcd_sn), &in, &out, &len);
+ rcd_sn++;
+
+ } while (rc == 0 && len);
+
+ scatterwalk_done(&in, 0, 0);
+ scatterwalk_done(&out, 1, 0);
+
+ return rc;
+}
+
+static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
+{
+ skb_copy_header(nskb, skb);
+
+ skb_put(nskb, skb->len);
+ memcpy(nskb->data, skb->data, headln);
+
+ /* All TLS offload devices support CHECKSUM_PARTIAL
+ * and since the pseudo header didn't change
+ * we don't have to update the checksum
+ */
+ BUG_ON(skb->ip_summed != CHECKSUM_PARTIAL);
+
+ nskb->destructor = skb->destructor;
+ nskb->sk = skb->sk;
+ skb->destructor = NULL;
+ skb->sk = NULL;
+ refcount_add(nskb->truesize - skb->truesize,
+ &nskb->sk->sk_wmem_alloc);
+}
+/* This function may be called after the user socket is already
+ * closed so make sure we don't use anything freed during
+ * tls_sk_proto_close here
+ */
+struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+ struct tls_record_info *record;
+ u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
+ s32 sync_size;
+ int remaining;
+ unsigned long flags;
+ struct sk_buff *nskb = NULL;
+ int i = 0;
+ struct scatterlist sg_in[2 * (MAX_SKB_FRAGS + 1)];
+ struct scatterlist sg_out[3];
+ struct aead_request *aead_req;
+ int tcp_header_size = tcp_hdrlen(skb);
+ int tcp_payload_offset = skb_transport_offset(skb) + tcp_header_size;
+ void *buf, *dummy_buf, *iv, *aad;
+ int buf_len;
+ int resync_sgs;
+ int rc;
+ int payload_len = skb->len - tcp_payload_offset;
+ u64 rcd_sn;
+
+ if (!payload_len)
+ return skb;
+
+ sg_init_table(sg_in, ARRAY_SIZE(sg_in));
+ sg_init_table(sg_out, ARRAY_SIZE(sg_out));
+
+ spin_lock_irqsave(&ctx->lock, flags);
+ record = tls_get_record(ctx, tcp_seq, &rcd_sn);
+ if (!record) {
+ spin_unlock_irqrestore(&ctx->lock, flags);
+ WARN(1, "Record not found for seq %u\n", tcp_seq);
+ goto free_orig;
+ }
+
+ sync_size = tcp_seq - (record->end_seq - record->len);
+ if (sync_size < 0) {
+ spin_unlock_irqrestore(&ctx->lock, flags);
+ if (!tls_record_is_start_marker(record))
+ /* This should only occur if the relevant record was
+ * already acked. In that case it should be ok
+ * to drop the packet and avoid retransmission.
+ *
+ * There is a corner case where the packet contains
+ * both an acked and a non-acked record.
+ * We currently don't handle that case and rely
+ * on TCP to retranmit a packet that doesn't contain
+ * already acked payload.
+ */
+ goto free_orig;
+
+ if (payload_len > -sync_size) {
+ WARN(1, "Fallback of partially offloaded packets is not supported\n");
+ goto free_orig;
+ } else {
+ return skb;
+ }
+ }
+
+ remaining = sync_size;
+ while (remaining > 0) {
+ skb_frag_t *frag = &record->frags[i];
+
+ __skb_frag_ref(frag);
+ sg_set_page(sg_in + i, skb_frag_page(frag),
+ skb_frag_size(frag), frag->page_offset);
+
+ remaining -= skb_frag_size(frag);
+
+ if (remaining < 0)
+ sg_in[i].length += remaining;
+
+ i++;
+ }
+ spin_unlock_irqrestore(&ctx->lock, flags);
+ resync_sgs = i;
+
+ aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC);
+ if (!aead_req)
+ goto put_sg;
+
+ buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
+ TLS_CIPHER_AES_GCM_128_IV_SIZE +
+ TLS_AAD_SPACE_SIZE +
+ sync_size +
+ tls_ctx->tag_size;
+ buf = kmalloc(buf_len, GFP_ATOMIC);
+ if (!buf)
+ goto free_req;
+
+ nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
+ if (!nskb)
+ goto free_req;
+
+ skb_reserve(nskb, skb_headroom(skb));
+
+ iv = buf;
+
+ memcpy(iv, tls_ctx->crypto_send_aes_gcm_128.salt,
+ TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+ aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
+ TLS_CIPHER_AES_GCM_128_IV_SIZE;
+ dummy_buf = aad + TLS_AAD_SPACE_SIZE;
+
+ sg_set_buf(&sg_out[0], dummy_buf, sync_size);
+ sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset,
+ payload_len);
+ /* Add room for authentication tag produced by crypto */
+ dummy_buf += sync_size;
+ sg_set_buf(&sg_out[2], dummy_buf, tls_ctx->tag_size);
+ rc = skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset,
+ payload_len);
+ if (rc < 0)
+ goto free_nskb;
+
+ rc = tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv,
+ rcd_sn, sync_size + payload_len);
+ if (rc < 0)
+ goto free_nskb;
+
+ complete_skb(nskb, skb, tcp_payload_offset);
+
+free_buf:
+ kfree(buf);
+free_req:
+ kfree(aead_req);
+put_sg:
+ for (i = 0; i < resync_sgs; i++)
+ put_page(sg_page(&sg_in[i]));
+free_orig:
+ kfree_skb(skb);
+ return nskb;
+
+free_nskb:
+ kfree_skb(nskb);
+ nskb = NULL;
+ goto free_buf;
+}
+
+static struct sk_buff *
+tls_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb)
+{
+ if (dev == tls_get_ctx(sk)->netdev)
+ return skb;
+
+ return tls_sw_fallback(sk, skb);
+}
+
+int tls_sw_fallback_init(struct sock *sk,
+ struct tls_offload_context *offload_ctx,
+ struct tls_crypto_info *crypto_info)
+{
+ int rc;
+
+ offload_ctx->aead_send = crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
+ if (IS_ERR(offload_ctx->aead_send)) {
+ pr_err("crypto_alloc_aead failed\n");
+ rc = PTR_ERR(offload_ctx->aead_send);
+ offload_ctx->aead_send = NULL;
+ goto err_out;
+ }
+
+ rc = crypto_aead_setkey(
+ offload_ctx->aead_send,
+ ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key,
+ TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+ if (rc)
+ goto free_aead;
+
+ rc = crypto_aead_setauthsize(offload_ctx->aead_send,
+ TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+ if (rc)
+ goto free_aead;
+
+ sk->sk_offload_check = tls_validate_xmit;
+ /* After the next line tls_is_sk_tx_device_offloaded
+ * will return true and ndo_start_xmit might access the
+ * offload context
+ */
+ return 0;
+free_aead:
+ crypto_free_aead(offload_ctx->aead_send);
+err_out:
+ return rc;
+}
@@ -48,6 +48,9 @@
enum {
TLS_BASE_TX,
TLS_SW_TX,
+#ifdef CONFIG_TLS_DEVICE
+ TLS_HW_TX,
+#endif
TLS_NUM_CONFIG,
};
@@ -401,11 +404,19 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
goto out;
}
- /* currently SW is default, we will have ethtool in future */
- rc = tls_set_sw_offload(sk, ctx);
- tx_conf = TLS_SW_TX;
- if (rc)
- goto err_crypto_info;
+#ifdef CONFIG_TLS_DEVICE
+ rc = tls_set_device_offload(sk, ctx);
+ tx_conf = TLS_HW_TX;
+ if (rc) {
+#else
+ {
+#endif
+ /* if HW offload fails fallback to SW */
+ rc = tls_set_sw_offload(sk, ctx);
+ tx_conf = TLS_SW_TX;
+ if (rc)
+ goto err_crypto_info;
+ }
ctx->tx_conf = tx_conf;
update_sk_prot(sk, ctx);
@@ -487,6 +498,12 @@ static void build_protos(struct proto *prot, struct proto *base)
prot[TLS_SW_TX] = prot[TLS_BASE_TX];
prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg;
prot[TLS_SW_TX].sendpage = tls_sw_sendpage;
+
+#ifdef CONFIG_TLS_DEVICE
+ prot[TLS_HW_TX] = prot[TLS_SW_TX];
+ prot[TLS_HW_TX].sendmsg = tls_device_sendmsg;
+ prot[TLS_HW_TX].sendpage = tls_device_sendpage;
+#endif
}
static int __init tls_register(void)
@@ -501,6 +518,9 @@ static int __init tls_register(void)
static void __exit tls_unregister(void)
{
tcp_unregister_ulp(&tcp_tls_ulp_ops);
+#ifdef CONFIG_TLS_DEVICE
+ tls_device_cleanup();
+#endif
}
module_init(tls_register);