diff mbox series

[bpf-next,3/7] skmsg: introduce sk_psock_hooks

Message ID 20200225135636.5768-4-lmb@cloudflare.com
State Changes Requested
Delegated to: BPF Maintainers
Headers show
Series bpf: sockmap, sockhash: support storing UDP sockets | expand

Commit Message

Lorenz Bauer Feb. 25, 2020, 1:56 p.m. UTC
The sockmap works by overriding some of the callbacks in sk->sk_prot, while
leaving others untouched. This means that we need access to the struct proto
for any protocol we want to support. For IPv4 this is trivial, since both
TCP and UDP are always compiled in. IPv6 may be disabled or compiled as a
module, so the existing TCP sockmap hooks use some trickery to lazily
initialize the modified struct proto for TCPv6.

Pull this logic into a standalone struct sk_psock_hooks, so that it can be
re-used by UDP sockmap.

Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
---
 include/linux/skmsg.h |  36 ++++++++-----
 include/net/tcp.h     |   1 -
 net/core/skmsg.c      |  52 +++++++++++++++++++
 net/core/sock_map.c   |  24 ++++-----
 net/ipv4/tcp_bpf.c    | 114 ++++++++++++------------------------------
 5 files changed, 116 insertions(+), 111 deletions(-)

Comments

Jakub Sitnicki Feb. 26, 2020, 2:57 p.m. UTC | #1
On Tue, Feb 25, 2020 at 02:56 PM CET, Lorenz Bauer wrote:
> The sockmap works by overriding some of the callbacks in sk->sk_prot, while
> leaving others untouched. This means that we need access to the struct proto
> for any protocol we want to support. For IPv4 this is trivial, since both
> TCP and UDP are always compiled in. IPv6 may be disabled or compiled as a
> module, so the existing TCP sockmap hooks use some trickery to lazily
> initialize the modified struct proto for TCPv6.
>
> Pull this logic into a standalone struct sk_psock_hooks, so that it can be
> re-used by UDP sockmap.
>
> Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
> ---
>  include/linux/skmsg.h |  36 ++++++++-----
>  include/net/tcp.h     |   1 -
>  net/core/skmsg.c      |  52 +++++++++++++++++++
>  net/core/sock_map.c   |  24 ++++-----
>  net/ipv4/tcp_bpf.c    | 114 ++++++++++++------------------------------
>  5 files changed, 116 insertions(+), 111 deletions(-)
>
> diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
> index c881094387db..70d65ab10b5c 100644
> --- a/include/linux/skmsg.h
> +++ b/include/linux/skmsg.h

[...]

> @@ -424,4 +425,13 @@ static inline void psock_progs_drop(struct sk_psock_progs *progs)
>  	psock_set_prog(&progs->skb_verdict, NULL);
>  }
>  
> +static inline int sk_psock_hooks_init(struct sk_psock_hooks *hooks,
> +				       struct proto *ipv4_base)
> +{
> +	hooks->ipv6_lock = __SPIN_LOCK_UNLOCKED();

We will need spin_lock_init instead to play nice with CONFIG_DEBUG_SPINLOCK.

> +	return hooks->rebuild_proto(hooks->ipv4, ipv4_base);
> +}
> +
> +int sk_psock_hooks_install(struct sk_psock_hooks *hooks, struct sock *sk);
> +
>  #endif /* _LINUX_SKMSG_H */

