diff mbox series

[bpf-next,4/8] bpf, sockmap: Don't let child socket inherit psock or its ops on copy

Message ID 20191123110751.6729-5-jakub@cloudflare.com
State Changes Requested
Delegated to: BPF Maintainers
Headers show
Series Extend SOCKMAP to store listening sockets | expand

Commit Message

Jakub Sitnicki Nov. 23, 2019, 11:07 a.m. UTC
Sockets cloned from the listening sockets that belongs to a SOCKMAP must
not inherit the psock state. Otherwise child sockets unintentionally share
the SOCKMAP entry with the listening socket, which would lead to
use-after-free bugs.

Restore the child socket psock state and its callbacks at the earliest
possible moment, that is right after the child socket gets created. This
ensures that neither children that get accept()'ed, nor those that are left
in accept queue and will get orphaned, don't inadvertently inherit parent's
psock.

Signed-off-by: Jakub Sitnicki <jakub@cloudflare.com>
---
 include/linux/skmsg.h | 17 +++++++++--
 net/ipv4/tcp_bpf.c    | 66 +++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 78 insertions(+), 5 deletions(-)

Comments

John Fastabend Nov. 24, 2019, 5:56 a.m. UTC | #1
Jakub Sitnicki wrote:
> Sockets cloned from the listening sockets that belongs to a SOCKMAP must
> not inherit the psock state. Otherwise child sockets unintentionally share
> the SOCKMAP entry with the listening socket, which would lead to
> use-after-free bugs.
> 
> Restore the child socket psock state and its callbacks at the earliest
> possible moment, that is right after the child socket gets created. This
> ensures that neither children that get accept()'ed, nor those that are left
> in accept queue and will get orphaned, don't inadvertently inherit parent's
> psock.
> 
> Signed-off-by: Jakub Sitnicki <jakub@cloudflare.com>
> ---

Acked-by: John Fastabend <john.fastabend@gmail.com>
Martin KaFai Lau Nov. 25, 2019, 10:38 p.m. UTC | #2
On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote:
[ ... ]

> @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk,
>  			sk->sk_prot = psock->sk_proto;
>  		psock->sk_proto = NULL;
>  	}
> +
> +	if (psock->icsk_af_ops) {
> +		icsk->icsk_af_ops = psock->icsk_af_ops;
> +		psock->icsk_af_ops = NULL;
> +	}
>  }

[ ... ]

