diff mbox series

[bpf,v4,1/4] bpf: tls,implement unhash to avoid transition out of ESTABLISHED

Message ID 155746426913.20677.2783358822817593806.stgit@john-XPS-13-9360
State Changes Requested
Delegated to: BPF Maintainers
Headers show
Series sockmap/ktls fixes | expand

Commit Message

John Fastabend May 10, 2019, 4:57 a.m. UTC
It is possible (via shutdown()) for TCP socks to go through TCP_CLOSE
state via tcp_disconnect() without calling into close callback. This
would allow a kTLS enabled socket to exist outside of ESTABLISHED
state which is not supported.

Solve this the same way we solved the sock{map|hash} case by adding
an unhash hook to remove tear down the TLS state.

In the process we also make the close hook more robust. We add a put
call into the close path, also in the unhash path, to remove the
reference to ulp data after free. Its no longer valid and may confuse
things later if the socket (re)enters kTLS code paths. Second we add
an 'if(ctx)' check to ensure the ctx is still valid and not released
from a previous unhash/close path.

Fixes: d91c3e17f75f2 ("net/tls: Only attach to sockets in ESTABLISHED state")
Reported-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: John Fastabend <john.fastabend@gmail.com>
---
 include/net/tls.h    |   28 +++++++---
 net/tls/tls_device.c |   10 ++--
 net/tls/tls_main.c   |   82 ++++++++++++++++++++---------
 net/tls/tls_sw.c     |  140 +++++++++++++++++++++++++++++---------------------
 4 files changed, 161 insertions(+), 99 deletions(-)

Comments

Jakub Kicinski May 10, 2019, 4:53 p.m. UTC | #1
On Thu, 09 May 2019 21:57:49 -0700, John Fastabend wrote:
>  #ifdef CONFIG_TLS_DEVICE
>  	if (ctx->rx_conf == TLS_HW)
>  		tls_device_offload_cleanup_rx(sk);
> -
> -	if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
> -#else
> -	{
>  #endif
> +
> +	if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW)

Did you try to build without CONFIG_TLS_DEVICE?

I think someone spent too much time in Verilog land and decided 
it's a good idea to hide enum values under an ifdef:

$ git grep -C4 TLS_HW,
include/net/tls.h-enum {
include/net/tls.h-      TLS_BASE,
include/net/tls.h-      TLS_SW,
include/net/tls.h-#ifdef CONFIG_TLS_DEVICE
include/net/tls.h:      TLS_HW,
include/net/tls.h-#endif
include/net/tls.h-      TLS_HW_RECORD,
include/net/tls.h-      TLS_NUM_CONFIG,
include/net/tls.h-};

:(

> +		return true;
> +	return false;
> +}
Jakub Kicinski May 10, 2019, 5 p.m. UTC | #2
On Thu, 09 May 2019 21:57:49 -0700, John Fastabend wrote:
> @@ -2042,12 +2060,14 @@ void tls_sw_free_resources_tx(struct sock *sk)
>  	if (atomic_read(&ctx->encrypt_pending))
>  		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
>  
> -	release_sock(sk);
> +	if (locked)
> +		release_sock(sk);
>  	cancel_delayed_work_sync(&ctx->tx_work.work);

So in the splat I got (on a slightly hacked up kernel) it seemed like
unhash may be called in atomic context:

[  783.232150]  tls_sk_proto_unhash+0x72/0x110 [tls]
[  783.237497]  tcp_set_state+0x484/0x640
[  783.241776]  ? __sk_mem_reduce_allocated+0x72/0x4a0
[  783.247317]  ? tcp_recv_timestamp+0x5c0/0x5c0
[  783.252265]  ? tcp_write_queue_purge+0xa6a/0x1180
[  783.257614]  tcp_done+0xac/0x260
[  783.261309]  tcp_reset+0xbe/0x350
[  783.265101]  tcp_validate_incoming+0xd9d/0x1530