[...]
Martin KaFai Lau Feb. 26, 2020, 6:37 p.m. UTC | #2
On Tue, Feb 25, 2020 at 01:56:32PM +0000, Lorenz Bauer wrote:
> The sockmap works by overriding some of the callbacks in sk->sk_prot, while
> leaving others untouched. This means that we need access to the struct proto
> for any protocol we want to support. For IPv4 this is trivial, since both
> TCP and UDP are always compiled in. IPv6 may be disabled or compiled as a
> module, so the existing TCP sockmap hooks use some trickery to lazily
> initialize the modified struct proto for TCPv6.
> 
> Pull this logic into a standalone struct sk_psock_hooks, so that it can be
> re-used by UDP sockmap.
> 
> Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
> ---
>  include/linux/skmsg.h |  36 ++++++++-----
>  include/net/tcp.h     |   1 -
>  net/core/skmsg.c      |  52 +++++++++++++++++++
>  net/core/sock_map.c   |  24 ++++-----
>  net/ipv4/tcp_bpf.c    | 114 ++++++++++++------------------------------
>  5 files changed, 116 insertions(+), 111 deletions(-)
> 
> diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
> index c881094387db..70d65ab10b5c 100644
> --- a/include/linux/skmsg.h
> +++ b/include/linux/skmsg.h
> @@ -109,6 +109,16 @@ struct sk_psock {
>  	};
>  };
>  
> +struct sk_psock_hooks {
> +	struct proto *base_ipv6;
> +	struct proto *ipv4;
> +	struct proto *ipv6;
> +	spinlock_t ipv6_lock;
> +	int (*rebuild_proto)(struct proto prot[], struct proto *base);
> +	struct proto *(*choose_proto)(struct proto prot[],
> +				      struct sk_psock *psock);
> +};
> +
>  int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
>  		 int elem_first_coalesce);
>  int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
> @@ -335,23 +345,14 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
>  	}
>  }
>  
> -static inline void sk_psock_update_proto(struct sock *sk,
> -					 struct sk_psock *psock,
> -					 struct proto *ops)
> -{
> -	psock->saved_unhash = sk->sk_prot->unhash;
> -	psock->saved_close = sk->sk_prot->close;
> -	psock->saved_write_space = sk->sk_write_space;
> -
> -	psock->sk_proto = sk->sk_prot;
> -	/* Pairs with lockless read in sk_clone_lock() */
> -	WRITE_ONCE(sk->sk_prot, ops);
> -}
> -
>  static inline void sk_psock_restore_proto(struct sock *sk,
>  					  struct sk_psock *psock)
>  {
> +	if (!psock->sk_proto)
> +		return;
> +
>  	sk->sk_prot->unhash = psock->saved_unhash;
> +
>  	if (inet_sk(sk)->is_icsk) {
>  		tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
>  	} else {
> @@ -424,4 +425,13 @@ static inline void psock_progs_drop(struct sk_psock_progs *progs)
>  	psock_set_prog(&progs->skb_verdict, NULL);
>  }
>  
> +static inline int sk_psock_hooks_init(struct sk_psock_hooks *hooks,
> +				       struct proto *ipv4_base)
> +{
> +	hooks->ipv6_lock = __SPIN_LOCK_UNLOCKED();
> +	return hooks->rebuild_proto(hooks->ipv4, ipv4_base);
> +}
> +
> +int sk_psock_hooks_install(struct sk_psock_hooks *hooks, struct sock *sk);
> +
>  #endif /* _LINUX_SKMSG_H */
> diff --git a/include/net/tcp.h b/include/net/tcp.h
> index 07f947cc80e6..ccf39d80b695 100644
> --- a/include/net/tcp.h
> +++ b/include/net/tcp.h
> @@ -2196,7 +2196,6 @@ struct sk_msg;
>  struct sk_psock;
>  
>  int tcp_bpf_init(struct sock *sk);
> -void tcp_bpf_reinit(struct sock *sk);
>  int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
>  			  int flags);
>  int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
> diff --git a/net/core/skmsg.c b/net/core/skmsg.c
> index eeb28cb85664..a9bdf02c2539 100644
> --- a/net/core/skmsg.c
> +++ b/net/core/skmsg.c
> @@ -844,3 +844,55 @@ void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
>  	strp_stop(&parser->strp);
>  	parser->enabled = false;
>  }
> +
> +static inline int sk_psock_hooks_init_ipv6(struct sk_psock_hooks *hooks,
> +					    struct proto *base)
> +{
> +	int ret = 0;
> +
> +	if (likely(base == smp_load_acquire(&hooks->base_ipv6)))
> +		return 0;
> +
> +	spin_lock_bh(&hooks->ipv6_lock);
> +	if (likely(base != hooks->base_ipv6)) {
> +		ret = hooks->rebuild_proto(hooks->ipv6, base);
> +		if (!ret)
> +			smp_store_release(&hooks->base_ipv6, base);
> +	}
> +	spin_unlock_bh(&hooks->ipv6_lock);
> +	return ret;
> +}
> +
> +int sk_psock_hooks_install(struct sk_psock_hooks *hooks, struct sock *sk)
> +{
> +	struct sk_psock *psock = sk_psock(sk);
> +	struct proto *prot_base;
> +
> +	WARN_ON_ONCE(!rcu_read_lock_held());
Is this only for the earlier sk_psock(sk)?

> +
> +	if (unlikely(!psock))
When will this happen?

> +		return -EINVAL;
> +
> +	/* Initialize saved callbacks and original proto only once.
> +	 * Since we've not installed the hooks, psock is not yet in use and
> +	 * we can initialize it without synchronization.
> +	 */
> +	if (!psock->sk_proto) {
If I read it correctly, this is to replace the tcp_bpf_reinit_sk_prot()?

I think some of the current reinit comment is useful to keep also:

/* Reinit occurs when program types change e.g. TCP_BPF_TX is removed ... */

> +		struct proto *prot = READ_ONCE(sk->sk_prot);
> +
> +		if (sk->sk_family == AF_INET6 &&
> +		    sk_psock_hooks_init_ipv6(hooks, prot))
> +			return -EINVAL;
> +
> +		psock->saved_unhash = prot->unhash;
> +		psock->saved_close = prot->close;
> +		psock->saved_write_space = sk->sk_write_space;
> +
> +		psock->sk_proto = prot;
> +	}
> +
> +	/* Pairs with lockless read in sk_clone_lock() */
> +	prot_base = sk->sk_family == AF_INET ? hooks->ipv4 : hooks->ipv6;
> +	WRITE_ONCE(sk->sk_prot, hooks->choose_proto(prot_base, psock));
> +	return 0;
> +}
Jakub Sitnicki Feb. 27, 2020, 9:27 a.m. UTC | #3
On Tue, Feb 25, 2020 at 02:56 PM CET, Lorenz Bauer wrote:
> The sockmap works by overriding some of the callbacks in sk->sk_prot, while
> leaving others untouched. This means that we need access to the struct proto
> for any protocol we want to support. For IPv4 this is trivial, since both
> TCP and UDP are always compiled in. IPv6 may be disabled or compiled as a
> module, so the existing TCP sockmap hooks use some trickery to lazily
> initialize the modified struct proto for TCPv6.
>
> Pull this logic into a standalone struct sk_psock_hooks, so that it can be
> re-used by UDP sockmap.
>
> Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
> ---
>  include/linux/skmsg.h |  36 ++++++++-----
>  include/net/tcp.h     |   1 -
>  net/core/skmsg.c      |  52 +++++++++++++++++++
>  net/core/sock_map.c   |  24 ++++-----
>  net/ipv4/tcp_bpf.c    | 114 ++++++++++++------------------------------
>  5 files changed, 116 insertions(+), 111 deletions(-)
>

