diff mbox series

[net-next,2/6] tls: Move cipher info to a separate struct

Message ID 20180320175346.GA23821@davejwatson-mba.local
State Changes Requested, archived
Delegated to: David Miller
Headers show
Series TLS Rx | expand

Commit Message

Dave Watson March 20, 2018, 5:53 p.m. UTC
Separate tx crypto parameters to a separate cipher_context struct.
The same parameters will be used for rx using the same struct.

tls_advance_record_sn is modified to only take the cipher info.

Signed-off-by: Dave Watson <davejwatson@fb.com>
---
 include/net/tls.h  | 26 +++++++++++++-----------
 net/tls/tls_main.c |  8 ++++----
 net/tls/tls_sw.c   | 58 ++++++++++++++++++++++++++++--------------------------
 3 files changed, 49 insertions(+), 43 deletions(-)
diff mbox series

Patch

diff --git a/include/net/tls.h b/include/net/tls.h
index 4913430..019e52d 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -81,6 +81,16 @@  enum {
 	TLS_PENDING_CLOSED_RECORD
 };
 
+struct cipher_context {
+	u16 prepend_size;
+	u16 tag_size;
+	u16 overhead_size;
+	u16 iv_size;
+	char *iv;
+	u16 rec_seq_size;
+	char *rec_seq;
+};
+
 struct tls_context {
 	union {
 		struct tls_crypto_info crypto_send;
@@ -91,13 +101,7 @@  struct tls_context {
 
 	u8 tx_conf:2;
 
-	u16 prepend_size;
-	u16 tag_size;
-	u16 overhead_size;
-	u16 iv_size;
-	char *iv;
-	u16 rec_seq_size;
-	char *rec_seq;
+	struct cipher_context tx;
 
 	struct scatterlist *partially_sent_record;
 	u16 partially_sent_offset;
@@ -190,7 +194,7 @@  static inline bool tls_bigint_increment(unsigned char *seq, int len)
 }
 
 static inline void tls_advance_record_sn(struct sock *sk,
-					 struct tls_context *ctx)
+					 struct cipher_context *ctx)
 {
 	if (tls_bigint_increment(ctx->rec_seq, ctx->rec_seq_size))
 		tls_err_abort(sk);
@@ -203,9 +207,9 @@  static inline void tls_fill_prepend(struct tls_context *ctx,
 			     size_t plaintext_len,
 			     unsigned char record_type)
 {
-	size_t pkt_len, iv_size = ctx->iv_size;
+	size_t pkt_len, iv_size = ctx->tx.iv_size;
 
-	pkt_len = plaintext_len + iv_size + ctx->tag_size;
+	pkt_len = plaintext_len + iv_size + ctx->tx.tag_size;
 
 	/* we cover nonce explicit here as well, so buf should be of
 	 * size KTLS_DTLS_HEADER_SIZE + KTLS_DTLS_NONCE_EXPLICIT_SIZE
@@ -217,7 +221,7 @@  static inline void tls_fill_prepend(struct tls_context *ctx,
 	buf[3] = pkt_len >> 8;
 	buf[4] = pkt_len & 0xFF;
 	memcpy(buf + TLS_NONCE_OFFSET,
-	       ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv_size);
+	       ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv_size);
 }
 
 static inline void tls_make_aad(char *buf,
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index d824d54..c671560 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -259,8 +259,8 @@  static void tls_sk_proto_close(struct sock *sk, long timeout)
 		}
 	}
 
-	kfree(ctx->rec_seq);
-	kfree(ctx->iv);
+	kfree(ctx->tx.rec_seq);
+	kfree(ctx->tx.iv);
 
 	if (ctx->tx_conf == TLS_SW_TX)
 		tls_sw_free_tx_resources(sk);
@@ -319,9 +319,9 @@  static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
 		}
 		lock_sock(sk);
 		memcpy(crypto_info_aes_gcm_128->iv,
-		       ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+		       ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 		       TLS_CIPHER_AES_GCM_128_IV_SIZE);
-		memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->rec_seq,
+		memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->tx.rec_seq,
 		       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
 		release_sock(sk);
 		if (copy_to_user(optval,
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index d58f675..dd4441d 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -79,7 +79,7 @@  static void trim_both_sgl(struct sock *sk, int target_size)
 		target_size);
 
 	if (target_size > 0)
-		target_size += tls_ctx->overhead_size;
+		target_size += tls_ctx->tx.overhead_size;
 
 	trim_sg(sk, ctx->sg_encrypted_data,
 		&ctx->sg_encrypted_num_elem,
@@ -207,21 +207,21 @@  static int tls_do_encryption(struct tls_context *tls_ctx,
 	if (!aead_req)
 		return -ENOMEM;
 
-	ctx->sg_encrypted_data[0].offset += tls_ctx->prepend_size;
-	ctx->sg_encrypted_data[0].length -= tls_ctx->prepend_size;
+	ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
+	ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
 
 	aead_request_set_tfm(aead_req, ctx->aead_send);
 	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
 	aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
-			       data_len, tls_ctx->iv);
+			       data_len, tls_ctx->tx.iv);
 
 	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
 				  crypto_req_done, &ctx->async_wait);
 
 	rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
 
-	ctx->sg_encrypted_data[0].offset -= tls_ctx->prepend_size;
-	ctx->sg_encrypted_data[0].length += tls_ctx->prepend_size;
+	ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
+	ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
 
 	kfree(aead_req);
 	return rc;
@@ -238,7 +238,7 @@  static int tls_push_record(struct sock *sk, int flags,
 	sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
 
 	tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
-		     tls_ctx->rec_seq, tls_ctx->rec_seq_size,
+		     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
 		     record_type);
 
 	tls_fill_prepend(tls_ctx,
@@ -271,7 +271,7 @@  static int tls_push_record(struct sock *sk, int flags,
 	if (rc < 0 && rc != -EAGAIN)
 		tls_err_abort(sk);
 
-	tls_advance_record_sn(sk, tls_ctx);
+	tls_advance_record_sn(sk, &tls_ctx->tx);
 	return rc;
 }
 
@@ -412,7 +412,7 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 		}
 
 		required_size = ctx->sg_plaintext_size + try_to_copy +
-				tls_ctx->overhead_size;
+				tls_ctx->tx.overhead_size;
 
 		if (!sk_stream_memory_free(sk))
 			goto wait_for_sndbuf;
@@ -475,7 +475,7 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 				&ctx->sg_encrypted_num_elem,
 				&ctx->sg_encrypted_size,
 				ctx->sg_plaintext_size +
-				tls_ctx->overhead_size);
+				tls_ctx->tx.overhead_size);
 		}
 
 		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
@@ -567,7 +567,7 @@  int tls_sw_sendpage(struct sock *sk, struct page *page,
 			full_record = true;
 		}
 		required_size = ctx->sg_plaintext_size + copy +
-			      tls_ctx->overhead_size;
+			      tls_ctx->tx.overhead_size;
 
 		if (!sk_stream_memory_free(sk))
 			goto wait_for_sndbuf;
@@ -699,24 +699,26 @@  int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
 		goto free_priv;
 	}
 
-	ctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
-	ctx->tag_size = tag_size;
-	ctx->overhead_size = ctx->prepend_size + ctx->tag_size;
-	ctx->iv_size = iv_size;
-	ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, GFP_KERNEL);
-	if (!ctx->iv) {
+	ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size;
+	ctx->tx.tag_size = tag_size;
+	ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size;
+	ctx->tx.iv_size = iv_size;
+	ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+			     GFP_KERNEL);
+	if (!ctx->tx.iv) {
 		rc = -ENOMEM;
 		goto free_priv;
 	}
-	memcpy(ctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-	memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
-	ctx->rec_seq_size = rec_seq_size;
-	ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
-	if (!ctx->rec_seq) {
+	memcpy(ctx->tx.iv, gcm_128_info->salt,
+	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+	memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
+	ctx->tx.rec_seq_size = rec_seq_size;
+	ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
+	if (!ctx->tx.rec_seq) {
 		rc = -ENOMEM;
 		goto free_iv;
 	}
-	memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
+	memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size);
 
 	sg_init_table(sw_ctx->sg_encrypted_data,
 		      ARRAY_SIZE(sw_ctx->sg_encrypted_data));
@@ -752,7 +754,7 @@  int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
 	if (rc)
 		goto free_aead;
 
-	rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tag_size);
+	rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tx.tag_size);
 	if (!rc)
 		return 0;
 
@@ -760,11 +762,11 @@  int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
 	crypto_free_aead(sw_ctx->aead_send);
 	sw_ctx->aead_send = NULL;
 free_rec_seq:
-	kfree(ctx->rec_seq);
-	ctx->rec_seq = NULL;
+	kfree(ctx->tx.rec_seq);
+	ctx->tx.rec_seq = NULL;
 free_iv:
-	kfree(ctx->iv);
-	ctx->iv = NULL;
+	kfree(ctx->tx.iv);
+	ctx->tx.iv = NULL;
 free_priv:
 	kfree(ctx->priv_ctx);
 	ctx->priv_ctx = NULL;