diff mbox series

[RFC,net-next,02/11] ip6mr: Make mroute_sk rcu-based

Message ID 1519165995-50854-3-git-send-email-yuvalm@mellanox.com
State RFC, archived
Delegated to: David Miller
Headers show
Series ipmr, ip6mr: Align multicast routing for IPv4 & IPv6 | expand

Commit Message

Yuval Mintz Feb. 20, 2018, 10:33 p.m. UTC
In ipmr the mr_table socket is handled under RCU. Introduce the same
for ip6mr.

Signed-off-by: Yuval Mintz <yuvalm@mellanox.com>
---
 include/linux/mroute6.h |  6 +++---
 net/ipv6/ip6_output.c   |  2 +-
 net/ipv6/ip6mr.c        | 43 ++++++++++++++++++++++++++-----------------
 3 files changed, 30 insertions(+), 21 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/mroute6.h b/include/linux/mroute6.h
index e5e5b82..e1b9fb0 100644
--- a/include/linux/mroute6.h
+++ b/include/linux/mroute6.h
@@ -111,12 +111,12 @@  extern int ip6mr_get_route(struct net *net, struct sk_buff *skb,
 			   struct rtmsg *rtm, u32 portid);
 
 #ifdef CONFIG_IPV6_MROUTE
-extern struct sock *mroute6_socket(struct net *net, struct sk_buff *skb);
+bool mroute6_is_socket(struct net *net, struct sk_buff *skb);
 extern int ip6mr_sk_done(struct sock *sk);
 #else
