diff mbox series

[bpf-next,4/7] bpf: sockmap: allow UDP sockets

Message ID 20200225135636.5768-5-lmb@cloudflare.com
State Changes Requested
Delegated to: BPF Maintainers
Headers show
Series bpf: sockmap, sockhash: support storing UDP sockets | expand

Commit Message

Lorenz Bauer Feb. 25, 2020, 1:56 p.m. UTC
Add basic psock hooks for UDP sockets. This allows adding and
removing sockets, as well as automatic removal on unhash and close.

Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
---
 MAINTAINERS         |  1 +
 include/linux/udp.h |  4 ++++
 net/core/sock_map.c | 47 +++++++++++++++++++++++-----------------
 net/ipv4/Makefile   |  1 +
 net/ipv4/udp_bpf.c  | 53 +++++++++++++++++++++++++++++++++++++++++++++
 5 files changed, 86 insertions(+), 20 deletions(-)
 create mode 100644 net/ipv4/udp_bpf.c

Comments

Martin KaFai Lau Feb. 26, 2020, 6:47 p.m. UTC | #1
On Tue, Feb 25, 2020 at 01:56:33PM +0000, Lorenz Bauer wrote:
> Add basic psock hooks for UDP sockets. This allows adding and
> removing sockets, as well as automatic removal on unhash and close.
> 
> Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
> ---
>  MAINTAINERS         |  1 +
>  include/linux/udp.h |  4 ++++
>  net/core/sock_map.c | 47 +++++++++++++++++++++++-----------------
>  net/ipv4/Makefile   |  1 +
>  net/ipv4/udp_bpf.c  | 53 +++++++++++++++++++++++++++++++++++++++++++++
>  5 files changed, 86 insertions(+), 20 deletions(-)
>  create mode 100644 net/ipv4/udp_bpf.c
> 
> diff --git a/MAINTAINERS b/MAINTAINERS
> index 2af5fa73155e..495ba52038ad 100644
> --- a/MAINTAINERS
> +++ b/MAINTAINERS
> @@ -9358,6 +9358,7 @@ F:	include/linux/skmsg.h
>  F:	net/core/skmsg.c
>  F:	net/core/sock_map.c
>  F:	net/ipv4/tcp_bpf.c
> +F:	net/ipv4/udp_bpf.c
>  
>  LANTIQ / INTEL Ethernet drivers
>  M:	Hauke Mehrtens <hauke@hauke-m.de>
> diff --git a/include/linux/udp.h b/include/linux/udp.h
> index aa84597bdc33..d90d8fd5f73d 100644
> --- a/include/linux/udp.h
> +++ b/include/linux/udp.h
> @@ -143,4 +143,8 @@ static inline bool udp_unexpected_gso(struct sock *sk, struct sk_buff *skb)
>  
>  #define IS_UDPLITE(__sk) (__sk->sk_protocol == IPPROTO_UDPLITE)
>  
> +#if defined(CONFIG_NET_SOCK_MSG)
> +int udp_bpf_init(struct sock *sk);
> +#endif
> +
>  #endif	/* _LINUX_UDP_H */
> diff --git a/net/core/sock_map.c b/net/core/sock_map.c
> index c84cc9fc7f6b..f998192c425f 100644
> --- a/net/core/sock_map.c
> +++ b/net/core/sock_map.c
> @@ -153,7 +153,7 @@ static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
>  	rcu_read_lock();
>  	psock = sk_psock(sk);
>  	if (psock) {
> -		if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
> +		if (sk->sk_prot->close != sock_map_close) {
>  			psock = ERR_PTR(-EBUSY);
>  			goto out;
>  		}
> @@ -166,6 +166,14 @@ static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
>  	return psock;
>  }
>  
> +static int sock_map_init_hooks(struct sock *sk)
> +{
> +	if (sk->sk_type == SOCK_DGRAM)
> +		return udp_bpf_init(sk);
> +	else
> +		return tcp_bpf_init(sk);
> +}
> +
>  static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
>  			 struct sock *sk)
>  {
> @@ -220,7 +228,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
>  	if (msg_parser)
>  		psock_set_prog(&psock->progs.msg_parser, msg_parser);
>  
> -	ret = tcp_bpf_init(sk);
> +	ret = sock_map_init_hooks(sk);
>  	if (ret < 0)
>  		goto out_drop;
>  
> @@ -267,7 +275,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
>  		return -ENOMEM;
>  
>  init:
> -	ret = tcp_bpf_init(sk);
> +	ret = sock_map_init_hooks(sk);
>  	if (ret < 0)
>  		sk_psock_put(sk, psock);
>  	return ret;
> @@ -394,9 +402,14 @@ static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
>  	return 0;
>  }
>  
> +static bool sock_map_sk_is_tcp(const struct sock *sk)
> +{
> +	return sk->sk_type == SOCK_STREAM && sk->sk_protocol == IPPROTO_TCP;
> +}
> +
>  static bool sock_map_redirect_allowed(const struct sock *sk)
>  {
> -	return sk->sk_state != TCP_LISTEN;
> +	return sock_map_sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN;
>  }
>  
>  static int sock_map_update_common(struct bpf_map *map, u32 idx,
> @@ -466,15 +479,17 @@ static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
>  	       ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
>  }
>  
> +static bool sock_map_sk_is_udp(const struct sock *sk)
> +{
> +	return sk->sk_type == SOCK_DGRAM && sk->sk_protocol == IPPROTO_UDP;
> +}
> +
>  static bool sock_map_sk_is_suitable(const struct sock *sk)
>  {
> -	return sk->sk_type == SOCK_STREAM &&
> -	       sk->sk_protocol == IPPROTO_TCP;
> -}
> +	const int tcp_flags = TCPF_ESTABLISHED | TCPF_LISTEN | TCPF_SYN_RECV;
hmm... I thought this patch is for adding UDP only.  However, if I read
it correctly, the tcp_flags is changed (| TCPF_SYN_RECV)?
Please elaborate in the commit message.

