@@ -101,7 +101,7 @@ void udp_v6_rehash(struct sock *sk)
static int compute_score(struct sock *sk, struct net *net,
const struct in6_addr *saddr, __be16 sport,
const struct in6_addr *daddr, unsigned short hnum,
- int dif, int sdif)
+ int dif, int sdif, unsigned char state)
{
int score;
struct inet_sock *inet;
@@ -112,6 +112,9 @@ static int compute_score(struct sock *sk, struct net *net,
sk->sk_family != PF_INET6)
return -1;
+ if (state && sk->sk_state != state)
+ return -1;
+
if (!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr))
return -1;
@@ -146,7 +149,7 @@ static struct sock *udp6_lib_lookup2(struct net *net,
const struct in6_addr *saddr, __be16 sport,
const struct in6_addr *daddr, unsigned int hnum,
int dif, int sdif, struct udp_hslot *hslot2,
- struct sk_buff *skb)
+ struct sk_buff *skb, unsigned char state)
{
struct sock *sk, *result;
int score, badness;
@@ -156,7 +159,7 @@ static struct sock *udp6_lib_lookup2(struct net *net,
badness = -1;
udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
score = compute_score(sk, net, saddr, sport,
- daddr, hnum, dif, sdif);
+ daddr, hnum, dif, sdif, state);
if (score > badness) {
if (sk->sk_reuseport) {
hash = udp6_ehashfn(net, daddr, hnum,
@@ -190,19 +193,34 @@ struct sock *__udp6_lib_lookup(struct net *net,
slot2 = hash2 & udptable->mask;
hslot2 = &udptable->hash2[slot2];
+ /* Lookup connected sockets */
result = udp6_lib_lookup2(net, saddr, sport,
daddr, hnum, dif, sdif,
- hslot2, skb);
- if (!result) {
- hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
- slot2 = hash2 & udptable->mask;
+ hslot2, skb, TCP_ESTABLISHED);
+ if (result)
+ goto done;
- hslot2 = &udptable->hash2[slot2];
+ /* Lookup redirect from BPF */
+ result = inet6_lookup_run_bpf(net, udptable->protocol,
+ saddr, sport, daddr, hnum);
+ if (result)
+ goto done;
- result = udp6_lib_lookup2(net, saddr, sport,
- &in6addr_any, hnum, dif, sdif,
- hslot2, skb);
- }
+ /* Lookup bound sockets */
+ result = udp6_lib_lookup2(net, saddr, sport,
+ daddr, hnum, dif, sdif,
+ hslot2, skb, 0);
+ if (result)
+ goto done;
+
+ hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
+ slot2 = hash2 & udptable->mask;
+ hslot2 = &udptable->hash2[slot2];
+
+ result = udp6_lib_lookup2(net, saddr, sport,
+ &in6addr_any, hnum, dif, sdif,
+ hslot2, skb, 0);
+done:
if (IS_ERR(result))
return NULL;
return result;