diff mbox series

[v1,bpf-next,01/11] tcp: Keep TCP_CLOSE sockets in the reuseport group.

Message ID 20201201144418.35045-2-kuniyu@amazon.co.jp
State Superseded
Headers show
Series Socket migration for SO_REUSEPORT. | expand

Commit Message

Kuniyuki Iwashima Dec. 1, 2020, 2:44 p.m. UTC
This patch is a preparation patch to migrate incoming connections in the
later commits and adds a field (num_closed_socks) to the struct
sock_reuseport to keep TCP_CLOSE sockets in the reuseport group.

When we close a listening socket, to migrate its connections to another
listener in the same reuseport group, we have to handle two kinds of child
sockets. One is that a listening socket has a reference to, and the other
is not.

The former is the TCP_ESTABLISHED/TCP_SYN_RECV sockets, and they are in the
accept queue of their listening socket. So, we can pop them out and push
them into another listener's queue at close() or shutdown() syscalls. On
the other hand, the latter, the TCP_NEW_SYN_RECV socket is during the
three-way handshake and not in the accept queue. Thus, we cannot access
such sockets at close() or shutdown() syscalls. Accordingly, we have to
migrate immature sockets after their listening socket has been closed.

Currently, if their listening socket has been closed, TCP_NEW_SYN_RECV
sockets are freed at receiving the final ACK or retransmitting SYN+ACKs. At
that time, if we could select a new listener from the same reuseport group,
no connection would be aborted. However, it is impossible because
reuseport_detach_sock() sets NULL to sk_reuseport_cb and forbids access to
the reuseport group from closed sockets.

This patch allows TCP_CLOSE sockets to remain in the reuseport group and to
have access to it while any child socket references to them. The point is
that reuseport_detach_sock() is called twice from inet_unhash() and
sk_destruct(). At first, it moves the socket backwards in socks[] and
increments num_closed_socks. Later, when all migrated connections are
accepted, it removes the socket from socks[], decrements num_closed_socks,
and sets NULL to sk_reuseport_cb.

By this change, closed sockets can keep sk_reuseport_cb until all child
requests have been freed or accepted. Consequently calling listen() after
shutdown() can cause EADDRINUSE or EBUSY in reuseport_add_sock() or
inet_csk_bind_conflict() which expect that such sockets should not have the
reuseport group. Therefore, this patch also loosens such validation rules
so that the socket can listen again if it has the same reuseport group with
other listening sockets.

Reviewed-by: Benjamin Herrenschmidt <benh@amazon.com>
Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.co.jp>
---
 include/net/sock_reuseport.h    |  5 ++-
 net/core/sock_reuseport.c       | 79 +++++++++++++++++++++++++++------
 net/ipv4/inet_connection_sock.c |  7 ++-
 3 files changed, 74 insertions(+), 17 deletions(-)

Comments

