@@ -353,7 +353,7 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum)
static int compute_score(struct sock *sk, struct net *net,
__be32 saddr, __be16 sport,
__be32 daddr, unsigned short hnum,
- int dif, int sdif)
+ int dif, int sdif, unsigned char state)
{
int score;
struct inet_sock *inet;
@@ -364,6 +364,9 @@ static int compute_score(struct sock *sk, struct net *net,
ipv6_only_sock(sk))
return -1;
+ if (state && sk->sk_state != state)
+ return -1;
+
if (sk->sk_rcv_saddr != daddr)
return -1;
@@ -411,7 +414,8 @@ static struct sock *udp4_lib_lookup2(struct net *net,
__be32 daddr, unsigned int hnum,
int dif, int sdif,
struct udp_hslot *hslot2,
- struct sk_buff *skb)
+ struct sk_buff *skb,
+ unsigned char state)
{
struct sock *sk, *result;
int score, badness;
@@ -421,7 +425,7 @@ static struct sock *udp4_lib_lookup2(struct net *net,
badness = 0;
udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
score = compute_score(sk, net, saddr, sport,
- daddr, hnum, dif, sdif);
+ daddr, hnum, dif, sdif, state);
if (score > badness) {
if (sk->sk_reuseport) {
hash = udp_ehashfn(net, daddr, hnum,
@@ -454,18 +458,34 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
slot2 = hash2 & udptable->mask;
hslot2 = &udptable->hash2[slot2];
+ /* Lookup connected sockets */
result = udp4_lib_lookup2(net, saddr, sport,
daddr, hnum, dif, sdif,
- hslot2, skb);
- if (!result) {
- hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
- slot2 = hash2 & udptable->mask;
- hslot2 = &udptable->hash2[slot2];
+ hslot2, skb, TCP_ESTABLISHED);
+ if (result)
+ goto done;
- result = udp4_lib_lookup2(net, saddr, sport,
- htonl(INADDR_ANY), hnum, dif, sdif,
- hslot2, skb);
- }
+ /* Lookup redirect from BPF */
+ result = inet_lookup_run_bpf(net, udptable->protocol,
+ saddr, sport, daddr, hnum);
+ if (result)
+ goto done;
+
+ /* Lookup bound sockets */
+ result = udp4_lib_lookup2(net, saddr, sport,
+ daddr, hnum, dif, sdif,
+ hslot2, skb, 0);
+ if (result)
+ goto done;
+
+ hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
+ slot2 = hash2 & udptable->mask;
+ hslot2 = &udptable->hash2[slot2];
+
+ result = udp4_lib_lookup2(net, saddr, sport,
+ htonl(INADDR_ANY), hnum, dif, sdif,
+ hslot2, skb, 0);
+done:
if (IS_ERR(result))
return NULL;
return result;