diff mbox series

[bpf-next,v3,03/12] bpf: tcp: move assertions into tcp_bpf_get_proto

Message ID 20200304101318.5225-4-lmb@cloudflare.com
State Changes Requested
Delegated to: BPF Maintainers
Headers show
Series bpf: sockmap, sockhash: support storing UDP sockets | expand

Commit Message

Lorenz Bauer March 4, 2020, 10:13 a.m. UTC
We need to ensure that sk->sk_prot uses certain callbacks, so that
code that directly calls e.g. tcp_sendmsg in certain corner cases
works. To avoid spurious asserts, we must to do this only if
sk_psock_update_proto has not yet been called. The same invariants
apply for tcp_bpf_check_v6_needs_rebuild, so move the call as well.

Doing so allows us to merge tcp_bpf_init and tcp_bpf_reinit.

Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
---
 include/net/tcp.h   |  1 -
 net/core/sock_map.c | 21 +++++++--------------
 net/ipv4/tcp_bpf.c  | 42 ++++++++++++++++++++++--------------------
 3 files changed, 29 insertions(+), 35 deletions(-)

Comments

Jakub Sitnicki March 5, 2020, 12:10 p.m. UTC | #1
On Wed, Mar 04, 2020 at 11:13 AM CET, Lorenz Bauer wrote:
> We need to ensure that sk->sk_prot uses certain callbacks, so that
> code that directly calls e.g. tcp_sendmsg in certain corner cases
> works. To avoid spurious asserts, we must to do this only if
> sk_psock_update_proto has not yet been called. The same invariants
> apply for tcp_bpf_check_v6_needs_rebuild, so move the call as well.
>
> Doing so allows us to merge tcp_bpf_init and tcp_bpf_reinit.
>
> Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
> ---

Reviewed-by: Jakub Sitnicki <jakub@cloudflare.com>

[...]
John Fastabend March 6, 2020, 3:25 p.m. UTC | #2
Lorenz Bauer wrote:
> We need to ensure that sk->sk_prot uses certain callbacks, so that
> code that directly calls e.g. tcp_sendmsg in certain corner cases
> works. To avoid spurious asserts, we must to do this only if
> sk_psock_update_proto has not yet been called. The same invariants
> apply for tcp_bpf_check_v6_needs_rebuild, so move the call as well.
> 
> Doing so allows us to merge tcp_bpf_init and tcp_bpf_reinit.
> 
> Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>

Small nit if you update it just carry the acks through.

Acked-by: John Fastabend <john.fastabend@gmail.com>
 
>  	skb_verdict = READ_ONCE(progs->skb_verdict);
> @@ -191,18 +191,14 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
>  			ret = -ENOMEM;
>  			goto out_progs;
>  		}
> -		sk_psock_is_new = true;
>  	}
>  
>  	if (msg_parser)
>  		psock_set_prog(&psock->progs.msg_parser, msg_parser);
> -	if (sk_psock_is_new) {
> -		ret = tcp_bpf_init(sk);
> -		if (ret < 0)
> -			goto out_drop;
> -	} else {
> -		tcp_bpf_reinit(sk);
> -	}
> +
> +	ret = tcp_bpf_init(sk);
> +	if (ret < 0)
> +		goto out_drop;
>  
>  	write_lock_bh(&sk->sk_callback_lock);
>  	if (skb_progs && !psock->parser.enabled) {
> @@ -239,12 +235,9 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
>  	if (IS_ERR(psock))
>  		return PTR_ERR(psock);
>  
> -	if (psock) {
> -		tcp_bpf_reinit(sk);
> -		return 0;
> -	}
> +	if (!psock)
> +		psock = sk_psock_init(sk, map->numa_node);
>  
> -	psock = sk_psock_init(sk, map->numa_node);
>  	if (!psock)
>  		return -ENOMEM;

also small nit this reads,

 if (!psock)
    psock = ...
 if (!psock)
    return -ENOMEM

how about,

 if (!psock) {
   psock = ...
   if (!psock) return -ENOMEM;
 }
diff mbox series

Patch

diff --git a/include/net/tcp.h b/include/net/tcp.h
index 07f947cc80e6..ccf39d80b695 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -2196,7 +2196,6 @@  struct sk_msg;
 struct sk_psock;
 
 int tcp_bpf_init(struct sock *sk);
-void tcp_bpf_reinit(struct sock *sk);
 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
 			  int flags);
 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index cb8f740f7949..bca560a79b5b 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -145,8 +145,8 @@  static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 			 struct sock *sk)
 {
 	struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
-	bool skb_progs, sk_psock_is_new = false;
 	struct sk_psock *psock;
+	bool skb_progs;
 	int ret;
 
 	skb_verdict = READ_ONCE(progs->skb_verdict);
@@ -191,18 +191,14 @@  static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 			ret = -ENOMEM;
 			goto out_progs;
 		}