[...]

> diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
> index 90955c96a9a8..81c0431a8dbd 100644
> --- a/net/ipv4/tcp_bpf.c
> +++ b/net/ipv4/tcp_bpf.c

[...]

> @@ -650,9 +599,10 @@ int tcp_bpf_init(struct sock *sk)
>   */
>  void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
>  {
> -	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
> -	struct proto *prot = newsk->sk_prot;
> +	struct proto *prot = READ_ONCE(sk->sk_prot);

For the sake of keeping the review open - we've identified a regression
here. sk->sk_prot can change by the time we get here, since the moment
we copied the listener sock. We need to stick to checking newsk->sk_prot
here.

>
> -	if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
> +	/* TCP_LISTEN can only use TCP_BPF_BASE, so this is safe */
> +	if (unlikely(prot == &tcp_bpf_ipv4[TCP_BPF_BASE] ||
> +	             prot == &tcp_bpf_ipv6[TCP_BPF_BASE]))
>  		newsk->sk_prot = sk->sk_prot_creator;
>  }
Jakub Sitnicki Feb. 27, 2020, 9:40 a.m. UTC | #4
On Tue, Feb 25, 2020 at 02:56 PM CET, Lorenz Bauer wrote:
> The sockmap works by overriding some of the callbacks in sk->sk_prot, while
> leaving others untouched. This means that we need access to the struct proto
> for any protocol we want to support. For IPv4 this is trivial, since both
> TCP and UDP are always compiled in. IPv6 may be disabled or compiled as a
> module, so the existing TCP sockmap hooks use some trickery to lazily
> initialize the modified struct proto for TCPv6.
>
> Pull this logic into a standalone struct sk_psock_hooks, so that it can be
> re-used by UDP sockmap.
>
> Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
> ---

