[bpf-next,5/9] bpf: Enable BPF_PROG_TYPE_SK_REUSEPORT bpf prog in reuseport selection

Message ID 20180808080126.3014031-1-kafai@fb.com
State Accepted
Delegated to: BPF Maintainers
Headers show
Series
  • Introduce BPF_MAP_TYPE_REUSEPORT_SOCKARRAY and BPF_PROG_TYPE_SK_REUSEPORT
Related show

Commit Message

Martin KaFai Lau Aug. 8, 2018, 8:01 a.m.
This patch allows a BPF_PROG_TYPE_SK_REUSEPORT bpf prog to select a
SO_REUSEPORT sk from a BPF_MAP_TYPE_REUSEPORT_ARRAY introduced in
the earlier patch.  "bpf_run_sk_reuseport()" will return -ECONNREFUSED
when the BPF_PROG_TYPE_SK_REUSEPORT prog returns SK_DROP.
The callers, in inet[6]_hashtable.c and ipv[46]/udp.c, are modified to
handle this case and return NULL immediately instead of continuing the
sk search from its hashtable.

It re-uses the existing SO_ATTACH_REUSEPORT_EBPF setsockopt to attach
BPF_PROG_TYPE_SK_REUSEPORT.  The "sk_reuseport_attach_bpf()" will check
if the attaching bpf prog is in the new SK_REUSEPORT or the existing
SOCKET_FILTER type and then check different things accordingly.

One level of "__reuseport_attach_prog()" call is removed.  The
"sk_unhashed() && ..." and "sk->sk_reuseport_cb" tests are pushed
back to "reuseport_attach_prog()" in sock_reuseport.c.  sock_reuseport.c
seems to have more knowledge on those test requirements than filter.c.
In "reuseport_attach_prog()", after new_prog is attached to reuse->prog,
the old_prog (if any) is also directly freed instead of returning the
old_prog to the caller and asking the caller to free.

The sysctl_optmem_max check is moved back to the
"sk_reuseport_attach_filter()" and "sk_reuseport_attach_bpf()".
As of other bpf prog types, the new BPF_PROG_TYPE_SK_REUSEPORT is only
bounded by the usual "bpf_prog_charge_memlock()" during load time
instead of bounded by both bpf_prog_charge_memlock and sysctl_optmem_max.

Signed-off-by: Martin KaFai Lau <kafai@fb.com>
Acked-by: Alexei Starovoitov <ast@kernel.org>
---
 include/linux/filter.h       |  1 +
 include/net/sock_reuseport.h |  3 +-
 net/core/filter.c            | 87 +++++++++++++++++++++---------------
 net/core/sock_reuseport.c    | 36 ++++++++++-----
 net/ipv4/inet_hashtables.c   | 14 +++---
 net/ipv4/udp.c               |  4 ++
 net/ipv6/inet6_hashtables.c  | 14 +++---
 net/ipv6/udp.c               |  4 ++
 8 files changed, 106 insertions(+), 57 deletions(-)

Patch

diff --git a/include/linux/filter.h b/include/linux/filter.h
index 29577c6f3289..e44c531f2002 100644
--- a/include/linux/filter.h
+++ b/include/linux/filter.h
@@ -739,6 +739,7 @@  int sk_attach_filter(struct sock_fprog *fprog, struct sock *sk);
 int sk_attach_bpf(u32 ufd, struct sock *sk);
 int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk);
 int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk);
+void sk_reuseport_prog_free(struct bpf_prog *prog);
 int sk_detach_filter(struct sock *sk);
 int sk_get_filter(struct sock *sk, struct sock_filter __user *filter,
 		  unsigned int len);
diff --git a/include/net/sock_reuseport.h b/include/net/sock_reuseport.h
index 73b569556be6..8a5f70c7cdf2 100644
--- a/include/net/sock_reuseport.h
+++ b/include/net/sock_reuseport.h
@@ -34,8 +34,7 @@  extern struct sock *reuseport_select_sock(struct sock *sk,
 					  u32 hash,
 					  struct sk_buff *skb,
 					  int hdr_len);
