diff mbox series

[bpf,v2,1/3] bpf: tls, implement unhash to avoid transition out of ESTABLISHED

Message ID 155620818860.22884.14832636768748270693.stgit@john-XPS-13-9360
State Changes Requested
Delegated to: BPF Maintainers
Headers show
Series sockmap/ktls fixes | expand

Commit Message

John Fastabend April 25, 2019, 4:03 p.m. UTC
It is possible (via shutdown()) for TCP socks to go through TCP_CLOSE
state via tcp_disconnect() without calling into close callback. This
would allow a kTLS enabled socket to exist outside of ESTABLISHED
state which is not supported.

Solve this the same way we solved the sock{map|hash} case by adding
an unhash hook to remove tear down the TLS state.

In the process we also make the close hook more robust. We add a put
call into the close path, also in the unhash path, to remove the
reference to ulp data after free. Its no longer valid and may confuse
things later if the socket (re)enters kTLS code paths. Second we add
an 'if(ctx)' check to ensure the ctx is still valid and not released
from a previous unhash/close path.

Fixes: d91c3e17f75f2 ("net/tls: Only attach to sockets in ESTABLISHED state")
Reported-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: John Fastabend <john.fastabend@gmail.com>
---
 include/net/tls.h  |   14 ++++++++++++-
 net/tls/tls_main.c |   55 +++++++++++++++++++++++++++++++++++++++-------------
 net/tls/tls_sw.c   |   13 +++++++++---
 3 files changed, 63 insertions(+), 19 deletions(-)

Comments