I may have been unclear off-list, I only tested the patch no longer
crashes the offload :(

> -	lock_sock(sk);
> +	if (locked)
> +		lock_sock(sk);
>  
>  	/* Tx whatever records we can transmit and abandon the rest */
> -	tls_tx_records(sk, -1);
> +	tls_tx_records(sk, tls_ctx, -1);
>  
>  	/* Free up un-sent records in tx_list. First, free
>  	 * the partially sent record if any at head of tx_list.
John Fastabend May 10, 2019, 11:03 p.m. UTC | #3
Jakub Kicinski wrote:
> On Thu, 09 May 2019 21:57:49 -0700, John Fastabend wrote:
> > @@ -2042,12 +2060,14 @@ void tls_sw_free_resources_tx(struct sock *sk)
> >  	if (atomic_read(&ctx->encrypt_pending))
> >  		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
> >  
> > -	release_sock(sk);
> > +	if (locked)
> > +		release_sock(sk);
> >  	cancel_delayed_work_sync(&ctx->tx_work.work);
> 
> So in the splat I got (on a slightly hacked up kernel) it seemed like
> unhash may be called in atomic context:
> 
> [  783.232150]  tls_sk_proto_unhash+0x72/0x110 [tls]
> [  783.237497]  tcp_set_state+0x484/0x640
> [  783.241776]  ? __sk_mem_reduce_allocated+0x72/0x4a0
> [  783.247317]  ? tcp_recv_timestamp+0x5c0/0x5c0
> [  783.252265]  ? tcp_write_queue_purge+0xa6a/0x1180
> [  783.257614]  tcp_done+0xac/0x260
> [  783.261309]  tcp_reset+0xbe/0x350
> [  783.265101]  tcp_validate_incoming+0xd9d/0x1530
> 
> I may have been unclear off-list, I only tested the patch no longer
> crashes the offload :(
> 

Yep, I misread and thought it was resolved here as well. OK I'll dig into
it. I'm not seeing it from selftests but I guess that means we are missing
a testcase. :( yet another version I guess.

Thanks,
John


> > -	lock_sock(sk);
> > +	if (locked)
> > +		lock_sock(sk);
> >  
> >  	/* Tx whatever records we can transmit and abandon the rest */
> > -	tls_tx_records(sk, -1);
> > +	tls_tx_records(sk, tls_ctx, -1);
> >  
> >  	/* Free up un-sent records in tx_list. First, free
> >  	 * the partially sent record if any at head of tx_list.
>
John Fastabend May 14, 2019, 10:34 p.m. UTC | #4
John Fastabend wrote:
> Jakub Kicinski wrote:
> > On Thu, 09 May 2019 21:57:49 -0700, John Fastabend wrote:
> > > @@ -2042,12 +2060,14 @@ void tls_sw_free_resources_tx(struct sock *sk)
> > >  	if (atomic_read(&ctx->encrypt_pending))
> > >  		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
> > >  
> > > -	release_sock(sk);
> > > +	if (locked)
> > > +		release_sock(sk);
> > >  	cancel_delayed_work_sync(&ctx->tx_work.work);
> > 
> > So in the splat I got (on a slightly hacked up kernel) it seemed like
> > unhash may be called in atomic context:
> > 
> > [  783.232150]  tls_sk_proto_unhash+0x72/0x110 [tls]
> > [  783.237497]  tcp_set_state+0x484/0x640
> > [  783.241776]  ? __sk_mem_reduce_allocated+0x72/0x4a0
> > [  783.247317]  ? tcp_recv_timestamp+0x5c0/0x5c0
> > [  783.252265]  ? tcp_write_queue_purge+0xa6a/0x1180
> > [  783.257614]  tcp_done+0xac/0x260
> > [  783.261309]  tcp_reset+0xbe/0x350
> > [  783.265101]  tcp_validate_incoming+0xd9d/0x1530
> > 
> > I may have been unclear off-list, I only tested the patch no longer
> > crashes the offload :(
> > 
> 
> Yep, I misread and thought it was resolved here as well. OK I'll dig into
> it. I'm not seeing it from selftests but I guess that means we are missing
> a testcase. :( yet another version I guess.
> 

Seems we need to call release_sock in the unhash case as well. Will
send a new patch shortly.

.John

> Thanks,
> John
>
Jakub Kicinski May 14, 2019, 10:58 p.m. UTC | #5
On Tue, 14 May 2019 15:34:55 -0700, John Fastabend wrote:
> John Fastabend wrote:
> > Jakub Kicinski wrote:  
> > > On Thu, 09 May 2019 21:57:49 -0700, John Fastabend wrote:  
> > > > @@ -2042,12 +2060,14 @@ void tls_sw_free_resources_tx(struct sock *sk)
> > > >  	if (atomic_read(&ctx->encrypt_pending))
> > > >  		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
> > > >  
> > > > -	release_sock(sk);
> > > > +	if (locked)
> > > > +		release_sock(sk);
> > > >  	cancel_delayed_work_sync(&ctx->tx_work.work);  
> > > 
> > > So in the splat I got (on a slightly hacked up kernel) it seemed like
> > > unhash may be called in atomic context:
> > > 
> > > [  783.232150]  tls_sk_proto_unhash+0x72/0x110 [tls]
> > > [  783.237497]  tcp_set_state+0x484/0x640
> > > [  783.241776]  ? __sk_mem_reduce_allocated+0x72/0x4a0
> > > [  783.247317]  ? tcp_recv_timestamp+0x5c0/0x5c0
> > > [  783.252265]  ? tcp_write_queue_purge+0xa6a/0x1180
> > > [  783.257614]  tcp_done+0xac/0x260
> > > [  783.261309]  tcp_reset+0xbe/0x350
> > > [  783.265101]  tcp_validate_incoming+0xd9d/0x1530
> > > 
> > > I may have been unclear off-list, I only tested the patch no longer
> > > crashes the offload :(
> > >   
> > 
> > Yep, I misread and thought it was resolved here as well. OK I'll dig into
> > it. I'm not seeing it from selftests but I guess that means we are missing
> > a testcase. :( yet another version I guess.
> >   
> 
> Seems we need to call release_sock in the unhash case as well. Will
> send a new patch shortly.

My reading of the stack trace was that unhash gets called from
tcp_reset(), IOW from soft IRQ, so we can't cancel_delayed_work_sync()
in tls_sw_free_resources_tx(), no?
John Fastabend May 15, 2019, 4:17 a.m. UTC | #6
Jakub Kicinski wrote:
> On Tue, 14 May 2019 15:34:55 -0700, John Fastabend wrote:
> > John Fastabend wrote:
> > > Jakub Kicinski wrote:  
> > > > On Thu, 09 May 2019 21:57:49 -0700, John Fastabend wrote:  
> > > > > @@ -2042,12 +2060,14 @@ void tls_sw_free_resources_tx(struct sock *sk)
> > > > >  	if (atomic_read(&ctx->encrypt_pending))
> > > > >  		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
> > > > >  
> > > > > -	release_sock(sk);
> > > > > +	if (locked)
> > > > > +		release_sock(sk);
> > > > >  	cancel_delayed_work_sync(&ctx->tx_work.work);  
> > > > 
> > > > So in the splat I got (on a slightly hacked up kernel) it seemed like
> > > > unhash may be called in atomic context:
> > > > 
> > > > [  783.232150]  tls_sk_proto_unhash+0x72/0x110 [tls]
> > > > [  783.237497]  tcp_set_state+0x484/0x640
> > > > [  783.241776]  ? __sk_mem_reduce_allocated+0x72/0x4a0
> > > > [  783.247317]  ? tcp_recv_timestamp+0x5c0/0x5c0
> > > > [  783.252265]  ? tcp_write_queue_purge+0xa6a/0x1180
> > > > [  783.257614]  tcp_done+0xac/0x260
> > > > [  783.261309]  tcp_reset+0xbe/0x350
> > > > [  783.265101]  tcp_validate_incoming+0xd9d/0x1530
> > > > 
> > > > I may have been unclear off-list, I only tested the patch no longer
> > > > crashes the offload :(
> > > >   
> > > 
> > > Yep, I misread and thought it was resolved here as well. OK I'll dig into
> > > it. I'm not seeing it from selftests but I guess that means we are missing
> > > a testcase. :( yet another version I guess.
> > >   
> > 
> > Seems we need to call release_sock in the unhash case as well. Will
> > send a new patch shortly.
> 
> My reading of the stack trace was that unhash gets called from
> tcp_reset(), IOW from soft IRQ, so we can't cancel_delayed_work_sync()
> in tls_sw_free_resources_tx(), no?

Well the tcp_close() path has the lock held and can also call unhash(). Anyways
this dropping the sock lock in the middle of the block seems a bit suspect
to me anyways. I think we can defer the free until after sock is released this
is how it was solved on sockmap side.
Jakub Kicinski May 22, 2019, 4:57 p.m. UTC | #7
On Thu, 09 May 2019 21:57:49 -0700, John Fastabend wrote:
> It is possible (via shutdown()) for TCP socks to go through TCP_CLOSE
> state via tcp_disconnect() without calling into close callback. This
> would allow a kTLS enabled socket to exist outside of ESTABLISHED
> state which is not supported.
> 
> Solve this the same way we solved the sock{map|hash} case by adding
> an unhash hook to remove tear down the TLS state.
> 
> In the process we also make the close hook more robust. We add a put
> call into the close path, also in the unhash path, to remove the
> reference to ulp data after free. Its no longer valid and may confuse
> things later if the socket (re)enters kTLS code paths. Second we add
> an 'if(ctx)' check to ensure the ctx is still valid and not released
> from a previous unhash/close path.
> 
> Fixes: d91c3e17f75f2 ("net/tls: Only attach to sockets in ESTABLISHED state")
> Reported-by: Eric Dumazet <edumazet@google.com>
> Signed-off-by: John Fastabend <john.fastabend@gmail.com>

Looks like David Beckett managed to trigger another nasty on the
release path :/

    BUG: kernel NULL pointer dereference, address: 0000000000000012
    PGD 0 P4D 0
    Oops: 0000 [#1] SMP PTI
    CPU: 7 PID: 0 Comm: swapper/7 Not tainted
    5.2.0-rc1-00139-g14629453a6d3 #21 RIP: 0010:tcp_peek_len+0x10/0x60
    RSP: 0018:ffffc02e41c54b98 EFLAGS: 00010246
    RAX: 0000000000000000 RBX: ffff9cf924c4e030 RCX: 0000000000000051
    RDX: 0000000000000000 RSI: 000000000000000c RDI: ffff9cf97128f480
    RBP: ffff9cf9365e0300 R08: ffff9cf94fe7d2c0 R09: 0000000000000000
    R10: 000000000000036b R11: ffff9cf939735e00 R12: ffff9cf91ad9ae40
    R13: ffff9cf924c4e000 R14: ffff9cf9a8fcbaae R15: 0000000000000020
    FS: 0000000000000000(0000) GS:ffff9cf9af7c0000(0000)
    knlGS:0000000000000000 CS: 0010 DS: 0000 ES: 0000 CR0:
    0000000080050033 CR2: 0000000000000012 CR3: 000000013920a003 CR4:
    00000000003606e0 DR0: 0000000000000000 DR1: 0000000000000000 DR2:
    0000000000000000 DR3: 0000000000000000 DR6: 00000000fffe0ff0 DR7:
    0000000000000400 Call Trace:
     <IRQ>
     strp_data_ready+0x48/0x90
     tls_data_ready+0x22/0xd0 [tls]
     tcp_rcv_established+0x569/0x620
     tcp_v4_do_rcv+0x127/0x1e0
     tcp_v4_rcv+0xad7/0xbf0
     ip_protocol_deliver_rcu+0x2c/0x1c0
     ip_local_deliver_finish+0x41/0x50
     ip_local_deliver+0x6b/0xe0
     ? ip_protocol_deliver_rcu+0x1c0/0x1c0
     ip_rcv+0x52/0xd0
     ? ip_rcv_finish_core.isra.20+0x380/0x380
     __netif_receive_skb_one_core+0x7e/0x90
     netif_receive_skb_internal+0x42/0xf0
     napi_gro_receive+0xed/0x150
     nfp_net_poll+0x7a2/0xd30 [nfp]
     ? kmem_cache_free_bulk+0x286/0x310
     net_rx_action+0x149/0x3b0
     __do_softirq+0xe3/0x30a
     ? handle_irq_event_percpu+0x6a/0x80
     irq_exit+0xe8/0xf0
     do_IRQ+0x85/0xd0
     common_interrupt+0xf/0xf
     </IRQ>
    RIP: 0010:cpuidle_enter_state+0xbc/0x450

If I read this right strparser calls sock->ops->peek_len(sock), but the
sock->sk is already NULL.  I'm guess this is because inet_release()
does:

		sock->sk = NULL;
		sk->sk_prot->close(sk, timeout);

And I don't really see a way for ktls to know that sock->sk is about to
be cleared, and therefore no way to stop strparser.  Or for strparser
to always do the check, given tcp_peek_len() will do another dereference
of sock->sk :S

That's mostly a guess, it takes me half an hour of ktls connections
running to repro.

Any advice would be appreciated..  Can we move the sock->sk assignment
after close?..

diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 5183a2daba64..aff93e7cdb31 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -428,8 +428,8 @@ int inet_release(struct socket *sock)
                if (sock_flag(sk, SOCK_LINGER) &&
                    !(current->flags & PF_EXITING))
                        timeout = sk->sk_lingertime;
-               sock->sk = NULL;
                sk->sk_prot->close(sk, timeout);
+               sock->sk = NULL;
        }
        return 0;
 }

I don't see IPv6 clearing this pointer, perhaps we don't have to?
We tested it and it seems to works, but this is pre-git code, so
it's hard to tell what the reason to clear was :)
John Fastabend May 22, 2019, 9:57 p.m. UTC | #8
Jakub Kicinski wrote:
> On Thu, 09 May 2019 21:57:49 -0700, John Fastabend wrote:

[...]

> 
> Looks like David Beckett managed to trigger another nasty on the
> release path :/
> 
>     BUG: kernel NULL pointer dereference, address: 0000000000000012
>     PGD 0 P4D 0
>     Oops: 0000 [#1] SMP PTI
>     CPU: 7 PID: 0 Comm: swapper/7 Not tainted
>     5.2.0-rc1-00139-g14629453a6d3 #21 RIP: 0010:tcp_peek_len+0x10/0x60
>     RSP: 0018:ffffc02e41c54b98 EFLAGS: 00010246
>     RAX: 0000000000000000 RBX: ffff9cf924c4e030 RCX: 0000000000000051
>     RDX: 0000000000000000 RSI: 000000000000000c RDI: ffff9cf97128f480
>     RBP: ffff9cf9365e0300 R08: ffff9cf94fe7d2c0 R09: 0000000000000000
>     R10: 000000000000036b R11: ffff9cf939735e00 R12: ffff9cf91ad9ae40
>     R13: ffff9cf924c4e000 R14: ffff9cf9a8fcbaae R15: 0000000000000020
>     FS: 0000000000000000(0000) GS:ffff9cf9af7c0000(0000)
>     knlGS:0000000000000000 CS: 0010 DS: 0000 ES: 0000 CR0:
>     0000000080050033 CR2: 0000000000000012 CR3: 000000013920a003 CR4:
>     00000000003606e0 DR0: 0000000000000000 DR1: 0000000000000000 DR2:
>     0000000000000000 DR3: 0000000000000000 DR6: 00000000fffe0ff0 DR7:
>     0000000000000400 Call Trace:
>      <IRQ>
>      strp_data_ready+0x48/0x90
>      tls_data_ready+0x22/0xd0 [tls]
>      tcp_rcv_established+0x569/0x620
>      tcp_v4_do_rcv+0x127/0x1e0
>      tcp_v4_rcv+0xad7/0xbf0
>      ip_protocol_deliver_rcu+0x2c/0x1c0
>      ip_local_deliver_finish+0x41/0x50
>      ip_local_deliver+0x6b/0xe0
>      ? ip_protocol_deliver_rcu+0x1c0/0x1c0
>      ip_rcv+0x52/0xd0
>      ? ip_rcv_finish_core.isra.20+0x380/0x380
>      __netif_receive_skb_one_core+0x7e/0x90
>      netif_receive_skb_internal+0x42/0xf0
>      napi_gro_receive+0xed/0x150
>      nfp_net_poll+0x7a2/0xd30 [nfp]
>      ? kmem_cache_free_bulk+0x286/0x310
>      net_rx_action+0x149/0x3b0
>      __do_softirq+0xe3/0x30a
>      ? handle_irq_event_percpu+0x6a/0x80
>      irq_exit+0xe8/0xf0
>      do_IRQ+0x85/0xd0
>      common_interrupt+0xf/0xf
>      </IRQ>
>     RIP: 0010:cpuidle_enter_state+0xbc/0x450
> 
> If I read this right strparser calls sock->ops->peek_len(sock), but the
> sock->sk is already NULL.  I'm guess this is because inet_release()
> does:
> 
> 		sock->sk = NULL;
> 		sk->sk_prot->close(sk, timeout);
> 
> And I don't really see a way for ktls to know that sock->sk is about to
> be cleared, and therefore no way to stop strparser.  Or for strparser
> to always do the check, given tcp_peek_len() will do another dereference
> of sock->sk :S
> 
> That's mostly a guess, it takes me half an hour of ktls connections
> running to repro.
> 
> Any advice would be appreciated..  Can we move the sock->sk assignment
> after close?..
> 
> diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
> index 5183a2daba64..aff93e7cdb31 100644
> --- a/net/ipv4/af_inet.c
> +++ b/net/ipv4/af_inet.c
> @@ -428,8 +428,8 @@ int inet_release(struct socket *sock)
>                 if (sock_flag(sk, SOCK_LINGER) &&
>                     !(current->flags & PF_EXITING))
>                         timeout = sk->sk_lingertime;
> -               sock->sk = NULL;
>                 sk->sk_prot->close(sk, timeout);
> +               sock->sk = NULL;
>         }
>         return 0;
>  }
> 
> I don't see IPv6 clearing this pointer, perhaps we don't have to?
> We tested it and it seems to works, but this is pre-git code, so
> it's hard to tell what the reason to clear was :)

How about making strp_peek_len tolerant of a null sock->sk?

diff --git a/net/strparser/strparser.c b/net/strparser/strparser.c
index e137698e8aef..79518f93d2d8 100644
--- a/net/strparser/strparser.c
+++ b/net/strparser/strparser.c
@@ -84,9 +84,10 @@ static void strp_parser_err(struct strparser *strp, int err,
 static inline int strp_peek_len(struct strparser *strp)
 {
        if (strp->sk) {
-               struct socket *sock = strp->sk->sk_socket;
+               struct socket *sock = READ_ONCE(strp->sk->sk_socket);
 
-               return sock->ops->peek_len(sock);
+               if (likely(sock))
+                       return sock->ops->peek_len(sock);
        }
Jakub Kicinski May 22, 2019, 10:15 p.m. UTC | #9
On Wed, 22 May 2019 14:57:33 -0700, John Fastabend wrote:
> Jakub Kicinski wrote:
> > On Thu, 09 May 2019 21:57:49 -0700, John Fastabend wrote:  
> 
> [...]
> 
> > 
> > Looks like David Beckett managed to trigger another nasty on the
> > release path :/
> > 
> >     BUG: kernel NULL pointer dereference, address: 0000000000000012
> >     PGD 0 P4D 0
> >     Oops: 0000 [#1] SMP PTI
> >     CPU: 7 PID: 0 Comm: swapper/7 Not tainted
> >     5.2.0-rc1-00139-g14629453a6d3 #21 RIP: 0010:tcp_peek_len+0x10/0x60
> >     RSP: 0018:ffffc02e41c54b98 EFLAGS: 00010246
> >     RAX: 0000000000000000 RBX: ffff9cf924c4e030 RCX: 0000000000000051
> >     RDX: 0000000000000000 RSI: 000000000000000c RDI: ffff9cf97128f480
> >     RBP: ffff9cf9365e0300 R08: ffff9cf94fe7d2c0 R09: 0000000000000000
> >     R10: 000000000000036b R11: ffff9cf939735e00 R12: ffff9cf91ad9ae40
> >     R13: ffff9cf924c4e000 R14: ffff9cf9a8fcbaae R15: 0000000000000020
> >     FS: 0000000000000000(0000) GS:ffff9cf9af7c0000(0000)
> >     knlGS:0000000000000000 CS: 0010 DS: 0000 ES: 0000 CR0:
> >     0000000080050033 CR2: 0000000000000012 CR3: 000000013920a003 CR4:
> >     00000000003606e0 DR0: 0000000000000000 DR1: 0000000000000000 DR2:
> >     0000000000000000 DR3: 0000000000000000 DR6: 00000000fffe0ff0 DR7:
> >     0000000000000400 Call Trace:
> >      <IRQ>
> >      strp_data_ready+0x48/0x90
> >      tls_data_ready+0x22/0xd0 [tls]
> >      tcp_rcv_established+0x569/0x620
> >      tcp_v4_do_rcv+0x127/0x1e0
> >      tcp_v4_rcv+0xad7/0xbf0
> >      ip_protocol_deliver_rcu+0x2c/0x1c0
> >      ip_local_deliver_finish+0x41/0x50
> >      ip_local_deliver+0x6b/0xe0
> >      ? ip_protocol_deliver_rcu+0x1c0/0x1c0
> >      ip_rcv+0x52/0xd0
> >      ? ip_rcv_finish_core.isra.20+0x380/0x380
> >      __netif_receive_skb_one_core+0x7e/0x90
> >      netif_receive_skb_internal+0x42/0xf0
> >      napi_gro_receive+0xed/0x150
> >      nfp_net_poll+0x7a2/0xd30 [nfp]
> >      ? kmem_cache_free_bulk+0x286/0x310
> >      net_rx_action+0x149/0x3b0
> >      __do_softirq+0xe3/0x30a
> >      ? handle_irq_event_percpu+0x6a/0x80
> >      irq_exit+0xe8/0xf0
> >      do_IRQ+0x85/0xd0
> >      common_interrupt+0xf/0xf
> >      </IRQ>
> >     RIP: 0010:cpuidle_enter_state+0xbc/0x450
> > 
> > If I read this right strparser calls sock->ops->peek_len(sock), but the
> > sock->sk is already NULL.  I'm guess this is because inet_release()
> > does:
> > 
> > 		sock->sk = NULL;
> > 		sk->sk_prot->close(sk, timeout);
> > 
> > And I don't really see a way for ktls to know that sock->sk is about to
> > be cleared, and therefore no way to stop strparser.  Or for strparser
> > to always do the check, given tcp_peek_len() will do another dereference
> > of sock->sk :S
> > 
> > That's mostly a guess, it takes me half an hour of ktls connections
> > running to repro.
> > 
> > Any advice would be appreciated..  Can we move the sock->sk assignment
> > after close?..
> > 
> > diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
> > index 5183a2daba64..aff93e7cdb31 100644
> > --- a/net/ipv4/af_inet.c
> > +++ b/net/ipv4/af_inet.c
> > @@ -428,8 +428,8 @@ int inet_release(struct socket *sock)
> >                 if (sock_flag(sk, SOCK_LINGER) &&
> >                     !(current->flags & PF_EXITING))
> >                         timeout = sk->sk_lingertime;
> > -               sock->sk = NULL;
> >                 sk->sk_prot->close(sk, timeout);
> > +               sock->sk = NULL;
> >         }
> >         return 0;
> >  }
> > 
> > I don't see IPv6 clearing this pointer, perhaps we don't have to?

Correction here, IPv6 just calls the IPv4 code, that's why IPv6 was
also fixed after my change.

> > We tested it and it seems to works, but this is pre-git code, so
> > it's hard to tell what the reason to clear was :)  
> 
> How about making strp_peek_len tolerant of a null sock->sk?
> 
> diff --git a/net/strparser/strparser.c b/net/strparser/strparser.c
> index e137698e8aef..79518f93d2d8 100644
> --- a/net/strparser/strparser.c
> +++ b/net/strparser/strparser.c
> @@ -84,9 +84,10 @@ static void strp_parser_err(struct strparser *strp, int err,
>  static inline int strp_peek_len(struct strparser *strp)
>  {
>         if (strp->sk) {
> -               struct socket *sock = strp->sk->sk_socket;
> +               struct socket *sock = READ_ONCE(strp->sk->sk_socket);
>  
> -               return sock->ops->peek_len(sock);
> +               if (likely(sock))
> +                       return sock->ops->peek_len(sock);
>         }

Mmm..  I'm not sure - sk->sk_socket doesn't get cleared AFAICT, 
the NULL deref is on sk_state of sock->sk so sock is non-NULL here,
then:

int tcp_peek_len(struct socket *sock)
{
	return tcp_inq(sock->sk);
}
EXPORT_SYMBOL(tcp_peek_len);

Will pass NULL to tcp_inq, which then does:

static inline int tcp_inq(struct sock *sk)
{
	struct tcp_sock *tp = tcp_sk(sk);
	int answ;

	if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) {
		answ = 0;

And sk->sk_state is what crashes the machine.
diff mbox series

Patch

diff --git a/include/net/tls.h b/include/net/tls.h
index 5934246b2c6f..05d8cd5a3297 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -250,11 +250,14 @@  struct tls_context {
 	bool in_tcp_sendpages;
 	bool pending_open_record_frags;
 
-	int (*push_pending_record)(struct sock *sk, int flags);
+	int (*push_pending_record)(struct sock *sk,
+				   struct tls_context *ctx, int flags);
 
 	void (*sk_write_space)(struct sock *sk);
 	void (*sk_destruct)(struct sock *sk);
 	void (*sk_proto_close)(struct sock *sk, long timeout);
+	void (*sk_proto_unhash)(struct sock *sk);
+	struct proto *sk_proto;
 
 	int  (*setsockopt)(struct sock *sk, int level,
 			   int optname, char __user *optval,
@@ -292,9 +295,10 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
 int tls_sw_sendpage(struct sock *sk, struct page *page,
 		    int offset, size_t size, int flags);
 void tls_sw_close(struct sock *sk, long timeout);
-void tls_sw_free_resources_tx(struct sock *sk);
-void tls_sw_free_resources_rx(struct sock *sk);
-void tls_sw_release_resources_rx(struct sock *sk);
+void tls_sw_free_resources_tx(struct sock *sk,
+			      struct tls_context *ctx, bool locked);
+void tls_sw_free_resources_rx(struct sock *sk, struct tls_context *ctx);
+void tls_sw_release_resources_rx(struct sock *sk, struct tls_context *ctx);
 int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 		   int nonblock, int flags, int *addr_len);
 bool tls_sw_stream_read(const struct sock *sk);
@@ -310,7 +314,7 @@  void tls_device_sk_destruct(struct sock *sk);
 void tls_device_free_resources_tx(struct sock *sk);
 void tls_device_init(void);
 void tls_device_cleanup(void);
-int tls_tx_records(struct sock *sk, int flags);
+int tls_tx_records(struct sock *sk, struct tls_context *ctx, int flags);
 
 struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
 				       u32 seq, u64 *p_record_sn);
@@ -416,12 +420,10 @@  static inline struct tls_context *tls_get_ctx(const struct sock *sk)
 }
 
 static inline void tls_advance_record_sn(struct sock *sk,
+					 struct tls_prot_info *prot,
 					 struct cipher_context *ctx,
 					 int version)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_prot_info *prot = &tls_ctx->prot_info;
-
 	if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size))
 		tls_err_abort(sk, EBADMSG);
 
@@ -493,6 +495,16 @@  static inline void xor_iv_with_seq(int version, char *iv, char *seq)
 	}
 }
 
+static inline void tls_put_ctx(struct sock *sk)
+{
+	struct inet_connection_sock *icsk = inet_csk(sk);
+	struct tls_context *ctx = icsk->icsk_ulp_data;
+
+	if (!ctx)
+		return;
+	sk->sk_prot = ctx->sk_proto;
+	icsk->icsk_ulp_data = NULL;
+}
 
 static inline struct tls_sw_context_rx *tls_sw_ctx_rx(
 		const struct tls_context *tls_ctx)
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index 14dedb24fa7b..d3cd887f0488 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -281,7 +281,8 @@  static int tls_push_record(struct sock *sk,
 	list_add_tail(&record->list, &offload_ctx->records_list);
 	spin_unlock_irq(&offload_ctx->lock);
 	offload_ctx->open_record = NULL;
-	tls_advance_record_sn(sk, &ctx->tx, ctx->crypto_send.info.version);
+	tls_advance_record_sn(sk, prot, &ctx->tx,
+			      ctx->crypto_send.info.version);
 
 	for (i = 0; i < record->num_frags; i++) {
 		frag = &record->frags[i];
@@ -548,7 +549,8 @@  struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
 }
 EXPORT_SYMBOL(tls_get_record);
 
-static int tls_device_push_pending_record(struct sock *sk, int flags)
+static int tls_device_push_pending_record(struct sock *sk,
+					  struct tls_context *ctx, int flags)
 {
 	struct iov_iter	msg_iter;
 
@@ -922,7 +924,7 @@  int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
 
 free_sw_resources:
 	up_read(&device_offload_lock);
-	tls_sw_free_resources_rx(sk);
+	tls_sw_free_resources_rx(sk, ctx);
 	down_read(&device_offload_lock);
 release_ctx:
 	ctx->priv_ctx_rx = NULL;
@@ -958,7 +960,7 @@  void tls_device_offload_cleanup_rx(struct sock *sk)
 	}
 out:
 	up_read(&device_offload_lock);
-	tls_sw_release_resources_rx(sk);
+	tls_sw_release_resources_rx(sk, tls_ctx);
 }
 
 static int tls_device_down(struct net_device *netdev)
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 478603f43964..7f7982361128 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -150,12 +150,11 @@  int tls_push_sg(struct sock *sk,
 	return 0;
 }
 
-static int tls_handle_open_record(struct sock *sk, int flags)
+static int tls_handle_open_record(struct sock *sk,
+				  struct tls_context *ctx, int flags)
 {
-	struct tls_context *ctx = tls_get_ctx(sk);
-
 	if (tls_is_pending_open_record(ctx))
-		return ctx->push_pending_record(sk, flags);
+		return ctx->push_pending_record(sk, ctx, flags);
 
 	return 0;
 }
@@ -163,6 +162,7 @@  static int tls_handle_open_record(struct sock *sk, int flags)
 int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
 		      unsigned char *record_type)
 {
+	struct tls_context *ctx;
 	struct cmsghdr *cmsg;
 	int rc = -EINVAL;
 
@@ -180,7 +180,11 @@  int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
 			if (msg->msg_flags & MSG_MORE)
 				return -EINVAL;
 
-			rc = tls_handle_open_record(sk, msg->msg_flags);
+			ctx = tls_get_ctx(sk);
+			if (unlikely(!ctx))
+				return -EBUSY;
+
+			rc = tls_handle_open_record(sk, ctx, msg->msg_flags);
 			if (rc)
 				return rc;
 
@@ -261,32 +265,28 @@  static void tls_ctx_free(struct tls_context *ctx)
 	kfree(ctx);
 }
 
-static void tls_sk_proto_close(struct sock *sk, long timeout)
+static bool tls_sk_proto_destroy(struct sock *sk,
+				 struct tls_context *ctx, bool locked)
 {
-	struct tls_context *ctx = tls_get_ctx(sk);
 	long timeo = sock_sndtimeo(sk, 0);
-	void (*sk_proto_close)(struct sock *sk, long timeout);
-	bool free_ctx = false;
-
-	lock_sock(sk);
-	sk_proto_close = ctx->sk_proto_close;
 
 	if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD)
-		goto skip_tx_cleanup;
+		return false;
 
-	if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) {
-		free_ctx = true;
-		goto skip_tx_cleanup;
-	}
+	if (ctx->tx_conf == TLS_SW && ctx->rx_conf == TLS_SW)
+		tls_put_ctx(sk);
+
+	if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)
+		return true;
 
 	if (!tls_complete_pending_work(sk, ctx, 0, &timeo))
-		tls_handle_open_record(sk, 0);
+		tls_handle_open_record(sk, ctx, 0);
 
 	/* We need these for tls_sw_fallback handling of other packets */
 	if (ctx->tx_conf == TLS_SW) {
 		kfree(ctx->tx.rec_seq);
 		kfree(ctx->tx.iv);
-		tls_sw_free_resources_tx(sk);
+		tls_sw_free_resources_tx(sk, ctx, locked);
 #ifdef CONFIG_TLS_DEVICE
 	} else if (ctx->tx_conf == TLS_HW) {
 		tls_device_free_resources_tx(sk);
@@ -294,21 +294,46 @@  static void tls_sk_proto_close(struct sock *sk, long timeout)
 	}
 
 	if (ctx->rx_conf == TLS_SW)
-		tls_sw_free_resources_rx(sk);
+		tls_sw_free_resources_rx(sk, ctx);
 
 #ifdef CONFIG_TLS_DEVICE
 	if (ctx->rx_conf == TLS_HW)
 		tls_device_offload_cleanup_rx(sk);
-
-	if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
-#else
-	{
 #endif
+
+	if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW)
+		return true;
+	return false;
+}
+
+static void tls_sk_proto_unhash(struct sock *sk)
+{
+	struct tls_context *ctx = tls_get_ctx(sk);
+	void (*sk_proto_unhash)(struct sock *sk);
+	bool free_ctx;
+
+	if (!ctx)
+		return sk->sk_prot->unhash(sk);
+	sk_proto_unhash = ctx->sk_proto_unhash;
+	free_ctx = tls_sk_proto_destroy(sk, ctx, false);
+	if (sk_proto_unhash)
+		sk_proto_unhash(sk);
+	if (free_ctx)
 		tls_ctx_free(ctx);
-		ctx = NULL;
-	}
+}
 