>  
> -static bool sock_map_sk_state_allowed(const struct sock *sk)
> -{
> -	return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
> +	return (sock_map_sk_is_udp(sk) && sk_hashed(sk)) ||
> +	       (sock_map_sk_is_tcp(sk) && (1 << sk->sk_state) & tcp_flags);
>  }
>  
>  static int sock_map_update_elem(struct bpf_map *map, void *key,
diff mbox series

Patch

diff --git a/MAINTAINERS b/MAINTAINERS
index 2af5fa73155e..495ba52038ad 100644
--- a/MAINTAINERS
+++ b/MAINTAINERS
@@ -9358,6 +9358,7 @@  F:	include/linux/skmsg.h
 F:	net/core/skmsg.c
 F:	net/core/sock_map.c
 F:	net/ipv4/tcp_bpf.c
+F:	net/ipv4/udp_bpf.c
 
 LANTIQ / INTEL Ethernet drivers
 M:	Hauke Mehrtens <hauke@hauke-m.de>
diff --git a/include/linux/udp.h b/include/linux/udp.h
index aa84597bdc33..d90d8fd5f73d 100644
--- a/include/linux/udp.h
+++ b/include/linux/udp.h
@@ -143,4 +143,8 @@  static inline bool udp_unexpected_gso(struct sock *sk, struct sk_buff *skb)
 
 #define IS_UDPLITE(__sk) (__sk->sk_protocol == IPPROTO_UDPLITE)
 
+#if defined(CONFIG_NET_SOCK_MSG)
+int udp_bpf_init(struct sock *sk);
+#endif
+
 #endif	/* _LINUX_UDP_H */
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index c84cc9fc7f6b..f998192c425f 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -153,7 +153,7 @@  static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
 	rcu_read_lock();
 	psock = sk_psock(sk);
 	if (psock) {
-		if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
+		if (sk->sk_prot->close != sock_map_close) {
 			psock = ERR_PTR(-EBUSY);
 			goto out;
 		}
@@ -166,6 +166,14 @@  static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
 	return psock;
 }
 
+static int sock_map_init_hooks(struct sock *sk)
+{
+	if (sk->sk_type == SOCK_DGRAM)
+		return udp_bpf_init(sk);
+	else
+		return tcp_bpf_init(sk);
+}
+
 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 			 struct sock *sk)
 {
@@ -220,7 +228,7 @@  static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 	if (msg_parser)
 		psock_set_prog(&psock->progs.msg_parser, msg_parser);
 
-	ret = tcp_bpf_init(sk);
+	ret = sock_map_init_hooks(sk);
 	if (ret < 0)
 		goto out_drop;
 
@@ -267,7 +275,7 @@  static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
 		return -ENOMEM;
 
 init:
-	ret = tcp_bpf_init(sk);
+	ret = sock_map_init_hooks(sk);
 	if (ret < 0)
 		sk_psock_put(sk, psock);
 	return ret;
@@ -394,9 +402,14 @@  static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
 	return 0;
 }
 
