diff mbox series

[v2,2/4] mptcp: check for plain TCP sock at accept time

Message ID 5d8722384e1d1370f75ccb52ed7000029f6fcf78.1592477699.git.pabeni@redhat.com
State Accepted, archived
Delegated to: Matthieu Baerts
Headers show
Series mptcp: fallback refactor follow-up | expand

Commit Message

Paolo Abeni June 18, 2020, 10:55 a.m. UTC
This cleanup the code a bit and avoid corrupted states
on weird syscall sequence (accept(), connect()).

Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
v1 -> v2:
 - mptcp_is_tcpsk() returns a bool value (Mat)
---
 net/mptcp/protocol.c | 69 +++++---------------------------------------
 1 file changed, 7 insertions(+), 62 deletions(-)

Comments

Mat Martineau June 18, 2020, 9:06 p.m. UTC | #1
On Thu, 18 Jun 2020, Paolo Abeni wrote:

> This cleanup the code a bit and avoid corrupted states
> on weird syscall sequence (accept(), connect()).
>
> Signed-off-by: Paolo Abeni <pabeni@redhat.com>
> ---
> v1 -> v2:
> - mptcp_is_tcpsk() returns a bool value (Mat)

Thanks for the v2, looks ready for merging.

Mat


> ---
> net/mptcp/protocol.c | 69 +++++---------------------------------------
> 1 file changed, 7 insertions(+), 62 deletions(-)
>
> diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
> index 2c15bf4cdb53..ff160b95cc84 100644
> --- a/net/mptcp/protocol.c
> +++ b/net/mptcp/protocol.c
> @@ -52,13 +52,10 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
> 	return msk->subflow;
> }
>
> -static struct socket *mptcp_is_tcpsk(struct sock *sk)
> +static bool mptcp_is_tcpsk(struct sock *sk)
> {
> 	struct socket *sock = sk->sk_socket;
>
> -	if (sock->sk != sk)
> -		return NULL;
> -
> 	if (unlikely(sk->sk_prot == &tcp_prot)) {
> 		/* we are being invoked after mptcp_accept() has
> 		 * accepted a non-mp-capable flow: sk is a tcp_sk,
> @@ -68,27 +65,21 @@ static struct socket *mptcp_is_tcpsk(struct sock *sk)
> 		 * bypass mptcp.
> 		 */
> 		sock->ops = &inet_stream_ops;
> -		return sock;
> +		return true;
> #if IS_ENABLED(CONFIG_MPTCP_IPV6)
> 	} else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
> 		sock->ops = &inet6_stream_ops;
> -		return sock;
> +		return true;
> #endif
> 	}
>
> -	return NULL;
> +	return false;
> }
>
> static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk)
> {
> -	struct socket *sock;
> -
> 	sock_owned_by_me((const struct sock *)msk);
>
> -	sock = mptcp_is_tcpsk((struct sock *)msk);
> -	if (unlikely(sock))
> -		return sock;
> -
> 	if (likely(!__mptcp_check_fallback(msk)))
> 		return NULL;
>
> @@ -1572,7 +1563,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
> 		return NULL;
>
> 	pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
> -
> 	if (sk_is_mptcp(newsk)) {
> 		struct mptcp_subflow_context *subflow;
> 		struct sock *new_mptcp_sock;
> @@ -1930,42 +1920,6 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
> 	return err;
> }
>
> -static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
> -			    int peer)
> -{
> -	if (sock->sk->sk_prot == &tcp_prot) {
> -		/* we are being invoked from __sys_accept4, after
> -		 * mptcp_accept() has just accepted a non-mp-capable
> -		 * flow: sk is a tcp_sk, not an mptcp one.
> -		 *
> -		 * Hand the socket over to tcp so all further socket ops
> -		 * bypass mptcp.
> -		 */
> -		sock->ops = &inet_stream_ops;
> -	}
> -
> -	return inet_getname(sock, uaddr, peer);
> -}
> -
> -#if IS_ENABLED(CONFIG_MPTCP_IPV6)
> -static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
> -			    int peer)
> -{
> -	if (sock->sk->sk_prot == &tcpv6_prot) {
> -		/* we are being invoked from __sys_accept4 after
> -		 * mptcp_accept() has accepted a non-mp-capable
> -		 * subflow: sk is a tcp_sk, not mptcp.
> -		 *
> -		 * Hand the socket over to tcp so all further
> -		 * socket ops bypass mptcp.
> -		 */
> -		sock->ops = &inet6_stream_ops;
> -	}
> -
> -	return inet6_getname(sock, uaddr, peer);
> -}
> -#endif
> -
> static int mptcp_listen(struct socket *sock, int backlog)
> {
> 	struct mptcp_sock *msk = mptcp_sk(sock->sk);
> @@ -1994,15 +1948,6 @@ static int mptcp_listen(struct socket *sock, int backlog)
> 	return err;
> }
>
> -static bool is_tcp_proto(const struct proto *p)
> -{
> -#if IS_ENABLED(CONFIG_MPTCP_IPV6)
> -	return p == &tcp_prot || p == &tcpv6_prot;
> -#else
> -	return p == &tcp_prot;
> -#endif
> -}
> -
> static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
> 			       int flags, bool kern)
> {
> @@ -2025,7 +1970,7 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
> 	release_sock(sock->sk);
>
> 	err = ssock->ops->accept(sock, newsock, flags, kern);
> -	if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) {
> +	if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) {
> 		struct mptcp_sock *msk = mptcp_sk(newsock->sk);
> 		struct mptcp_subflow_context *subflow;
>
> @@ -2136,7 +2081,7 @@ static const struct proto_ops mptcp_stream_ops = {
> 	.connect	   = mptcp_stream_connect,
> 	.socketpair	   = sock_no_socketpair,
> 	.accept		   = mptcp_stream_accept,
> -	.getname	   = mptcp_v4_getname,
> +	.getname	   = inet_getname,
> 	.poll		   = mptcp_poll,
> 	.ioctl		   = inet_ioctl,
> 	.gettstamp	   = sock_gettstamp,
> @@ -2196,7 +2141,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
> 	.connect	   = mptcp_stream_connect,
> 	.socketpair	   = sock_no_socketpair,
> 	.accept		   = mptcp_stream_accept,
> -	.getname	   = mptcp_v6_getname,
> +	.getname	   = inet6_getname,
> 	.poll		   = mptcp_poll,
> 	.ioctl		   = inet6_ioctl,
> 	.gettstamp	   = sock_gettstamp,
> -- 
> 2.26.2
> _______________________________________________
> mptcp mailing list -- mptcp@lists.01.org
> To unsubscribe send an email to mptcp-leave@lists.01.org
>

--
Mat Martineau
Intel
diff mbox series

Patch

diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index 2c15bf4cdb53..ff160b95cc84 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -52,13 +52,10 @@  static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
 	return msk->subflow;
 }
 
-static struct socket *mptcp_is_tcpsk(struct sock *sk)
+static bool mptcp_is_tcpsk(struct sock *sk)
 {
 	struct socket *sock = sk->sk_socket;
 
-	if (sock->sk != sk)
-		return NULL;
-
 	if (unlikely(sk->sk_prot == &tcp_prot)) {
 		/* we are being invoked after mptcp_accept() has
 		 * accepted a non-mp-capable flow: sk is a tcp_sk,
@@ -68,27 +65,21 @@  static struct socket *mptcp_is_tcpsk(struct sock *sk)
 		 * bypass mptcp.
 		 */
 		sock->ops = &inet_stream_ops;
-		return sock;
+		return true;
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
 	} else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
 		sock->ops = &inet6_stream_ops;
-		return sock;
+		return true;
 #endif
 	}
 
-	return NULL;
+	return false;
 }
 
 static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk)
 {
-	struct socket *sock;
-
 	sock_owned_by_me((const struct sock *)msk);
 
-	sock = mptcp_is_tcpsk((struct sock *)msk);
-	if (unlikely(sock))
-		return sock;
-
 	if (likely(!__mptcp_check_fallback(msk)))
 		return NULL;
 
@@ -1572,7 +1563,6 @@  static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
 		return NULL;
 
 	pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
-
 	if (sk_is_mptcp(newsk)) {
 		struct mptcp_subflow_context *subflow;
 		struct sock *new_mptcp_sock;
@@ -1930,42 +1920,6 @@  static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
 	return err;
 }
 
-static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
-			    int peer)
-{
-	if (sock->sk->sk_prot == &tcp_prot) {
-		/* we are being invoked from __sys_accept4, after
-		 * mptcp_accept() has just accepted a non-mp-capable
-		 * flow: sk is a tcp_sk, not an mptcp one.
-		 *
-		 * Hand the socket over to tcp so all further socket ops
-		 * bypass mptcp.
-		 */
-		sock->ops = &inet_stream_ops;
-	}
-
-	return inet_getname(sock, uaddr, peer);
-}
-
-#if IS_ENABLED(CONFIG_MPTCP_IPV6)
-static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
-			    int peer)
-{
-	if (sock->sk->sk_prot == &tcpv6_prot) {
-		/* we are being invoked from __sys_accept4 after
-		 * mptcp_accept() has accepted a non-mp-capable
-		 * subflow: sk is a tcp_sk, not mptcp.
-		 *
-		 * Hand the socket over to tcp so all further
-		 * socket ops bypass mptcp.
-		 */
-		sock->ops = &inet6_stream_ops;
-	}
-
-	return inet6_getname(sock, uaddr, peer);
-}
-#endif
-
 static int mptcp_listen(struct socket *sock, int backlog)
 {
 	struct mptcp_sock *msk = mptcp_sk(sock->sk);
@@ -1994,15 +1948,6 @@  static int mptcp_listen(struct socket *sock, int backlog)
 	return err;
 }
 
