@@ -24,6 +24,58 @@
#define MPTCP_SAME_STATE TCP_MAX_STATES
+static void __mptcp_close(struct sock *sk, long timeout);
+
+static const struct proto_ops * tcp_proto_ops(struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ return &inet6_stream_ops;
+#endif
+ return &inet_stream_ops;
+}
+
+/* MP_CAPABLE handshake failed, convert msk to plain tcp, replacing
+ * socket->sk and stream ops and destroying msk
+ * return the msk socket, as we can't access msk anymore after this function
+ * completes
+ * Called with msk lock held, releases such lock before returning
+ */
+static struct socket *__mptcp_fallback_to_tcp(struct mptcp_sock *msk,
+ struct sock *ssk)
+{
+ struct mptcp_subflow_context *subflow;
+ struct socket *sock;
+ struct sock *sk;
+
+ sk = (struct sock *)msk;
+ sock = sk->sk_socket;
+ subflow = mptcp_subflow_ctx(ssk);
+
+ /* detach the msk socket */
+ list_del_init(&subflow->node);
+ sock_orphan(sk);
+ sock->sk = NULL;
+
+ /* socket is now TCP */
+ lock_sock(ssk);
+ sock_graft(ssk, sock);
+ if (subflow->conn) {
+ /* We can't release the ULP data on a live socket,
+ * restore the tcp callback
+ */
+ mptcp_subflow_tcp_fallback(ssk, subflow);
+ sock_put(subflow->conn);
+ subflow->conn = NULL;
+ }
+ release_sock(ssk);
+ sock->ops = tcp_proto_ops(ssk);
+
+ /* destroy the left-over msk sock */
+ __mptcp_close(sk, 0);
+ return sock;
+}
+
/* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not
* completed yet or has failed, return the subflow socket.
* Otherwise return NULL.
@@ -36,25 +88,37 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
return msk->subflow;
}
-/* if msk has a single subflow, and the mp_capable handshake is failed,
- * return it.
+static bool __mptcp_needs_tcp_fallback(const struct mptcp_sock *msk)
+{
+ return msk->first && !sk_is_mptcp(msk->first);
+}
+
+/* if the mp_capable handshake has failed, it fallbacks msk to plain TCP,
+ * releases the socket lock and returns a reference to the now TCP socket.
* Otherwise returns NULL
*/
-static struct socket *__mptcp_tcp_fallback(const struct mptcp_sock *msk)
+static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk)
{
- struct socket *ssock = __mptcp_nmpc_socket(msk);
-
sock_owned_by_me((const struct sock *)msk);
- if (!ssock || sk_is_mptcp(ssock->sk))
+ if (likely(!__mptcp_needs_tcp_fallback(msk)))
return NULL;
- return ssock;
+ if (msk->subflow) {
+ /* the first subflow is an active connection, discart the
+ * paired socket
+ */
+ msk->subflow->sk = NULL;
+ sock_release(msk->subflow);
+ msk->subflow = NULL;
+ }
+
+ return __mptcp_fallback_to_tcp(msk, msk->first);
}
static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk)
{
- return ((struct sock *)msk)->sk_state == TCP_CLOSE;
+ return !msk->first;
}
static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)
@@ -75,6 +139,7 @@ static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)
if (err)
return ERR_PTR(err);
+ msk->first = ssock->sk;
msk->subflow = ssock;
subflow = mptcp_subflow_ctx(ssock->sk);
list_add(&subflow->node, &msk->conn_list);
@@ -154,6 +219,8 @@ static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,
ret = sk_stream_wait_memory(ssk, timeo);
if (ret)
return ret;
+ if (unlikely(__mptcp_needs_tcp_fallback(msk)))
+ return 0;
}
/* compute copy limit */
@@ -265,11 +332,11 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
lock_sock(sk);
ssock = __mptcp_tcp_fallback(msk);
- if (ssock) {
+ if (unlikely(ssock)) {
+fallback:
pr_debug("fallback passthrough");
ret = sock_sendmsg(ssock, msg);
- release_sock(sk);
- return ret;
+ return ret >= 0 ? ret + copied : (copied ? copied : ret);
}
timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
@@ -288,6 +355,11 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
&size_goal);
if (ret < 0)
break;
+ if (ret == 0 && unlikely(__mptcp_needs_tcp_fallback(msk))) {
+ release_sock(ssk);
+ ssock = __mptcp_tcp_fallback(msk);
+ goto fallback;
+ }
copied += ret;
}
@@ -368,11 +440,11 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
lock_sock(sk);
ssock = __mptcp_tcp_fallback(msk);
- if (ssock) {
+ if (unlikely(ssock)) {
+fallback:
pr_debug("fallback-read subflow=%p",
mptcp_subflow_ctx(ssock->sk));
copied = sock_recvmsg(ssock, msg, flags);
- release_sock(sk);
return copied;
}
@@ -477,6 +549,8 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
pr_debug("block timeout %ld", timeo);
wait_data = true;
mptcp_wait_data(sk, &timeo);
+ if (unlikely(__mptcp_tcp_fallback(msk)))
+ goto fallback;
}
if (more_data_avail) {
@@ -529,6 +603,8 @@ static int __mptcp_init_sock(struct sock *sk)
INIT_LIST_HEAD(&msk->conn_list);
__set_bit(MPTCP_SEND_SPACE, &msk->flags);
+ msk->first = NULL;
+
return 0;
}
@@ -563,7 +639,8 @@ static void mptcp_subflow_shutdown(struct sock *ssk, int how)
release_sock(ssk);
}
-static void mptcp_close(struct sock *sk, long timeout)
+/* Called with msk lock held, releases such lock before returning */
+static void __mptcp_close(struct sock *sk, long timeout)
{
struct mptcp_subflow_context *subflow, *tmp;
struct mptcp_sock *msk = mptcp_sk(sk);
@@ -571,8 +648,6 @@ static void mptcp_close(struct sock *sk, long timeout)
mptcp_token_destroy(msk->token);
inet_sk_state_store(sk, TCP_CLOSE);
- lock_sock(sk);
-
list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
@@ -585,6 +660,12 @@ static void mptcp_close(struct sock *sk, long timeout)
sk_common_release(sk);
}
+static void mptcp_close(struct sock *sk, long timeout)
+{
+ lock_sock(sk);
+ __mptcp_close(sk, timeout);
+}
+
static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
{
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
@@ -654,6 +735,7 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
msk->local_key = subflow->local_key;
msk->token = subflow->token;
msk->subflow = NULL;
+ msk->first = newsk;
mptcp_token_update_accept(newsk, new_mptcp_sock);
@@ -1007,8 +1089,8 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
static __poll_t mptcp_poll(struct file *file, struct socket *sock,
struct poll_table_struct *wait)
{
- const struct mptcp_sock *msk;
struct sock *sk = sock->sk;
+ struct mptcp_sock *msk;
struct socket *ssock;
__poll_t mask = 0;
@@ -1024,6 +1106,9 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
release_sock(sk);
sock_poll_wait(file, sock, wait);
lock_sock(sk);
+ ssock = __mptcp_tcp_fallback(msk);
+ if (unlikely(ssock))
+ return ssock->ops->poll(file, ssock, NULL);
if (test_bit(MPTCP_DATA_READY, &msk->flags))
mask = EPOLLIN | EPOLLRDNORM;
@@ -73,6 +73,7 @@ struct mptcp_sock {
struct list_head conn_list;
struct skb_ext *cached_ext; /* for the next sendmsg */
struct socket *subflow; /* outgoing connect/listener/!mp_capable */
+ struct sock *first;
};
#define mptcp_for_each_subflow(__msk, __subflow) \