Jakub Kicinski April 25, 2019, 7:29 p.m. UTC | #1
On Thu, 25 Apr 2019 09:03:08 -0700, John Fastabend wrote:
> +static void tls_sk_proto_unhash(struct sock *sk)
> +{
> +	struct tls_context *ctx = tls_get_ctx(sk);
> +	void (*sk_proto_unhash)(struct sock *sk);
> +	bool free_ctx;
> +
> +	if (!ctx)
> +		return sk->sk_prot->unhash(sk);
> +	sk_proto_unhash = ctx->sk_proto_unhash;
> +	free_ctx = tls_sk_proto_destroy(sk, ctx, false);
> +	tls_put_ctx(sk);

Oh, I think you can't put_ctx() unconditionally,
when free_ctx is false, tls_device_sk_destruct() 
needs it the ctx pointer.

I think this explains the offload crashing.

> +	if (sk_proto_unhash)
> +		sk_proto_unhash(sk);
> +	if (free_ctx)
> +		tls_ctx_free(ctx);
> +}
>  
> -skip_tx_cleanup:
> +static void tls_sk_proto_close(struct sock *sk, long timeout)
> +{
> +	void (*sk_proto_close)(struct sock *sk, long timeout);
> +	struct tls_context *ctx = tls_get_ctx(sk);
> +	bool free_ctx;
> +
> +	if (!ctx)
> +		return sk->sk_prot->destroy(sk);
> +
> +	lock_sock(sk);
> +	sk_proto_close = ctx->sk_proto_close;
> +	free_ctx = tls_sk_proto_destroy(sk, ctx, true);
> +	tls_put_ctx(sk);
John Fastabend April 25, 2019, 7:32 p.m. UTC | #2
On 4/25/19 12:29 PM, Jakub Kicinski wrote:
> On Thu, 25 Apr 2019 09:03:08 -0700, John Fastabend wrote:
>> +static void tls_sk_proto_unhash(struct sock *sk)
>> +{
>> +	struct tls_context *ctx = tls_get_ctx(sk);
>> +	void (*sk_proto_unhash)(struct sock *sk);
>> +	bool free_ctx;
>> +
>> +	if (!ctx)
>> +		return sk->sk_prot->unhash(sk);
>> +	sk_proto_unhash = ctx->sk_proto_unhash;
>> +	free_ctx = tls_sk_proto_destroy(sk, ctx, false);
>> +	tls_put_ctx(sk);
> 
> Oh, I think you can't put_ctx() unconditionally,
> when free_ctx is false, tls_device_sk_destruct() 
> needs it the ctx pointer.
> 
> I think this explains the offload crashing.
> 

ugh yeah. So we need to _not_ free it from tls_sk_proto_destroy
do the put_ctx and then finally free it. Otherwise we can't
restore the sk_proto fields. v3 on its way. Thanks.

>> +	if (sk_proto_unhash)
>> +		sk_proto_unhash(sk);
>> +	if (free_ctx)
>> +		tls_ctx_free(ctx);
>> +}
>>  
>> -skip_tx_cleanup:
>> +static void tls_sk_proto_close(struct sock *sk, long timeout)
>> +{
>> +	void (*sk_proto_close)(struct sock *sk, long timeout);
>> +	struct tls_context *ctx = tls_get_ctx(sk);
>> +	bool free_ctx;
>> +
>> +	if (!ctx)
>> +		return sk->sk_prot->destroy(sk);
>> +
>> +	lock_sock(sk);
>> +	sk_proto_close = ctx->sk_proto_close;
>> +	free_ctx = tls_sk_proto_destroy(sk, ctx, true);
>> +	tls_put_ctx(sk);
John Fastabend April 25, 2019, 7:35 p.m. UTC | #3
On 4/25/19 12:32 PM, John Fastabend wrote:
> On 4/25/19 12:29 PM, Jakub Kicinski wrote:
>> On Thu, 25 Apr 2019 09:03:08 -0700, John Fastabend wrote:
>>> +static void tls_sk_proto_unhash(struct sock *sk)
>>> +{
>>> +	struct tls_context *ctx = tls_get_ctx(sk);
>>> +	void (*sk_proto_unhash)(struct sock *sk);
>>> +	bool free_ctx;
>>> +
>>> +	if (!ctx)
>>> +		return sk->sk_prot->unhash(sk);
>>> +	sk_proto_unhash = ctx->sk_proto_unhash;
>>> +	free_ctx = tls_sk_proto_destroy(sk, ctx, false);
>>> +	tls_put_ctx(sk);
>>
>> Oh, I think you can't put_ctx() unconditionally,
>> when free_ctx is false, tls_device_sk_destruct() 
>> needs it the ctx pointer.
>>
>> I think this explains the offload crashing.
>>
> 
> ugh yeah. So we need to _not_ free it from tls_sk_proto_destroy
> do the put_ctx and then finally free it. Otherwise we can't
> restore the sk_proto fields. v3 on its way. Thanks.
> 

I'm going to throw that patch I sent earlier in this thread
on the series as well. Its the minimal set to get things working
again for me. Will follow up some selftests so we don't get
here again.

>>> +	if (sk_proto_unhash)
>>> +		sk_proto_unhash(sk);
>>> +	if (free_ctx)
>>> +		tls_ctx_free(ctx);
>>> +}
>>>  
>>> -skip_tx_cleanup:
>>> +static void tls_sk_proto_close(struct sock *sk, long timeout)
>>> +{
>>> +	void (*sk_proto_close)(struct sock *sk, long timeout);
>>> +	struct tls_context *ctx = tls_get_ctx(sk);
>>> +	bool free_ctx;
>>> +
>>> +	if (!ctx)
>>> +		return sk->sk_prot->destroy(sk);
>>> +
>>> +	lock_sock(sk);
>>> +	sk_proto_close = ctx->sk_proto_close;
>>> +	free_ctx = tls_sk_proto_destroy(sk, ctx, true);
>>> +	tls_put_ctx(sk);
>
Jakub Kicinski April 25, 2019, 7:41 p.m. UTC | #4
On Thu, 25 Apr 2019 12:35:58 -0700, John Fastabend wrote:
> On 4/25/19 12:32 PM, John Fastabend wrote:
> > On 4/25/19 12:29 PM, Jakub Kicinski wrote:  
> >> On Thu, 25 Apr 2019 09:03:08 -0700, John Fastabend wrote:  
> >>> +static void tls_sk_proto_unhash(struct sock *sk)
> >>> +{
> >>> +	struct tls_context *ctx = tls_get_ctx(sk);
> >>> +	void (*sk_proto_unhash)(struct sock *sk);
> >>> +	bool free_ctx;
> >>> +
> >>> +	if (!ctx)
> >>> +		return sk->sk_prot->unhash(sk);
> >>> +	sk_proto_unhash = ctx->sk_proto_unhash;
> >>> +	free_ctx = tls_sk_proto_destroy(sk, ctx, false);
> >>> +	tls_put_ctx(sk);  
> >>
> >> Oh, I think you can't put_ctx() unconditionally,
> >> when free_ctx is false, tls_device_sk_destruct() 
> >> needs it the ctx pointer.
> >>
> >> I think this explains the offload crashing.
> >>  
> > 
> > ugh yeah. So we need to _not_ free it from tls_sk_proto_destroy
> > do the put_ctx and then finally free it. Otherwise we can't
> > restore the sk_proto fields. v3 on its way. Thanks.
> >   
> 
> I'm going to throw that patch I sent earlier in this thread
> on the series as well. Its the minimal set to get things working
> again for me. Will follow up some selftests so we don't get
> here again.

SGTM, I've been racking my brain trying to come up with a good test for
the offload stuff, because it's really hard to test that without actual
HW.  I don't see any other way than adding full on per-packet crypto
logic into netdevsim or such :/  Trying to lie about having offloaded
the crypto breaks down in corner cases.

> >>> +	if (sk_proto_unhash)
> >>> +		sk_proto_unhash(sk);
> >>> +	if (free_ctx)
> >>> +		tls_ctx_free(ctx);
> >>> +}
> >>>  
> >>> -skip_tx_cleanup:
> >>> +static void tls_sk_proto_close(struct sock *sk, long timeout)
> >>> +{
> >>> +	void (*sk_proto_close)(struct sock *sk, long timeout);
> >>> +	struct tls_context *ctx = tls_get_ctx(sk);
> >>> +	bool free_ctx;
> >>> +
> >>> +	if (!ctx)
> >>> +		return sk->sk_prot->destroy(sk);
> >>> +
> >>> +	lock_sock(sk);
> >>> +	sk_proto_close = ctx->sk_proto_close;
> >>> +	free_ctx = tls_sk_proto_destroy(sk, ctx, true);
> >>> +	tls_put_ctx(sk);
diff mbox series

Patch

diff --git a/include/net/tls.h b/include/net/tls.h
index d9d0ac66f040..ae13ea19b375 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -266,6 +266,8 @@  struct tls_context {
 	void (*sk_write_space)(struct sock *sk);
 	void (*sk_destruct)(struct sock *sk);
 	void (*sk_proto_close)(struct sock *sk, long timeout);
+	void (*sk_proto_unhash)(struct sock *sk);
+	struct proto *sk_proto;
 
 	int  (*setsockopt)(struct sock *sk, int level,
 			   int optname, char __user *optval,
@@ -303,7 +305,7 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
 int tls_sw_sendpage(struct sock *sk, struct page *page,
 		    int offset, size_t size, int flags);
 void tls_sw_close(struct sock *sk, long timeout);
-void tls_sw_free_resources_tx(struct sock *sk);
+void tls_sw_free_resources_tx(struct sock *sk, bool locked);
 void tls_sw_free_resources_rx(struct sock *sk);
 void tls_sw_release_resources_rx(struct sock *sk);
 int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
@@ -504,6 +506,16 @@  static inline void xor_iv_with_seq(int version, char *iv, char *seq)
 	}
 }
 
+static inline void tls_put_ctx(struct sock *sk)
+{
+	struct inet_connection_sock *icsk = inet_csk(sk);
+	struct tls_context *ctx = icsk->icsk_ulp_data;
+
+	if (!ctx)
+		return;
+	sk->sk_prot = ctx->sk_proto;
+	icsk->icsk_ulp_data = NULL;
+}
 
 static inline struct tls_sw_context_rx *tls_sw_ctx_rx(
 		const struct tls_context *tls_ctx)
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 7e546b8ec000..54842d0ddbb5 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -261,23 +261,16 @@  static void tls_ctx_free(struct tls_context *ctx)
 	kfree(ctx);
 }
 
-static void tls_sk_proto_close(struct sock *sk, long timeout)
+static bool tls_sk_proto_destroy(struct sock *sk,
+				 struct tls_context *ctx, bool locked)
 {
-	struct tls_context *ctx = tls_get_ctx(sk);
 	long timeo = sock_sndtimeo(sk, 0);
-	void (*sk_proto_close)(struct sock *sk, long timeout);
-	bool free_ctx = false;
-
-	lock_sock(sk);
-	sk_proto_close = ctx->sk_proto_close;
 
 	if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD)
-		goto skip_tx_cleanup;
+		return false;
 
-	if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) {
-		free_ctx = true;
-		goto skip_tx_cleanup;
-	}
+	if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)
+		return true;
 
 	if (!tls_complete_pending_work(sk, ctx, 0, &timeo))
 		tls_handle_open_record(sk, 0);
@@ -286,7 +279,7 @@  static void tls_sk_proto_close(struct sock *sk, long timeout)
 	if (ctx->tx_conf == TLS_SW) {
 		kfree(ctx->tx.rec_seq);
 		kfree(ctx->tx.iv);
-		tls_sw_free_resources_tx(sk);
+		tls_sw_free_resources_tx(sk, locked);
 #ifdef CONFIG_TLS_DEVICE
 	} else if (ctx->tx_conf == TLS_HW) {
 		tls_device_free_resources_tx(sk);
@@ -310,8 +303,39 @@  static void tls_sk_proto_close(struct sock *sk, long timeout)
 		tls_ctx_free(ctx);
 		ctx = NULL;
 	}
+	return false;
+}
+
+static void tls_sk_proto_unhash(struct sock *sk)
+{
+	struct tls_context *ctx = tls_get_ctx(sk);
+	void (*sk_proto_unhash)(struct sock *sk);
+	bool free_ctx;
+
+	if (!ctx)
+		return sk->sk_prot->unhash(sk);
+	sk_proto_unhash = ctx->sk_proto_unhash;
+	free_ctx = tls_sk_proto_destroy(sk, ctx, false);
+	tls_put_ctx(sk);
+	if (sk_proto_unhash)
+		sk_proto_unhash(sk);
+	if (free_ctx)
+		tls_ctx_free(ctx);
+}
 
