diff mbox series

[net-next,4/9] netfilter: conntrack: move rcu read lock to nf_conntrack_find_get

Message ID 20230118123208.17167-5-fw@strlen.de
State Accepted
Delegated to: Florian Westphal
Headers show
Series [net-next,1/9] netfilter: conntrack: sctp: use nf log infrastructure for invalid packets | expand

Commit Message

Florian Westphal Jan. 18, 2023, 12:32 p.m. UTC
Move rcu_read_lock/unlock to nf_conntrack_find_get(), this avoids
nested rcu_read_lock call from resolve_normal_ct().

Signed-off-by: Florian Westphal <fw@strlen.de>
---
 net/netfilter/nf_conntrack_core.c | 17 +++++++++--------
 1 file changed, 9 insertions(+), 8 deletions(-)
diff mbox series

Patch

diff --git a/net/netfilter/nf_conntrack_core.c b/net/netfilter/nf_conntrack_core.c
index 9e12cade4e0f..c00858344f02 100644
--- a/net/netfilter/nf_conntrack_core.c
+++ b/net/netfilter/nf_conntrack_core.c
@@ -783,8 +783,6 @@  __nf_conntrack_find_get(struct net *net, const struct nf_conntrack_zone *zone,
 	struct nf_conntrack_tuple_hash *h;
 	struct nf_conn *ct;
 
-	rcu_read_lock();
-
 	h = ____nf_conntrack_find(net, zone, tuple, hash);
 	if (h) {
 		/* We have a candidate that matches the tuple we're interested
@@ -796,7 +794,7 @@  __nf_conntrack_find_get(struct net *net, const struct nf_conntrack_zone *zone,
 			smp_acquire__after_ctrl_dep();
 
 			if (likely(nf_ct_key_equal(h, tuple, zone, net)))
-				goto found;
+				return h;
 
 			/* TYPESAFE_BY_RCU recycled the candidate */
 			nf_ct_put(ct);
@@ -804,8 +802,6 @@  __nf_conntrack_find_get(struct net *net, const struct nf_conntrack_zone *zone,
 
 		h = NULL;
 	}
-found:
-	rcu_read_unlock();
 
 	return h;
 }
@@ -817,16 +813,21 @@  nf_conntrack_find_get(struct net *net, const struct nf_conntrack_zone *zone,
 	unsigned int rid, zone_id = nf_ct_zone_id(zone, IP_CT_DIR_ORIGINAL);
 	struct nf_conntrack_tuple_hash *thash;
 
+	rcu_read_lock();
+
 	thash = __nf_conntrack_find_get(net, zone, tuple,
 					hash_conntrack_raw(tuple, zone_id, net));
 
 	if (thash)
-		return thash;
+		goto out_unlock;
 
 	rid = nf_ct_zone_id(zone, IP_CT_DIR_REPLY);
 	if (rid != zone_id)
-		return __nf_conntrack_find_get(net, zone, tuple,
-					       hash_conntrack_raw(tuple, rid, net));
+		thash = __nf_conntrack_find_get(net, zone, tuple,
+						hash_conntrack_raw(tuple, rid, net));
+
+out_unlock:
+	rcu_read_unlock();
 	return thash;
 }
 EXPORT_SYMBOL_GPL(nf_conntrack_find_get);