-		sk_psock_is_new = true;
 	}
 
 	if (msg_parser)
 		psock_set_prog(&psock->progs.msg_parser, msg_parser);
-	if (sk_psock_is_new) {
-		ret = tcp_bpf_init(sk);
-		if (ret < 0)
-			goto out_drop;
-	} else {
-		tcp_bpf_reinit(sk);
-	}
+
+	ret = tcp_bpf_init(sk);
+	if (ret < 0)
+		goto out_drop;
 
 	write_lock_bh(&sk->sk_callback_lock);
 	if (skb_progs && !psock->parser.enabled) {
@@ -239,12 +235,9 @@  static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
 	if (IS_ERR(psock))
 		return PTR_ERR(psock);
 
-	if (psock) {
-		tcp_bpf_reinit(sk);
-		return 0;
-	}
+	if (!psock)
+		psock = sk_psock_init(sk, map->numa_node);
 
-	psock = sk_psock_init(sk, map->numa_node);
 	if (!psock)
 		return -ENOMEM;
 
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 3327afa05c3d..ed8a8f3c9afe 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -629,14 +629,6 @@  static int __init tcp_bpf_v4_build_proto(void)
 }
 core_initcall(tcp_bpf_v4_build_proto);
 
-static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
-{
-	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
-	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
-
-	sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
-}
-
 static int tcp_bpf_assert_proto_ops(struct proto *ops)
 {
 	/* In order to avoid retpoline, we make assumptions when we call
@@ -648,34 +640,44 @@  static int tcp_bpf_assert_proto_ops(struct proto *ops)
 	       ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 }
 
-void tcp_bpf_reinit(struct sock *sk)
+static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
 {
-	struct sk_psock *psock;
+	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
+	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 
-	sock_owned_by_me(sk);
+	if (!psock->sk_proto) {
+		struct proto *ops = READ_ONCE(sk->sk_prot);
 
-	rcu_read_lock();
-	psock = sk_psock(sk);
-	tcp_bpf_update_sk_prot(sk, psock);
-	rcu_read_unlock();
+		if (tcp_bpf_assert_proto_ops(ops))
+			return ERR_PTR(-EINVAL);
+
+		tcp_bpf_check_v6_needs_rebuild(sk, ops);
+	}
+
+	return &tcp_bpf_prots[family][config];
 }
 
 int tcp_bpf_init(struct sock *sk)
 {
-	struct proto *ops = READ_ONCE(sk->sk_prot);
 	struct sk_psock *psock;
+	struct proto *prot;
 
 	sock_owned_by_me(sk);
 
 	rcu_read_lock();
 	psock = sk_psock(sk);
-	if (unlikely(!psock || psock->sk_proto ||
-		     tcp_bpf_assert_proto_ops(ops))) {
+	if (unlikely(!psock)) {
 		rcu_read_unlock();
 		return -EINVAL;
 	}
-	tcp_bpf_check_v6_needs_rebuild(sk, ops);
-	tcp_bpf_update_sk_prot(sk, psock);
+
+	prot = tcp_bpf_get_proto(sk, psock);
+	if (IS_ERR(prot)) {
+		rcu_read_unlock();
+		return PTR_ERR(prot);
+	}
+
+	sk_psock_update_proto(sk, psock, prot);
 	rcu_read_unlock();
 	return 0;
 }