I've been looking at how to simplify this a bit. One thing that seems
like an easy win is to fold sk_psock_hooks_init into its callers. Then
we can go back to using spinlock initializer macros. Patch below.

This highlights some inconsistency in naming instances of
sk_psock_hooks, that is tcp_bpf_hooks vs udp_psock_proto.

---
 include/linux/skmsg.h | 7 -------
 net/ipv4/tcp_bpf.c    | 3 ++-
 net/ipv4/udp_bpf.c    | 3 ++-
 3 files changed, 4 insertions(+), 9 deletions(-)

diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 174c76c725fb..4566724dc0c9 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -425,13 +425,6 @@ static inline void psock_progs_drop(struct sk_psock_progs *progs)
 	psock_set_prog(&progs->skb_verdict, NULL);
 }

-static inline int sk_psock_hooks_init(struct sk_psock_hooks *hooks,
-				       struct proto *ipv4_base)
-{
-	spin_lock_init(&hooks->ipv6_lock);
-	return hooks->rebuild_proto(hooks->ipv4, ipv4_base);
-}
-
 int sk_psock_hooks_install(struct sk_psock_hooks *hooks, struct sock *sk);

 #endif /* _LINUX_SKMSG_H */
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index fa7e474b981b..5cb9a0724cf6 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -570,13 +570,14 @@ static struct proto tcp_bpf_ipv6[TCP_BPF_NUM_CFGS];
 static struct sk_psock_hooks tcp_bpf_hooks __read_mostly = {
 	.ipv4 = &tcp_bpf_ipv4[0],
 	.ipv6 = &tcp_bpf_ipv6[0],
+	.ipv6_lock = __SPIN_LOCK_UNLOCKED(tcp_bpf_hooks.ipv6_lock),
 	.rebuild_proto = tcp_bpf_rebuild_proto,
 	.choose_proto = tcp_bpf_choose_proto,
 };

 static int __init tcp_bpf_init_psock_hooks(void)
 {
-	return sk_psock_hooks_init(&tcp_bpf_hooks, &tcp_prot);
+	return tcp_bpf_rebuild_proto(tcp_bpf_ipv4, &tcp_prot);
 }
 core_initcall(tcp_bpf_init_psock_hooks);

diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c
index e085a0648a94..da5eb1d2265d 100644
--- a/net/ipv4/udp_bpf.c
+++ b/net/ipv4/udp_bpf.c
@@ -30,13 +30,14 @@ static struct proto udpv6_proto;
 static struct sk_psock_hooks udp_psock_proto __read_mostly = {
 	.ipv4 = &udpv4_proto,
 	.ipv6 = &udpv6_proto,
+	.ipv6_lock = __SPIN_LOCK_UNLOCKED(udp_psock_proto.ipv6_lock),
 	.rebuild_proto = udp_bpf_rebuild_protos,
 	.choose_proto = udp_bpf_choose_proto,
 };

 static int __init udp_bpf_init_psock_hooks(void)
 {
-	return sk_psock_hooks_init(&udp_psock_proto, &udp_prot);
+	return udp_bpf_rebuild_protos(&udpv4_proto, &udp_prot);
 }
 core_initcall(udp_bpf_init_psock_hooks);