-skip_tx_cleanup:
+static void tls_sk_proto_close(struct sock *sk, long timeout)
+{
+	void (*sk_proto_close)(struct sock *sk, long timeout);
+	struct tls_context *ctx = tls_get_ctx(sk);
+	bool free_ctx;
+
+	if (!ctx)
+		return sk->sk_prot->destroy(sk);
+
+	lock_sock(sk);
+	sk_proto_close = ctx->sk_proto_close;
+	free_ctx = tls_sk_proto_destroy(sk, ctx, true);
 	release_sock(sk);
 	sk_proto_close(sk, timeout);
 	/* free ctx for TLS_HW_RECORD, used by tcp_set_state
@@ -601,6 +626,8 @@  static struct tls_context *create_ctx(struct sock *sk)
 	ctx->setsockopt = sk->sk_prot->setsockopt;
 	ctx->getsockopt = sk->sk_prot->getsockopt;
 	ctx->sk_proto_close = sk->sk_prot->close;
+	ctx->sk_proto_unhash = sk->sk_prot->unhash;
+	ctx->sk_proto = sk->sk_prot;
 	return ctx;
 }
 
@@ -738,6 +765,7 @@  static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
 	prot[TLS_SW][TLS_SW].recvmsg		= tls_sw_recvmsg;
 	prot[TLS_SW][TLS_SW].stream_memory_read	= tls_sw_stream_read;
 	prot[TLS_SW][TLS_SW].close		= tls_sk_proto_close;
+	prot[TLS_SW][TLS_SW].unhash		= tls_sk_proto_unhash;
 
 #ifdef CONFIG_TLS_DEVICE
 	prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 29d6af43dd24..2de433232b99 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -266,9 +266,9 @@  static void tls_trim_both_msgs(struct sock *sk, int target_size)
 	sk_msg_trim(sk, &rec->msg_encrypted, target_size);
 }
 
-static int tls_alloc_encrypted_msg(struct sock *sk, int len)
+static int tls_alloc_encrypted_msg(struct sock *sk,
+				   struct tls_context *tls_ctx,  int len)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct tls_rec *rec = ctx->open_rec;
 	struct sk_msg *msg_en = &rec->msg_encrypted;
@@ -300,9 +300,8 @@  static int tls_clone_plaintext_msg(struct sock *sk, int required)
 	return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
 }
 
-static struct tls_rec *tls_get_rec(struct sock *sk)
+static struct tls_rec *tls_get_rec(struct sock *sk, struct tls_context *tls_ctx)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct sk_msg *msg_pl, *msg_en;
@@ -339,9 +338,8 @@  static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
 	kfree(rec);
 }
 
-static void tls_free_open_rec(struct sock *sk)
+static void tls_free_open_rec(struct sock *sk, struct tls_context *tls_ctx)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct tls_rec *rec = ctx->open_rec;
 
@@ -351,9 +349,8 @@  static void tls_free_open_rec(struct sock *sk)
 	}
 }
 
-int tls_tx_records(struct sock *sk, int flags)
+int tls_tx_records(struct sock *sk, struct tls_context *tls_ctx, int flags)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct tls_rec *rec, *tmp;
 	struct sk_msg *msg_en;
@@ -519,12 +516,13 @@  static int tls_do_encryption(struct sock *sk,
 
 	/* Unhook the record from context if encryption is not failure */
 	ctx->open_rec = NULL;
-	tls_advance_record_sn(sk, &tls_ctx->tx, prot->version);
+	tls_advance_record_sn(sk, prot, &tls_ctx->tx, prot->version);
 	return rc;
 }
 
