diff mbox series

[11/13] net: Track socket refcounts in skb_steal_sock()

Message ID 20200831040333.6058-12-khalid.elmously@canonical.com
State New
Headers show
Series Requested eBPF improvements | expand

Commit Message

Khalid Elmously Aug. 31, 2020, 4:03 a.m. UTC
From: Joe Stringer <joe@wand.net.nz>

[ upstream commit 71489e21d720a09388b565d60ef87ae993c10528 ]

Refactor the UDP/TCP handlers slightly to allow skb_steal_sock() to make
the determination of whether the socket is reference counted in the case
where it is prefetched by earlier logic such as early_demux.

Signed-off-by: Joe Stringer <joe@wand.net.nz>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Martin KaFai Lau <kafai@fb.com>
Link: https://lore.kernel.org/bpf/20200329225342.16317-3-joe@wand.net.nz
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Signed-off-by: Khalid Elmously <khalid.elmously@canonical.com>
---
 include/net/inet6_hashtables.h |  3 +--
 include/net/inet_hashtables.h  |  3 +--
 include/net/sock.h             | 10 +++++++++-
 net/ipv4/udp.c                 |  6 ++++--
 net/ipv6/udp.c                 |  9 ++++++---
 5 files changed, 21 insertions(+), 10 deletions(-)

Comments

Kleber Souza Sept. 3, 2020, 10:57 a.m. UTC | #1
On 31.08.20 06:03, Khalid Elmously wrote:
> From: Joe Stringer <joe@wand.net.nz>

This patch is missing:

BugLink: https://bugs.launchpad.net/bugs/1887740

> 
> [ upstream commit 71489e21d720a09388b565d60ef87ae993c10528 ]
> 
> Refactor the UDP/TCP handlers slightly to allow skb_steal_sock() to make
> the determination of whether the socket is reference counted in the case
> where it is prefetched by earlier logic such as early_demux.
> 
> Signed-off-by: Joe Stringer <joe@wand.net.nz>
> Signed-off-by: Alexei Starovoitov <ast@kernel.org>
> Acked-by: Martin KaFai Lau <kafai@fb.com>
> Link: https://lore.kernel.org/bpf/20200329225342.16317-3-joe@wand.net.nz
> Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
> Signed-off-by: Khalid Elmously <khalid.elmously@canonical.com>
> ---
>  include/net/inet6_hashtables.h |  3 +--
>  include/net/inet_hashtables.h  |  3 +--
>  include/net/sock.h             | 10 +++++++++-
>  net/ipv4/udp.c                 |  6 ++++--
>  net/ipv6/udp.c                 |  9 ++++++---
>  5 files changed, 21 insertions(+), 10 deletions(-)
> 
> diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h
> index fe96bf247aac..81b965953036 100644
> --- a/include/net/inet6_hashtables.h
> +++ b/include/net/inet6_hashtables.h
> @@ -85,9 +85,8 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
>  					      int iif, int sdif,
>  					      bool *refcounted)
>  {
> -	struct sock *sk = skb_steal_sock(skb);
> +	struct sock *sk = skb_steal_sock(skb, refcounted);
>  
> -	*refcounted = true;
>  	if (sk)
>  		return sk;
>  
> diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
> index d0019d3395cf..ad64ba6a057f 100644
> --- a/include/net/inet_hashtables.h
> +++ b/include/net/inet_hashtables.h
> @@ -379,10 +379,9 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
>  					     const int sdif,
>  					     bool *refcounted)
>  {
> -	struct sock *sk = skb_steal_sock(skb);
> +	struct sock *sk = skb_steal_sock(skb, refcounted);
>  	const struct iphdr *iph = ip_hdr(skb);
>  
> -	*refcounted = true;
>  	if (sk)
>  		return sk;
>  
> diff --git a/include/net/sock.h b/include/net/sock.h
> index b754050401d8..6cb1f0efa01b 100644
> --- a/include/net/sock.h
> +++ b/include/net/sock.h
> @@ -2492,15 +2492,23 @@ skb_sk_is_prefetched(struct sk_buff *skb)
>  #endif /* CONFIG_INET */
>  }
>  
> -static inline struct sock *skb_steal_sock(struct sk_buff *skb)
> +/**
> + * skb_steal_sock
> + * @skb to steal the socket from
> + * @refcounted is set to true if the socket is reference-counted
> + */
> +static inline struct sock *
> +skb_steal_sock(struct sk_buff *skb, bool *refcounted)
>  {
>  	if (skb->sk) {
>  		struct sock *sk = skb->sk;
>  
> +		*refcounted = true;
>  		skb->destructor = NULL;
>  		skb->sk = NULL;
>  		return sk;
>  	}
> +	*refcounted = false;
>  	return NULL;
>  }
>  
> diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
> index f3b7cb725c1b..b7b01f721310 100644
> --- a/net/ipv4/udp.c
> +++ b/net/ipv4/udp.c
> @@ -2286,6 +2286,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
>  	struct rtable *rt = skb_rtable(skb);
>  	__be32 saddr, daddr;
>  	struct net *net = dev_net(skb->dev);
> +	bool refcounted;
>  
>  	/*
>  	 *  Validate the packet.
> @@ -2311,7 +2312,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
>  	if (udp4_csum_init(skb, uh, proto))
>  		goto csum_error;
>  
> -	sk = skb_steal_sock(skb);
> +	sk = skb_steal_sock(skb, &refcounted);
>  	if (sk) {
>  		struct dst_entry *dst = skb_dst(skb);
>  		int ret;
> @@ -2320,7 +2321,8 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
>  			udp_sk_rx_dst_set(sk, dst);
>  
>  		ret = udp_unicast_rcv_skb(sk, skb, uh);
> -		sock_put(sk);
> +		if (refcounted)
> +			sock_put(sk);
>  		return ret;
>  	}
>  
> diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
> index 9fec580c968e..3d34e00124ff 100644
> --- a/net/ipv6/udp.c
> +++ b/net/ipv6/udp.c
> @@ -844,6 +844,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
>  	struct net *net = dev_net(skb->dev);
>  	struct udphdr *uh;
>  	struct sock *sk;
> +	bool refcounted;
>  	u32 ulen = 0;
>  
>  	if (!pskb_may_pull(skb, sizeof(struct udphdr)))
> @@ -880,7 +881,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
>  		goto csum_error;
>  
>  	/* Check if the socket is already available, e.g. due to early demux */
> -	sk = skb_steal_sock(skb);
> +	sk = skb_steal_sock(skb, &refcounted);
>  	if (sk) {
>  		struct dst_entry *dst = skb_dst(skb);
>  		int ret;
> @@ -889,12 +890,14 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
>  			udp6_sk_rx_dst_set(sk, dst);
>  
>  		if (!uh->check && !udp_sk(sk)->no_check6_rx) {
> -			sock_put(sk);
> +			if (refcounted)
> +				sock_put(sk);
>  			goto report_csum_error;
>  		}
>  
>  		ret = udp6_unicast_rcv_skb(sk, skb, uh);
> -		sock_put(sk);
> +		if (refcounted)
> +			sock_put(sk);
>  		return ret;
>  	}
>  
>
diff mbox series