Lorenz Bauer Feb. 28, 2020, 10:48 a.m. UTC | #5
On Wed, 26 Feb 2020 at 18:37, Martin KaFai Lau <kafai@fb.com> wrote:
>
> > +int sk_psock_hooks_install(struct sk_psock_hooks *hooks, struct sock *sk)
> > +{
> > +     struct sk_psock *psock = sk_psock(sk);
> > +     struct proto *prot_base;
> > +
> > +     WARN_ON_ONCE(!rcu_read_lock_held());
> Is this only for the earlier sk_psock(sk)?

The function is an amalgamation of tcp_bpf_reinit and tcp_bpf_init,
which both take the
read lock. I figured it would make sense to assert this behaviour in
sk_psock_hooks_install.

>
> > +
> > +     if (unlikely(!psock))
> When will this happen?

I don't know to be honest, this is adapted from tcp_bpf_init:

       psock = sk_psock(sk);
       if (unlikely(!psock || psock->sk_proto ||
                    tcp_bpf_assert_proto_ops(ops))) {
               rcu_read_unlock();
               return -EINVAL;
       }

>
> > +             return -EINVAL;
> > +
> > +     /* Initialize saved callbacks and original proto only once.
> > +      * Since we've not installed the hooks, psock is not yet in use and
> > +      * we can initialize it without synchronization.
> > +      */
> > +     if (!psock->sk_proto) {
> If I read it correctly, this is to replace the tcp_bpf_reinit_sk_prot()?
>
> I think some of the current reinit comment is useful to keep also:
>
> /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed ... */

Ack, I will elaborate.
diff mbox series

Patch

diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index c881094387db..70d65ab10b5c 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -109,6 +109,16 @@  struct sk_psock {
 	};
 };
 
+struct sk_psock_hooks {
+	struct proto *base_ipv6;
+	struct proto *ipv4;
+	struct proto *ipv6;
+	spinlock_t ipv6_lock;
+	int (*rebuild_proto)(struct proto prot[], struct proto *base);
+	struct proto *(*choose_proto)(struct proto prot[],
+				      struct sk_psock *psock);
+};
+
 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
 		 int elem_first_coalesce);
 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
@@ -335,23 +345,14 @@  static inline void sk_psock_cork_free(struct sk_psock *psock)
 	}
 }
 
-static inline void sk_psock_update_proto(struct sock *sk,
-					 struct sk_psock *psock,
-					 struct proto *ops)
-{
-	psock->saved_unhash = sk->sk_prot->unhash;
-	psock->saved_close = sk->sk_prot->close;
-	psock->saved_write_space = sk->sk_write_space;
-
-	psock->sk_proto = sk->sk_prot;
-	/* Pairs with lockless read in sk_clone_lock() */
-	WRITE_ONCE(sk->sk_prot, ops);
-}
-
 static inline void sk_psock_restore_proto(struct sock *sk,
 					  struct sk_psock *psock)
 {
+	if (!psock->sk_proto)
+		return;
+
 	sk->sk_prot->unhash = psock->saved_unhash;
+
 	if (inet_sk(sk)->is_icsk) {
 		tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
 	} else {
@@ -424,4 +425,13 @@  static inline void psock_progs_drop(struct sk_psock_progs *progs)
 	psock_set_prog(&progs->skb_verdict, NULL);
 }
 
+static inline int sk_psock_hooks_init(struct sk_psock_hooks *hooks,
+				       struct proto *ipv4_base)
+{
+	hooks->ipv6_lock = __SPIN_LOCK_UNLOCKED();
+	return hooks->rebuild_proto(hooks->ipv4, ipv4_base);
+}
+
+int sk_psock_hooks_install(struct sk_psock_hooks *hooks, struct sock *sk);
+
 #endif /* _LINUX_SKMSG_H */
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 07f947cc80e6..ccf39d80b695 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -2196,7 +2196,6 @@  struct sk_msg;
 struct sk_psock;
 
 int tcp_bpf_init(struct sock *sk);
-void tcp_bpf_reinit(struct sock *sk);
 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
 			  int flags);
 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index eeb28cb85664..a9bdf02c2539 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -844,3 +844,55 @@  void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
 	strp_stop(&parser->strp);
 	parser->enabled = false;
 }