-static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
-				 struct tls_rec **to, struct sk_msg *msg_opl,
+static int tls_split_open_record(struct sock *sk, struct tls_context *tls_ctx,
+				 struct tls_rec *from, struct tls_rec **to,
+				 struct sk_msg *msg_opl,
 				 struct sk_msg *msg_oen, u32 split_point,
 				 u32 tx_overhead_size, u32 *orig_end)
 {
@@ -536,7 +534,7 @@  static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
 	struct tls_rec *new;
 	int ret;
 
-	new = tls_get_rec(sk);
+	new = tls_get_rec(sk, tls_ctx);
 	if (!new)
 		return -ENOMEM;
 	ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
@@ -641,10 +639,9 @@  static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
 	kfree(from);
 }
 
-static int tls_push_record(struct sock *sk, int flags,
-			   unsigned char record_type)
+static int tls_push_record(struct sock *sk, struct tls_context *tls_ctx,
+			   int flags, unsigned char record_type)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
@@ -663,7 +660,8 @@  static int tls_push_record(struct sock *sk, int flags,
 	split_point = msg_pl->apply_bytes;
 	split = split_point && split_point < msg_pl->sg.size;
 	if (split) {
-		rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
+		rc = tls_split_open_record(sk, tls_ctx, rec, &tmp,
+					   msg_pl, msg_en,
 					   split_point, prot->overhead_size,
 					   &orig_end);
 		if (rc < 0)
@@ -732,14 +730,14 @@  static int tls_push_record(struct sock *sk, int flags,
 		ctx->open_rec = tmp;
 	}
 
-	return tls_tx_records(sk, flags);
+	return tls_tx_records(sk, tls_ctx, flags);
 }
 
 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
+			       struct tls_context *tls_ctx,
 			       bool full_record, u8 record_type,
 			       size_t *copied, int flags)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct sk_msg msg_redir = { };
 	struct sk_psock *psock;
