diff mbox series

[RFC,bpf-next,3/7] ipv6: Run inet_lookup bpf program on socket lookup

Message ID 20190618130050.8344-4-jakub@cloudflare.com
State RFC
Delegated to: BPF Maintainers
Headers show
Series Programming socket lookup with BPF | expand

Commit Message

Jakub Sitnicki June 18, 2019, 1 p.m. UTC
Following the ipv4 changes, run a BPF program attached to netns in context
of which we're doing the socket lookup so that it can rewrite the
destination IP and port we use as keys for the lookup.

The program is called before the listening socket lookup for TCP, and
before connected or not-connected socket lookup for UDP.

Suggested-by: Marek Majkowski <marek@cloudflare.com>
Signed-off-by: Jakub Sitnicki <jakub@cloudflare.com>
---
 include/net/inet6_hashtables.h | 39 ++++++++++++++++++++++++++++++++++
 net/ipv6/inet6_hashtables.c    | 11 ++++++----
 net/ipv6/udp.c                 |  6 ++++--
 3 files changed, 50 insertions(+), 6 deletions(-)
diff mbox series

Patch

diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h
index 9db98af46985..ab06961d33a9 100644
--- a/include/net/inet6_hashtables.h
+++ b/include/net/inet6_hashtables.h
@@ -108,6 +108,45 @@  struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
 			  const int dif);
 
 int inet6_hash(struct sock *sk);
+
+#ifdef CONFIG_BPF_SYSCALL
+static inline void inet6_lookup_run_bpf(struct net *net,
+					const struct in6_addr *saddr,
+					const __be16 sport,
+					struct in6_addr *daddr,
+					unsigned short *hnum)
+{
+	struct bpf_inet_lookup_kern ctx = {
+		.family	= AF_INET6,
+		.saddr6	= *saddr,
+		.sport	= sport,
+		.daddr6	= *daddr,
+		.hnum	= *hnum,
+	};
+	struct bpf_prog *prog;
+	int ret = BPF_OK;
+
+	rcu_read_lock();
+	prog = rcu_dereference(net->inet_lookup_prog);
+	if (prog)
+		ret = BPF_PROG_RUN(prog, &ctx);
+	rcu_read_unlock();
+
+	if (ret == BPF_REDIRECT) {
+		*daddr = ctx.daddr6;
+		*hnum = ctx.hnum;
+	}
+}
+#else
+static inline void inet6_lookup_run_bpf(struct net *net,
+					const struct in6_addr *saddr,
+					const __be16 sport,
+					struct in6_addr *daddr,
+					unsigned short *hnum)
+{
+}
+#endif /* CONFIG_BPF_SYSCALL */
+
 #endif /* IS_ENABLED(CONFIG_IPV6) */
 
 #define INET6_MATCH(__sk, __net, __saddr, __daddr, __ports, __dif, __sdif) \
diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
index f3515ebe9b3a..280a9b8bf914 100644
--- a/net/ipv6/inet6_hashtables.c
+++ b/net/ipv6/inet6_hashtables.c
@@ -158,24 +158,27 @@  struct sock *inet6_lookup_listener(struct net *net,
 		const unsigned short hnum, const int dif, const int sdif)
 {
 	struct inet_listen_hashbucket *ilb2;
+	struct in6_addr daddr2 = *daddr;
+	unsigned short hnum2 = hnum;
 	struct sock *result = NULL;
 	unsigned int hash2;
 
-	hash2 = ipv6_portaddr_hash(net, daddr, hnum);
+	inet6_lookup_run_bpf(net, saddr, sport, &daddr2, &hnum2);
+	hash2 = ipv6_portaddr_hash(net, &daddr2, hnum2);
 	ilb2 = inet_lhash2_bucket(hashinfo, hash2);
 
 	result = inet6_lhash2_lookup(net, ilb2, skb, doff,
-				     saddr, sport, daddr, hnum,
+				     saddr, sport, &daddr2, hnum2,
 				     dif, sdif);
 	if (result)
 		goto done;
 
 	/* Lookup lhash2 with in6addr_any */
-	hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
+	hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum2);
 	ilb2 = inet_lhash2_bucket(hashinfo, hash2);
 
 	result = inet6_lhash2_lookup(net, ilb2, skb, doff,
-				     saddr, sport, &in6addr_any, hnum,
+				     saddr, sport, &in6addr_any, hnum2,
 				     dif, sdif);
 done:
 	if (unlikely(IS_ERR(result)))
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 07fa579dfb96..6c0030ba83c6 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -196,17 +196,19 @@  struct sock *__udp6_lib_lookup(struct net *net,
 			       struct sk_buff *skb)
 {
 	unsigned short hnum = ntohs(dport);
+	struct in6_addr daddr2 = *daddr;
 	unsigned int hash2, slot2;
 	struct udp_hslot *hslot2;
 	struct sock *result;
 	bool exact_dif = udp6_lib_exact_dif_match(net, skb);
 
-	hash2 = ipv6_portaddr_hash(net, daddr, hnum);
+	inet6_lookup_run_bpf(net, saddr, sport, &daddr2, &hnum);
+	hash2 = ipv6_portaddr_hash(net, &daddr2, hnum);
 	slot2 = hash2 & udptable->mask;
 	hslot2 = &udptable->hash2[slot2];
 
 	result = udp6_lib_lookup2(net, saddr, sport,
-				  daddr, hnum, dif, sdif, exact_dif,
+				  &daddr2, hnum, dif, sdif, exact_dif,
 				  hslot2, skb);
 	if (!result) {
 		hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);