+
+static inline int sk_psock_hooks_init_ipv6(struct sk_psock_hooks *hooks,
+					    struct proto *base)
+{
+	int ret = 0;
+
+	if (likely(base == smp_load_acquire(&hooks->base_ipv6)))
+		return 0;
+
+	spin_lock_bh(&hooks->ipv6_lock);
+	if (likely(base != hooks->base_ipv6)) {
+		ret = hooks->rebuild_proto(hooks->ipv6, base);
+		if (!ret)
+			smp_store_release(&hooks->base_ipv6, base);
+	}
+	spin_unlock_bh(&hooks->ipv6_lock);
+	return ret;
+}
+
+int sk_psock_hooks_install(struct sk_psock_hooks *hooks, struct sock *sk)
+{
+	struct sk_psock *psock = sk_psock(sk);
+	struct proto *prot_base;
+
+	WARN_ON_ONCE(!rcu_read_lock_held());
+
+	if (unlikely(!psock))
+		return -EINVAL;
+
+	/* Initialize saved callbacks and original proto only once.
+	 * Since we've not installed the hooks, psock is not yet in use and
+	 * we can initialize it without synchronization.
+	 */
+	if (!psock->sk_proto) {
+		struct proto *prot = READ_ONCE(sk->sk_prot);
+
+		if (sk->sk_family == AF_INET6 &&
+		    sk_psock_hooks_init_ipv6(hooks, prot))
+			return -EINVAL;
+
+		psock->saved_unhash = prot->unhash;
+		psock->saved_close = prot->close;
+		psock->saved_write_space = sk->sk_write_space;
+
+		psock->sk_proto = prot;
+	}
+
+	/* Pairs with lockless read in sk_clone_lock() */
+	prot_base = sk->sk_family == AF_INET ? hooks->ipv4 : hooks->ipv6;
+	WRITE_ONCE(sk->sk_prot, hooks->choose_proto(prot_base, psock));
+	return 0;
+}
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index 459b3ba16023..c84cc9fc7f6b 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -170,8 +170,8 @@  static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 			 struct sock *sk)
 {
 	struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
-	bool skb_progs, sk_psock_is_new = false;
 	struct sk_psock *psock;
+	bool skb_progs;
 	int ret;
 
 	skb_verdict = READ_ONCE(progs->skb_verdict);
@@ -206,9 +206,8 @@  static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 	if (psock) {
 		if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
 		    (skb_progs  && READ_ONCE(psock->progs.skb_parser))) {
-			sk_psock_put(sk, psock);
 			ret = -EBUSY;
-			goto out_progs;
+			goto out_drop;
 		}
 	} else {
 		psock = sk_psock_init(sk, map->numa_node);
@@ -216,18 +215,14 @@  static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 			ret = -ENOMEM;
 			goto out_progs;
 		}
-		sk_psock_is_new = true;
 	}
 
 	if (msg_parser)
 		psock_set_prog(&psock->progs.msg_parser, msg_parser);
-	if (sk_psock_is_new) {
-		ret = tcp_bpf_init(sk);
-		if (ret < 0)
-			goto out_drop;
-	} else {
-		tcp_bpf_reinit(sk);
-	}
+
+	ret = tcp_bpf_init(sk);
+	if (ret < 0)
+		goto out_drop;
 
 	write_lock_bh(&sk->sk_callback_lock);
 	if (skb_progs && !psock->parser.enabled) {
@@ -264,15 +259,14 @@  static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
 	if (IS_ERR(psock))
 		return PTR_ERR(psock);
 
-	if (psock) {
-		tcp_bpf_reinit(sk);
-		return 0;
-	}
+	if (psock)
+		goto init;
 
 	psock = sk_psock_init(sk, map->numa_node);
 	if (!psock)
 		return -ENOMEM;
 
+init:
 	ret = tcp_bpf_init(sk);
 	if (ret < 0)
 		sk_psock_put(sk, psock);
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 90955c96a9a8..81c0431a8dbd 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -528,25 +528,23 @@  static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
 	return copied ? copied : err;
 }
 
-enum {
-	TCP_BPF_IPV4,
-	TCP_BPF_IPV6,
-	TCP_BPF_NUM_PROTS,
-};
-
 enum {
 	TCP_BPF_BASE,
 	TCP_BPF_TX,
 	TCP_BPF_NUM_CFGS,
 };
 
-static struct proto *tcpv6_prot_saved __read_mostly;
-static DEFINE_SPINLOCK(tcpv6_prot_lock);
-static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
-
-static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
-				   struct proto *base)
+static int tcp_bpf_rebuild_proto(struct proto prot[], struct proto *base)
 {
+	/* In order to avoid retpoline, we make assumptions when we call
+	 * into ops if e.g. a psock is not present. Make sure they are
+	 * indeed valid assumptions.
+	 */
+	if (base->recvmsg  != tcp_recvmsg ||
+	    base->sendmsg  != tcp_sendmsg ||
+	    base->sendpage != tcp_sendpage)
+		return -ENOTSUPP;
+
 	prot[TCP_BPF_BASE]			= *base;
 	prot[TCP_BPF_BASE].unhash		= sock_map_unhash;
 	prot[TCP_BPF_BASE].close		= sock_map_close;
@@ -556,91 +554,42 @@  static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
 	prot[TCP_BPF_TX]			= prot[TCP_BPF_BASE];
 	prot[TCP_BPF_TX].sendmsg		= tcp_bpf_sendmsg;
 	prot[TCP_BPF_TX].sendpage		= tcp_bpf_sendpage;
-}
-
-static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
-{
-	if (sk->sk_family == AF_INET6 &&
-	    unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
-		spin_lock_bh(&tcpv6_prot_lock);
-		if (likely(ops != tcpv6_prot_saved)) {
-			tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
-			smp_store_release(&tcpv6_prot_saved, ops);
-		}
-		spin_unlock_bh(&tcpv6_prot_lock);
-	}
-}
-
-static int __init tcp_bpf_v4_build_proto(void)
-{
-	tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
 	return 0;
 }