Martin KaFai Lau Dec. 5, 2020, 1:31 a.m. UTC | #1
On Tue, Dec 01, 2020 at 11:44:08PM +0900, Kuniyuki Iwashima wrote:
> This patch is a preparation patch to migrate incoming connections in the
> later commits and adds a field (num_closed_socks) to the struct
> sock_reuseport to keep TCP_CLOSE sockets in the reuseport group.
> 
> When we close a listening socket, to migrate its connections to another
> listener in the same reuseport group, we have to handle two kinds of child
> sockets. One is that a listening socket has a reference to, and the other
> is not.
> 
> The former is the TCP_ESTABLISHED/TCP_SYN_RECV sockets, and they are in the
> accept queue of their listening socket. So, we can pop them out and push
> them into another listener's queue at close() or shutdown() syscalls. On
> the other hand, the latter, the TCP_NEW_SYN_RECV socket is during the
> three-way handshake and not in the accept queue. Thus, we cannot access
> such sockets at close() or shutdown() syscalls. Accordingly, we have to
> migrate immature sockets after their listening socket has been closed.
> 
> Currently, if their listening socket has been closed, TCP_NEW_SYN_RECV
> sockets are freed at receiving the final ACK or retransmitting SYN+ACKs. At
> that time, if we could select a new listener from the same reuseport group,
> no connection would be aborted. However, it is impossible because
> reuseport_detach_sock() sets NULL to sk_reuseport_cb and forbids access to
> the reuseport group from closed sockets.
> 
> This patch allows TCP_CLOSE sockets to remain in the reuseport group and to
> have access to it while any child socket references to them. The point is
> that reuseport_detach_sock() is called twice from inet_unhash() and
> sk_destruct(). At first, it moves the socket backwards in socks[] and
> increments num_closed_socks. Later, when all migrated connections are
> accepted, it removes the socket from socks[], decrements num_closed_socks,
> and sets NULL to sk_reuseport_cb.
> 
> By this change, closed sockets can keep sk_reuseport_cb until all child
> requests have been freed or accepted. Consequently calling listen() after
> shutdown() can cause EADDRINUSE or EBUSY in reuseport_add_sock() or
> inet_csk_bind_conflict() which expect that such sockets should not have the
> reuseport group. Therefore, this patch also loosens such validation rules
> so that the socket can listen again if it has the same reuseport group with
> other listening sockets.
> 
> Reviewed-by: Benjamin Herrenschmidt <benh@amazon.com>
> Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.co.jp>
> ---
>  include/net/sock_reuseport.h    |  5 ++-
>  net/core/sock_reuseport.c       | 79 +++++++++++++++++++++++++++------
>  net/ipv4/inet_connection_sock.c |  7 ++-
>  3 files changed, 74 insertions(+), 17 deletions(-)
> 
> diff --git a/include/net/sock_reuseport.h b/include/net/sock_reuseport.h
> index 505f1e18e9bf..0e558ca7afbf 100644
> --- a/include/net/sock_reuseport.h
> +++ b/include/net/sock_reuseport.h
> @@ -13,8 +13,9 @@ extern spinlock_t reuseport_lock;
>  struct sock_reuseport {
>  	struct rcu_head		rcu;
>  
> -	u16			max_socks;	/* length of socks */
> -	u16			num_socks;	/* elements in socks */
> +	u16			max_socks;		/* length of socks */
> +	u16			num_socks;		/* elements in socks */
> +	u16			num_closed_socks;	/* closed elements in socks */
>  	/* The last synq overflow event timestamp of this
>  	 * reuse->socks[] group.
>  	 */
> diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c
> index bbdd3c7b6cb5..fd133516ac0e 100644
> --- a/net/core/sock_reuseport.c
> +++ b/net/core/sock_reuseport.c
> @@ -98,16 +98,21 @@ static struct sock_reuseport *reuseport_grow(struct sock_reuseport *reuse)
>  		return NULL;
>  
>  	more_reuse->num_socks = reuse->num_socks;
> +	more_reuse->num_closed_socks = reuse->num_closed_socks;
>  	more_reuse->prog = reuse->prog;
>  	more_reuse->reuseport_id = reuse->reuseport_id;
>  	more_reuse->bind_inany = reuse->bind_inany;
>  	more_reuse->has_conns = reuse->has_conns;
> +	more_reuse->synq_overflow_ts = READ_ONCE(reuse->synq_overflow_ts);
>  
>  	memcpy(more_reuse->socks, reuse->socks,
>  	       reuse->num_socks * sizeof(struct sock *));
> -	more_reuse->synq_overflow_ts = READ_ONCE(reuse->synq_overflow_ts);
> +	memcpy(more_reuse->socks +
> +	       (more_reuse->max_socks - more_reuse->num_closed_socks),
> +	       reuse->socks + reuse->num_socks,
> +	       reuse->num_closed_socks * sizeof(struct sock *));
>  
> -	for (i = 0; i < reuse->num_socks; ++i)
> +	for (i = 0; i < reuse->max_socks; ++i)
>  		rcu_assign_pointer(reuse->socks[i]->sk_reuseport_cb,
>  				   more_reuse);
>  
> @@ -129,6 +134,25 @@ static void reuseport_free_rcu(struct rcu_head *head)
>  	kfree(reuse);
>  }
>  
> +static int reuseport_sock_index(struct sock_reuseport *reuse, struct sock *sk,
> +				bool closed)
> +{
> +	int left, right;
> +
> +	if (!closed) {
> +		left = 0;
> +		right = reuse->num_socks;
> +	} else {
> +		left = reuse->max_socks - reuse->num_closed_socks;
> +		right = reuse->max_socks;
> +	}
> +
> +	for (; left < right; left++)
> +		if (reuse->socks[left] == sk)
> +			return left;
> +	return -1;
> +}
> +
>  /**
>   *  reuseport_add_sock - Add a socket to the reuseport group of another.
>   *  @sk:  New socket to add to the group.
> @@ -153,12 +177,23 @@ int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany)
>  					  lockdep_is_held(&reuseport_lock));
>  	old_reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
>  					     lockdep_is_held(&reuseport_lock));
> -	if (old_reuse && old_reuse->num_socks != 1) {
> +
> +	if (old_reuse == reuse) {
> +		int i = reuseport_sock_index(reuse, sk, true);
> +
> +		if (i == -1) {
When will this happen?

I found the new logic in the closed sk shuffling within socks[] quite
complicated to read.  I can see why the closed sk wants to keep its
sk->sk_reuseport_cb.  However, does it need to stay
in socks[]?


> +			spin_unlock_bh(&reuseport_lock);
> +			return -EBUSY;
> +		}
> +
> +		reuse->socks[i] = reuse->socks[reuse->max_socks - reuse->num_closed_socks];
> +		reuse->num_closed_socks--;
> +	} else if (old_reuse && old_reuse->num_socks != 1) {
>  		spin_unlock_bh(&reuseport_lock);
>  		return -EBUSY;
>  	}
>  
> -	if (reuse->num_socks == reuse->max_socks) {
> +	if (reuse->num_socks + reuse->num_closed_socks == reuse->max_socks) {
>  		reuse = reuseport_grow(reuse);
>  		if (!reuse) {
>  			spin_unlock_bh(&reuseport_lock);
Kuniyuki Iwashima Dec. 6, 2020, 4:38 a.m. UTC | #2
I'm sending this mail just for logging because I failed to send mails only 
to LKML, netdev, and bpf yesterday.


From:   Martin KaFai Lau <kafai@fb.com>
Date:   Fri, 4 Dec 2020 17:31:03 -0800
> On Tue, Dec 01, 2020 at 11:44:08PM +0900, Kuniyuki Iwashima wrote:
> > This patch is a preparation patch to migrate incoming connections in the
> > later commits and adds a field (num_closed_socks) to the struct
> > sock_reuseport to keep TCP_CLOSE sockets in the reuseport group.
> > 
> > When we close a listening socket, to migrate its connections to another
> > listener in the same reuseport group, we have to handle two kinds of child
> > sockets. One is that a listening socket has a reference to, and the other
> > is not.
> > 
> > The former is the TCP_ESTABLISHED/TCP_SYN_RECV sockets, and they are in the
> > accept queue of their listening socket. So, we can pop them out and push
> > them into another listener's queue at close() or shutdown() syscalls. On
> > the other hand, the latter, the TCP_NEW_SYN_RECV socket is during the
> > three-way handshake and not in the accept queue. Thus, we cannot access
> > such sockets at close() or shutdown() syscalls. Accordingly, we have to
> > migrate immature sockets after their listening socket has been closed.
> > 
> > Currently, if their listening socket has been closed, TCP_NEW_SYN_RECV
> > sockets are freed at receiving the final ACK or retransmitting SYN+ACKs. At
> > that time, if we could select a new listener from the same reuseport group,
> > no connection would be aborted. However, it is impossible because
> > reuseport_detach_sock() sets NULL to sk_reuseport_cb and forbids access to
> > the reuseport group from closed sockets.
> > 
> > This patch allows TCP_CLOSE sockets to remain in the reuseport group and to
> > have access to it while any child socket references to them. The point is
> > that reuseport_detach_sock() is called twice from inet_unhash() and
> > sk_destruct(). At first, it moves the socket backwards in socks[] and
> > increments num_closed_socks. Later, when all migrated connections are
> > accepted, it removes the socket from socks[], decrements num_closed_socks,
> > and sets NULL to sk_reuseport_cb.
> > 
> > By this change, closed sockets can keep sk_reuseport_cb until all child
> > requests have been freed or accepted. Consequently calling listen() after
> > shutdown() can cause EADDRINUSE or EBUSY in reuseport_add_sock() or
> > inet_csk_bind_conflict() which expect that such sockets should not have the
> > reuseport group. Therefore, this patch also loosens such validation rules
> > so that the socket can listen again if it has the same reuseport group with
> > other listening sockets.
> > 
> > Reviewed-by: Benjamin Herrenschmidt <benh@amazon.com>
> > Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.co.jp>
> > ---
> >  include/net/sock_reuseport.h    |  5 ++-
> >  net/core/sock_reuseport.c       | 79 +++++++++++++++++++++++++++------
> >  net/ipv4/inet_connection_sock.c |  7 ++-
> >  3 files changed, 74 insertions(+), 17 deletions(-)
> > 
> > diff --git a/include/net/sock_reuseport.h b/include/net/sock_reuseport.h
> > index 505f1e18e9bf..0e558ca7afbf 100644
> > --- a/include/net/sock_reuseport.h
> > +++ b/include/net/sock_reuseport.h
> > @@ -13,8 +13,9 @@ extern spinlock_t reuseport_lock;
> >  struct sock_reuseport {
> >  	struct rcu_head		rcu;
> >  
> > -	u16			max_socks;	/* length of socks */
> > -	u16			num_socks;	/* elements in socks */
> > +	u16			max_socks;		/* length of socks */
> > +	u16			num_socks;		/* elements in socks */
> > +	u16			num_closed_socks;	/* closed elements in socks */
> >  	/* The last synq overflow event timestamp of this
> >  	 * reuse->socks[] group.
> >  	 */
> > diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c
> > index bbdd3c7b6cb5..fd133516ac0e 100644
> > --- a/net/core/sock_reuseport.c
> > +++ b/net/core/sock_reuseport.c
> > @@ -98,16 +98,21 @@ static struct sock_reuseport *reuseport_grow(struct sock_reuseport *reuse)
> >  		return NULL;
> >  
> >  	more_reuse->num_socks = reuse->num_socks;
> > +	more_reuse->num_closed_socks = reuse->num_closed_socks;
> >  	more_reuse->prog = reuse->prog;
> >  	more_reuse->reuseport_id = reuse->reuseport_id;
> >  	more_reuse->bind_inany = reuse->bind_inany;
> >  	more_reuse->has_conns = reuse->has_conns;
> > +	more_reuse->synq_overflow_ts = READ_ONCE(reuse->synq_overflow_ts);
> >  
> >  	memcpy(more_reuse->socks, reuse->socks,
> >  	       reuse->num_socks * sizeof(struct sock *));
> > -	more_reuse->synq_overflow_ts = READ_ONCE(reuse->synq_overflow_ts);
> > +	memcpy(more_reuse->socks +
> > +	       (more_reuse->max_socks - more_reuse->num_closed_socks),
> > +	       reuse->socks + reuse->num_socks,
> > +	       reuse->num_closed_socks * sizeof(struct sock *));
> >  
> > -	for (i = 0; i < reuse->num_socks; ++i)
> > +	for (i = 0; i < reuse->max_socks; ++i)
> >  		rcu_assign_pointer(reuse->socks[i]->sk_reuseport_cb,
> >  				   more_reuse);
> >  
> > @@ -129,6 +134,25 @@ static void reuseport_free_rcu(struct rcu_head *head)
> >  	kfree(reuse);
> >  }
> >  
> > +static int reuseport_sock_index(struct sock_reuseport *reuse, struct sock *sk,
> > +				bool closed)
> > +{
> > +	int left, right;
> > +
> > +	if (!closed) {
> > +		left = 0;
> > +		right = reuse->num_socks;
> > +	} else {
> > +		left = reuse->max_socks - reuse->num_closed_socks;
> > +		right = reuse->max_socks;
> > +	}
> > +
> > +	for (; left < right; left++)
> > +		if (reuse->socks[left] == sk)
> > +			return left;
> > +	return -1;
> > +}
> > +
> >  /**
> >   *  reuseport_add_sock - Add a socket to the reuseport group of another.
> >   *  @sk:  New socket to add to the group.
> > @@ -153,12 +177,23 @@ int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany)
> >  					  lockdep_is_held(&reuseport_lock));
> >  	old_reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
> >  					     lockdep_is_held(&reuseport_lock));
> > -	if (old_reuse && old_reuse->num_socks != 1) {
> > +
> > +	if (old_reuse == reuse) {
> > +		int i = reuseport_sock_index(reuse, sk, true);
> > +
> > +		if (i == -1) {
> When will this happen?

I understood the original code did nothing if the sk was not found in
socks[], so I rewrote it this way, but I also think `i` will never be -1.

If I rewrite, it will be like:

---8<---
for (; left < right; left++)
    if (reuse->socks[left] == sk)
        break;
return left;
---8<---


> I found the new logic in the closed sk shuffling within socks[] quite
> complicated to read.  I can see why the closed sk wants to keep its
> sk->sk_reuseport_cb.  However, does it need to stay
> in socks[]?

Currently, I do not use closed sockets in socks[], so the only thing I need
to do seems to be to count num_closed_socks to free struct sock_reuseport.
I will change the code only to keep sk_reuseport_cb and count
num_closed_socks.

(As a side note, I wrote the code while thinking of stack and heap to share
the same array, but I also feel a bit difficult to read.)
diff mbox series

Patch

diff --git a/include/net/sock_reuseport.h b/include/net/sock_reuseport.h
index 505f1e18e9bf..0e558ca7afbf 100644
--- a/include/net/sock_reuseport.h
+++ b/include/net/sock_reuseport.h
@@ -13,8 +13,9 @@  extern spinlock_t reuseport_lock;
 struct sock_reuseport {
 	struct rcu_head		rcu;
 
-	u16			max_socks;	/* length of socks */
-	u16			num_socks;	/* elements in socks */
+	u16			max_socks;		/* length of socks */
+	u16			num_socks;		/* elements in socks */
+	u16			num_closed_socks;	/* closed elements in socks */
 	/* The last synq overflow event timestamp of this
 	 * reuse->socks[] group.
 	 */
diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c
index bbdd3c7b6cb5..fd133516ac0e 100644
--- a/net/core/sock_reuseport.c
+++ b/net/core/sock_reuseport.c
@@ -98,16 +98,21 @@  static struct sock_reuseport *reuseport_grow(struct sock_reuseport *reuse)
 		return NULL;
 
 	more_reuse->num_socks = reuse->num_socks;
+	more_reuse->num_closed_socks = reuse->num_closed_socks;
 	more_reuse->prog = reuse->prog;
 	more_reuse->reuseport_id = reuse->reuseport_id;
 	more_reuse->bind_inany = reuse->bind_inany;
 	more_reuse->has_conns = reuse->has_conns;
+	more_reuse->synq_overflow_ts = READ_ONCE(reuse->synq_overflow_ts);
 
 	memcpy(more_reuse->socks, reuse->socks,
 	       reuse->num_socks * sizeof(struct sock *));
-	more_reuse->synq_overflow_ts = READ_ONCE(reuse->synq_overflow_ts);
+	memcpy(more_reuse->socks +
+	       (more_reuse->max_socks - more_reuse->num_closed_socks),
+	       reuse->socks + reuse->num_socks,
+	       reuse->num_closed_socks * sizeof(struct sock *));
 
-	for (i = 0; i < reuse->num_socks; ++i)
+	for (i = 0; i < reuse->max_socks; ++i)
 		rcu_assign_pointer(reuse->socks[i]->sk_reuseport_cb,
 				   more_reuse);
 
