diff mbox series

[bpf-next,v5,2/8] net: netfilter: Deduplicate code in bpf_{xdp,skb}_ct_lookup

Message ID 20220623192637.3866852-3-memxor@gmail.com
State Handled Elsewhere
Delegated to: Pablo Neira
Headers show
Series New nf_conntrack kfuncs for insertion, changing timeout, status | expand

Commit Message

Kumar Kartikeya Dwivedi June 23, 2022, 7:26 p.m. UTC
Move common checks inside the common function, and maintain the only
difference the two being how to obtain the struct net * from ctx.
No functional change intended.

Signed-off-by: Kumar Kartikeya Dwivedi <memxor@gmail.com>
---
 net/netfilter/nf_conntrack_bpf.c | 52 +++++++++++---------------------
 1 file changed, 18 insertions(+), 34 deletions(-)
diff mbox series

Patch

diff --git a/net/netfilter/nf_conntrack_bpf.c b/net/netfilter/nf_conntrack_bpf.c
index bc4d5cd63a94..5cb1820054fb 100644
--- a/net/netfilter/nf_conntrack_bpf.c
+++ b/net/netfilter/nf_conntrack_bpf.c
@@ -57,16 +57,19 @@  enum {
 
 static struct nf_conn *__bpf_nf_ct_lookup(struct net *net,
 					  struct bpf_sock_tuple *bpf_tuple,
-					  u32 tuple_len, u8 protonum,
-					  s32 netns_id, u8 *dir)
+					  u32 tuple_len, struct bpf_ct_opts *opts,
+					  u32 opts_len)
 {
 	struct nf_conntrack_tuple_hash *hash;
 	struct nf_conntrack_tuple tuple;
 	struct nf_conn *ct;
 
-	if (unlikely(protonum != IPPROTO_TCP && protonum != IPPROTO_UDP))
+	if (!opts || !bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
+	    opts_len != NF_BPF_CT_OPTS_SZ)
+		return ERR_PTR(-EINVAL);
+	if (unlikely(opts->l4proto != IPPROTO_TCP && opts->l4proto != IPPROTO_UDP))
 		return ERR_PTR(-EPROTO);
-	if (unlikely(netns_id < BPF_F_CURRENT_NETNS))
+	if (unlikely(opts->netns_id < BPF_F_CURRENT_NETNS))
 		return ERR_PTR(-EINVAL);
 
 	memset(&tuple, 0, sizeof(tuple));
@@ -89,23 +92,22 @@  static struct nf_conn *__bpf_nf_ct_lookup(struct net *net,
 		return ERR_PTR(-EAFNOSUPPORT);
 	}
 
-	tuple.dst.protonum = protonum;
+	tuple.dst.protonum = opts->l4proto;
 
-	if (netns_id >= 0) {
-		net = get_net_ns_by_id(net, netns_id);
+	if (opts->netns_id >= 0) {
+		net = get_net_ns_by_id(net, opts->netns_id);
 		if (unlikely(!net))
 			return ERR_PTR(-ENONET);
 	}
 
 	hash = nf_conntrack_find_get(net, &nf_ct_zone_dflt, &tuple);
-	if (netns_id >= 0)
+	if (opts->netns_id >= 0)
 		put_net(net);
 	if (!hash)
 		return ERR_PTR(-ENOENT);
 
 	ct = nf_ct_tuplehash_to_ctrack(hash);
-	if (dir)
-		*dir = NF_CT_DIRECTION(hash);
+	opts->dir = NF_CT_DIRECTION(hash);
 
 	return ct;
 }
@@ -138,20 +140,11 @@  bpf_xdp_ct_lookup(struct xdp_md *xdp_ctx, struct bpf_sock_tuple *bpf_tuple,
 	struct net *caller_net;
 	struct nf_conn *nfct;
 
-	BUILD_BUG_ON(sizeof(struct bpf_ct_opts) != NF_BPF_CT_OPTS_SZ);
-
-	if (!opts)
-		return NULL;
-	if (!bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
-	    opts__sz != NF_BPF_CT_OPTS_SZ) {
-		opts->error = -EINVAL;
-		return NULL;
-	}
 	caller_net = dev_net(ctx->rxq->dev);
-	nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts->l4proto,
-				  opts->netns_id, &opts->dir);
+	nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts, opts__sz);
 	if (IS_ERR(nfct)) {
-		opts->error = PTR_ERR(nfct);
+		if (opts)
+			opts->error = PTR_ERR(nfct);
 		return NULL;
 	}
 	return nfct;
@@ -181,20 +174,11 @@  bpf_skb_ct_lookup(struct __sk_buff *skb_ctx, struct bpf_sock_tuple *bpf_tuple,
 	struct net *caller_net;
 	struct nf_conn *nfct;
 
-	BUILD_BUG_ON(sizeof(struct bpf_ct_opts) != NF_BPF_CT_OPTS_SZ);
-
-	if (!opts)
-		return NULL;
-	if (!bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
-	    opts__sz != NF_BPF_CT_OPTS_SZ) {
-		opts->error = -EINVAL;
-		return NULL;
-	}
 	caller_net = skb->dev ? dev_net(skb->dev) : sock_net(skb->sk);
-	nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts->l4proto,
-				  opts->netns_id, &opts->dir);
+	nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts, opts__sz);
 	if (IS_ERR(nfct)) {
-		opts->error = PTR_ERR(nfct);
+		if (opts)
+			opts->error = PTR_ERR(nfct);
 		return NULL;
 	}
 	return nfct;