diff mbox

netlink: Replace rhash_portid with bound

Message ID 20150921133415.GA1740@gondor.apana.org.au
State Superseded, archived
Delegated to: David Miller
Headers show

Commit Message

Herbert Xu Sept. 21, 2015, 1:34 p.m. UTC
On Sun, Sep 20, 2015 at 11:11:04PM -0700, David Miller wrote:
>
> Yeah at this point incremental patches work the best.

OK here is the patch:

---8<---
The commit 1f770c0a09da855a2b51af6d19de97fb955eca85 ("netlink:
Fix autobind race condition that leads to zero port ID") created
some new races that can occur due to inconcsistencies between the
two port IDs.

Tejun is right that a barrier is unavoidable.  Therefore I am
reverting to the original patch that used a boolean to indicate
that a user netlink socket has been bound.

Barriers have been added where necessary to ensure that a valid
portid is used.

Fixes: 1f770c0a09da ("netlink: Fix autobind race condition that leads to zero port ID")
Reported-by: Tejun Heo <tj@kernel.org>
Reported-by: Linus Torvalds <torvalds@linux-foundation.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>

Comments

Tejun Heo Sept. 21, 2015, 6:20 p.m. UTC | #1
Hello, Herbert.

On Mon, Sep 21, 2015 at 09:34:16PM +0800, Herbert Xu wrote:
> @@ -1119,7 +1120,11 @@ static int netlink_insert(struct sock *sk, u32 portid)
>  		goto err;
>  	}
>  
> -	nlk_sk(sk)->portid = portid;
> +	/* rhashtable_insert carries an implicit write memory barrier
> +	 * so we don't need an smp_wmb here in order to ensure that
> +	 * portid is set before bound.
> +	 */
> +	nlk_sk(sk)->bound = portid;

store_release and load_acquire are different from the usual memory
barriers and can't be paired this way.  You have to pair store_release
and load_acquire.  Besides, it isn't a particularly good idea to
depend on memory barriers embedded in other data structures like the
above.  Here, especially, rhashtable_insert() would have write barrier
*before* the entry is hashed not necessarily *after*, which means that
in the above case, a socket which appears to have set bound to a
reader might not visible when the reader tries to look up the socket
on the hashtable.

There's no reason to be overly smart here.  This isn't a crazy hot
path, write barriers tend to be very cheap, store_release more so.
Please just do smp_store_release() and note what it's paired with.

> @@ -1539,7 +1546,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
>  		}
>  	}
>  
> -	if (!nlk->portid) {
> +	if (!nlk->bound) {

I don't think you can skip load_acquire here just because this is the
second deref of the variable.  That doesn't change anything.  Race
condition could still happen between the first and second tests and
skipping the second would lead to the same kind of bug.

> @@ -1587,7 +1594,7 @@ static int netlink_connect(struct socket *sock, struct sockaddr *addr,
>  	    !netlink_allowed(sock, NL_CFG_F_NONROOT_SEND))
>  		return -EPERM;
>  
> -	if (!nlk->portid)
> +	if (!nlk->bound)

Don't we need load_acquire here too?  Is this path holding a lock
which makes that unnecessary?

I'd suggest making it clear that ->bound is internal (name it
->__bound or sth) and provide a test macro which always uses
load_acquire.  It could be that there are a couple places which can
avoid load_acquire but it just isn't worth it.  load_acquire is very
cheap but bugs around it can be extremely subtle.  Let's please keep
it straight-forward.

Thanks.
diff mbox

Patch

diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 303efb7..f5362aae 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -24,6 +24,7 @@ 
 
 #include <linux/module.h>
 
+#include <asm/barrier.h>
 #include <linux/capability.h>
 #include <linux/kernel.h>
 #include <linux/init.h>
@@ -1015,7 +1016,7 @@  static inline int netlink_compare(struct rhashtable_compare_arg *arg,
 	const struct netlink_compare_arg *x = arg->key;
 	const struct netlink_sock *nlk = ptr;
 
-	return nlk->rhash_portid != x->portid ||
+	return nlk->portid != x->portid ||
 	       !net_eq(sock_net(&nlk->sk), read_pnet(&x->pnet));
 }
 
@@ -1041,7 +1042,7 @@  static int __netlink_insert(struct netlink_table *table, struct sock *sk)
 {
 	struct netlink_compare_arg arg;
 
-	netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->rhash_portid);
+	netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->portid);
 	return rhashtable_lookup_insert_key(&table->hash, &arg,
 					    &nlk_sk(sk)->node,
 					    netlink_rhashtable_params);