+static bool sock_map_sk_is_tcp(const struct sock *sk)
+{
+	return sk->sk_type == SOCK_STREAM && sk->sk_protocol == IPPROTO_TCP;
+}
+
 static bool sock_map_redirect_allowed(const struct sock *sk)
 {
-	return sk->sk_state != TCP_LISTEN;
+	return sock_map_sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN;
 }
 
 static int sock_map_update_common(struct bpf_map *map, u32 idx,
@@ -466,15 +479,17 @@  static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
 	       ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
 }
 
+static bool sock_map_sk_is_udp(const struct sock *sk)
+{
+	return sk->sk_type == SOCK_DGRAM && sk->sk_protocol == IPPROTO_UDP;
+}
+
 static bool sock_map_sk_is_suitable(const struct sock *sk)
 {
-	return sk->sk_type == SOCK_STREAM &&
-	       sk->sk_protocol == IPPROTO_TCP;
-}
+	const int tcp_flags = TCPF_ESTABLISHED | TCPF_LISTEN | TCPF_SYN_RECV;
 
-static bool sock_map_sk_state_allowed(const struct sock *sk)
-{
-	return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
+	return (sock_map_sk_is_udp(sk) && sk_hashed(sk)) ||
+	       (sock_map_sk_is_tcp(sk) && (1 << sk->sk_state) & tcp_flags);
 }
 
 static int sock_map_update_elem(struct bpf_map *map, void *key,
@@ -501,13 +516,9 @@  static int sock_map_update_elem(struct bpf_map *map, void *key,
 		ret = -EINVAL;
 		goto out;
 	}
-	if (!sock_map_sk_is_suitable(sk)) {
-		ret = -EOPNOTSUPP;
-		goto out;
-	}
 
 	sock_map_sk_acquire(sk);
-	if (!sock_map_sk_state_allowed(sk))
+	if (!sock_map_sk_is_suitable(sk))
 		ret = -EOPNOTSUPP;
 	else
 		ret = sock_map_update_common(map, idx, sk, flags);
@@ -849,13 +860,9 @@  static int sock_hash_update_elem(struct bpf_map *map, void *key,
 		ret = -EINVAL;
 		goto out;
 	}
-	if (!sock_map_sk_is_suitable(sk)) {
-		ret = -EOPNOTSUPP;
-		goto out;
-	}
 
 	sock_map_sk_acquire(sk);
-	if (!sock_map_sk_state_allowed(sk))
+	if (!sock_map_sk_is_suitable(sk))
 		ret = -EOPNOTSUPP;
 	else
 		ret = sock_hash_update_common(map, key, sk, flags);
diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile
index 9d97bace13c8..48cc05d365e4 100644
--- a/net/ipv4/Makefile
+++ b/net/ipv4/Makefile
@@ -61,6 +61,7 @@  obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o
 obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o
 obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o
 obj-$(CONFIG_NET_SOCK_MSG) += tcp_bpf.o
+obj-$(CONFIG_NET_SOCK_MSG) += udp_bpf.o
 obj-$(CONFIG_NETLABEL) += cipso_ipv4.o
 
 obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \
diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c
new file mode 100644
index 000000000000..e085a0648a94
--- /dev/null
+++ b/net/ipv4/udp_bpf.c
@@ -0,0 +1,53 @@ 
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */
+
+#include <linux/bpf.h>
+#include <linux/filter.h>
+#include <linux/init.h>
+#include <linux/skmsg.h>
+#include <linux/wait.h>
+#include <net/udp.h>
+
+#include <net/inet_common.h>
+
+static int udp_bpf_rebuild_protos(struct proto *prot, struct proto *base)
+{
+	*prot        = *base;
+	prot->unhash = sock_map_unhash;
+	prot->close  = sock_map_close;
+	return 0;
+}
+
+static struct proto *udp_bpf_choose_proto(struct proto prot[],
+					  struct sk_psock *psock)
+{
+	return prot;
+}
+
+static struct proto udpv4_proto;
+static struct proto udpv6_proto;
+
+static struct sk_psock_hooks udp_psock_proto __read_mostly = {
+	.ipv4 = &udpv4_proto,
+	.ipv6 = &udpv6_proto,
+	.rebuild_proto = udp_bpf_rebuild_protos,
+	.choose_proto = udp_bpf_choose_proto,
+};
+
+static int __init udp_bpf_init_psock_hooks(void)
+{
+	return sk_psock_hooks_init(&udp_psock_proto, &udp_prot);
+}
+core_initcall(udp_bpf_init_psock_hooks);
+
+int udp_bpf_init(struct sock *sk)
+{
+	int ret;
+
+	sock_owned_by_me(sk);
+
+	rcu_read_lock();
+	ret = sk_psock_hooks_install(&udp_psock_proto, sk);
+	rcu_read_unlock();
+	return ret;
+}