-core_initcall(tcp_bpf_v4_build_proto);
 
-static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
+static struct proto *tcp_bpf_choose_proto(struct proto prot[],
+					  struct sk_psock *psock)
 {
-	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
-	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
+	int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
 
-	sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
+	return &prot[config];
 }
 
-static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
-{
-	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
-	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
-
-	/* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
-	 * or added requiring sk_prot hook updates. We keep original saved
-	 * hooks in this case.
-	 *
-	 * Pairs with lockless read in sk_clone_lock().
-	 */
-	WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
-}
-
-static int tcp_bpf_assert_proto_ops(struct proto *ops)
-{
-	/* In order to avoid retpoline, we make assumptions when we call
-	 * into ops if e.g. a psock is not present. Make sure they are
-	 * indeed valid assumptions.
-	 */
-	return ops->recvmsg  == tcp_recvmsg &&
-	       ops->sendmsg  == tcp_sendmsg &&
-	       ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
-}
+static struct proto tcp_bpf_ipv4[TCP_BPF_NUM_CFGS];
+static struct proto tcp_bpf_ipv6[TCP_BPF_NUM_CFGS];
+static struct sk_psock_hooks tcp_bpf_hooks __read_mostly = {
+	.ipv4 = &tcp_bpf_ipv4[0],
+	.ipv6 = &tcp_bpf_ipv6[0],
+	.rebuild_proto = tcp_bpf_rebuild_proto,
+	.choose_proto = tcp_bpf_choose_proto,
+};
 
-void tcp_bpf_reinit(struct sock *sk)
+static int __init tcp_bpf_init_psock_hooks(void)
 {
-	struct sk_psock *psock;
-
-	sock_owned_by_me(sk);
-
-	rcu_read_lock();
-	psock = sk_psock(sk);
-	tcp_bpf_reinit_sk_prot(sk, psock);
-	rcu_read_unlock();
+	return sk_psock_hooks_init(&tcp_bpf_hooks, &tcp_prot);
 }
+core_initcall(tcp_bpf_init_psock_hooks);
 
 int tcp_bpf_init(struct sock *sk)
 {
-	struct proto *ops = READ_ONCE(sk->sk_prot);
-	struct sk_psock *psock;
+	int ret;
 
 	sock_owned_by_me(sk);
 
 	rcu_read_lock();
-	psock = sk_psock(sk);
-	if (unlikely(!psock || psock->sk_proto ||
-		     tcp_bpf_assert_proto_ops(ops))) {
-		rcu_read_unlock();
-		return -EINVAL;
-	}
-	tcp_bpf_check_v6_needs_rebuild(sk, ops);
-	tcp_bpf_update_sk_prot(sk, psock);
+	ret = sk_psock_hooks_install(&tcp_bpf_hooks, sk);
 	rcu_read_unlock();
-	return 0;
+	return ret;
 }
 
 /* If a child got cloned from a listening socket that had tcp_bpf
@@ -650,9 +599,10 @@  int tcp_bpf_init(struct sock *sk)
  */
 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
 {
-	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
-	struct proto *prot = newsk->sk_prot;
+	struct proto *prot = READ_ONCE(sk->sk_prot);
 
-	if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
+	/* TCP_LISTEN can only use TCP_BPF_BASE, so this is safe */
+	if (unlikely(prot == &tcp_bpf_ipv4[TCP_BPF_BASE] ||
+	             prot == &tcp_bpf_ipv6[TCP_BPF_BASE]))
 		newsk->sk_prot = sk->sk_prot_creator;
 }