-skip_tx_cleanup:
+static void tls_sk_proto_close(struct sock *sk, long timeout)
+{
+	void (*sk_proto_close)(struct sock *sk, long timeout);
+	struct tls_context *ctx = tls_get_ctx(sk);
+	bool free_ctx;
+
+	if (!ctx)
+		return sk->sk_prot->destroy(sk);
+
+	lock_sock(sk);
+	sk_proto_close = ctx->sk_proto_close;
+	free_ctx = tls_sk_proto_destroy(sk, ctx, true);
+	tls_put_ctx(sk);
 	release_sock(sk);
 	sk_proto_close(sk, timeout);
 	/* free ctx for TLS_HW_RECORD, used by tcp_set_state
@@ -609,6 +633,8 @@  static struct tls_context *create_ctx(struct sock *sk)
 	ctx->setsockopt = sk->sk_prot->setsockopt;
 	ctx->getsockopt = sk->sk_prot->getsockopt;
 	ctx->sk_proto_close = sk->sk_prot->close;
+	ctx->sk_proto_unhash = sk->sk_prot->unhash;
+	ctx->sk_proto = sk->sk_prot;
 	return ctx;
 }
 
@@ -732,6 +758,7 @@  static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
 	prot[TLS_BASE][TLS_BASE].setsockopt	= tls_setsockopt;
 	prot[TLS_BASE][TLS_BASE].getsockopt	= tls_getsockopt;
 	prot[TLS_BASE][TLS_BASE].close		= tls_sk_proto_close;
+	prot[TLS_BASE][TLS_BASE].unhash		= tls_sk_proto_unhash;
 
 	prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
 	prot[TLS_SW][TLS_BASE].sendmsg		= tls_sw_sendmsg;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index f780b473827b..0577633c319b 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -2044,7 +2044,7 @@  static void tls_data_ready(struct sock *sk)
 	}
 }
 
-void tls_sw_free_resources_tx(struct sock *sk)
+void tls_sw_free_resources_tx(struct sock *sk, bool locked)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
@@ -2055,9 +2055,11 @@  void tls_sw_free_resources_tx(struct sock *sk)
 	if (atomic_read(&ctx->encrypt_pending))
 		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
 
-	release_sock(sk);
+	if (locked)
+		release_sock(sk);
 	cancel_delayed_work_sync(&ctx->tx_work.work);
-	lock_sock(sk);
+	if (locked)
+		lock_sock(sk);
 
 	/* Tx whatever records we can transmit and abandon the rest */
 	tls_tx_records(sk, -1);
@@ -2080,7 +2082,10 @@  void tls_sw_free_resources_tx(struct sock *sk)
 		kfree(rec);
 	}
 
-	crypto_free_aead(ctx->aead_send);
+	if (ctx->aead_send) {
+		crypto_free_aead(ctx->aead_send);
+		ctx->aead_send = NULL;
+	}
 	tls_free_open_rec(sk);
 
 	kfree(ctx);