@@ -752,7 +750,7 @@  static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
 	policy = !(flags & MSG_SENDPAGE_NOPOLICY);
 	psock = sk_psock_get(sk);
 	if (!psock || !policy)
-		return tls_push_record(sk, flags, record_type);
+		return tls_push_record(sk, tls_ctx, flags, record_type);
 more_data:
 	enospc = sk_msg_full(msg);
 	if (psock->eval == __SK_NONE) {
@@ -775,10 +773,10 @@  static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
 
 	switch (psock->eval) {
 	case __SK_PASS:
-		err = tls_push_record(sk, flags, record_type);
+		err = tls_push_record(sk, tls_ctx, flags, record_type);
 		if (err < 0) {
 			*copied -= sk_msg_free(sk, msg);
-			tls_free_open_rec(sk);
+			tls_free_open_rec(sk, tls_ctx);
 			goto out_err;
 		}
 		break;
@@ -799,7 +797,7 @@  static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
 			msg->sg.size = 0;
 		}
 		if (msg->sg.size == 0)
-			tls_free_open_rec(sk);
+			tls_free_open_rec(sk, tls_ctx);
 		break;
 	case __SK_DROP:
 	default:
@@ -809,7 +807,7 @@  static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
 		else
 			msg->apply_bytes -= send;
 		if (msg->sg.size == 0)
-			tls_free_open_rec(sk);
+			tls_free_open_rec(sk, tls_ctx);
 		*copied -= (send + delta);
 		err = -EACCES;
 	}