-static inline struct sock *mroute6_socket(struct net *net, struct sk_buff *skb)
+static inline bool mroute6_is_socket(struct net *net, struct sk_buff *skb)
 {
-	return NULL;
+	return false;
 }
 static inline int ip6mr_sk_done(struct sock *sk)
 {
diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c
index 997c7f1..a6eb0e6 100644
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -71,7 +71,7 @@  static int ip6_finish_output2(struct net *net, struct sock *sk, struct sk_buff *
 		struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb));
 
 		if (!(dev->flags & IFF_LOOPBACK) && sk_mc_loop(sk) &&
-		    ((mroute6_socket(net, skb) &&
+		    ((mroute6_is_socket(net, skb) &&
 		     !(IP6CB(skb)->flags & IP6SKB_FORWARDED)) ||
 		     ipv6_chk_mcast_addr(dev, &ipv6_hdr(skb)->daddr,
 					 &ipv6_hdr(skb)->saddr))) {
diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c
index e397990..7792fc5 100644
--- a/net/ipv6/ip6mr.c
+++ b/net/ipv6/ip6mr.c
@@ -58,7 +58,7 @@  struct mr6_table {
 	struct list_head	list;
 	possible_net_t		net;
 	u32			id;
-	struct sock		*mroute6_sk;
+	struct sock __rcu	*mroute6_sk;
 	struct timer_list	ipmr_expire_timer;
 	struct list_head	mfc6_unres_queue;
 	struct list_head	mfc6_cache_array[MFC6_LINES];
@@ -1121,6 +1121,7 @@  static void ip6mr_cache_resolve(struct net *net, struct mr6_table *mrt,
 static int ip6mr_cache_report(struct mr6_table *mrt, struct sk_buff *pkt,
 			      mifi_t mifi, int assert)
 {
+	struct sock *mroute6_sk;
 	struct sk_buff *skb;
 	struct mrt6msg *msg;
 	int ret;
@@ -1190,17 +1191,19 @@  static int ip6mr_cache_report(struct mr6_table *mrt, struct sk_buff *pkt,
 	skb->ip_summed = CHECKSUM_UNNECESSARY;
 	}
 
-	if (!mrt->mroute6_sk) {
+	rcu_read_lock();
+	mroute6_sk = rcu_dereference(mrt->mroute6_sk);
+	if (!mroute6_sk) {
+		rcu_read_unlock();
 		kfree_skb(skb);
 		return -EINVAL;
 	}
 
 	mrt6msg_netlink_event(mrt, skb);
 
-	/*
-	 *	Deliver to user space multicast routing algorithms
-	 */
+	/* Deliver to user space multicast routing algorithms */
 	ret = sock_queue_rcv_skb(mrt->mroute6_sk, skb);
+	rcu_read_unlock();
 	if (ret < 0) {
 		net_warn_ratelimited("mroute6: pending queue full, dropping entries\n");
 		kfree_skb(skb);
@@ -1584,11 +1587,11 @@  static int ip6mr_sk_init(struct mr6_table *mrt, struct sock *sk)
 
 	rtnl_lock();
 	write_lock_bh(&mrt_lock);
-	if (likely(mrt->mroute6_sk == NULL)) {
-		mrt->mroute6_sk = sk;
-		net->ipv6.devconf_all->mc_forwarding++;
-	} else {
+	if (rtnl_dereference(mrt->mroute6_sk)) {
 		err = -EADDRINUSE;
+	} else {
+		rcu_assign_pointer(mrt->mroute6_sk, sk);
+		net->ipv6.devconf_all->mc_forwarding++;
 	}
 	write_unlock_bh(&mrt_lock);
 
@@ -1614,9 +1617,9 @@  int ip6mr_sk_done(struct sock *sk)
 
 	rtnl_lock();
 	ip6mr_for_each_table(mrt, net) {
-		if (sk == mrt->mroute6_sk) {
+		if (sk == rtnl_dereference(mrt->mroute6_sk)) {
 			write_lock_bh(&mrt_lock);
-			mrt->mroute6_sk = NULL;
+			RCU_INIT_POINTER(mrt->mroute6_sk, NULL);
 			net->ipv6.devconf_all->mc_forwarding--;
 			write_unlock_bh(&mrt_lock);
 			inet6_netconf_notify_devconf(net, RTM_NEWNETCONF,
@@ -1630,11 +1633,12 @@  int ip6mr_sk_done(struct sock *sk)
 		}
 	}
 	rtnl_unlock();
+	synchronize_rcu();
 
 	return err;
 }
 
-struct sock *mroute6_socket(struct net *net, struct sk_buff *skb)
+bool mroute6_is_socket(struct net *net, struct sk_buff *skb)
 {
 	struct mr6_table *mrt;
 	struct flowi6 fl6 = {
@@ -1646,8 +1650,9 @@  struct sock *mroute6_socket(struct net *net, struct sk_buff *skb)
 	if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0)
 		return NULL;
 
-	return mrt->mroute6_sk;
+	return rcu_access_pointer(mrt->mroute6_sk);
 }
+EXPORT_SYMBOL(mroute6_is_socket);
 
 /*
  *	Socket options and virtual interface manipulation. The whole
@@ -1674,7 +1679,8 @@  int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns
 		return -ENOENT;
 
 	if (optname != MRT6_INIT) {
-		if (sk != mrt->mroute6_sk && !ns_capable(net->user_ns, CAP_NET_ADMIN))
+		if (sk != rcu_access_pointer(mrt->mroute6_sk) &&
+		    !ns_capable(net->user_ns, CAP_NET_ADMIN))
 			return -EACCES;
 	}
 
@@ -1696,7 +1702,8 @@  int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns
 		if (vif.mif6c_mifi >= MAXMIFS)
 			return -ENFILE;
 		rtnl_lock();
-		ret = mif6_add(net, mrt, &vif, sk == mrt->mroute6_sk);
+		ret = mif6_add(net, mrt, &vif,
+			       sk == rtnl_dereference(mrt->mroute6_sk));
 		rtnl_unlock();
 		return ret;
 
@@ -1731,7 +1738,9 @@  int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns
 			ret = ip6mr_mfc_delete(mrt, &mfc, parent);
 		else
 			ret = ip6mr_mfc_add(net, mrt, &mfc,
-					    sk == mrt->mroute6_sk, parent);
+					    sk ==
+					    rtnl_dereference(mrt->mroute6_sk),
+					    parent);
 		rtnl_unlock();
 		return ret;
 
@@ -1783,7 +1792,7 @@  int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns
 		/* "pim6reg%u" should not exceed 16 bytes (IFNAMSIZ) */
 		if (v != RT_TABLE_DEFAULT && v >= 100000000)
 			return -EINVAL;
-		if (sk == mrt->mroute6_sk)
+		if (sk == rcu_access_pointer(mrt->mroute6_sk))
 			return -EBUSY;
 
 		rtnl_lock();