@@ -178,6 +178,7 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
struct sock *ssk;
long timeo;
+ pr_debug("msk=%p", msk);
lock_sock(sk);
ssock = __mptcp_fallback_get_ref(msk);
if (ssock) {
@@ -846,38 +847,72 @@ static struct proto mptcp_prot = {
.no_autobind = 1,
};
+static struct socket *mptcp_socket_create_get(struct mptcp_sock *msk)
+{
+ struct mptcp_subflow_context *subflow;
+ struct sock *sk = (struct sock *)msk;
+ struct socket *ssock;
+ int err;
+
+ lock_sock(sk);
+ ssock = __mptcp_fallback_get_ref(msk);
+ if (ssock)
+ goto release;
+
+ err = mptcp_subflow_create_socket(sk, &ssock);
+ if (err) {
+ ssock = ERR_PTR(err);
+ goto release;
+ }
+
+ msk->subflow = ssock;
+ subflow = mptcp_subflow_ctx(msk->subflow->sk);
+ subflow->request_mptcp = 1; /* @@ if MPTCP enabled */
+ subflow->request_cksum = 0; /* checksum not supported */
+ subflow->request_version = 0; /* only v0 supported */
+
+ sock_hold(ssock->sk);
+
+release:
+ release_sock(sk);
+ return ssock;
+}
+
static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
{
struct mptcp_sock *msk = mptcp_sk(sock->sk);
+ struct socket *ssock;
int err = -ENOTSUPP;
if (uaddr->sa_family != AF_INET) // @@ allow only IPv4 for now
return err;
- if (!msk->subflow) {
- err = mptcp_subflow_create_socket(sock->sk, &msk->subflow);
- if (err)
- return err;
- }
- return inet_bind(msk->subflow, uaddr, addr_len);
+ ssock = mptcp_socket_create_get(msk);
+ if (IS_ERR(ssock))
+ return PTR_ERR(ssock);
+
+ err = inet_bind(ssock, uaddr, addr_len);
+ sock_put(ssock->sk);
+ return err;
}
static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
int addr_len, int flags)
{
struct mptcp_sock *msk = mptcp_sk(sock->sk);
+ struct socket *ssock;
int err = -ENOTSUPP;
if (uaddr->sa_family != AF_INET) // @@ allow only IPv4 for now
return err;
- if (!msk->subflow) {
- err = mptcp_subflow_create_socket(sock->sk, &msk->subflow);
- if (err)
- return err;
- }
+ ssock = mptcp_socket_create_get(msk);
+ if (IS_ERR(ssock))
+ return PTR_ERR(ssock);
- return inet_stream_connect(msk->subflow, uaddr, addr_len, flags);
+ err = inet_stream_connect(ssock, uaddr, addr_len, flags);
+ sock_put(ssock->sk);
+ return err;
}
static int mptcp_getname(struct socket *sock, struct sockaddr *uaddr,
@@ -929,29 +964,36 @@ static int mptcp_getname(struct socket *sock, struct sockaddr *uaddr,
static int mptcp_listen(struct socket *sock, int backlog)
{
struct mptcp_sock *msk = mptcp_sk(sock->sk);
+ struct socket *ssock;
int err;
pr_debug("msk=%p", msk);
- if (!msk->subflow) {
- err = mptcp_subflow_create_socket(sock->sk, &msk->subflow);
- if (err)
- return err;
- }
- return inet_listen(msk->subflow, backlog);
+ ssock = mptcp_socket_create_get(msk);
+ if (IS_ERR(ssock))
+ return PTR_ERR(ssock);
+
+ err = inet_listen(ssock, backlog);
+ sock_put(ssock->sk);
+ return err;
}
static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
int flags, bool kern)
{
struct mptcp_sock *msk = mptcp_sk(sock->sk);
+ struct socket *ssock;
+ int err;
pr_debug("msk=%p", msk);
- if (!msk->subflow)
+ ssock = mptcp_fallback_get_ref(msk);
+ if (!ssock)
return -EINVAL;
- return inet_accept(sock, newsock, flags, kern);
+ err = inet_accept(sock, newsock, flags, kern);
+ sock_put(ssock->sk);
+ return err;
}
static __poll_t mptcp_poll(struct file *file, struct socket *sock,
@@ -293,9 +293,6 @@ int mptcp_subflow_create_socket(struct sock *sk, struct socket **new_sock)
*new_sock = sf;
sock_hold(sk);
subflow->conn = sk;
- subflow->request_mptcp = 1; // @@ if MPTCP enabled
- subflow->request_cksum = 1; // @@ if checksum enabled
- subflow->request_version = 0;
return 0;
}