-extern struct bpf_prog *reuseport_attach_prog(struct sock *sk,
-					      struct bpf_prog *prog);
+extern int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog);
 int reuseport_get_id(struct sock_reuseport *reuse);
 
 #endif  /* _SOCK_REUSEPORT_H */
diff --git a/net/core/filter.c b/net/core/filter.c
index f4c928709756..315db0a478f0 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -1453,30 +1453,6 @@  static int __sk_attach_prog(struct bpf_prog *prog, struct sock *sk)
 	return 0;
 }
 
-static int __reuseport_attach_prog(struct bpf_prog *prog, struct sock *sk)
-{
-	struct bpf_prog *old_prog;
-	int err;
-
-	if (bpf_prog_size(prog->len) > sysctl_optmem_max)
-		return -ENOMEM;
-
-	if (sk_unhashed(sk) && sk->sk_reuseport) {
-		err = reuseport_alloc(sk, false);
-		if (err)
-			return err;
-	} else if (!rcu_access_pointer(sk->sk_reuseport_cb)) {
-		/* The socket wasn't bound with SO_REUSEPORT */
-		return -EINVAL;
-	}
-
-	old_prog = reuseport_attach_prog(sk, prog);
-	if (old_prog)
-		bpf_prog_destroy(old_prog);
-
-	return 0;
-}
-
 static
 struct bpf_prog *__get_filter(struct sock_fprog *fprog, struct sock *sk)
 {
@@ -1550,13 +1526,15 @@  int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk)
 	if (IS_ERR(prog))
 		return PTR_ERR(prog);
 
-	err = __reuseport_attach_prog(prog, sk);
-	if (err < 0) {
+	if (bpf_prog_size(prog->len) > sysctl_optmem_max)
+		err = -ENOMEM;
+	else
+		err = reuseport_attach_prog(sk, prog);
+
+	if (err)
 		__bpf_prog_release(prog);
-		return err;
-	}
 
-	return 0;
+	return err;
 }
 
 static struct bpf_prog *__get_bpf(u32 ufd, struct sock *sk)
@@ -1586,19 +1564,58 @@  int sk_attach_bpf(u32 ufd, struct sock *sk)
 
 int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk)
 {
-	struct bpf_prog *prog = __get_bpf(ufd, sk);
+	struct bpf_prog *prog;
 	int err;
 
+	if (sock_flag(sk, SOCK_FILTER_LOCKED))
+		return -EPERM;
+
+	prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SOCKET_FILTER);
+	if (IS_ERR(prog) && PTR_ERR(prog) == -EINVAL)
+		prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SK_REUSEPORT);
 	if (IS_ERR(prog))
 		return PTR_ERR(prog);
 
-	err = __reuseport_attach_prog(prog, sk);
-	if (err < 0) {
-		bpf_prog_put(prog);
-		return err;
+	if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT) {
+		/* Like other non BPF_PROG_TYPE_SOCKET_FILTER
+		 * bpf prog (e.g. sockmap).  It depends on the
+		 * limitation imposed by bpf_prog_load().
+		 * Hence, sysctl_optmem_max is not checked.
+		 */
+		if ((sk->sk_type != SOCK_STREAM &&
+		     sk->sk_type != SOCK_DGRAM) ||
+		    (sk->sk_protocol != IPPROTO_UDP &&
+		     sk->sk_protocol != IPPROTO_TCP) ||
+		    (sk->sk_family != AF_INET &&
+		     sk->sk_family != AF_INET6)) {
+			err = -ENOTSUPP;
+			goto err_prog_put;
+		}
+	} else {
+		/* BPF_PROG_TYPE_SOCKET_FILTER */
+		if (bpf_prog_size(prog->len) > sysctl_optmem_max) {
+			err = -ENOMEM;
+			goto err_prog_put;
+		}
 	}
 