@@ -129,6 +134,25 @@  static void reuseport_free_rcu(struct rcu_head *head)
 	kfree(reuse);
 }
 
+static int reuseport_sock_index(struct sock_reuseport *reuse, struct sock *sk,
+				bool closed)
+{
+	int left, right;
+
+	if (!closed) {
+		left = 0;
+		right = reuse->num_socks;
+	} else {
+		left = reuse->max_socks - reuse->num_closed_socks;
+		right = reuse->max_socks;
+	}
+
+	for (; left < right; left++)
+		if (reuse->socks[left] == sk)
+			return left;
+	return -1;
+}
+
 /**
  *  reuseport_add_sock - Add a socket to the reuseport group of another.
  *  @sk:  New socket to add to the group.
@@ -153,12 +177,23 @@  int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany)
 					  lockdep_is_held(&reuseport_lock));
 	old_reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
 					     lockdep_is_held(&reuseport_lock));
-	if (old_reuse && old_reuse->num_socks != 1) {
+
+	if (old_reuse == reuse) {
+		int i = reuseport_sock_index(reuse, sk, true);
+
+		if (i == -1) {
+			spin_unlock_bh(&reuseport_lock);
+			return -EBUSY;
+		}
+
+		reuse->socks[i] = reuse->socks[reuse->max_socks - reuse->num_closed_socks];
+		reuse->num_closed_socks--;
+	} else if (old_reuse && old_reuse->num_socks != 1) {
 		spin_unlock_bh(&reuseport_lock);
 		return -EBUSY;
 	}
 
-	if (reuse->num_socks == reuse->max_socks) {
+	if (reuse->num_socks + reuse->num_closed_socks == reuse->max_socks) {
 		reuse = reuseport_grow(reuse);
 		if (!reuse) {
 			spin_unlock_bh(&reuseport_lock);
@@ -174,8 +209,9 @@  int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany)
 
 	spin_unlock_bh(&reuseport_lock);
 
-	if (old_reuse)
+	if (old_reuse && old_reuse != reuse)
 		call_rcu(&old_reuse->rcu, reuseport_free_rcu);
+
 	return 0;
 }
 EXPORT_SYMBOL(reuseport_add_sock);
@@ -199,17 +235,34 @@  void reuseport_detach_sock(struct sock *sk)
 	 */
 	bpf_sk_reuseport_detach(sk);
 