@@ -1095,7 +1096,7 @@  static int netlink_insert(struct sock *sk, u32 portid)
 	lock_sock(sk);
 
 	err = -EBUSY;
-	if (nlk_sk(sk)->portid)
+	if (nlk_sk(sk)->bound)
 		goto err;
 
 	err = -ENOMEM;
@@ -1103,7 +1104,7 @@  static int netlink_insert(struct sock *sk, u32 portid)
 	    unlikely(atomic_read(&table->hash.nelems) >= UINT_MAX))
 		goto err;
 
-	nlk_sk(sk)->rhash_portid = portid;
+	nlk_sk(sk)->portid = portid;
 	sock_hold(sk);
 
 	err = __netlink_insert(table, sk);
@@ -1119,7 +1120,11 @@  static int netlink_insert(struct sock *sk, u32 portid)
 		goto err;
 	}
 
-	nlk_sk(sk)->portid = portid;
+	/* rhashtable_insert carries an implicit write memory barrier
+	 * so we don't need an smp_wmb here in order to ensure that
+	 * portid is set before bound.
+	 */
+	nlk_sk(sk)->bound = portid;
 
 err:
 	release_sock(sk);
@@ -1521,9 +1526,11 @@  static int netlink_bind(struct socket *sock, struct sockaddr *addr,
 			return err;
 	}
 
-	if (nlk->portid)
+	/* Ensure nlk->portid is up-to-date. */
+	if (smp_load_acquire(&nlk->bound)) {
 		if (nladdr->nl_pid != nlk->portid)
 			return -EINVAL;
+	}
 
 	if (nlk->netlink_bind && groups) {
 		int group;
@@ -1539,7 +1546,7 @@  static int netlink_bind(struct socket *sock, struct sockaddr *addr,
 		}
 	}
 
-	if (!nlk->portid) {
+	if (!nlk->bound) {
 		err = nladdr->nl_pid ?
 			netlink_insert(sk, nladdr->nl_pid) :
 			netlink_autobind(sock);
@@ -1587,7 +1594,7 @@  static int netlink_connect(struct socket *sock, struct sockaddr *addr,
 	    !netlink_allowed(sock, NL_CFG_F_NONROOT_SEND))
 		return -EPERM;
 
-	if (!nlk->portid)
+	if (!nlk->bound)
 		err = netlink_autobind(sock);
 
 	if (err == 0) {
@@ -2428,7 +2435,8 @@  static int netlink_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 		dst_group = nlk->dst_group;
 	}
 
-	if (!nlk->portid) {
+	/* Ensure nlk->portid is up-to-date. */
+	if (!smp_load_acquire(&nlk->bound)) {
 		err = netlink_autobind(sock);
 		if (err)
 			goto out;
@@ -3257,7 +3265,7 @@  static inline u32 netlink_hash(const void *data, u32 len, u32 seed)
 	const struct netlink_sock *nlk = data;
 	struct netlink_compare_arg arg;
 
-	netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->rhash_portid);
+	netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->portid);
 	return jhash2((u32 *)&arg, netlink_compare_arg_len / sizeof(u32), seed);
 }
 
diff --git a/net/netlink/af_netlink.h b/net/netlink/af_netlink.h
index c96dfa3..e6aae40 100644
--- a/net/netlink/af_netlink.h
+++ b/net/netlink/af_netlink.h
@@ -25,7 +25,6 @@  struct netlink_ring {
 struct netlink_sock {
 	/* struct sock has to be the first member of netlink_sock */
 	struct sock		sk;
-	u32			rhash_portid;
 	u32			portid;
 	u32			dst_portid;
 	u32			dst_group;
@@ -36,6 +35,7 @@  struct netlink_sock {
 	unsigned long		state;
 	size_t			max_recvmsg_len;
 	wait_queue_head_t	wait;
+	bool			bound;
 	bool			cb_running;
 	struct netlink_callback	cb;
 	struct mutex		*cb_mutex;