-	return 0;
+	err = reuseport_attach_prog(sk, prog);
+err_prog_put:
+	if (err)
+		bpf_prog_put(prog);
+
+	return err;
+}
+
+void sk_reuseport_prog_free(struct bpf_prog *prog)
+{
+	if (!prog)
+		return;
+
+	if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT)
+		bpf_prog_put(prog);
+	else
+		bpf_prog_destroy(prog);
 }
 
 struct bpf_scratchpad {
diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c
index d260167f5f77..ba5cba56f574 100644
--- a/net/core/sock_reuseport.c
+++ b/net/core/sock_reuseport.c
@@ -9,6 +9,7 @@ 
 #include <net/sock_reuseport.h>
 #include <linux/bpf.h>
 #include <linux/idr.h>
+#include <linux/filter.h>
 #include <linux/rcupdate.h>
 
 #define INIT_SOCKS 128
@@ -133,8 +134,7 @@  static void reuseport_free_rcu(struct rcu_head *head)
 	struct sock_reuseport *reuse;
 
 	reuse = container_of(head, struct sock_reuseport, rcu);
-	if (reuse->prog)
-		bpf_prog_destroy(reuse->prog);
+	sk_reuseport_prog_free(rcu_dereference_protected(reuse->prog, 1));
 	if (reuse->reuseport_id)
 		ida_simple_remove(&reuseport_ida, reuse->reuseport_id);
 	kfree(reuse);
@@ -219,9 +219,9 @@  void reuseport_detach_sock(struct sock *sk)
 }
 EXPORT_SYMBOL(reuseport_detach_sock);
 
-static struct sock *run_bpf(struct sock_reuseport *reuse, u16 socks,
-			    struct bpf_prog *prog, struct sk_buff *skb,
-			    int hdr_len)
+static struct sock *run_bpf_filter(struct sock_reuseport *reuse, u16 socks,
+				   struct bpf_prog *prog, struct sk_buff *skb,
+				   int hdr_len)
 {
 	struct sk_buff *nskb = NULL;
 	u32 index;
@@ -282,9 +282,15 @@  struct sock *reuseport_select_sock(struct sock *sk,
 		/* paired with smp_wmb() in reuseport_add_sock() */
 		smp_rmb();
 
-		if (prog && skb)
-			sk2 = run_bpf(reuse, socks, prog, skb, hdr_len);
+		if (!prog || !skb)
+			goto select_by_hash;
+
+		if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT)
+			sk2 = bpf_run_sk_reuseport(reuse, sk, prog, skb, hash);
+		else
+			sk2 = run_bpf_filter(reuse, socks, prog, skb, hdr_len);
 
+select_by_hash:
 		/* no bpf or invalid bpf result: fall back to hash usage */
 		if (!sk2)
 			sk2 = reuse->socks[reciprocal_scale(hash, socks)];
@@ -296,12 +302,21 @@  struct sock *reuseport_select_sock(struct sock *sk,
 }
 EXPORT_SYMBOL(reuseport_select_sock);
 
-struct bpf_prog *
-reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
+int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
 {
 	struct sock_reuseport *reuse;
 	struct bpf_prog *old_prog;
 
+	if (sk_unhashed(sk) && sk->sk_reuseport) {
+		int err = reuseport_alloc(sk, false);
+
+		if (err)
+			return err;
+	} else if (!rcu_access_pointer(sk->sk_reuseport_cb)) {
+		/* The socket wasn't bound with SO_REUSEPORT */
+		return -EINVAL;
+	}
+
 	spin_lock_bh(&reuseport_lock);
 	reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
 					  lockdep_is_held(&reuseport_lock));
@@ -310,6 +325,7 @@  reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
 	rcu_assign_pointer(reuse->prog, prog);
 	spin_unlock_bh(&reuseport_lock);
 
-	return old_prog;
+	sk_reuseport_prog_free(old_prog);
+	return 0;
 }
 EXPORT_SYMBOL(reuseport_attach_prog);
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index 370e24463fb7..f5c9ef2586de 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -328,7 +328,7 @@  struct sock *__inet_lookup_listener(struct net *net,
 				    saddr, sport, daddr, hnum,
 				    dif, sdif);
 	if (result)