@@ -838,14 +836,15 @@  static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
 	return err;
 }
 
-static int tls_sw_push_pending_record(struct sock *sk, int flags)
+static int tls_sw_push_pending_record(struct sock *sk,
+				      struct tls_context *tls_ctx, int flags)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	struct tls_rec *rec = ctx->open_rec;
 	struct sk_msg *msg_pl;
+	struct tls_rec *rec;
 	size_t copied;
 
+	rec = ctx->open_rec;
 	if (!rec)
 		return 0;
 
@@ -854,31 +853,39 @@  static int tls_sw_push_pending_record(struct sock *sk, int flags)
 	if (!copied)
 		return 0;
 
-	return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
+	return bpf_exec_tx_verdict(msg_pl, sk, tls_ctx, true,
+				   TLS_RECORD_TYPE_DATA,
 				   &copied, flags);
 }
 
 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 {
 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_prot_info *prot = &tls_ctx->prot_info;
-	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	bool async_capable = ctx->async_capable;
 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
 	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
 	bool eor = !(msg->msg_flags & MSG_MORE);
+	bool full_record, async_capable;
 	size_t try_to_copy, copied = 0;
 	struct sk_msg *msg_pl, *msg_en;
+	struct tls_sw_context_tx *ctx;
+	struct tls_context *tls_ctx;
+	struct tls_prot_info *prot;
 	struct tls_rec *rec;
 	int required_size;
 	int num_async = 0;
-	bool full_record;
 	int record_room;
 	int num_zc = 0;
 	int orig_size;
 	int ret = 0;
 
+	tls_ctx = tls_get_ctx(sk);
+	if (unlikely(!tls_ctx))
+		return -EBUSY;
+
+	prot = &tls_ctx->prot_info;
+	ctx = tls_sw_ctx_tx(tls_ctx);
+	async_capable = ctx->async_capable;
+
 	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
 		return -ENOTSUPP;
 
@@ -910,7 +917,7 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 		if (ctx->open_rec)
 			rec = ctx->open_rec;
 		else
-			rec = ctx->open_rec = tls_get_rec(sk);
+			rec = ctx->open_rec = tls_get_rec(sk, tls_ctx);
 		if (!rec) {
 			ret = -ENOMEM;
 			goto send_end;
@@ -935,7 +942,7 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 			goto wait_for_sndbuf;
 
 alloc_encrypted:
-		ret = tls_alloc_encrypted_msg(sk, required_size);
+		ret = tls_alloc_encrypted_msg(sk, tls_ctx, required_size);
 		if (ret) {
 			if (ret != -ENOSPC)
 				goto wait_for_memory;
@@ -962,7 +969,8 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 			copied += try_to_copy;
 
 			sk_msg_sg_copy_set(msg_pl, first);
-			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+			ret = bpf_exec_tx_verdict(msg_pl, sk, tls_ctx,
+						  full_record,
 						  record_type, &copied,
 						  msg->msg_flags);
 			if (ret) {
@@ -1015,7 +1023,8 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 		tls_ctx->pending_open_record_frags = true;
 		copied += try_to_copy;
 		if (full_record || eor) {
-			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+			ret = bpf_exec_tx_verdict(msg_pl, sk, tls_ctx,
+						  full_record,
 						  record_type, &copied,
 						  msg->msg_flags);
 			if (ret) {
@@ -1069,7 +1078,7 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 	/* Transmit if any encryptions have completed */
 	if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
 		cancel_delayed_work(&ctx->tx_work.work);
-		tls_tx_records(sk, msg->msg_flags);
+		tls_tx_records(sk, tls_ctx, msg->msg_flags);
 	}
 
 send_end:
@@ -1083,10 +1092,10 @@  static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
 			      int offset, size_t size, int flags)
 {
 	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
+	struct tls_sw_context_tx *ctx;
+	struct tls_context *tls_ctx;
+	struct tls_prot_info *prot;
 	struct sk_msg *msg_pl;
 	struct tls_rec *rec;
 	int num_async = 0;
@@ -1096,6 +1105,13 @@  static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
 	int ret = 0;
 	bool eor;
 
+	tls_ctx = tls_get_ctx(sk);
+	if (unlikely(!tls_ctx))
+		return -EBUSY;
+
+	ctx = tls_sw_ctx_tx(tls_ctx);
+	prot = &tls_ctx->prot_info;
+
 	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
 	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
 
@@ -1118,7 +1134,7 @@  static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
 		if (ctx->open_rec)
 			rec = ctx->open_rec;
 		else
-			rec = ctx->open_rec = tls_get_rec(sk);
+			rec = ctx->open_rec = tls_get_rec(sk, tls_ctx);
 		if (!rec) {
 			ret = -ENOMEM;
 			goto sendpage_end;
@@ -1140,7 +1156,7 @@  static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
 		if (!sk_stream_memory_free(sk))
 			goto wait_for_sndbuf;
 alloc_payload:
-		ret = tls_alloc_encrypted_msg(sk, required_size);
+		ret = tls_alloc_encrypted_msg(sk, tls_ctx, required_size);
 		if (ret) {
 			if (ret != -ENOSPC)
 				goto wait_for_memory;
@@ -1163,7 +1179,8 @@  static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
 		tls_ctx->pending_open_record_frags = true;
 		if (full_record || eor || sk_msg_full(msg_pl)) {
 			rec->inplace_crypto = 0;
-			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+			ret = bpf_exec_tx_verdict(msg_pl, sk, tls_ctx,
+						  full_record,
 						  record_type, &copied, flags);
 			if (ret) {
 				if (ret == -EINPROGRESS)
@@ -1194,7 +1211,7 @@  static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
 		/* Transmit if any encryptions have completed */
 		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
 			cancel_delayed_work(&ctx->tx_work.work);
-			tls_tx_records(sk, flags);
+			tls_tx_records(sk, tls_ctx, flags);
 		}
 	}
 sendpage_end:
@@ -1479,7 +1496,8 @@  static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
 					       async);
 			if (err < 0) {
 				if (err == -EINPROGRESS)
-					tls_advance_record_sn(sk, &tls_ctx->rx,
+					tls_advance_record_sn(sk, prot,
+							      &tls_ctx->rx,
 							      version);
 
 				return err;
@@ -1491,7 +1509,7 @@  static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
 		rxm->full_len -= padding_length(ctx, tls_ctx, skb);
 		rxm->offset += prot->prepend_size;
 		rxm->full_len -= prot->overhead_size;
-		tls_advance_record_sn(sk, &tls_ctx->rx, version);
+		tls_advance_record_sn(sk, prot, &tls_ctx->rx, version);
 		ctx->decrypted = true;
 		ctx->saved_data_ready(sk);
 	} else {
@@ -2031,9 +2049,9 @@  static void tls_data_ready(struct sock *sk)
 	}
 }
 
-void tls_sw_free_resources_tx(struct sock *sk)
+void tls_sw_free_resources_tx(struct sock *sk,
+			      struct tls_context *tls_ctx, bool locked)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	struct tls_rec *rec, *tmp;
 
@@ -2042,12 +2060,14 @@  void tls_sw_free_resources_tx(struct sock *sk)
 	if (atomic_read(&ctx->encrypt_pending))
 		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
 
-	release_sock(sk);
+	if (locked)
+		release_sock(sk);
 	cancel_delayed_work_sync(&ctx->tx_work.work);
-	lock_sock(sk);
+	if (locked)
+		lock_sock(sk);
 
 	/* Tx whatever records we can transmit and abandon the rest */
-	tls_tx_records(sk, -1);
+	tls_tx_records(sk, tls_ctx, -1);
 
 	/* Free up un-sent records in tx_list. First, free
 	 * the partially sent record if any at head of tx_list.
@@ -2067,15 +2087,17 @@  void tls_sw_free_resources_tx(struct sock *sk)
 		kfree(rec);
 	}
 
-	crypto_free_aead(ctx->aead_send);
-	tls_free_open_rec(sk);
+	if (ctx->aead_send) {
+		crypto_free_aead(ctx->aead_send);
+		ctx->aead_send = NULL;
+	}
+	tls_free_open_rec(sk, tls_ctx);
 
 	kfree(ctx);
 }
 
-void tls_sw_release_resources_rx(struct sock *sk)
+void tls_sw_release_resources_rx(struct sock *sk, struct tls_context *tls_ctx)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 
 	kfree(tls_ctx->rx.rec_seq);
@@ -2096,13 +2118,11 @@  void tls_sw_release_resources_rx(struct sock *sk)
 	}
 }
 
-void tls_sw_free_resources_rx(struct sock *sk)
+void tls_sw_free_resources_rx(struct sock *sk, struct tls_context *tls_ctx)
 {
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 
-	tls_sw_release_resources_rx(sk);
-
+	tls_sw_release_resources_rx(sk, tls_ctx);
 	kfree(ctx);
 }
 
@@ -2120,7 +2140,7 @@  static void tx_work_handler(struct work_struct *work)
 		return;
 
 	lock_sock(sk);
-	tls_tx_records(sk, -1);
+	tls_tx_records(sk, tls_ctx, -1);
 	release_sock(sk);
 }