Patch

diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h
index fe96bf247aac..81b965953036 100644
--- a/include/net/inet6_hashtables.h
+++ b/include/net/inet6_hashtables.h
@@ -85,9 +85,8 @@  static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
 					      int iif, int sdif,
 					      bool *refcounted)
 {
-	struct sock *sk = skb_steal_sock(skb);
+	struct sock *sk = skb_steal_sock(skb, refcounted);
 
-	*refcounted = true;
 	if (sk)
 		return sk;
 
diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index d0019d3395cf..ad64ba6a057f 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -379,10 +379,9 @@  static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
 					     const int sdif,
 					     bool *refcounted)
 {
-	struct sock *sk = skb_steal_sock(skb);
+	struct sock *sk = skb_steal_sock(skb, refcounted);
 	const struct iphdr *iph = ip_hdr(skb);
 
-	*refcounted = true;
 	if (sk)
 		return sk;
 
diff --git a/include/net/sock.h b/include/net/sock.h
index b754050401d8..6cb1f0efa01b 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -2492,15 +2492,23 @@  skb_sk_is_prefetched(struct sk_buff *skb)
 #endif /* CONFIG_INET */
 }
 
-static inline struct sock *skb_steal_sock(struct sk_buff *skb)
+/**
+ * skb_steal_sock
+ * @skb to steal the socket from
+ * @refcounted is set to true if the socket is reference-counted
+ */
+static inline struct sock *
+skb_steal_sock(struct sk_buff *skb, bool *refcounted)
 {
 	if (skb->sk) {
 		struct sock *sk = skb->sk;
 
+		*refcounted = true;
 		skb->destructor = NULL;
 		skb->sk = NULL;
 		return sk;
 	}
+	*refcounted = false;
 	return NULL;
 }
 
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index f3b7cb725c1b..b7b01f721310 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -2286,6 +2286,7 @@  int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
 	struct rtable *rt = skb_rtable(skb);
 	__be32 saddr, daddr;
 	struct net *net = dev_net(skb->dev);
+	bool refcounted;
 
 	/*
 	 *  Validate the packet.
@@ -2311,7 +2312,7 @@  int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
 	if (udp4_csum_init(skb, uh, proto))
 		goto csum_error;
 
-	sk = skb_steal_sock(skb);
+	sk = skb_steal_sock(skb, &refcounted);
 	if (sk) {
 		struct dst_entry *dst = skb_dst(skb);
 		int ret;
@@ -2320,7 +2321,8 @@  int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
 			udp_sk_rx_dst_set(sk, dst);
 
 		ret = udp_unicast_rcv_skb(sk, skb, uh);
-		sock_put(sk);
+		if (refcounted)
+			sock_put(sk);
 		return ret;
 	}
 
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 9fec580c968e..3d34e00124ff 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -844,6 +844,7 @@  int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
 	struct net *net = dev_net(skb->dev);
 	struct udphdr *uh;
 	struct sock *sk;
+	bool refcounted;
 	u32 ulen = 0;
 
 	if (!pskb_may_pull(skb, sizeof(struct udphdr)))
@@ -880,7 +881,7 @@  int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
 		goto csum_error;
 
 	/* Check if the socket is already available, e.g. due to early demux */
-	sk = skb_steal_sock(skb);
+	sk = skb_steal_sock(skb, &refcounted);
 	if (sk) {
 		struct dst_entry *dst = skb_dst(skb);
 		int ret;
@@ -889,12 +890,14 @@  int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
 			udp6_sk_rx_dst_set(sk, dst);
 
 		if (!uh->check && !udp_sk(sk)->no_check6_rx) {
-			sock_put(sk);
+			if (refcounted)
+				sock_put(sk);
 			goto report_csum_error;
 		}
 
 		ret = udp6_unicast_rcv_skb(sk, skb, uh);
-		sock_put(sk);
+		if (refcounted)
+			sock_put(sk);
 		return ret;
 	}