-		return result;
+		goto done;
 
 	/* Lookup lhash2 with INADDR_ANY */
 
@@ -337,9 +337,10 @@  struct sock *__inet_lookup_listener(struct net *net,
 	if (ilb2->count > ilb->count)
 		goto port_lookup;
 
-	return inet_lhash2_lookup(net, ilb2, skb, doff,
-				  saddr, sport, daddr, hnum,
-				  dif, sdif);
+	result = inet_lhash2_lookup(net, ilb2, skb, doff,
+				    saddr, sport, daddr, hnum,
+				    dif, sdif);
+	goto done;
 
 port_lookup:
 	sk_for_each_rcu(sk, &ilb->head) {
@@ -352,12 +353,15 @@  struct sock *__inet_lookup_listener(struct net *net,
 				result = reuseport_select_sock(sk, phash,
 							       skb, doff);
 				if (result)
-					return result;
+					goto done;
 			}
 			result = sk;
 			hiscore = score;
 		}
 	}
+done:
+	if (unlikely(IS_ERR(result)))
+		return NULL;
 	return result;
 }
 EXPORT_SYMBOL_GPL(__inet_lookup_listener);
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 038dd7909051..f4e35b2ff8b8 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -499,6 +499,8 @@  struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
 						  daddr, hnum, dif, sdif,
 						  exact_dif, hslot2, skb);
 		}
+		if (unlikely(IS_ERR(result)))
+			return NULL;
 		return result;
 	}
 begin:
@@ -513,6 +515,8 @@  struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
 						   saddr, sport);
 				result = reuseport_select_sock(sk, hash, skb,
 							sizeof(struct udphdr));
+				if (unlikely(IS_ERR(result)))
+					return NULL;
 				if (result)
 					return result;
 			}
diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
index 595ad408dba0..3d7c7460a0c5 100644
--- a/net/ipv6/inet6_hashtables.c
+++ b/net/ipv6/inet6_hashtables.c
@@ -191,7 +191,7 @@  struct sock *inet6_lookup_listener(struct net *net,
 				     saddr, sport, daddr, hnum,
 				     dif, sdif);
 	if (result)
-		return result;
+		goto done;
 
 	/* Lookup lhash2 with in6addr_any */
 
@@ -200,9 +200,10 @@  struct sock *inet6_lookup_listener(struct net *net,
 	if (ilb2->count > ilb->count)
 		goto port_lookup;
 
-	return inet6_lhash2_lookup(net, ilb2, skb, doff,
-				   saddr, sport, daddr, hnum,
-				   dif, sdif);
+	result = inet6_lhash2_lookup(net, ilb2, skb, doff,
+				     saddr, sport, daddr, hnum,
+				     dif, sdif);
+	goto done;
 
 port_lookup:
 	sk_for_each(sk, &ilb->head) {
@@ -214,12 +215,15 @@  struct sock *inet6_lookup_listener(struct net *net,
 				result = reuseport_select_sock(sk, phash,
 							       skb, doff);
 				if (result)
-					return result;
+					goto done;
 			}
 			result = sk;
 			hiscore = score;
 		}
 	}
+done:
+	if (unlikely(IS_ERR(result)))
+		return NULL;
 	return result;
 }
 EXPORT_SYMBOL_GPL(inet6_lookup_listener);
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index f6b96956a8ed..83f4c77c79d8 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -235,6 +235,8 @@  struct sock *__udp6_lib_lookup(struct net *net,
 						  exact_dif, hslot2,
 						  skb);
 		}
+		if (unlikely(IS_ERR(result)))
+			return NULL;
 		return result;
 	}
 begin:
@@ -249,6 +251,8 @@  struct sock *__udp6_lib_lookup(struct net *net,
 						    saddr, sport);
 				result = reuseport_select_sock(sk, hash, skb,
 							sizeof(struct udphdr));
+				if (unlikely(IS_ERR(result)))
+					return NULL;
 				if (result)
 					return result;
 			}