-static bool is_tcp_proto(const struct proto *p)
-{
-#if IS_ENABLED(CONFIG_MPTCP_IPV6)
-	return p == &tcp_prot || p == &tcpv6_prot;
-#else
-	return p == &tcp_prot;
-#endif
-}
-
 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
 			       int flags, bool kern)
 {
@@ -2025,7 +1970,7 @@  static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
 	release_sock(sock->sk);
 
 	err = ssock->ops->accept(sock, newsock, flags, kern);
-	if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) {
+	if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) {
 		struct mptcp_sock *msk = mptcp_sk(newsock->sk);
 		struct mptcp_subflow_context *subflow;
 
@@ -2136,7 +2081,7 @@  static const struct proto_ops mptcp_stream_ops = {
 	.connect	   = mptcp_stream_connect,
 	.socketpair	   = sock_no_socketpair,
 	.accept		   = mptcp_stream_accept,
-	.getname	   = mptcp_v4_getname,
+	.getname	   = inet_getname,
 	.poll		   = mptcp_poll,
 	.ioctl		   = inet_ioctl,
 	.gettstamp	   = sock_gettstamp,
@@ -2196,7 +2141,7 @@  static const struct proto_ops mptcp_v6_stream_ops = {
 	.connect	   = mptcp_stream_connect,
 	.socketpair	   = sock_no_socketpair,
 	.accept		   = mptcp_stream_accept,
-	.getname	   = mptcp_v6_getname,
+	.getname	   = inet6_getname,
 	.poll		   = mptcp_poll,
 	.ioctl		   = inet6_ioctl,
 	.gettstamp	   = sock_gettstamp,