-	rcu_assign_pointer(sk->sk_reuseport_cb, NULL);
+	if (sk->sk_protocol != IPPROTO_TCP || sk->sk_state == TCP_LISTEN) {
+		i = reuseport_sock_index(reuse, sk, false);
+		if (i == -1)
+			goto out;
+
+		reuse->num_socks--;
+		reuse->socks[i] = reuse->socks[reuse->num_socks];
 
-	for (i = 0; i < reuse->num_socks; i++) {
-		if (reuse->socks[i] == sk) {
-			reuse->socks[i] = reuse->socks[reuse->num_socks - 1];
-			reuse->num_socks--;
-			if (reuse->num_socks == 0)
-				call_rcu(&reuse->rcu, reuseport_free_rcu);
-			break;
+		if (sk->sk_protocol == IPPROTO_TCP) {
+			reuse->num_closed_socks++;
+			reuse->socks[reuse->max_socks - reuse->num_closed_socks] = sk;
+		} else {
+			rcu_assign_pointer(sk->sk_reuseport_cb, NULL);
 		}
+	} else {
+		i = reuseport_sock_index(reuse, sk, true);
+		if (i == -1)
+			goto out;
+
+		reuse->socks[i] = reuse->socks[reuse->max_socks - reuse->num_closed_socks];
+		reuse->num_closed_socks--;
+
+		rcu_assign_pointer(sk->sk_reuseport_cb, NULL);
 	}
+
+	if (reuse->num_socks + reuse->num_closed_socks == 0)
+		call_rcu(&reuse->rcu, reuseport_free_rcu);
+out:
 	spin_unlock_bh(&reuseport_lock);
 }
 EXPORT_SYMBOL(reuseport_detach_sock);
diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c
index f60869acbef0..1451aa9712b0 100644
--- a/net/ipv4/inet_connection_sock.c
+++ b/net/ipv4/inet_connection_sock.c
@@ -138,6 +138,7 @@  static int inet_csk_bind_conflict(const struct sock *sk,
 	bool reuse = sk->sk_reuse;
 	bool reuseport = !!sk->sk_reuseport;
 	kuid_t uid = sock_i_uid((struct sock *)sk);
+	struct sock_reuseport *reuseport_cb = rcu_access_pointer(sk->sk_reuseport_cb);
 
 	/*
 	 * Unlike other sk lookup places we do not check
@@ -156,14 +157,16 @@  static int inet_csk_bind_conflict(const struct sock *sk,
 				if ((!relax ||
 				     (!reuseport_ok &&
 				      reuseport && sk2->sk_reuseport &&
-				      !rcu_access_pointer(sk->sk_reuseport_cb) &&
+				      (!reuseport_cb ||
+				       reuseport_cb == rcu_access_pointer(sk2->sk_reuseport_cb)) &&
 				      (sk2->sk_state == TCP_TIME_WAIT ||
 				       uid_eq(uid, sock_i_uid(sk2))))) &&
 				    inet_rcv_saddr_equal(sk, sk2, true))
 					break;
 			} else if (!reuseport_ok ||
 				   !reuseport || !sk2->sk_reuseport ||
-				   rcu_access_pointer(sk->sk_reuseport_cb) ||
+				   (reuseport_cb &&
+				    reuseport_cb != rcu_access_pointer(sk2->sk_reuseport_cb)) ||
 				   (sk2->sk_state != TCP_TIME_WAIT &&
 				    !uid_eq(uid, sock_i_uid(sk2)))) {
 				if (inet_rcv_saddr_equal(sk, sk2, true))