diff mbox series

[RFCv2,bpf-next,08/12] udp: Run inet_lookup bpf program on socket lookup

Message ID 20190828072250.29828-9-jakub@cloudflare.com
State RFC
Delegated to: BPF Maintainers
Headers show
Series Programming socket lookup with BPF | expand

Commit Message

Jakub Sitnicki Aug. 28, 2019, 7:22 a.m. UTC
Following the TCP socket lookup changes, allow selecting the receiving
socket from BPF before searching for bound socket by destination address
and port.

As connected and bound but non-connected socket lookup currently happens in
one step, we split the lookup in two phases to run BPF only after a lookup
for a connected socket was a miss. Hence making sure connected UDP sockets
continue to work as expected in presence of a BPF inet_lookup program.

Suggested-by: Marek Majkowski <marek@cloudflare.com>
Reviewed-by: Lorenz Bauer <lmb@cloudflare.com>
Signed-off-by: Jakub Sitnicki <jakub@cloudflare.com>
---
 net/ipv4/udp.c | 44 ++++++++++++++++++++++++++++++++------------
 1 file changed, 32 insertions(+), 12 deletions(-)
diff mbox series

Patch

diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 9fffe9e9eec6..3a4b98f89249 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -353,7 +353,7 @@  int udp_v4_get_port(struct sock *sk, unsigned short snum)
 static int compute_score(struct sock *sk, struct net *net,
 			 __be32 saddr, __be16 sport,
 			 __be32 daddr, unsigned short hnum,
-			 int dif, int sdif)
+			 int dif, int sdif, unsigned char state)
 {
 	int score;
 	struct inet_sock *inet;
@@ -364,6 +364,9 @@  static int compute_score(struct sock *sk, struct net *net,
 	    ipv6_only_sock(sk))
 		return -1;
 
+	if (state && sk->sk_state != state)
+		return -1;
+
 	if (sk->sk_rcv_saddr != daddr)
 		return -1;
 
@@ -411,7 +414,8 @@  static struct sock *udp4_lib_lookup2(struct net *net,
 				     __be32 daddr, unsigned int hnum,
 				     int dif, int sdif,
 				     struct udp_hslot *hslot2,
-				     struct sk_buff *skb)
+				     struct sk_buff *skb,
+				     unsigned char state)
 {
 	struct sock *sk, *result;
 	int score, badness;
@@ -421,7 +425,7 @@  static struct sock *udp4_lib_lookup2(struct net *net,
 	badness = 0;
 	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
 		score = compute_score(sk, net, saddr, sport,
-				      daddr, hnum, dif, sdif);
+				      daddr, hnum, dif, sdif, state);
 		if (score > badness) {
 			if (sk->sk_reuseport) {
 				hash = udp_ehashfn(net, daddr, hnum,
@@ -454,18 +458,34 @@  struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
 	slot2 = hash2 & udptable->mask;
 	hslot2 = &udptable->hash2[slot2];
 
+	/* Lookup connected sockets */
 	result = udp4_lib_lookup2(net, saddr, sport,
 				  daddr, hnum, dif, sdif,
-				  hslot2, skb);
-	if (!result) {
-		hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
-		slot2 = hash2 & udptable->mask;
-		hslot2 = &udptable->hash2[slot2];
+				  hslot2, skb, TCP_ESTABLISHED);
+	if (result)
+		goto done;
 
-		result = udp4_lib_lookup2(net, saddr, sport,
-					  htonl(INADDR_ANY), hnum, dif, sdif,
-					  hslot2, skb);
-	}
+	/* Lookup redirect from BPF */
+	result = inet_lookup_run_bpf(net, udptable->protocol,
+				     saddr, sport, daddr, hnum);
+	if (result)
+		goto done;
+
+	/* Lookup bound sockets */
+	result = udp4_lib_lookup2(net, saddr, sport,
+				  daddr, hnum, dif, sdif,
+				  hslot2, skb, 0);
+	if (result)
+		goto done;
+
+	hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
+	slot2 = hash2 & udptable->mask;
+	hslot2 = &udptable->hash2[slot2];
+
+	result = udp4_lib_lookup2(net, saddr, sport,
+				  htonl(INADDR_ANY), hnum, dif, sdif,
+				  hslot2, skb, 0);
+done:
 	if (IS_ERR(result))
 		return NULL;
 	return result;