@@ -98,8 +98,9 @@ static struct socket *__mptcp_tcp_fallba
return NULL;
if (msk->subflow) {
+ sock = msk->subflow;
release_sock((struct sock *)msk);
- return msk->subflow;
+ return sock;
}
return NULL;
@@ -1843,17 +1844,19 @@ static __poll_t mptcp_poll(struct file *
{
struct sock *sk = sock->sk;
struct mptcp_sock *msk;
- struct socket *ssock;
+ struct socket *ssock, *subflow;
__poll_t mask = 0;
msk = mptcp_sk(sk);
lock_sock(sk);
+ subflow = msk->subflow;
ssock = __mptcp_tcp_fallback(msk);
if (!ssock)
ssock = __mptcp_nmpc_socket(msk);
if (ssock) {
mask = ssock->ops->poll(file, ssock, wait);
- release_sock(sk);
+ if (ssock != subflow)
+ release_sock(sk);
return mask;
}
@@ -1878,15 +1881,17 @@ static int mptcp_shutdown(struct socket
{
struct mptcp_sock *msk = mptcp_sk(sock->sk);
struct mptcp_subflow_context *subflow;
- struct socket *ssock;
+ struct socket *ssock, *subf_sock;
int ret = 0;
pr_debug("sk=%p, how=%d", msk, how);
lock_sock(sock->sk);
+ subf_sock = msk->subflow;
ssock = __mptcp_tcp_fallback(msk);
if (ssock) {
- release_sock(sock->sk);
+ if (ssock != subf_sock)
+ release_sock(sock->sk);
return inet_shutdown(ssock, how);
}