> +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
> +					  struct sk_buff *skb,
> +					  struct request_sock *req,
> +					  struct dst_entry *dst,
> +					  struct request_sock *req_unhash,
> +					  bool *own_req)
> +{
> +	const struct inet_connection_sock_af_ops *ops;
> +	void (*write_space)(struct sock *sk);
> +	struct sk_psock *psock;
> +	struct proto *proto;
> +	struct sock *child;
> +
> +	rcu_read_lock();
> +	psock = sk_psock(sk);
> +	if (likely(psock)) {
> +		proto = psock->sk_proto;
> +		write_space = psock->saved_write_space;
> +		ops = psock->icsk_af_ops;
It is not immediately clear to me what ensure
ops is not NULL here.

It is likely I missed something.  A short comment would
be very useful here.

> +	} else {
> +		ops = inet_csk(sk)->icsk_af_ops;
> +	}
> +	rcu_read_unlock();
> +
> +	child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
> +
> +	/* Child must not inherit psock or its ops. */
> +	if (child && psock) {
> +		rcu_assign_sk_user_data(child, NULL);
> +		child->sk_prot = proto;
> +		child->sk_write_space = write_space;
> +
> +		/* v4-mapped sockets don't inherit parent ops. Don't restore. */
> +		if (inet_csk(child)->icsk_af_ops == inet_csk(sk)->icsk_af_ops)
> +			inet_csk(child)->icsk_af_ops = ops;
> +	}
> +	return child;
> +}
> +
>  enum {
>  	TCP_BPF_IPV4,
>  	TCP_BPF_IPV6,
> @@ -597,6 +642,7 @@ enum {
>  static struct proto *tcpv6_prot_saved __read_mostly;
>  static DEFINE_SPINLOCK(tcpv6_prot_lock);
>  static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
> +static struct inet_connection_sock_af_ops tcp_bpf_af_ops[TCP_BPF_NUM_PROTS];
>  
>  static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
>  				   struct proto *base)
> @@ -612,13 +658,23 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
>  	prot[TCP_BPF_TX].sendpage		= tcp_bpf_sendpage;
>  }
>  
> -static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
> +static void tcp_bpf_rebuild_af_ops(struct inet_connection_sock_af_ops *ops,
> +				   const struct inet_connection_sock_af_ops *base)
> +{
> +	*ops = *base;
> +	ops->syn_recv_sock = tcp_bpf_syn_recv_sock;
> +}
> +
> +static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops,
> +					   const struct inet_connection_sock_af_ops *af_ops)
>  {
>  	if (sk->sk_family == AF_INET6 &&
>  	    unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
>  		spin_lock_bh(&tcpv6_prot_lock);
>  		if (likely(ops != tcpv6_prot_saved)) {
>  			tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
> +			tcp_bpf_rebuild_af_ops(&tcp_bpf_af_ops[TCP_BPF_IPV6],
> +					       af_ops);
>  			smp_store_release(&tcpv6_prot_saved, ops);
>  		}
>  		spin_unlock_bh(&tcpv6_prot_lock);
> @@ -628,6 +684,8 @@ static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
>  static int __init tcp_bpf_v4_build_proto(void)
>  {
>  	tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
> +	tcp_bpf_rebuild_af_ops(&tcp_bpf_af_ops[TCP_BPF_IPV4], &ipv4_specific);
> +
>  	return 0;
>  }
>  core_initcall(tcp_bpf_v4_build_proto);
> @@ -637,7 +695,8 @@ static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
>  	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
>  	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
>  
> -	sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
> +	sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config],
> +			      &tcp_bpf_af_ops[family]);
>  }
>  
>  static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
> @@ -677,6 +736,7 @@ void tcp_bpf_reinit(struct sock *sk)
>  
>  int tcp_bpf_init(struct sock *sk)
>  {
> +	struct inet_connection_sock *icsk = inet_csk(sk);
>  	struct proto *ops = READ_ONCE(sk->sk_prot);
>  	struct sk_psock *psock;
>  
> @@ -689,7 +749,7 @@ int tcp_bpf_init(struct sock *sk)
>  		rcu_read_unlock();
>  		return -EINVAL;
>  	}
> -	tcp_bpf_check_v6_needs_rebuild(sk, ops);
> +	tcp_bpf_check_v6_needs_rebuild(sk, ops, icsk->icsk_af_ops);
>  	tcp_bpf_update_sk_prot(sk, psock);
>  	rcu_read_unlock();
>  	return 0;
> -- 
> 2.20.1
>
Jakub Sitnicki Nov. 26, 2019, 3:54 p.m. UTC | #3
On Mon, Nov 25, 2019 at 11:38 PM CET, Martin Lau wrote:
> On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote:
> [ ... ]
>
>> @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk,
>>  			sk->sk_prot = psock->sk_proto;
>>  		psock->sk_proto = NULL;
>>  	}
>> +
>> +	if (psock->icsk_af_ops) {
>> +		icsk->icsk_af_ops = psock->icsk_af_ops;
>> +		psock->icsk_af_ops = NULL;
>> +	}
>>  }
>
> [ ... ]
>
>> +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
>> +					  struct sk_buff *skb,
>> +					  struct request_sock *req,
>> +					  struct dst_entry *dst,
>> +					  struct request_sock *req_unhash,
>> +					  bool *own_req)
>> +{
>> +	const struct inet_connection_sock_af_ops *ops;
>> +	void (*write_space)(struct sock *sk);
>> +	struct sk_psock *psock;
>> +	struct proto *proto;
>> +	struct sock *child;
>> +
>> +	rcu_read_lock();
>> +	psock = sk_psock(sk);
>> +	if (likely(psock)) {
>> +		proto = psock->sk_proto;
>> +		write_space = psock->saved_write_space;
>> +		ops = psock->icsk_af_ops;
> It is not immediately clear to me what ensure
> ops is not NULL here.
>
> It is likely I missed something.  A short comment would
> be very useful here.

I can see the readability problem. Looking at it now, perhaps it should
be rewritten, to the same effect, as:

static struct sock *tcp_bpf_syn_recv_sock(...)
{
	const struct inet_connection_sock_af_ops *ops = NULL;
        ...

        rcu_read_lock();
	psock = sk_psock(sk);
	if (likely(psock)) {
		proto = psock->sk_proto;
		write_space = psock->saved_write_space;
		ops = psock->icsk_af_ops;
	}
	rcu_read_unlock();

        if (!ops)
		ops = inet_csk(sk)->icsk_af_ops;
        child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);

If psock->icsk_af_ops were NULL, it would mean we haven't initialized it
properly. To double check what happens here:

In sock_map_link we do a setup dance where we first create the psock and
later initialize the socket callbacks (tcp_bpf_init).

static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
			 struct sock *sk)
{
        ...
	if (psock) {
                ...
	} else {
		psock = sk_psock_init(sk, map->numa_node);
		if (!psock) {
			ret = -ENOMEM;
			goto out_progs;
		}
		sk_psock_is_new = true;
	}
        ...
        if (sk_psock_is_new) {
		ret = tcp_bpf_init(sk);
		if (ret < 0)
			goto out_drop;
	} else {
		tcp_bpf_reinit(sk);
	}

The "if (sk_psock_new)" branch triggers the call chain that leads to
saving & overriding socket callbacks.

tcp_bpf_init -> tcp_bpf_update_sk_prot -> sk_psock_update_proto

Among them, icsk_af_ops.

static inline void sk_psock_update_proto(...)
{
        ...
	psock->icsk_af_ops = icsk->icsk_af_ops;
	icsk->icsk_af_ops = af_ops;
}

Goes without saying that a comment is needed.

Thanks for the feedback,
Jakub
Martin KaFai Lau Nov. 26, 2019, 5:16 p.m. UTC | #4
On Tue, Nov 26, 2019 at 04:54:33PM +0100, Jakub Sitnicki wrote:
> On Mon, Nov 25, 2019 at 11:38 PM CET, Martin Lau wrote:
> > On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote:
> > [ ... ]
> >
> >> @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk,
> >>  			sk->sk_prot = psock->sk_proto;
> >>  		psock->sk_proto = NULL;
> >>  	}
> >> +
> >> +	if (psock->icsk_af_ops) {
> >> +		icsk->icsk_af_ops = psock->icsk_af_ops;
> >> +		psock->icsk_af_ops = NULL;
> >> +	}
> >>  }
> >
> > [ ... ]
> >
> >> +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
> >> +					  struct sk_buff *skb,
> >> +					  struct request_sock *req,
> >> +					  struct dst_entry *dst,
> >> +					  struct request_sock *req_unhash,
> >> +					  bool *own_req)
> >> +{
> >> +	const struct inet_connection_sock_af_ops *ops;
> >> +	void (*write_space)(struct sock *sk);
> >> +	struct sk_psock *psock;
> >> +	struct proto *proto;
> >> +	struct sock *child;
> >> +
> >> +	rcu_read_lock();
> >> +	psock = sk_psock(sk);
> >> +	if (likely(psock)) {
> >> +		proto = psock->sk_proto;
> >> +		write_space = psock->saved_write_space;
> >> +		ops = psock->icsk_af_ops;
> > It is not immediately clear to me what ensure
> > ops is not NULL here.
> >
> > It is likely I missed something.  A short comment would
> > be very useful here.
> 
> I can see the readability problem. Looking at it now, perhaps it should
> be rewritten, to the same effect, as:
> 
> static struct sock *tcp_bpf_syn_recv_sock(...)
> {
> 	const struct inet_connection_sock_af_ops *ops = NULL;
>         ...
> 
>         rcu_read_lock();
> 	psock = sk_psock(sk);
> 	if (likely(psock)) {
> 		proto = psock->sk_proto;
> 		write_space = psock->saved_write_space;
> 		ops = psock->icsk_af_ops;
> 	}
> 	rcu_read_unlock();
> 
>         if (!ops)
> 		ops = inet_csk(sk)->icsk_af_ops;
>         child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
> 
> If psock->icsk_af_ops were NULL, it would mean we haven't initialized it
> properly. To double check what happens here:
I did not mean the init path.  The init path is fine since it init
eveything on psock before publishing the sk to the sock_map.

I was thinking the delete path (e.g. sock_map_delete_elem).  It is not clear
to me what prevent the earlier pasted sk_psock_restore_proto() which sets
psock->icsk_af_ops to NULL from running in parallel with
tcp_bpf_syn_recv_sock()?  An explanation would be useful.

> 
> In sock_map_link we do a setup dance where we first create the psock and
> later initialize the socket callbacks (tcp_bpf_init).
> 
> static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
> 			 struct sock *sk)
> {
>         ...
> 	if (psock) {
>                 ...
> 	} else {
> 		psock = sk_psock_init(sk, map->numa_node);
> 		if (!psock) {
> 			ret = -ENOMEM;
> 			goto out_progs;
> 		}
> 		sk_psock_is_new = true;
> 	}
>         ...
>         if (sk_psock_is_new) {
> 		ret = tcp_bpf_init(sk);
> 		if (ret < 0)
> 			goto out_drop;
> 	} else {
> 		tcp_bpf_reinit(sk);
> 	}
> 
> The "if (sk_psock_new)" branch triggers the call chain that leads to
> saving & overriding socket callbacks.
> 
> tcp_bpf_init -> tcp_bpf_update_sk_prot -> sk_psock_update_proto
> 
> Among them, icsk_af_ops.
> 
> static inline void sk_psock_update_proto(...)
> {
>         ...
> 	psock->icsk_af_ops = icsk->icsk_af_ops;
> 	icsk->icsk_af_ops = af_ops;
> }
> 
> Goes without saying that a comment is needed.
> 
> Thanks for the feedback,
> Jakub
Jakub Sitnicki Nov. 26, 2019, 6:36 p.m. UTC | #5
On Tue, Nov 26, 2019 at 06:16 PM CET, Martin Lau wrote:
> On Tue, Nov 26, 2019 at 04:54:33PM +0100, Jakub Sitnicki wrote:
>> On Mon, Nov 25, 2019 at 11:38 PM CET, Martin Lau wrote:
>> > On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote:
>> > [ ... ]
>> >
>> >> @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk,
>> >>  			sk->sk_prot = psock->sk_proto;
>> >>  		psock->sk_proto = NULL;
>> >>  	}
>> >> +
>> >> +	if (psock->icsk_af_ops) {
>> >> +		icsk->icsk_af_ops = psock->icsk_af_ops;
>> >> +		psock->icsk_af_ops = NULL;
>> >> +	}
>> >>  }
>> >
>> > [ ... ]
>> >
>> >> +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
>> >> +					  struct sk_buff *skb,
>> >> +					  struct request_sock *req,
>> >> +					  struct dst_entry *dst,
>> >> +					  struct request_sock *req_unhash,
>> >> +					  bool *own_req)
>> >> +{
>> >> +	const struct inet_connection_sock_af_ops *ops;
>> >> +	void (*write_space)(struct sock *sk);
>> >> +	struct sk_psock *psock;
>> >> +	struct proto *proto;
>> >> +	struct sock *child;
>> >> +
>> >> +	rcu_read_lock();
>> >> +	psock = sk_psock(sk);
>> >> +	if (likely(psock)) {
>> >> +		proto = psock->sk_proto;
>> >> +		write_space = psock->saved_write_space;
>> >> +		ops = psock->icsk_af_ops;
>> > It is not immediately clear to me what ensure
>> > ops is not NULL here.
>> >
>> > It is likely I missed something.  A short comment would
>> > be very useful here.
>>
>> I can see the readability problem. Looking at it now, perhaps it should
>> be rewritten, to the same effect, as:
>>
>> static struct sock *tcp_bpf_syn_recv_sock(...)
>> {
>> 	const struct inet_connection_sock_af_ops *ops = NULL;
>>         ...
>>
>>         rcu_read_lock();
>> 	psock = sk_psock(sk);
>> 	if (likely(psock)) {
>> 		proto = psock->sk_proto;
>> 		write_space = psock->saved_write_space;
>> 		ops = psock->icsk_af_ops;
>> 	}
>> 	rcu_read_unlock();
>>
>>         if (!ops)
>> 		ops = inet_csk(sk)->icsk_af_ops;
>>         child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
>>
>> If psock->icsk_af_ops were NULL, it would mean we haven't initialized it
>> properly. To double check what happens here:
> I did not mean the init path.  The init path is fine since it init
> eveything on psock before publishing the sk to the sock_map.
>
> I was thinking the delete path (e.g. sock_map_delete_elem).  It is not clear
> to me what prevent the earlier pasted sk_psock_restore_proto() which sets
> psock->icsk_af_ops to NULL from running in parallel with
> tcp_bpf_syn_recv_sock()?  An explanation would be useful.

Ah, I misunderstood. Nothing prevents the race, AFAIK.

Setting psock->icsk_af_ops to null on restore and not checking for it
here was a bad move on my side.  Also I need to revisit what to do about
psock->sk_proto so the child socket doesn't end up with null sk_proto.

This race should be easy enough to trigger. Will give it a shot.

Thank you for bringing this up,
Jakub

>
>>
>> In sock_map_link we do a setup dance where we first create the psock and
>> later initialize the socket callbacks (tcp_bpf_init).
>>
>> static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
>> 			 struct sock *sk)
>> {
>>         ...
>> 	if (psock) {
>>                 ...
>> 	} else {
>> 		psock = sk_psock_init(sk, map->numa_node);
>> 		if (!psock) {
>> 			ret = -ENOMEM;
>> 			goto out_progs;
>> 		}
>> 		sk_psock_is_new = true;
>> 	}
>>         ...
>>         if (sk_psock_is_new) {
>> 		ret = tcp_bpf_init(sk);
>> 		if (ret < 0)
>> 			goto out_drop;
>> 	} else {
>> 		tcp_bpf_reinit(sk);
>> 	}
>>
>> The "if (sk_psock_new)" branch triggers the call chain that leads to
>> saving & overriding socket callbacks.
>>
>> tcp_bpf_init -> tcp_bpf_update_sk_prot -> sk_psock_update_proto
>>
>> Among them, icsk_af_ops.
>>
>> static inline void sk_psock_update_proto(...)
>> {
>>         ...
>> 	psock->icsk_af_ops = icsk->icsk_af_ops;
>> 	icsk->icsk_af_ops = af_ops;
>> }
>>
>> Goes without saying that a comment is needed.
>>
>> Thanks for the feedback,
>> Jakub
John Fastabend Nov. 26, 2019, 6:43 p.m. UTC | #6
Martin Lau wrote:
> On Tue, Nov 26, 2019 at 04:54:33PM +0100, Jakub Sitnicki wrote:
> > On Mon, Nov 25, 2019 at 11:38 PM CET, Martin Lau wrote:
> > > On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote:
> > > [ ... ]
> > >
> > >> @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk,
> > >>  			sk->sk_prot = psock->sk_proto;
> > >>  		psock->sk_proto = NULL;
> > >>  	}
> > >> +
> > >> +	if (psock->icsk_af_ops) {
> > >> +		icsk->icsk_af_ops = psock->icsk_af_ops;
> > >> +		psock->icsk_af_ops = NULL;
> > >> +	}
> > >>  }
> > >
> > > [ ... ]
> > >
> > >> +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
> > >> +					  struct sk_buff *skb,
> > >> +					  struct request_sock *req,
> > >> +					  struct dst_entry *dst,
> > >> +					  struct request_sock *req_unhash,
> > >> +					  bool *own_req)
> > >> +{
> > >> +	const struct inet_connection_sock_af_ops *ops;
> > >> +	void (*write_space)(struct sock *sk);
> > >> +	struct sk_psock *psock;
> > >> +	struct proto *proto;
> > >> +	struct sock *child;
> > >> +
> > >> +	rcu_read_lock();
> > >> +	psock = sk_psock(sk);
> > >> +	if (likely(psock)) {
> > >> +		proto = psock->sk_proto;
> > >> +		write_space = psock->saved_write_space;
> > >> +		ops = psock->icsk_af_ops;
> > > It is not immediately clear to me what ensure
> > > ops is not NULL here.
> > >
> > > It is likely I missed something.  A short comment would
> > > be very useful here.
> > 
> > I can see the readability problem. Looking at it now, perhaps it should
> > be rewritten, to the same effect, as:
> > 
> > static struct sock *tcp_bpf_syn_recv_sock(...)
> > {
> > 	const struct inet_connection_sock_af_ops *ops = NULL;
> >         ...
> > 
> >     rcu_read_lock();
> > 	psock = sk_psock(sk);
> > 	if (likely(psock)) {
> > 		proto = psock->sk_proto;
> > 		write_space = psock->saved_write_space;
> > 		ops = psock->icsk_af_ops;
> > 	}
> > 	rcu_read_unlock();
> > 
> >         if (!ops)
> > 		ops = inet_csk(sk)->icsk_af_ops;
> >         child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
> > 
> > If psock->icsk_af_ops were NULL, it would mean we haven't initialized it
> > properly. To double check what happens here:
> I did not mean the init path.  The init path is fine since it init
> eveything on psock before publishing the sk to the sock_map.
> 
> I was thinking the delete path (e.g. sock_map_delete_elem).  It is not clear
> to me what prevent the earlier pasted sk_psock_restore_proto() which sets
> psock->icsk_af_ops to NULL from running in parallel with
> tcp_bpf_syn_recv_sock()?  An explanation would be useful.
> 

I'll answer. Updates are protected via sk_callback_lock so we don't have
parrallel updates in-flight causing write_space and sk_proto to be out
of sync. However access should be OK because its a pointer write we
never update the pointer in place, e.g.

static inline void sk_psock_restore_proto(struct sock *sk,
					  struct sk_psock *psock)
{
+       struct inet_connection_sock *icsk = inet_csk(sk);
+
	sk->sk_write_space = psock->saved_write_space;

	if (psock->sk_proto) {
		struct inet_connection_sock *icsk = inet_csk(sk);
		bool has_ulp = !!icsk->icsk_ulp_data;

		if (has_ulp)
			tcp_update_ulp(sk, psock->sk_proto);
		else
			sk->sk_prot = psock->sk_proto;
		psock->sk_proto = NULL;
	}

+
+       if (psock->icsk_af_ops) {
+               icsk->icsk_af_ops = psock->icsk_af_ops;
+               psock->icsk_af_ops = NULL;
+       }
}

In restore case either psock->icsk_af_ops is null or not. If its
null below code catches it. If its not null (from init path) then
we have a valid pointer.

        rcu_read_lock();
	psock = sk_psock(sk);
 	if (likely(psock)) {
 		proto = psock->sk_proto;
 		write_space = psock->saved_write_space;
 		ops = psock->icsk_af_ops;
 	}
 	rcu_read_unlock();
 
        if (!ops)
		ops = inet_csk(sk)->icsk_af_ops;
        child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);


We should do this with proper READ_ONCE/WRITE_ONCE to make it clear
what is going on and to stop compiler from breaking these assumptions. I
was going to generate that patch after this series but can do it before
as well. I didn't mention it here because it seems a bit out of scope
for this series because its mostly a fix to older code.

Also I started to think that write_space might be out of sync with ops but
it seems we never actually remove psock_write_space until after
rcu grace period so that should be OK as well and always point to the
previous write_space.

Finally I wondered if we could remove the ops and then add it back
quickly which seems at least in theory possible, but that would get
hit with a grace period because we can't have conflicting psock
definitions on the same sock. So expanding the rcu block to include
the ops = inet_csk(sk)->icsk_af_ops would fix that case.

So in summary I think we should expand the rcu lock here to include the
ops = inet_csk(sk)->icsk_af_ops to ensure we dont race with tear
down and create. I'll push the necessary update with WRITE_ONCE and
READ_ONCE to fix that up. Seeing we have to wait until the merge
window opens most likely anyways I'll send those out sooner rather
then later and this series can add the proper annotations as well.
Jakub Sitnicki Nov. 27, 2019, 10:18 p.m. UTC | #7
On Tue, Nov 26, 2019 at 07:43 PM CET, John Fastabend wrote:
> Martin Lau wrote:
>> On Tue, Nov 26, 2019 at 04:54:33PM +0100, Jakub Sitnicki wrote:
>> > On Mon, Nov 25, 2019 at 11:38 PM CET, Martin Lau wrote:
>> > > On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote:
>> > > [ ... ]
>> > >
>> > >> @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk,
>> > >>  			sk->sk_prot = psock->sk_proto;
>> > >>  		psock->sk_proto = NULL;
>> > >>  	}
>> > >> +
>> > >> +	if (psock->icsk_af_ops) {
>> > >> +		icsk->icsk_af_ops = psock->icsk_af_ops;
>> > >> +		psock->icsk_af_ops = NULL;
>> > >> +	}
>> > >>  }
>> > >
>> > > [ ... ]
>> > >
>> > >> +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
>> > >> +					  struct sk_buff *skb,
>> > >> +					  struct request_sock *req,
>> > >> +					  struct dst_entry *dst,
>> > >> +					  struct request_sock *req_unhash,
>> > >> +					  bool *own_req)
>> > >> +{
>> > >> +	const struct inet_connection_sock_af_ops *ops;
>> > >> +	void (*write_space)(struct sock *sk);
>> > >> +	struct sk_psock *psock;
>> > >> +	struct proto *proto;
>> > >> +	struct sock *child;
>> > >> +
>> > >> +	rcu_read_lock();
>> > >> +	psock = sk_psock(sk);
>> > >> +	if (likely(psock)) {
>> > >> +		proto = psock->sk_proto;
>> > >> +		write_space = psock->saved_write_space;
>> > >> +		ops = psock->icsk_af_ops;
>> > > It is not immediately clear to me what ensure
>> > > ops is not NULL here.
>> > >
>> > > It is likely I missed something.  A short comment would
>> > > be very useful here.
>> >
>> > I can see the readability problem. Looking at it now, perhaps it should
>> > be rewritten, to the same effect, as:
>> >
>> > static struct sock *tcp_bpf_syn_recv_sock(...)
>> > {
>> > 	const struct inet_connection_sock_af_ops *ops = NULL;
>> >         ...
>> >
>> >     rcu_read_lock();
>> > 	psock = sk_psock(sk);
>> > 	if (likely(psock)) {
>> > 		proto = psock->sk_proto;
>> > 		write_space = psock->saved_write_space;
>> > 		ops = psock->icsk_af_ops;
>> > 	}
>> > 	rcu_read_unlock();
>> >
>> >         if (!ops)
>> > 		ops = inet_csk(sk)->icsk_af_ops;
>> >         child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
>> >
>> > If psock->icsk_af_ops were NULL, it would mean we haven't initialized it
>> > properly. To double check what happens here:
>> I did not mean the init path.  The init path is fine since it init
>> eveything on psock before publishing the sk to the sock_map.
>>
>> I was thinking the delete path (e.g. sock_map_delete_elem).  It is not clear
>> to me what prevent the earlier pasted sk_psock_restore_proto() which sets
>> psock->icsk_af_ops to NULL from running in parallel with
>> tcp_bpf_syn_recv_sock()?  An explanation would be useful.
>>
>
> I'll answer. Updates are protected via sk_callback_lock so we don't have
> parrallel updates in-flight causing write_space and sk_proto to be out
> of sync. However access should be OK because its a pointer write we
> never update the pointer in place, e.g.
>
> static inline void sk_psock_restore_proto(struct sock *sk,
> 					  struct sk_psock *psock)
> {
> +       struct inet_connection_sock *icsk = inet_csk(sk);
> +
> 	sk->sk_write_space = psock->saved_write_space;
>
> 	if (psock->sk_proto) {
> 		struct inet_connection_sock *icsk = inet_csk(sk);
> 		bool has_ulp = !!icsk->icsk_ulp_data;
>
> 		if (has_ulp)
> 			tcp_update_ulp(sk, psock->sk_proto);
> 		else
> 			sk->sk_prot = psock->sk_proto;
> 		psock->sk_proto = NULL;
> 	}
>
> +
> +       if (psock->icsk_af_ops) {
> +               icsk->icsk_af_ops = psock->icsk_af_ops;
> +               psock->icsk_af_ops = NULL;
> +       }
> }
>
> In restore case either psock->icsk_af_ops is null or not. If its
> null below code catches it. If its not null (from init path) then
> we have a valid pointer.
>
>         rcu_read_lock();
> 	psock = sk_psock(sk);
>  	if (likely(psock)) {
>  		proto = psock->sk_proto;
>  		write_space = psock->saved_write_space;
>  		ops = psock->icsk_af_ops;
>  	}
>  	rcu_read_unlock();
>
>         if (!ops)
> 		ops = inet_csk(sk)->icsk_af_ops;
>         child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
>
>
> We should do this with proper READ_ONCE/WRITE_ONCE to make it clear
> what is going on and to stop compiler from breaking these assumptions. I
> was going to generate that patch after this series but can do it before
> as well. I didn't mention it here because it seems a bit out of scope
> for this series because its mostly a fix to older code.

+1, looking forward to your patch. Also, as I've recently learned, that
should enable KTSAN to reason about the psock code [0].

> Also I started to think that write_space might be out of sync with ops but
> it seems we never actually remove psock_write_space until after
> rcu grace period so that should be OK as well and always point to the
> previous write_space.
>
> Finally I wondered if we could remove the ops and then add it back
> quickly which seems at least in theory possible, but that would get
> hit with a grace period because we can't have conflicting psock
> definitions on the same sock. So expanding the rcu block to include
> the ops = inet_csk(sk)->icsk_af_ops would fix that case.

I see, ops = inet_csk(sk)->icsk_af_ops might read out a re-overwritten
ops after sock_map_unlink, followed by sock_map_link. Ouch.

> So in summary I think we should expand the rcu lock here to include the
> ops = inet_csk(sk)->icsk_af_ops to ensure we dont race with tear
> down and create. I'll push the necessary update with WRITE_ONCE and
> READ_ONCE to fix that up. Seeing we have to wait until the merge
> window opens most likely anyways I'll send those out sooner rather
> then later and this series can add the proper annotations as well.

Or I could leave psock->icsk_af_ops set in restore_proto, like we do for
write_space as you noted. Restoring it twice doesn't seem harmful, it
has no side-effects. Less state changes to think about?

I'll still have to apply what you suggest for saving psock->sk_proto,
though.

Thanks,
Jakub

[0] https://github.com/google/ktsan/wiki/READ_ONCE-and-WRITE_ONCE
Martin KaFai Lau Dec. 11, 2019, 5:20 p.m. UTC | #8
On Tue, Dec 10, 2019 at 03:45:37PM +0100, Jakub Sitnicki wrote:
> John, Martin,
> 
> On Tue, Nov 26, 2019 at 07:36 PM CET, Jakub Sitnicki wrote:
> > On Tue, Nov 26, 2019 at 06:16 PM CET, Martin Lau wrote:
> >> On Tue, Nov 26, 2019 at 04:54:33PM +0100, Jakub Sitnicki wrote:
> >>> On Mon, Nov 25, 2019 at 11:38 PM CET, Martin Lau wrote:
> >>> > On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote:
> >>> > [ ... ]
> >>> >
> >>> >> @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk,
> >>> >>  			sk->sk_prot = psock->sk_proto;
> >>> >>  		psock->sk_proto = NULL;
> >>> >>  	}
> >>> >> +
> >>> >> +	if (psock->icsk_af_ops) {
> >>> >> +		icsk->icsk_af_ops = psock->icsk_af_ops;
> >>> >> +		psock->icsk_af_ops = NULL;
> >>> >> +	}
> >>> >>  }
> >>> >
> >>> > [ ... ]
> >>> >
> >>> >> +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
> >>> >> +					  struct sk_buff *skb,
> >>> >> +					  struct request_sock *req,
> >>> >> +					  struct dst_entry *dst,
> >>> >> +					  struct request_sock *req_unhash,
> >>> >> +					  bool *own_req)
> >>> >> +{
> >>> >> +	const struct inet_connection_sock_af_ops *ops;
> >>> >> +	void (*write_space)(struct sock *sk);
> >>> >> +	struct sk_psock *psock;
> >>> >> +	struct proto *proto;
> >>> >> +	struct sock *child;
> >>> >> +
> >>> >> +	rcu_read_lock();
> >>> >> +	psock = sk_psock(sk);
> >>> >> +	if (likely(psock)) {
> >>> >> +		proto = psock->sk_proto;
> >>> >> +		write_space = psock->saved_write_space;
> >>> >> +		ops = psock->icsk_af_ops;
> >>> > It is not immediately clear to me what ensure
> >>> > ops is not NULL here.
> >>> >
> >>> > It is likely I missed something.  A short comment would
> >>> > be very useful here.
> >>>
> >>> I can see the readability problem. Looking at it now, perhaps it should
> >>> be rewritten, to the same effect, as:
> >>>
> >>> static struct sock *tcp_bpf_syn_recv_sock(...)
> >>> {
> >>> 	const struct inet_connection_sock_af_ops *ops = NULL;
> >>>         ...
> >>>
> >>>         rcu_read_lock();
> >>> 	psock = sk_psock(sk);
> >>> 	if (likely(psock)) {
> >>> 		proto = psock->sk_proto;
> >>> 		write_space = psock->saved_write_space;
> >>> 		ops = psock->icsk_af_ops;
> >>> 	}
> >>> 	rcu_read_unlock();
> >>>
> >>>         if (!ops)
> >>> 		ops = inet_csk(sk)->icsk_af_ops;
> >>>         child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
> >>>
> >>> If psock->icsk_af_ops were NULL, it would mean we haven't initialized it
> >>> properly. To double check what happens here:
> >> I did not mean the init path.  The init path is fine since it init
> >> eveything on psock before publishing the sk to the sock_map.
> >>
> >> I was thinking the delete path (e.g. sock_map_delete_elem).  It is not clear
> >> to me what prevent the earlier pasted sk_psock_restore_proto() which sets
> >> psock->icsk_af_ops to NULL from running in parallel with
> >> tcp_bpf_syn_recv_sock()?  An explanation would be useful.
> >
> > Ah, I misunderstood. Nothing prevents the race, AFAIK.
> >
> > Setting psock->icsk_af_ops to null on restore and not checking for it
> > here was a bad move on my side.  Also I need to revisit what to do about
> > psock->sk_proto so the child socket doesn't end up with null sk_proto.
> >
> > This race should be easy enough to trigger. Will give it a shot.
> 
> I've convinced myself that this approach is racy beyond repair.
> 
> Once syn_recv_sock() has returned it is too late to reset the child
> sk_user_data and restore its callbacks. It has been already inserted
> into ehash and ingress path can invoke its callbacks.
> 
> The race can be triggered with with a reproducer where:
> 
> thread-1:
> 
>         p = accept(s, ...);
>         close(p);
> 
> thread-2:
> 
> 	bpf_map_update_elem(mapfd, &key, &s, BPF_NOEXIST);
> 	bpf_map_delete_elem(mapfd, &key);
> 
> This a dead-end because we can't have the parent and the child share the
> psock state. Even though psock itself is refcounted, and potentially we
> could grab a reference before cloning the parent, link into the map that
> psock holds is not.
> 
> Two ways out come to mind. Both involve touching TCP code, which I was
> hoping to avoid:
> 
> 1) reset sk_user_data when initializing the child
> 
>    This is problematic because tcp_bpf callbacks are not designed to
>    handle sockets with no psock _and_ with overridden sk_prot
>    callbacks. (Although, I think they could if the fallback was directly
>    on {tcp,tcpv6}_prot based on socket domain.)
> 
>    Also, there are other sk_user_data users like DRBD which rely on
>    sharing the sk_user_data pointer between parent and child, if I read
>    the code correctly [0]. If anything, clearing the sk_user_data on
>    clone would have to be guarded by a flag.
Can the copy/not-to-copy sk_user_data decision be made in
sk_clone_lock()?

> 
> 2) Restore sk_prot callbacks on clone to {tcp,tcpv6}_prot
> 
>    The simpler way out. tcp_bpf callbacks never get invoked on the child
>    socket so the copied psock reference is no longer a problem. We can
>    clear the pointer on accept().
> 
>    So far I wasn't able poke any holes in it and it comes down to
>    patching tcp_create_openreq_child() with:
> 
> 	/* sk_msg and ULP frameworks can override the callbacks into
> 	 * protocol. We don't assume they are intended to be inherited
> 	 * by the child. Frameworks can re-install the callbacks on
> 	 * accept() if needed.
> 	 */
> 	WRITE_ONCE(newsk->sk_prot, sk->sk_prot_creator);
> 
>    That's what I'm going with for v2.
> 
> Open to suggestions.
> 
> Thanks,
> Jakub
> 
> BTW. Reading into kTLS code, I noticed it has been limited down to just
> established sockets due to the same problem I'm struggling with here:
> 
> static int tls_init(struct sock *sk)
> {
> ...
> 	/* The TLS ulp is currently supported only for TCP sockets
> 	 * in ESTABLISHED state.
> 	 * Supporting sockets in LISTEN state will require us
> 	 * to modify the accept implementation to clone rather then
> 	 * share the ulp context.
> 	 */
> 	if (sk->sk_state != TCP_ESTABLISHED)
> 		return -ENOTCONN;
> 
> [0] https://urldefense.proofpoint.com/v2/url?u=https-3A__elixir.bootlin.com_linux_v5.5-2Drc1_source_drivers_block_drbd_drbd-5Freceiver.c-23L682&d=DwIBAg&c=5VD0RTtNlTh3ycd41b3MUw&r=VQnoQ7LvghIj0gVEaiQSUw&m=z2Cz1gEcqiw-8YqVOluxlUHh_CBs6PJWQN2vgirOyFk&s=WAiM0asZN0OkqrW02xm2mCMIzWhKQCc3KiY7pzMKNg4&e=
Jakub Sitnicki Dec. 12, 2019, 11:27 a.m. UTC | #9
On Wed, Dec 11, 2019 at 06:20 PM CET, Martin Lau wrote:
> On Tue, Dec 10, 2019 at 03:45:37PM +0100, Jakub Sitnicki wrote:
>> John, Martin,
>>
>> On Tue, Nov 26, 2019 at 07:36 PM CET, Jakub Sitnicki wrote:
>> > On Tue, Nov 26, 2019 at 06:16 PM CET, Martin Lau wrote:
>> >> On Tue, Nov 26, 2019 at 04:54:33PM +0100, Jakub Sitnicki wrote:
>> >>> On Mon, Nov 25, 2019 at 11:38 PM CET, Martin Lau wrote:
>> >>> > On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote:
>> >>> > [ ... ]
>> >>> >
>> >>> >> @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk,
>> >>> >>  			sk->sk_prot = psock->sk_proto;
>> >>> >>  		psock->sk_proto = NULL;
>> >>> >>  	}
>> >>> >> +
>> >>> >> +	if (psock->icsk_af_ops) {
>> >>> >> +		icsk->icsk_af_ops = psock->icsk_af_ops;
>> >>> >> +		psock->icsk_af_ops = NULL;
>> >>> >> +	}
>> >>> >>  }
>> >>> >
>> >>> > [ ... ]
>> >>> >
>> >>> >> +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
>> >>> >> +					  struct sk_buff *skb,
>> >>> >> +					  struct request_sock *req,
>> >>> >> +					  struct dst_entry *dst,
>> >>> >> +					  struct request_sock *req_unhash,
>> >>> >> +					  bool *own_req)
>> >>> >> +{
>> >>> >> +	const struct inet_connection_sock_af_ops *ops;
>> >>> >> +	void (*write_space)(struct sock *sk);
>> >>> >> +	struct sk_psock *psock;
>> >>> >> +	struct proto *proto;
>> >>> >> +	struct sock *child;
>> >>> >> +
>> >>> >> +	rcu_read_lock();
>> >>> >> +	psock = sk_psock(sk);
>> >>> >> +	if (likely(psock)) {
>> >>> >> +		proto = psock->sk_proto;
>> >>> >> +		write_space = psock->saved_write_space;
>> >>> >> +		ops = psock->icsk_af_ops;
>> >>> > It is not immediately clear to me what ensure
>> >>> > ops is not NULL here.
>> >>> >
>> >>> > It is likely I missed something.  A short comment would
>> >>> > be very useful here.
>> >>>
>> >>> I can see the readability problem. Looking at it now, perhaps it should
>> >>> be rewritten, to the same effect, as:
>> >>>
>> >>> static struct sock *tcp_bpf_syn_recv_sock(...)
>> >>> {
>> >>> 	const struct inet_connection_sock_af_ops *ops = NULL;
>> >>>         ...
>> >>>
>> >>>         rcu_read_lock();
>> >>> 	psock = sk_psock(sk);
>> >>> 	if (likely(psock)) {
>> >>> 		proto = psock->sk_proto;
>> >>> 		write_space = psock->saved_write_space;
>> >>> 		ops = psock->icsk_af_ops;
>> >>> 	}
>> >>> 	rcu_read_unlock();
>> >>>
>> >>>         if (!ops)
>> >>> 		ops = inet_csk(sk)->icsk_af_ops;
>> >>>         child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
>> >>>
>> >>> If psock->icsk_af_ops were NULL, it would mean we haven't initialized it
>> >>> properly. To double check what happens here:
>> >> I did not mean the init path.  The init path is fine since it init
>> >> eveything on psock before publishing the sk to the sock_map.
>> >>
>> >> I was thinking the delete path (e.g. sock_map_delete_elem).  It is not clear
>> >> to me what prevent the earlier pasted sk_psock_restore_proto() which sets
>> >> psock->icsk_af_ops to NULL from running in parallel with
>> >> tcp_bpf_syn_recv_sock()?  An explanation would be useful.
>> >
>> > Ah, I misunderstood. Nothing prevents the race, AFAIK.
>> >
>> > Setting psock->icsk_af_ops to null on restore and not checking for it
>> > here was a bad move on my side.  Also I need to revisit what to do about
>> > psock->sk_proto so the child socket doesn't end up with null sk_proto.
>> >
>> > This race should be easy enough to trigger. Will give it a shot.
>>
>> I've convinced myself that this approach is racy beyond repair.
>>
>> Once syn_recv_sock() has returned it is too late to reset the child
>> sk_user_data and restore its callbacks. It has been already inserted
>> into ehash and ingress path can invoke its callbacks.
>>
>> The race can be triggered with with a reproducer where:
>>
>> thread-1:
>>
>>         p = accept(s, ...);
>>         close(p);
>>
>> thread-2:
>>
>> 	bpf_map_update_elem(mapfd, &key, &s, BPF_NOEXIST);
>> 	bpf_map_delete_elem(mapfd, &key);
>>
>> This a dead-end because we can't have the parent and the child share the
>> psock state. Even though psock itself is refcounted, and potentially we
>> could grab a reference before cloning the parent, link into the map that
>> psock holds is not.
>>
>> Two ways out come to mind. Both involve touching TCP code, which I was
>> hoping to avoid:
>>
>> 1) reset sk_user_data when initializing the child
>>
>>    This is problematic because tcp_bpf callbacks are not designed to
>>    handle sockets with no psock _and_ with overridden sk_prot
>>    callbacks. (Although, I think they could if the fallback was directly
>>    on {tcp,tcpv6}_prot based on socket domain.)
>>
>>    Also, there are other sk_user_data users like DRBD which rely on
>>    sharing the sk_user_data pointer between parent and child, if I read
>>    the code correctly [0]. If anything, clearing the sk_user_data on
>>    clone would have to be guarded by a flag.
> Can the copy/not-to-copy sk_user_data decision be made in
> sk_clone_lock()?

Yes, this could be pushed down to sk_clone_lock(), where we do similar
work (reset sk_reuseport_cb and clone bpf_sk_storage):

	/* User data can hold reference. Child must not
	 * inherit the pointer without acquiring a reference.
	 */
	if (sock_flag(sk, SOCK_OWNS_USER_DATA)) {
		sock_reset_flag(newsk, SOCK_OWNS_USER_DATA);
		RCU_INIT_POINTER(newsk->sk_user_data, NULL);
	}

I belive this would still need to be guarded by a flag.  Do you see
value in clearing child sk_user_data on clone as opposed to dealying
that work until accept() time?

-Jakub

>
>>
>> 2) Restore sk_prot callbacks on clone to {tcp,tcpv6}_prot
>>
>>    The simpler way out. tcp_bpf callbacks never get invoked on the child
>>    socket so the copied psock reference is no longer a problem. We can
>>    clear the pointer on accept().
>>
>>    So far I wasn't able poke any holes in it and it comes down to
>>    patching tcp_create_openreq_child() with:
>>
>> 	/* sk_msg and ULP frameworks can override the callbacks into
>> 	 * protocol. We don't assume they are intended to be inherited
>> 	 * by the child. Frameworks can re-install the callbacks on
>> 	 * accept() if needed.
>> 	 */
>> 	WRITE_ONCE(newsk->sk_prot, sk->sk_prot_creator);
>>
>>    That's what I'm going with for v2.
>>
>> Open to suggestions.
>>
>> Thanks,
>> Jakub
>>
>> BTW. Reading into kTLS code, I noticed it has been limited down to just
>> established sockets due to the same problem I'm struggling with here:
>>
>> static int tls_init(struct sock *sk)
>> {
>> ...
>> 	/* The TLS ulp is currently supported only for TCP sockets
>> 	 * in ESTABLISHED state.
>> 	 * Supporting sockets in LISTEN state will require us
>> 	 * to modify the accept implementation to clone rather then
>> 	 * share the ulp context.
>> 	 */
>> 	if (sk->sk_state != TCP_ESTABLISHED)
>> 		return -ENOTCONN;
>>
>> [0] https://urldefense.proofpoint.com/v2/url?u=https-3A__elixir.bootlin.com_linux_v5.5-2Drc1_source_drivers_block_drbd_drbd-5Freceiver.c-23L682&d=DwIBAg&c=5VD0RTtNlTh3ycd41b3MUw&r=VQnoQ7LvghIj0gVEaiQSUw&m=z2Cz1gEcqiw-8YqVOluxlUHh_CBs6PJWQN2vgirOyFk&s=WAiM0asZN0OkqrW02xm2mCMIzWhKQCc3KiY7pzMKNg4&e=
Martin KaFai Lau Dec. 12, 2019, 7:23 p.m. UTC | #10
On Thu, Dec 12, 2019 at 12:27:19PM +0100, Jakub Sitnicki wrote:
> On Wed, Dec 11, 2019 at 06:20 PM CET, Martin Lau wrote:
> > On Tue, Dec 10, 2019 at 03:45:37PM +0100, Jakub Sitnicki wrote:
> >> John, Martin,
> >>
> >> On Tue, Nov 26, 2019 at 07:36 PM CET, Jakub Sitnicki wrote:
> >> > On Tue, Nov 26, 2019 at 06:16 PM CET, Martin Lau wrote:
> >> >> On Tue, Nov 26, 2019 at 04:54:33PM +0100, Jakub Sitnicki wrote:
> >> >>> On Mon, Nov 25, 2019 at 11:38 PM CET, Martin Lau wrote:
> >> >>> > On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote:
> >> >>> > [ ... ]
> >> >>> >
> >> >>> >> @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk,
> >> >>> >>  			sk->sk_prot = psock->sk_proto;
> >> >>> >>  		psock->sk_proto = NULL;
> >> >>> >>  	}
> >> >>> >> +
> >> >>> >> +	if (psock->icsk_af_ops) {
> >> >>> >> +		icsk->icsk_af_ops = psock->icsk_af_ops;
> >> >>> >> +		psock->icsk_af_ops = NULL;
> >> >>> >> +	}
> >> >>> >>  }
> >> >>> >
> >> >>> > [ ... ]
> >> >>> >
> >> >>> >> +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
> >> >>> >> +					  struct sk_buff *skb,
> >> >>> >> +					  struct request_sock *req,
> >> >>> >> +					  struct dst_entry *dst,
> >> >>> >> +					  struct request_sock *req_unhash,
> >> >>> >> +					  bool *own_req)
> >> >>> >> +{
> >> >>> >> +	const struct inet_connection_sock_af_ops *ops;
> >> >>> >> +	void (*write_space)(struct sock *sk);
> >> >>> >> +	struct sk_psock *psock;
> >> >>> >> +	struct proto *proto;
> >> >>> >> +	struct sock *child;
> >> >>> >> +
> >> >>> >> +	rcu_read_lock();
> >> >>> >> +	psock = sk_psock(sk);
> >> >>> >> +	if (likely(psock)) {
> >> >>> >> +		proto = psock->sk_proto;
> >> >>> >> +		write_space = psock->saved_write_space;
> >> >>> >> +		ops = psock->icsk_af_ops;
> >> >>> > It is not immediately clear to me what ensure
> >> >>> > ops is not NULL here.
> >> >>> >
> >> >>> > It is likely I missed something.  A short comment would
> >> >>> > be very useful here.
> >> >>>
> >> >>> I can see the readability problem. Looking at it now, perhaps it should
> >> >>> be rewritten, to the same effect, as:
> >> >>>
> >> >>> static struct sock *tcp_bpf_syn_recv_sock(...)
> >> >>> {
> >> >>> 	const struct inet_connection_sock_af_ops *ops = NULL;
> >> >>>         ...
> >> >>>
> >> >>>         rcu_read_lock();
> >> >>> 	psock = sk_psock(sk);
> >> >>> 	if (likely(psock)) {
> >> >>> 		proto = psock->sk_proto;
> >> >>> 		write_space = psock->saved_write_space;
> >> >>> 		ops = psock->icsk_af_ops;
> >> >>> 	}
> >> >>> 	rcu_read_unlock();
> >> >>>
> >> >>>         if (!ops)
> >> >>> 		ops = inet_csk(sk)->icsk_af_ops;
> >> >>>         child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
> >> >>>
> >> >>> If psock->icsk_af_ops were NULL, it would mean we haven't initialized it
> >> >>> properly. To double check what happens here:
> >> >> I did not mean the init path.  The init path is fine since it init
> >> >> eveything on psock before publishing the sk to the sock_map.
> >> >>
> >> >> I was thinking the delete path (e.g. sock_map_delete_elem).  It is not clear
> >> >> to me what prevent the earlier pasted sk_psock_restore_proto() which sets
> >> >> psock->icsk_af_ops to NULL from running in parallel with
> >> >> tcp_bpf_syn_recv_sock()?  An explanation would be useful.
> >> >
> >> > Ah, I misunderstood. Nothing prevents the race, AFAIK.
> >> >
> >> > Setting psock->icsk_af_ops to null on restore and not checking for it
> >> > here was a bad move on my side.  Also I need to revisit what to do about
> >> > psock->sk_proto so the child socket doesn't end up with null sk_proto.
> >> >
> >> > This race should be easy enough to trigger. Will give it a shot.
> >>
> >> I've convinced myself that this approach is racy beyond repair.
> >>
> >> Once syn_recv_sock() has returned it is too late to reset the child
> >> sk_user_data and restore its callbacks. It has been already inserted
> >> into ehash and ingress path can invoke its callbacks.
> >>
> >> The race can be triggered with with a reproducer where:
> >>
> >> thread-1:
> >>
> >>         p = accept(s, ...);
> >>         close(p);
> >>
> >> thread-2:
> >>
> >> 	bpf_map_update_elem(mapfd, &key, &s, BPF_NOEXIST);
> >> 	bpf_map_delete_elem(mapfd, &key);
> >>
> >> This a dead-end because we can't have the parent and the child share the
> >> psock state. Even though psock itself is refcounted, and potentially we
> >> could grab a reference before cloning the parent, link into the map that
> >> psock holds is not.
> >>
> >> Two ways out come to mind. Both involve touching TCP code, which I was
> >> hoping to avoid:
> >>
> >> 1) reset sk_user_data when initializing the child
> >>
> >>    This is problematic because tcp_bpf callbacks are not designed to
> >>    handle sockets with no psock _and_ with overridden sk_prot
> >>    callbacks. (Although, I think they could if the fallback was directly
> >>    on {tcp,tcpv6}_prot based on socket domain.)
> >>
> >>    Also, there are other sk_user_data users like DRBD which rely on
> >>    sharing the sk_user_data pointer between parent and child, if I read
> >>    the code correctly [0]. If anything, clearing the sk_user_data on
> >>    clone would have to be guarded by a flag.
> > Can the copy/not-to-copy sk_user_data decision be made in
> > sk_clone_lock()?
> 
> Yes, this could be pushed down to sk_clone_lock(), where we do similar
> work (reset sk_reuseport_cb and clone bpf_sk_storage):
aha.  I missed your eariler "clearing the sk_user_data on clone would have
to be guarded by a flag..." part.  It turns out we were talking the same
thing on (1).  sock_flag works better if there is still bit left (and it
seems there is one),  although I was thinking more like adding
something (e.g. a func ptr) to 'struct proto' to mangle sk_user_data
before returning newsk....but not sure this kind of logic
belongs to 'struct proto'

> 
> 	/* User data can hold reference. Child must not
> 	 * inherit the pointer without acquiring a reference.
> 	 */
> 	if (sock_flag(sk, SOCK_OWNS_USER_DATA)) {
> 		sock_reset_flag(newsk, SOCK_OWNS_USER_DATA);
> 		RCU_INIT_POINTER(newsk->sk_user_data, NULL);
> 	}
> 
> I belive this would still need to be guarded by a flag.  Do you see
> value in clearing child sk_user_data on clone as opposed to dealying
> that work until accept() time?
It seems to me clearing things up front at the very beginning is more
straight forward, such that it does not have to worry about the
sk_user_data may be used in a wrong way before it gets a chance
to be cleared in accept().

Just something to consider, if it is obvious that there is no hole in
clearing it in accept(), it is fine too.

> >>
> >> 2) Restore sk_prot callbacks on clone to {tcp,tcpv6}_prot
> >>
> >>    The simpler way out. tcp_bpf callbacks never get invoked on the child
> >>    socket so the copied psock reference is no longer a problem. We can
> >>    clear the pointer on accept().
> >>
> >>    So far I wasn't able poke any holes in it and it comes down to
> >>    patching tcp_create_openreq_child() with:
> >>
> >> 	/* sk_msg and ULP frameworks can override the callbacks into
> >> 	 * protocol. We don't assume they are intended to be inherited
> >> 	 * by the child. Frameworks can re-install the callbacks on
> >> 	 * accept() if needed.
> >> 	 */
> >> 	WRITE_ONCE(newsk->sk_prot, sk->sk_prot_creator);
> >>
> >>    That's what I'm going with for v2.
> >>
> >> Open to suggestions.
> >>
> >> Thanks,
> >> Jakub
> >>
> >> BTW. Reading into kTLS code, I noticed it has been limited down to just
> >> established sockets due to the same problem I'm struggling with here:
> >>
> >> static int tls_init(struct sock *sk)
> >> {
> >> ...
> >> 	/* The TLS ulp is currently supported only for TCP sockets
> >> 	 * in ESTABLISHED state.
> >> 	 * Supporting sockets in LISTEN state will require us
> >> 	 * to modify the accept implementation to clone rather then
> >> 	 * share the ulp context.
> >> 	 */
> >> 	if (sk->sk_state != TCP_ESTABLISHED)
> >> 		return -ENOTCONN;
> >>
> >> [0] https://urldefense.proofpoint.com/v2/url?u=https-3A__elixir.bootlin.com_linux_v5.5-2Drc1_source_drivers_block_drbd_drbd-5Freceiver.c-23L682&d=DwIBAg&c=5VD0RTtNlTh3ycd41b3MUw&r=VQnoQ7LvghIj0gVEaiQSUw&m=z2Cz1gEcqiw-8YqVOluxlUHh_CBs6PJWQN2vgirOyFk&s=WAiM0asZN0OkqrW02xm2mCMIzWhKQCc3KiY7pzMKNg4&e=
Jakub Sitnicki Dec. 17, 2019, 3:06 p.m. UTC | #11
On Thu, Dec 12, 2019 at 08:23 PM CET, Martin Lau wrote:
> On Thu, Dec 12, 2019 at 12:27:19PM +0100, Jakub Sitnicki wrote:
>> On Wed, Dec 11, 2019 at 06:20 PM CET, Martin Lau wrote:
>> > On Tue, Dec 10, 2019 at 03:45:37PM +0100, Jakub Sitnicki wrote:
>> >> John, Martin,
>> >>
>> >> On Tue, Nov 26, 2019 at 07:36 PM CET, Jakub Sitnicki wrote:
>> >> > On Tue, Nov 26, 2019 at 06:16 PM CET, Martin Lau wrote:
>> >> >> On Tue, Nov 26, 2019 at 04:54:33PM +0100, Jakub Sitnicki wrote:
>> >> >>> On Mon, Nov 25, 2019 at 11:38 PM CET, Martin Lau wrote:
>> >> >>> > On Sat, Nov 23, 2019 at 12:07:47PM +0100, Jakub Sitnicki wrote:
>> >> >>> > [ ... ]
>> >> >>> >
>> >> >>> >> @@ -370,6 +378,11 @@ static inline void sk_psock_restore_proto(struct sock *sk,
>> >> >>> >>  			sk->sk_prot = psock->sk_proto;
>> >> >>> >>  		psock->sk_proto = NULL;
>> >> >>> >>  	}
>> >> >>> >> +
>> >> >>> >> +	if (psock->icsk_af_ops) {
>> >> >>> >> +		icsk->icsk_af_ops = psock->icsk_af_ops;
>> >> >>> >> +		psock->icsk_af_ops = NULL;
>> >> >>> >> +	}
>> >> >>> >>  }
>> >> >>> >
>> >> >>> > [ ... ]
>> >> >>> >
>> >> >>> >> +static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
>> >> >>> >> +					  struct sk_buff *skb,
>> >> >>> >> +					  struct request_sock *req,
>> >> >>> >> +					  struct dst_entry *dst,
>> >> >>> >> +					  struct request_sock *req_unhash,
>> >> >>> >> +					  bool *own_req)
>> >> >>> >> +{
>> >> >>> >> +	const struct inet_connection_sock_af_ops *ops;
>> >> >>> >> +	void (*write_space)(struct sock *sk);
>> >> >>> >> +	struct sk_psock *psock;
>> >> >>> >> +	struct proto *proto;
>> >> >>> >> +	struct sock *child;
>> >> >>> >> +
>> >> >>> >> +	rcu_read_lock();
>> >> >>> >> +	psock = sk_psock(sk);
>> >> >>> >> +	if (likely(psock)) {
>> >> >>> >> +		proto = psock->sk_proto;
>> >> >>> >> +		write_space = psock->saved_write_space;
>> >> >>> >> +		ops = psock->icsk_af_ops;
>> >> >>> > It is not immediately clear to me what ensure
>> >> >>> > ops is not NULL here.
>> >> >>> >
>> >> >>> > It is likely I missed something.  A short comment would
>> >> >>> > be very useful here.
>> >> >>>
>> >> >>> I can see the readability problem. Looking at it now, perhaps it should
>> >> >>> be rewritten, to the same effect, as:
>> >> >>>
>> >> >>> static struct sock *tcp_bpf_syn_recv_sock(...)
>> >> >>> {
>> >> >>> 	const struct inet_connection_sock_af_ops *ops = NULL;
>> >> >>>         ...
>> >> >>>
>> >> >>>         rcu_read_lock();
>> >> >>> 	psock = sk_psock(sk);
>> >> >>> 	if (likely(psock)) {
>> >> >>> 		proto = psock->sk_proto;
>> >> >>> 		write_space = psock->saved_write_space;
>> >> >>> 		ops = psock->icsk_af_ops;
>> >> >>> 	}
>> >> >>> 	rcu_read_unlock();
>> >> >>>
>> >> >>>         if (!ops)
>> >> >>> 		ops = inet_csk(sk)->icsk_af_ops;
>> >> >>>         child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
>> >> >>>
>> >> >>> If psock->icsk_af_ops were NULL, it would mean we haven't initialized it
>> >> >>> properly. To double check what happens here:
>> >> >> I did not mean the init path.  The init path is fine since it init
>> >> >> eveything on psock before publishing the sk to the sock_map.
>> >> >>
>> >> >> I was thinking the delete path (e.g. sock_map_delete_elem).  It is not clear
>> >> >> to me what prevent the earlier pasted sk_psock_restore_proto() which sets
>> >> >> psock->icsk_af_ops to NULL from running in parallel with
>> >> >> tcp_bpf_syn_recv_sock()?  An explanation would be useful.
>> >> >
>> >> > Ah, I misunderstood. Nothing prevents the race, AFAIK.
>> >> >
>> >> > Setting psock->icsk_af_ops to null on restore and not checking for it
>> >> > here was a bad move on my side.  Also I need to revisit what to do about
>> >> > psock->sk_proto so the child socket doesn't end up with null sk_proto.
>> >> >
>> >> > This race should be easy enough to trigger. Will give it a shot.
>> >>
>> >> I've convinced myself that this approach is racy beyond repair.
>> >>
>> >> Once syn_recv_sock() has returned it is too late to reset the child
>> >> sk_user_data and restore its callbacks. It has been already inserted
>> >> into ehash and ingress path can invoke its callbacks.
>> >>
>> >> The race can be triggered with with a reproducer where:
>> >>
>> >> thread-1:
>> >>
>> >>         p = accept(s, ...);
>> >>         close(p);
>> >>
>> >> thread-2:
>> >>
>> >> 	bpf_map_update_elem(mapfd, &key, &s, BPF_NOEXIST);
>> >> 	bpf_map_delete_elem(mapfd, &key);
>> >>
>> >> This a dead-end because we can't have the parent and the child share the
>> >> psock state. Even though psock itself is refcounted, and potentially we
>> >> could grab a reference before cloning the parent, link into the map that
>> >> psock holds is not.
>> >>
>> >> Two ways out come to mind. Both involve touching TCP code, which I was
>> >> hoping to avoid:
>> >>
>> >> 1) reset sk_user_data when initializing the child
>> >>
>> >>    This is problematic because tcp_bpf callbacks are not designed to
>> >>    handle sockets with no psock _and_ with overridden sk_prot
>> >>    callbacks. (Although, I think they could if the fallback was directly
>> >>    on {tcp,tcpv6}_prot based on socket domain.)
>> >>
>> >>    Also, there are other sk_user_data users like DRBD which rely on
>> >>    sharing the sk_user_data pointer between parent and child, if I read
>> >>    the code correctly [0]. If anything, clearing the sk_user_data on
>> >>    clone would have to be guarded by a flag.
>> > Can the copy/not-to-copy sk_user_data decision be made in
>> > sk_clone_lock()?
>>
>> Yes, this could be pushed down to sk_clone_lock(), where we do similar
>> work (reset sk_reuseport_cb and clone bpf_sk_storage):
> aha.  I missed your eariler "clearing the sk_user_data on clone would have
> to be guarded by a flag..." part.  It turns out we were talking the same
> thing on (1).  sock_flag works better if there is still bit left (and it
> seems there is one),  although I was thinking more like adding
> something (e.g. a func ptr) to 'struct proto' to mangle sk_user_data
> before returning newsk....but not sure this kind of logic
> belongs to 'struct proto'

Sorry for late reply.

We have 4 bits left by my count. The multi-line comment for SOCK_NOFCS
is getting in the way of counting them line-for-bit.

A callback invoked on socket clone is something I was considering too.
I'm not sure either where it belongs. At risk of being too use-case
specific, perhaps it could live together with sk_user_data and sk_prot,
which it would mangle on sk_clone_lock():

struct sock {
        ...
	void			*sk_user_data;
	void			(*sk_clone)(struct sock *sk,
					    struct sock *newsk);
        ...
}

But, I feel adding a new sock field just for this wouldn't be justified.
I can get by with a sock flag. Unless we have other uses for it?

>
>>
>> 	/* User data can hold reference. Child must not
>> 	 * inherit the pointer without acquiring a reference.
>> 	 */
>> 	if (sock_flag(sk, SOCK_OWNS_USER_DATA)) {
>> 		sock_reset_flag(newsk, SOCK_OWNS_USER_DATA);
>> 		RCU_INIT_POINTER(newsk->sk_user_data, NULL);
>> 	}
>>
>> I belive this would still need to be guarded by a flag.  Do you see
>> value in clearing child sk_user_data on clone as opposed to dealying
>> that work until accept() time?
> It seems to me clearing things up front at the very beginning is more
> straight forward, such that it does not have to worry about the
> sk_user_data may be used in a wrong way before it gets a chance
> to be cleared in accept().
>
> Just something to consider, if it is obvious that there is no hole in
> clearing it in accept(), it is fine too.

Just when I thought I could get away with lazily clearing the
sk_user_data at accept() time, it occurred to me that it is not enough.

Listening socket could get deleted from sockmap before a child socket
that inherited a copy of sk_user_data pointer gets accept()'ed. In such
scenario the pointer would not get NULL'ed on accept(), because
listening socket would have it's sk_prot->accept restored by then.

I will need that flag after all...

-jkbs

>
>> >>
>> >> 2) Restore sk_prot callbacks on clone to {tcp,tcpv6}_prot
>> >>
>> >>    The simpler way out. tcp_bpf callbacks never get invoked on the child
>> >>    socket so the copied psock reference is no longer a problem. We can
>> >>    clear the pointer on accept().
>> >>
>> >>    So far I wasn't able poke any holes in it and it comes down to
>> >>    patching tcp_create_openreq_child() with:
>> >>
>> >> 	/* sk_msg and ULP frameworks can override the callbacks into
>> >> 	 * protocol. We don't assume they are intended to be inherited
>> >> 	 * by the child. Frameworks can re-install the callbacks on
>> >> 	 * accept() if needed.
>> >> 	 */
>> >> 	WRITE_ONCE(newsk->sk_prot, sk->sk_prot_creator);
>> >>
>> >>    That's what I'm going with for v2.
>> >>
>> >> Open to suggestions.
>> >>
>> >> Thanks,
>> >> Jakub
>> >>
>> >> BTW. Reading into kTLS code, I noticed it has been limited down to just
>> >> established sockets due to the same problem I'm struggling with here:
>> >>
>> >> static int tls_init(struct sock *sk)
>> >> {
>> >> ...
>> >> 	/* The TLS ulp is currently supported only for TCP sockets
>> >> 	 * in ESTABLISHED state.
>> >> 	 * Supporting sockets in LISTEN state will require us
>> >> 	 * to modify the accept implementation to clone rather then
>> >> 	 * share the ulp context.
>> >> 	 */
>> >> 	if (sk->sk_state != TCP_ESTABLISHED)
>> >> 		return -ENOTCONN;
>> >>
>> >> [0] https://urldefense.proofpoint.com/v2/url?u=https-3A__elixir.bootlin.com_linux_v5.5-2Drc1_source_drivers_block_drbd_drbd-5Freceiver.c-23L682&d=DwIBAg&c=5VD0RTtNlTh3ycd41b3MUw&r=VQnoQ7LvghIj0gVEaiQSUw&m=z2Cz1gEcqiw-8YqVOluxlUHh_CBs6PJWQN2vgirOyFk&s=WAiM0asZN0OkqrW02xm2mCMIzWhKQCc3KiY7pzMKNg4&e=
diff mbox series

Patch

diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 6cb077b646a5..b5ade8dac69d 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -98,6 +98,7 @@  struct sk_psock {
 	void (*saved_close)(struct sock *sk, long timeout);
 	void (*saved_write_space)(struct sock *sk);
 	struct proto			*sk_proto;
+	const struct inet_connection_sock_af_ops *icsk_af_ops;
 	struct sk_psock_work_state	work_state;
 	struct work_struct		work;
 	union {
@@ -345,23 +346,30 @@  static inline void sk_psock_cork_free(struct sk_psock *psock)
 
 static inline void sk_psock_update_proto(struct sock *sk,
 					 struct sk_psock *psock,
-					 struct proto *ops)
+					 struct proto *ops,
+					 struct inet_connection_sock_af_ops *af_ops)
 {
+	struct inet_connection_sock *icsk = inet_csk(sk);
+
 	psock->saved_unhash = sk->sk_prot->unhash;
 	psock->saved_close = sk->sk_prot->close;
 	psock->saved_write_space = sk->sk_write_space;
 
 	psock->sk_proto = sk->sk_prot;
 	sk->sk_prot = ops;
+
+	psock->icsk_af_ops = icsk->icsk_af_ops;
+	icsk->icsk_af_ops = af_ops;
 }
 
 static inline void sk_psock_restore_proto(struct sock *sk,
 					  struct sk_psock *psock)
 {
+	struct inet_connection_sock *icsk = inet_csk(sk);
+
 	sk->sk_write_space = psock->saved_write_space;
 
 	if (psock->sk_proto) {
-		struct inet_connection_sock *icsk = inet_csk(sk);
 		bool has_ulp = !!icsk->icsk_ulp_data;
 
 		if (has_ulp)
@@ -370,6 +378,11 @@  static inline void sk_psock_restore_proto(struct sock *sk,
 			sk->sk_prot = psock->sk_proto;
 		psock->sk_proto = NULL;
 	}
+
+	if (psock->icsk_af_ops) {
+		icsk->icsk_af_ops = psock->icsk_af_ops;
+		psock->icsk_af_ops = NULL;
+	}
 }
 
 static inline void sk_psock_set_state(struct sk_psock *psock,
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 8a56e09cfb0e..dc709949c8e5 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -10,6 +10,8 @@ 
 #include <net/inet_common.h>
 #include <net/tls.h>
 
+extern const struct inet_connection_sock_af_ops ipv4_specific;
+
 static bool tcp_bpf_stream_read(const struct sock *sk)
 {
 	struct sk_psock *psock;
@@ -535,6 +537,10 @@  static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
 {
 	struct sk_psock_link *link;
 
+	/* Did a child socket inadvertently inherit parent's psock? */
+	if (WARN_ON(sk != psock->sk))
+		return;
+
 	while ((link = sk_psock_link_pop(psock))) {
 		sk_psock_unlink(sk, link);
 		sk_psock_free_link(link);
@@ -582,6 +588,45 @@  static void tcp_bpf_close(struct sock *sk, long timeout)
 	saved_close(sk, timeout);
 }
 
+static struct sock *tcp_bpf_syn_recv_sock(const struct sock *sk,
+					  struct sk_buff *skb,
+					  struct request_sock *req,
+					  struct dst_entry *dst,
+					  struct request_sock *req_unhash,
+					  bool *own_req)
+{
+	const struct inet_connection_sock_af_ops *ops;
+	void (*write_space)(struct sock *sk);
+	struct sk_psock *psock;
+	struct proto *proto;
+	struct sock *child;
+
+	rcu_read_lock();
+	psock = sk_psock(sk);
+	if (likely(psock)) {
+		proto = psock->sk_proto;
+		write_space = psock->saved_write_space;
+		ops = psock->icsk_af_ops;
+	} else {
+		ops = inet_csk(sk)->icsk_af_ops;
+	}
+	rcu_read_unlock();
+
+	child = ops->syn_recv_sock(sk, skb, req, dst, req_unhash, own_req);
+
+	/* Child must not inherit psock or its ops. */
+	if (child && psock) {
+		rcu_assign_sk_user_data(child, NULL);
+		child->sk_prot = proto;
+		child->sk_write_space = write_space;
+
+		/* v4-mapped sockets don't inherit parent ops. Don't restore. */
+		if (inet_csk(child)->icsk_af_ops == inet_csk(sk)->icsk_af_ops)
+			inet_csk(child)->icsk_af_ops = ops;
+	}
+	return child;
+}
+
 enum {
 	TCP_BPF_IPV4,
 	TCP_BPF_IPV6,
@@ -597,6 +642,7 @@  enum {
 static struct proto *tcpv6_prot_saved __read_mostly;
 static DEFINE_SPINLOCK(tcpv6_prot_lock);
 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
+static struct inet_connection_sock_af_ops tcp_bpf_af_ops[TCP_BPF_NUM_PROTS];
 
 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
 				   struct proto *base)
@@ -612,13 +658,23 @@  static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
 	prot[TCP_BPF_TX].sendpage		= tcp_bpf_sendpage;
 }
 
-static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
+static void tcp_bpf_rebuild_af_ops(struct inet_connection_sock_af_ops *ops,
+				   const struct inet_connection_sock_af_ops *base)
+{
+	*ops = *base;
+	ops->syn_recv_sock = tcp_bpf_syn_recv_sock;
+}
+
+static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops,
+					   const struct inet_connection_sock_af_ops *af_ops)
 {
 	if (sk->sk_family == AF_INET6 &&
 	    unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
 		spin_lock_bh(&tcpv6_prot_lock);
 		if (likely(ops != tcpv6_prot_saved)) {
 			tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
+			tcp_bpf_rebuild_af_ops(&tcp_bpf_af_ops[TCP_BPF_IPV6],
+					       af_ops);
 			smp_store_release(&tcpv6_prot_saved, ops);
 		}
 		spin_unlock_bh(&tcpv6_prot_lock);
@@ -628,6 +684,8 @@  static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
 static int __init tcp_bpf_v4_build_proto(void)
 {
 	tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
+	tcp_bpf_rebuild_af_ops(&tcp_bpf_af_ops[TCP_BPF_IPV4], &ipv4_specific);
+
 	return 0;
 }
 core_initcall(tcp_bpf_v4_build_proto);
@@ -637,7 +695,8 @@  static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
 	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 
-	sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
+	sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config],
+			      &tcp_bpf_af_ops[family]);
 }
 
 static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
@@ -677,6 +736,7 @@  void tcp_bpf_reinit(struct sock *sk)
 
 int tcp_bpf_init(struct sock *sk)
 {
+	struct inet_connection_sock *icsk = inet_csk(sk);
 	struct proto *ops = READ_ONCE(sk->sk_prot);
 	struct sk_psock *psock;
 
@@ -689,7 +749,7 @@  int tcp_bpf_init(struct sock *sk)
 		rcu_read_unlock();
 		return -EINVAL;
 	}
-	tcp_bpf_check_v6_needs_rebuild(sk, ops);
+	tcp_bpf_check_v6_needs_rebuild(sk, ops, icsk->icsk_af_ops);
 	tcp_bpf_update_sk_prot(sk, psock);
 	rcu_read_unlock();
 	return 0;