@@ -136,51 +136,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net)
}
-/* Lookup the tunnel socket, possibly involving the fs code if the socket is
- * owned by userspace. A struct sock returned from this function must be
- * released using l2tp_tunnel_sock_put once you're done with it.
- */
-static struct sock *l2tp_tunnel_sock_lookup(struct l2tp_tunnel *tunnel)
-{
- int err = 0;
- struct socket *sock = NULL;
- struct sock *sk = NULL;
-
- if (!tunnel)
- goto out;
-
- if (tunnel->fd >= 0) {
- /* Socket is owned by userspace, who might be in the process
- * of closing it. Look the socket up using the fd to ensure
- * consistency.
- */
- sock = sockfd_lookup(tunnel->fd, &err);
- if (sock)
- sk = sock->sk;
- } else {
- /* Socket is owned by kernelspace */
- sk = tunnel->sock;
- sock_hold(sk);
- }
-
-out:
- return sk;
-}
-
-/* Drop a reference to a tunnel socket obtained via. l2tp_tunnel_sock_put */
-static void l2tp_tunnel_sock_put(struct sock *sk)
-{
- struct l2tp_tunnel *tunnel = l2tp_sock_to_tunnel(sk);
- if (tunnel) {
- if (tunnel->fd >= 0) {
- /* Socket is owned by userspace */
- sockfd_put(sk->sk_socket);
- }
- sock_put(sk);
- }
- sock_put(sk);
-}
-
/* Session hash list.
* The session_id SHOULD be random according to RFC2661, but several
* L2TP implementations (Cisco and Microsoft) use incrementing
@@ -193,6 +148,12 @@ static void l2tp_tunnel_sock_put(struct sock *sk)
return &tunnel->session_hlist[hash_32(session_id, L2TP_HASH_BITS)];
}
+void l2tp_tunnel_free(struct l2tp_tunnel *tunnel)
+{
+ sock_put(tunnel->sock);
+ /* the tunnel is freed in the socket destructor */
+}
+
/* Lookup a tunnel. A new reference is held on the returned tunnel. */
struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
{
@@ -202,6 +163,13 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
rcu_read_lock_bh();
list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
if (tunnel->tunnel_id == tunnel_id) {
+ spin_lock_bh(&tunnel->lock);
+ if (tunnel->closing) {
+ spin_unlock_bh(&tunnel->lock);
+ rcu_read_unlock_bh();
+ return NULL;
+ }
+ spin_unlock_bh(&tunnel->lock);
l2tp_tunnel_inc_refcount(tunnel);
rcu_read_unlock_bh();
@@ -230,7 +198,14 @@ struct l2tp_session *l2tp_session_get(const struct net *net,
rcu_read_lock_bh();
hlist_for_each_entry_rcu(session, session_list, global_hlist) {
if (session->session_id == session_id) {
+ spin_lock_bh(&session->lock);
+ if (session->closing) {
+ spin_unlock_bh(&session->lock);
+ rcu_read_unlock_bh();
+ return NULL;
+ }
l2tp_session_inc_refcount(session);
+ spin_unlock_bh(&session->lock);
rcu_read_unlock_bh();
return session;
@@ -245,7 +220,14 @@ struct l2tp_session *l2tp_session_get(const struct net *net,
read_lock_bh(&tunnel->hlist_lock);
hlist_for_each_entry(session, session_list, hlist) {
if (session->session_id == session_id) {
+ spin_lock_bh(&session->lock);
+ if (session->closing) {
+ spin_unlock_bh(&session->lock);
+ read_unlock_bh(&tunnel->hlist_lock);
+ return NULL;
+ }
l2tp_session_inc_refcount(session);
+ spin_unlock_bh(&session->lock);
read_unlock_bh(&tunnel->hlist_lock);
return session;
@@ -266,6 +248,12 @@ struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth)
read_lock_bh(&tunnel->hlist_lock);
for (hash = 0; hash < L2TP_HASH_SIZE; hash++) {
hlist_for_each_entry(session, &tunnel->session_hlist[hash], hlist) {
+ spin_lock_bh(&session->lock);
+ if (session->closing) {
+ spin_unlock_bh(&session->lock);
+ continue;
+ }
+ spin_unlock_bh(&session->lock);
if (++count > nth) {
l2tp_session_inc_refcount(session);
read_unlock_bh(&tunnel->hlist_lock);
@@ -293,6 +281,12 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
rcu_read_lock_bh();
for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) {
hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) {
+ spin_lock_bh(&session->lock);
+ if (session->closing) {
+ spin_unlock_bh(&session->lock);
+ continue;
+ }
+ spin_unlock_bh(&session->lock);
if (!strcmp(session->ifname, ifname)) {
l2tp_session_inc_refcount(session);
rcu_read_unlock_bh();
@@ -317,13 +311,17 @@ int l2tp_session_register(struct l2tp_session *session,
struct l2tp_net *pn;
int err;
+ spin_lock_bh(&tunnel->lock);
+ if (tunnel->closing) {
+ spin_unlock_bh(&tunnel->lock);
+ return -ENODEV;
+ }
+ l2tp_tunnel_inc_refcount(tunnel);
+ spin_unlock_bh(&tunnel->lock);
+
head = l2tp_session_id_hash(tunnel, session->session_id);
write_lock_bh(&tunnel->hlist_lock);
- if (!tunnel->acpt_newsess) {
- err = -ENODEV;
- goto err_tlock;
- }
hlist_for_each_entry(session_walk, head, hlist)
if (session_walk->session_id == session->session_id) {
@@ -344,14 +342,9 @@ int l2tp_session_register(struct l2tp_session *session,
goto err_tlock_pnlock;
}
- l2tp_tunnel_inc_refcount(tunnel);
- sock_hold(tunnel->sock);
hlist_add_head_rcu(&session->global_hlist, g_head);
spin_unlock_bh(&pn->l2tp_session_hlist_lock);
- } else {
- l2tp_tunnel_inc_refcount(tunnel);
- sock_hold(tunnel->sock);
}
hlist_add_head(&session->hlist, head);
@@ -363,6 +356,7 @@ int l2tp_session_register(struct l2tp_session *session,
spin_unlock_bh(&pn->l2tp_session_hlist_lock);
err_tlock:
write_unlock_bh(&tunnel->hlist_lock);
+ l2tp_tunnel_dec_refcount(tunnel);
return err;
}
@@ -969,7 +963,7 @@ int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff *skb)
{
struct l2tp_tunnel *tunnel;
- tunnel = l2tp_sock_to_tunnel(sk);
+ tunnel = l2tp_tunnel(sk);
if (tunnel == NULL)
goto pass_up;
@@ -977,13 +971,10 @@ int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff *skb)
tunnel->name, skb->len);
if (l2tp_udp_recv_core(tunnel, skb, tunnel->recv_payload_hook))
- goto pass_up_put;
+ goto pass_up;
- sock_put(sk);
return 0;
-pass_up_put:
- sock_put(sk);
pass_up:
return 1;
}
@@ -1214,8 +1205,8 @@ static void l2tp_tunnel_destruct(struct sock *sk)
l2tp_info(tunnel, L2TP_MSG_CONTROL, "%s: closing...\n", tunnel->name);
-
/* Disable udp encapsulation */
+ write_lock_bh(&sk->sk_callback_lock);
switch (tunnel->encap) {
case L2TP_ENCAPTYPE_UDP:
/* No longer an encapsulation socket. See net/ipv4/udp.c */
@@ -1229,7 +1220,8 @@ static void l2tp_tunnel_destruct(struct sock *sk)
/* Remove hooks into tunnel socket */
sk->sk_destruct = tunnel->old_sk_destruct;
- sk->sk_user_data = NULL;
+ rcu_assign_sk_user_data(sk, NULL);
+ write_unlock_bh(&sk->sk_callback_lock);
/* Remove the tunnel struct from the tunnel list */
pn = l2tp_pernet(tunnel->l2tp_net);
@@ -1237,12 +1229,11 @@ static void l2tp_tunnel_destruct(struct sock *sk)
list_del_rcu(&tunnel->list);
spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
- tunnel->sock = NULL;
- l2tp_tunnel_dec_refcount(tunnel);
-
/* Call the original destructor */
if (sk->sk_destruct)
(*sk->sk_destruct)(sk);
+
+ kfree_rcu(tunnel, rcu);
end:
return;
}
@@ -1262,38 +1253,10 @@ void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel)
tunnel->name);
write_lock_bh(&tunnel->hlist_lock);
- tunnel->acpt_newsess = false;
for (hash = 0; hash < L2TP_HASH_SIZE; hash++) {
-again:
hlist_for_each_safe(walk, tmp, &tunnel->session_hlist[hash]) {
session = hlist_entry(walk, struct l2tp_session, hlist);
-
- l2tp_info(session, L2TP_MSG_CONTROL,
- "%s: closing session\n", session->name);
-
- hlist_del_init(&session->hlist);
-
- if (test_and_set_bit(0, &session->dead))
- goto again;
-
- write_unlock_bh(&tunnel->hlist_lock);
-
- __l2tp_session_unhash(session);
- l2tp_session_queue_purge(session);
-
- if (session->session_close != NULL)
- (*session->session_close)(session);
-
- l2tp_session_dec_refcount(session);
-
- write_lock_bh(&tunnel->hlist_lock);
-
- /* Now restart from the beginning of this hash
- * chain. We always remove a session from the
- * list so we are guaranteed to make forward
- * progress.
- */
- goto again;
+ l2tp_session_delete(session);
}
}
write_unlock_bh(&tunnel->hlist_lock);
@@ -1303,30 +1266,21 @@ void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel)
/* Tunnel socket destroy hook for UDP encapsulation */
static void l2tp_udp_encap_destroy(struct sock *sk)
{
- struct l2tp_tunnel *tunnel = l2tp_sock_to_tunnel(sk);
+ struct l2tp_tunnel *tunnel = l2tp_tunnel(sk);
if (tunnel) {
- l2tp_tunnel_closeall(tunnel);
- sock_put(sk);
+ l2tp_tunnel_delete(tunnel);
}
}
/* Workqueue tunnel deletion function */
static void l2tp_tunnel_del_work(struct work_struct *work)
{
- struct l2tp_tunnel *tunnel = NULL;
- struct socket *sock = NULL;
- struct sock *sk = NULL;
-
- tunnel = container_of(work, struct l2tp_tunnel, del_work);
+ struct l2tp_tunnel *tunnel = container_of(work, struct l2tp_tunnel, del_work);
+ struct sock *sk = tunnel->sock;
+ struct socket *sock = sk->sk_socket;
l2tp_tunnel_closeall(tunnel);
- sk = l2tp_tunnel_sock_lookup(tunnel);
- if (!sk)
- goto out;
-
- sock = sk->sk_socket;
-
/* If the tunnel socket was created by userspace, then go through the
* inet layer to shut the socket down, and let userspace close it.
* Otherwise, if we created the socket directly within the kernel, use
@@ -1335,7 +1289,7 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
* destructor when the tunnel socket goes away.
*/
if (tunnel->fd >= 0) {
- if (sock)
+ if (sock && sock->sk)
inet_shutdown(sock, 2);
} else {
if (sock) {
@@ -1344,8 +1298,10 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
}
}
- l2tp_tunnel_sock_put(sk);
-out:
+ /* drop initial ref */
+ l2tp_tunnel_dec_refcount(tunnel);
+
+ /* drop workqueue ref */
l2tp_tunnel_dec_refcount(tunnel);
}
@@ -1495,8 +1451,6 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
} else {
sock = sockfd_lookup(fd, &err);
if (!sock) {
- pr_err("tunl %u: sockfd_lookup(fd=%d) returned %d\n",
- tunnel_id, fd, err);
err = -EBADF;
goto err;
}
@@ -1534,14 +1488,6 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
break;
}
- /* Check if this socket has already been prepped */
- tunnel = l2tp_tunnel(sk);
- if (tunnel != NULL) {
- /* This socket has already been prepped */
- err = -EBUSY;
- goto err;
- }
-
tunnel = kzalloc(sizeof(struct l2tp_tunnel), GFP_KERNEL);
if (tunnel == NULL) {
err = -ENOMEM;
@@ -1555,8 +1501,8 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
tunnel->magic = L2TP_TUNNEL_MAGIC;
sprintf(&tunnel->name[0], "tunl %u", tunnel_id);
+ spin_lock_init(&tunnel->lock);
rwlock_init(&tunnel->hlist_lock);
- tunnel->acpt_newsess = true;
/* The net we belong to */
tunnel->l2tp_net = net;
@@ -1583,6 +1529,20 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
}
#endif
+ /* Assign socket sk_user_data. Must be done with
+ * sk_callback_lock. Bail if sk_user_data is already assigned.
+ */
+ write_lock_bh(&sk->sk_callback_lock);
+ if (sk->sk_user_data) {
+ err = -EALREADY;
+ write_unlock_bh(&sk->sk_callback_lock);
+ kfree(tunnel);
+ tunnel = NULL;
+ goto err;
+ }
+ rcu_assign_sk_user_data(sk, tunnel);
+ write_unlock_bh(&sk->sk_callback_lock);
+
/* Mark socket as an encapsulation socket. See net/ipv4/udp.c */
tunnel->encap = encap;
if (encap == L2TP_ENCAPTYPE_UDP) {
@@ -1594,8 +1554,6 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
udp_cfg.encap_destroy = l2tp_udp_encap_destroy;
setup_udp_tunnel_sock(net, sock, &udp_cfg);
- } else {
- sk->sk_user_data = tunnel;
}
/* Hook on the tunnel socket destructor so that we can cleanup
@@ -1603,6 +1561,7 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
*/
tunnel->old_sk_destruct = sk->sk_destruct;
sk->sk_destruct = &l2tp_tunnel_destruct;
+
tunnel->sock = sk;
tunnel->fd = fd;
lockdep_set_class_and_name(&sk->sk_lock.slock, &l2tp_socket_class, "l2tp_sock");
@@ -1616,9 +1575,12 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
INIT_LIST_HEAD(&tunnel->list);
/* Bump the reference count. The tunnel context is deleted
- * only when this drops to zero. Must be done before list insertion
+ * only when this drops to zero. A reference is also held on
+ * the tunnel socket to ensure that it is not released while
+ * the tunnel is extant. Must be done before list insertion
*/
refcount_set(&tunnel->ref_count, 1);
+ sock_hold(sk);
spin_lock_bh(&pn->l2tp_tunnel_list_lock);
list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list);
spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
@@ -1642,10 +1604,17 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
*/
void l2tp_tunnel_delete(struct l2tp_tunnel *tunnel)
{
- if (!test_and_set_bit(0, &tunnel->dead)) {
- l2tp_tunnel_inc_refcount(tunnel);
- queue_work(l2tp_wq, &tunnel->del_work);
+ spin_lock_bh(&tunnel->lock);
+ if (tunnel->closing) {
+ spin_unlock_bh(&tunnel->lock);
+ return;
}
+ tunnel->closing = true;
+ spin_unlock_bh(&tunnel->lock);
+
+ /* Hold tunnel ref while queued work item is pending */
+ l2tp_tunnel_inc_refcount(tunnel);
+ queue_work(l2tp_wq, &tunnel->del_work);
}
EXPORT_SYMBOL_GPL(l2tp_tunnel_delete);
@@ -1657,14 +1626,15 @@ void l2tp_session_free(struct l2tp_session *session)
BUG_ON(refcount_read(&session->ref_count) != 0);
+ if (session->session_free)
+ session->session_free(session);
+ else
+ kfree(session);
+
if (tunnel) {
BUG_ON(tunnel->magic != L2TP_TUNNEL_MAGIC);
- sock_put(tunnel->sock);
- session->tunnel = NULL;
l2tp_tunnel_dec_refcount(tunnel);
}
-
- kfree(session);
}
EXPORT_SYMBOL_GPL(l2tp_session_free);
@@ -1673,7 +1643,7 @@ void l2tp_session_free(struct l2tp_session *session)
* shutdown via. l2tp_session_delete and a pseudowire-specific session_close
* callback.
*/
-void __l2tp_session_unhash(struct l2tp_session *session)
+static void l2tp_session_unhash(struct l2tp_session *session)
{
struct l2tp_tunnel *tunnel = session->tunnel;
@@ -1694,23 +1664,43 @@ void __l2tp_session_unhash(struct l2tp_session *session)
}
}
}
-EXPORT_SYMBOL_GPL(__l2tp_session_unhash);
-/* This function is used by the netlink SESSION_DELETE command and by
- pseudowire modules.
- */
-int l2tp_session_delete(struct l2tp_session *session)
+/* Workqueue session deletion function */
+static void l2tp_session_del_work(struct work_struct *work)
{
- if (test_and_set_bit(0, &session->dead))
- return 0;
+ struct l2tp_session *session = container_of(work, struct l2tp_session, del_work);
- __l2tp_session_unhash(session);
+ l2tp_info(session, L2TP_MSG_CONTROL,
+ "%s: closing session\n", session->name);
+
+ l2tp_session_unhash(session);
l2tp_session_queue_purge(session);
if (session->session_close != NULL)
(*session->session_close)(session);
+ /* drop initial ref */
+ l2tp_session_dec_refcount(session);
+
+ /* drop workqueue ref */
l2tp_session_dec_refcount(session);
+}
+
+/* This function is used by the netlink SESSION_DELETE command and by
+ pseudowire modules.
+ */
+int l2tp_session_delete(struct l2tp_session *session)
+{
+ spin_lock_bh(&session->lock);
+ if (session->closing) {
+ spin_unlock_bh(&session->lock);
+ return 0;
+ }
+ session->closing = true;
+ spin_unlock_bh(&session->lock);
+ /* Hold session ref while queued work item is pending */
+ l2tp_session_inc_refcount(session);
+ queue_work(l2tp_wq, &session->del_work);
return 0;
}
EXPORT_SYMBOL_GPL(l2tp_session_delete);
@@ -1738,6 +1728,13 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
{
struct l2tp_session *session;
+ spin_lock_bh(&tunnel->lock);
+ if (tunnel->closing) {
+ spin_unlock_bh(&tunnel->lock);
+ return ERR_PTR(-ENODEV);
+ }
+ spin_unlock_bh(&tunnel->lock);
+
session = kzalloc(sizeof(struct l2tp_session) + priv_size, GFP_KERNEL);
if (session != NULL) {
session->magic = L2TP_SESSION_MAGIC;
@@ -1763,6 +1760,9 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
INIT_HLIST_NODE(&session->hlist);
INIT_HLIST_NODE(&session->global_hlist);
+ spin_lock_init(&session->lock);
+
+ INIT_WORK(&session->del_work, l2tp_session_del_work);
/* Inherit debug options from tunnel */
session->debug = tunnel->debug;
@@ -74,7 +74,8 @@ struct l2tp_session_cfg {
struct l2tp_session {
int magic; /* should be
* L2TP_SESSION_MAGIC */
- long dead;
+ bool closing;
+ spinlock_t lock;
struct l2tp_tunnel *tunnel; /* back pointer to tunnel
* context */
@@ -121,9 +122,12 @@ struct l2tp_session {
struct l2tp_stats stats;
struct hlist_node global_hlist; /* Global hash list node */
+ struct work_struct del_work;
+
int (*build_header)(struct l2tp_session *session, void *buf);
void (*recv_skb)(struct l2tp_session *session, struct sk_buff *skb, int data_len);
void (*session_close)(struct l2tp_session *session);
+ void (*session_free)(struct l2tp_session *session);
#if IS_ENABLED(CONFIG_L2TP_DEBUGFS)
void (*show)(struct seq_file *m, void *priv);
#endif
@@ -155,14 +159,11 @@ struct l2tp_tunnel_cfg {
struct l2tp_tunnel {
int magic; /* Should be L2TP_TUNNEL_MAGIC */
- unsigned long dead;
+ bool closing;
+ spinlock_t lock;
struct rcu_head rcu;
rwlock_t hlist_lock; /* protect session_hlist */
- bool acpt_newsess; /* Indicates whether this
- * tunnel accepts new sessions.
- * Protected by hlist_lock.
- */
struct hlist_head session_hlist[L2TP_HASH_SIZE];
/* hashed list of sessions,
* hashed by id */
@@ -214,27 +215,8 @@ static inline void *l2tp_session_priv(struct l2tp_session *session)
return &session->priv[0];
}
-static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk)
-{
- struct l2tp_tunnel *tunnel;
-
- if (sk == NULL)
- return NULL;
-
- sock_hold(sk);
- tunnel = (struct l2tp_tunnel *)(sk->sk_user_data);
- if (tunnel == NULL) {
- sock_put(sk);
- goto out;
- }
-
- BUG_ON(tunnel->magic != L2TP_TUNNEL_MAGIC);
-
-out:
- return tunnel;
-}
-
struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id);
+void l2tp_tunnel_free(struct l2tp_tunnel *tunnel);
struct l2tp_session *l2tp_session_get(const struct net *net,
struct l2tp_tunnel *tunnel,
@@ -257,7 +239,6 @@ struct l2tp_session *l2tp_session_create(int priv_size,
int l2tp_session_register(struct l2tp_session *session,
struct l2tp_tunnel *tunnel);
-void __l2tp_session_unhash(struct l2tp_session *session);
int l2tp_session_delete(struct l2tp_session *session);
void l2tp_session_free(struct l2tp_session *session);
void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
@@ -283,7 +264,7 @@ static inline void l2tp_tunnel_inc_refcount(struct l2tp_tunnel *tunnel)
static inline void l2tp_tunnel_dec_refcount(struct l2tp_tunnel *tunnel)
{
if (refcount_dec_and_test(&tunnel->ref_count))
- kfree_rcu(tunnel, rcu);
+ l2tp_tunnel_free(tunnel);
}
/* Session reference counts. Incremented when code obtains a reference
@@ -234,17 +234,17 @@ static void l2tp_ip_close(struct sock *sk, long timeout)
static void l2tp_ip_destroy_sock(struct sock *sk)
{
struct sk_buff *skb;
- struct l2tp_tunnel *tunnel = l2tp_sock_to_tunnel(sk);
+ struct l2tp_tunnel *tunnel;
while ((skb = __skb_dequeue_tail(&sk->sk_write_queue)) != NULL)
kfree_skb(skb);
+ rcu_read_lock();
+ tunnel = rcu_dereference_sk_user_data(sk);
if (tunnel) {
- l2tp_tunnel_closeall(tunnel);
- sock_put(sk);
+ l2tp_tunnel_delete(tunnel);
}
-
- sk_refcnt_debug_dec(sk);
+ rcu_read_unlock();
}
static int l2tp_ip_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len)
@@ -248,16 +248,18 @@ static void l2tp_ip6_close(struct sock *sk, long timeout)
static void l2tp_ip6_destroy_sock(struct sock *sk)
{
- struct l2tp_tunnel *tunnel = l2tp_sock_to_tunnel(sk);
+ struct l2tp_tunnel *tunnel;
lock_sock(sk);
ip6_flush_pending_frames(sk);
release_sock(sk);
+ rcu_read_lock();
+ tunnel = rcu_dereference_sk_user_data(sk);
if (tunnel) {
- l2tp_tunnel_closeall(tunnel);
- sock_put(sk);
+ l2tp_tunnel_delete(tunnel);
}
+ rcu_read_unlock();
inet6_destroy_sock(sk);
}
@@ -166,16 +166,25 @@ static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk)
if (sk == NULL)
return NULL;
- sock_hold(sk);
- session = (struct l2tp_session *)(sk->sk_user_data);
+ rcu_read_lock_bh();
+ session = rcu_dereference_bh(__sk_user_data((sk)));
if (session == NULL) {
- sock_put(sk);
- goto out;
+ rcu_read_unlock_bh();
+ return NULL;
}
+ spin_lock_bh(&session->lock);
+ if (session->closing) {
+ spin_unlock_bh(&session->lock);
+ rcu_read_unlock_bh();
+ return NULL;
+ }
+ l2tp_session_inc_refcount(session);
+ spin_unlock_bh(&session->lock);
+ rcu_read_unlock_bh();
+
BUG_ON(session->magic != L2TP_SESSION_MAGIC);
-out:
return session;
}
@@ -243,8 +252,8 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
/* If the socket is bound, send it in to PPP's input queue. Otherwise
* queue it on the session socket.
*/
- rcu_read_lock();
- sk = rcu_dereference(ps->sk);
+ rcu_read_lock_bh();
+ sk = rcu_dereference_bh(ps->sk);
if (sk == NULL)
goto no_sock;
@@ -267,12 +276,12 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
kfree_skb(skb);
}
}
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return;
no_sock:
- rcu_read_unlock();
+ rcu_read_unlock_bh();
l2tp_info(session, L2TP_MSG_DATA, "%s: no socket\n", session->name);
kfree_skb(skb);
}
@@ -341,12 +350,12 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m,
l2tp_xmit_skb(session, skb, session->hdr_len);
local_bh_enable();
- sock_put(sk);
+ l2tp_session_dec_refcount(session);
return total_len;
error_put_sess:
- sock_put(sk);
+ l2tp_session_dec_refcount(session);
error:
return error;
}
@@ -400,12 +409,12 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
l2tp_xmit_skb(session, skb, session->hdr_len);
local_bh_enable();
- sock_put(sk);
+ l2tp_session_dec_refcount(session);
return 1;
abort_put_sess:
- sock_put(sk);
+ l2tp_session_dec_refcount(session);
abort:
/* Free the original skb */
kfree_skb(skb);
@@ -416,18 +425,73 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
* Session (and tunnel control) socket create/destroy.
*****************************************************************************/
-/* Called by l2tp_core when a session socket is being closed.
+/* called with ps->sk_lock */
+static void pppol2tp_attach(struct l2tp_session *session, struct sock *sk)
+{
+ struct pppol2tp_session *ps = l2tp_session_priv(session);
+
+ write_lock_bh(&sk->sk_callback_lock);
+ rcu_assign_sk_user_data(sk, session);
+ write_unlock_bh(&sk->sk_callback_lock);
+ rcu_assign_pointer(ps->sk, sk);
+}
+
+/* called with ps->sk_lock */
+static void pppol2tp_detach(struct l2tp_session *session, struct sock *sk)
+{
+ struct pppol2tp_session *ps = l2tp_session_priv(session);
+
+ rcu_assign_pointer(ps->sk, NULL);
+ write_lock_bh(&sk->sk_callback_lock);
+ rcu_assign_sk_user_data(sk, NULL);
+ write_unlock_bh(&sk->sk_callback_lock);
+}
+
+static void pppol2tp_put_sk(struct rcu_head *head)
+{
+ struct pppol2tp_session *ps = container_of(head, typeof(*ps), rcu);
+ struct l2tp_session *session = container_of((void *)ps, typeof(*session), priv);
+
+ BUG_ON(session->magic != L2TP_SESSION_MAGIC);
+ sock_put(ps->__sk);
+ kfree(session);
+}
+
+/* Called by l2tp_core when a session is being freed.
+ */
+static void pppol2tp_session_free(struct l2tp_session *session)
+{
+ struct pppol2tp_session *ps = l2tp_session_priv(session);
+ struct sock *sk = ps->__sk;
+ BUG_ON(session->magic != L2TP_SESSION_MAGIC);
+
+ if (sk) {
+ struct socket *sock = sk->sk_socket;
+ if (sock && sock->sk)
+ inet_shutdown(sock, SEND_SHUTDOWN);
+ call_rcu(&ps->rcu, pppol2tp_put_sk);
+ } else {
+ synchronize_rcu();
+ kfree(session);
+ }
+}
+
+/* Called by l2tp_core when a session is being closed.
*/
static void pppol2tp_session_close(struct l2tp_session *session)
{
struct sock *sk;
BUG_ON(session->magic != L2TP_SESSION_MAGIC);
-
sk = pppol2tp_session_get_sock(session);
if (sk) {
- if (sk->sk_socket)
- inet_shutdown(sk->sk_socket, SEND_SHUTDOWN);
+ struct pppol2tp_session *ps = l2tp_session_priv(session);
+ mutex_lock(&ps->sk_lock);
+ ps->__sk = rcu_dereference_protected(ps->sk,
+ lockdep_is_held(&ps->sk_lock));
+ RCU_INIT_POINTER(ps->sk, NULL);
+ pppol2tp_detach(session, sk);
+ mutex_unlock(&ps->sk_lock);
sock_put(sk);
}
}
@@ -437,24 +501,8 @@ static void pppol2tp_session_close(struct l2tp_session *session)
*/
static void pppol2tp_session_destruct(struct sock *sk)
{
- struct l2tp_session *session = sk->sk_user_data;
-
skb_queue_purge(&sk->sk_receive_queue);
skb_queue_purge(&sk->sk_write_queue);
-
- if (session) {
- sk->sk_user_data = NULL;
- BUG_ON(session->magic != L2TP_SESSION_MAGIC);
- l2tp_session_dec_refcount(session);
- }
-}
-
-static void pppol2tp_put_sk(struct rcu_head *head)
-{
- struct pppol2tp_session *ps;
-
- ps = container_of(head, typeof(*ps), rcu);
- sock_put(ps->__sk);
}
/* Called when the PPPoX socket (session) is closed.
@@ -479,28 +527,14 @@ static int pppol2tp_release(struct socket *sock)
sk->sk_state = PPPOX_DEAD;
sock_orphan(sk);
sock->sk = NULL;
+ release_sock(sk);
- session = pppol2tp_sock_to_session(sk);
-
- if (session != NULL) {
- struct pppol2tp_session *ps;
-
+ rcu_read_lock_bh();
+ session = rcu_dereference_bh(__sk_user_data((sk)));
+ if (session) {
l2tp_session_delete(session);
-
- ps = l2tp_session_priv(session);
- mutex_lock(&ps->sk_lock);
- ps->__sk = rcu_dereference_protected(ps->sk,
- lockdep_is_held(&ps->sk_lock));
- RCU_INIT_POINTER(ps->sk, NULL);
- mutex_unlock(&ps->sk_lock);
- call_rcu(&ps->rcu, pppol2tp_put_sk);
-
- /* Rely on the sock_put() call at the end of the function for
- * dropping the reference held by pppol2tp_sock_to_session().
- * The last reference will be dropped by pppol2tp_put_sk().
- */
}
- release_sock(sk);
+ rcu_read_unlock_bh();
/* This will delete the session context via
* pppol2tp_session_destruct() if the socket's refcnt drops to
@@ -584,6 +618,7 @@ static void pppol2tp_session_init(struct l2tp_session *session)
session->recv_skb = pppol2tp_recv;
session->session_close = pppol2tp_session_close;
+ session->session_free = pppol2tp_session_free;
#if IS_ENABLED(CONFIG_L2TP_DEBUGFS)
session->show = pppol2tp_show;
#endif
@@ -605,25 +640,142 @@ static void pppol2tp_session_init(struct l2tp_session *session)
}
}
-/* connect() handler. Attach a PPPoX socket to a tunnel UDP socket
+/* Prepare a tunnel. If a tunnel instance doesn't already exist,
+ * optionally create it. Return with a ref on the tunnel instance.
+ */
+static int pppol2tp_tunnel_prep(struct net *net, int fd, int ver, u32 tunnel_id, u32 peer_tunnel_id, bool can_create, struct l2tp_tunnel **tunnelp)
+{
+ struct l2tp_tunnel *tunnel;
+ int error;
+
+ tunnel = l2tp_tunnel_get(net, tunnel_id);
+ if (!tunnel && can_create) {
+ struct l2tp_tunnel_cfg tcfg = {
+ .encap = L2TP_ENCAPTYPE_UDP,
+ .debug = 0,
+ };
+ error = l2tp_tunnel_create(net, fd, ver, tunnel_id, peer_tunnel_id, &tcfg, &tunnel);
+ if (error < 0)
+ return error;
+
+ l2tp_tunnel_inc_refcount(tunnel);
+ }
+
+ /* Error if we can't find the tunnel */
+ if (tunnel == NULL)
+ return -ENOENT;
+
+ if (tunnel->recv_payload_hook == NULL)
+ tunnel->recv_payload_hook = pppol2tp_recv_payload_hook;
+
+ if (tunnel->peer_tunnel_id == 0)
+ tunnel->peer_tunnel_id = peer_tunnel_id;
+
+ *tunnelp = tunnel;
+ return 0;
+
+ l2tp_tunnel_dec_refcount(tunnel);
+ return error;
+}
+
+/* Prepare a session in a tunnel. If the session doesn't already
+ * exist, create it and add it to the tunnel's session list. Return
+ * with a ref on the session instance and its sk_lock held.
+ */
+static int pppol2tp_session_prep(struct sock *sk, struct l2tp_tunnel *tunnel, u32 session_id, u32 peer_session_id, struct l2tp_session **sessionp)
+{
+ struct l2tp_session *session;
+ struct pppol2tp_session *ps;
+ int error;
+ struct l2tp_session_cfg cfg = {};
+
+ session = l2tp_session_get(sock_net(sk), tunnel, session_id);
+ if (session) {
+ ps = l2tp_session_priv(session);
+
+ /* Using a pre-existing session is fine as long as it hasn't
+ * been connected yet.
+ */
+ mutex_lock(&ps->sk_lock);
+ if (rcu_dereference_protected(ps->sk,
+ lockdep_is_held(&ps->sk_lock))) {
+ mutex_unlock(&ps->sk_lock);
+ l2tp_session_dec_refcount(session);
+ return -EEXIST;
+ }
+ } else {
+ /* Default MTU must allow space for UDP/L2TP/PPP headers */
+ cfg.mtu = 1500 - PPPOL2TP_HEADER_OVERHEAD;
+ cfg.mru = cfg.mtu;
+
+ session = l2tp_session_create(sizeof(struct pppol2tp_session),
+ tunnel, session_id,
+ peer_session_id, &cfg);
+ if (IS_ERR(session)) {
+ error = PTR_ERR(session);
+ return error;
+ }
+
+ pppol2tp_session_init(session);
+ ps = l2tp_session_priv(session);
+
+ mutex_lock(&ps->sk_lock);
+ error = l2tp_session_register(session, tunnel);
+ if (error < 0) {
+ mutex_unlock(&ps->sk_lock);
+ kfree(session);
+ return error;
+ }
+ l2tp_session_inc_refcount(session);
+ }
+
+ *sessionp = session;
+ return 0;
+}
+
+static int pppol2tp_setup_ppp(struct l2tp_session *session, struct sock *sk)
+{
+ struct pppox_sock *po = pppox_sk(sk);
+
+ /* The only header we need to worry about is the L2TP
+ * header. This size is different depending on whether
+ * sequence numbers are enabled for the data channel.
+ */
+ po->chan.hdrlen = PPPOL2TP_L2TP_HDR_SIZE_NOSEQ;
+
+ po->chan.private = sk;
+ po->chan.ops = &pppol2tp_chan_ops;
+ po->chan.mtu = session->mtu;
+
+ return ppp_register_net_channel(sock_net(sk), &po->chan);
+}
+
+/* connect() handler. Attach a PPPoX socket to a tunnel socket.
+ * The PPPoX socket is associated with an l2tp_session and the tunnel
+ * socket is associated with an l2tp_tunnel. The l2tp_tunnel and
+ * l2tp_session are usually created by netlink before the PPPoX socket
+ * is connected. However, for L2TPv2 we support a legacy mode where
+ * netlink is not used and we create the l2tp_tunnel and l2tp_session
+ * when the PPPoX sockets are connected. In legacy mode, a per-tunnel
+ * PPPoX socket is used as a control socket for the tunnel and is
+ * identified by session_id 0. An l2tp_session is created to manage
+ * the control socket and an l2tp_tunnel is created for the tunnel if
+ * it doesn't already exist.
*/
static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
int sockaddr_len, int flags)
{
struct sock *sk = sock->sk;
struct sockaddr_pppol2tp *sp = (struct sockaddr_pppol2tp *) uservaddr;
- struct pppox_sock *po = pppox_sk(sk);
struct l2tp_session *session = NULL;
- struct l2tp_tunnel *tunnel;
+ struct l2tp_tunnel *tunnel = NULL;
struct pppol2tp_session *ps;
- struct l2tp_session_cfg cfg = { 0, };
int error = 0;
u32 tunnel_id, peer_tunnel_id;
u32 session_id, peer_session_id;
- bool drop_refcnt = false;
- bool drop_tunnel = false;
int ver = 2;
int fd;
+ bool is_ctrl_skt;
lock_sock(sk);
@@ -685,135 +837,54 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
goto end; /* bad socket address */
}
- /* Don't bind if tunnel_id is 0 */
error = -EINVAL;
- if (tunnel_id == 0)
+ if (tunnel_id == 0 || peer_tunnel_id == 0)
goto end;
- tunnel = l2tp_tunnel_get(sock_net(sk), tunnel_id);
- if (tunnel)
- drop_tunnel = true;
-
- /* Special case: create tunnel context if session_id and
- * peer_session_id is 0. Otherwise look up tunnel using supplied
- * tunnel id.
+ /* The socket is a control socket if session_id is 0. There is
+ * one control socket per tunnel. Control sockets do not have ppp.
*/
- if ((session_id == 0) && (peer_session_id == 0)) {
- if (tunnel == NULL) {
- struct l2tp_tunnel_cfg tcfg = {
- .encap = L2TP_ENCAPTYPE_UDP,
- .debug = 0,
- };
- error = l2tp_tunnel_create(sock_net(sk), fd, ver, tunnel_id, peer_tunnel_id, &tcfg, &tunnel);
- if (error < 0)
- goto end;
- }
- } else {
- /* Error if we can't find the tunnel */
- error = -ENOENT;
- if (tunnel == NULL)
- goto end;
-
- /* Error if socket is not prepped */
- if (tunnel->sock == NULL)
- goto end;
- }
-
- if (tunnel->recv_payload_hook == NULL)
- tunnel->recv_payload_hook = pppol2tp_recv_payload_hook;
-
- if (tunnel->peer_tunnel_id == 0)
- tunnel->peer_tunnel_id = peer_tunnel_id;
-
- session = l2tp_session_get(sock_net(sk), tunnel, session_id);
- if (session) {
- drop_refcnt = true;
- ps = l2tp_session_priv(session);
-
- /* Using a pre-existing session is fine as long as it hasn't
- * been connected yet.
- */
- mutex_lock(&ps->sk_lock);
- if (rcu_dereference_protected(ps->sk,
- lockdep_is_held(&ps->sk_lock))) {
- mutex_unlock(&ps->sk_lock);
- error = -EEXIST;
- goto end;
- }
- } else {
- /* Default MTU must allow space for UDP/L2TP/PPP headers */
- cfg.mtu = 1500 - PPPOL2TP_HEADER_OVERHEAD;
- cfg.mru = cfg.mtu;
+ is_ctrl_skt = (session_id == 0 && peer_session_id == 0);
- session = l2tp_session_create(sizeof(struct pppol2tp_session),
- tunnel, session_id,
- peer_session_id, &cfg);
- if (IS_ERR(session)) {
- error = PTR_ERR(session);
- goto end;
- }
+ /* prep and possibly create the l2tp tunnel instance */
+ error = pppol2tp_tunnel_prep(sock_net(sk), fd, ver, tunnel_id,
+ peer_tunnel_id, is_ctrl_skt, &tunnel);
+ if (error)
+ goto end;
- pppol2tp_session_init(session);
- ps = l2tp_session_priv(session);
- l2tp_session_inc_refcount(session);
+ /* prep and possibly create the l2tp session instance */
+ error = pppol2tp_session_prep(sk, tunnel, session_id,
+ peer_session_id, &session);
+ if (error)
+ goto end;
- mutex_lock(&ps->sk_lock);
- error = l2tp_session_register(session, tunnel);
- if (error < 0) {
+ /* setup ppp unless it's a control socket */
+ ps = l2tp_session_priv(session);
+ if (!is_ctrl_skt) {
+ error = pppol2tp_setup_ppp(session, sk);
+ if (error) {
mutex_unlock(&ps->sk_lock);
- kfree(session);
goto end;
}
- drop_refcnt = true;
}
- /* Special case: if source & dest session_id == 0x0000, this
- * socket is being created to manage the tunnel. Just set up
- * the internal context for use by ioctl() and sockopt()
- * handlers.
+ /* The session has now been added to the tunnel. Hold the
+ * socket to prevent it going away until the session is
+ * destroyed and attach it to the session such that we can get
+ * the session instance from the socket and vice versa.
*/
- if ((session->session_id == 0) &&
- (session->peer_session_id == 0)) {
- error = 0;
- goto out_no_ppp;
- }
-
- /* The only header we need to worry about is the L2TP
- * header. This size is different depending on whether
- * sequence numbers are enabled for the data channel.
- */
- po->chan.hdrlen = PPPOL2TP_L2TP_HDR_SIZE_NOSEQ;
-
- po->chan.private = sk;
- po->chan.ops = &pppol2tp_chan_ops;
- po->chan.mtu = session->mtu;
-
- error = ppp_register_net_channel(sock_net(sk), &po->chan);
- if (error) {
- mutex_unlock(&ps->sk_lock);
- goto end;
- }
-
-out_no_ppp:
- /* This is how we get the session context from the socket. */
- sk->sk_user_data = session;
- rcu_assign_pointer(ps->sk, sk);
+ sock_hold(sk);
+ pppol2tp_attach(session, sk);
mutex_unlock(&ps->sk_lock);
- /* Keep the reference we've grabbed on the session: sk doesn't expect
- * the session to disappear. pppol2tp_session_destruct() is responsible
- * for dropping it.
- */
- drop_refcnt = false;
-
sk->sk_state = PPPOX_CONNECTED;
l2tp_info(session, L2TP_MSG_CONTROL, "%s: created\n",
session->name);
end:
- if (drop_refcnt)
+ if (session)
l2tp_session_dec_refcount(session);
- if (drop_tunnel)
+ if (tunnel)
l2tp_tunnel_dec_refcount(tunnel);
release_sock(sk);
@@ -829,6 +900,7 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
{
int error;
struct l2tp_session *session;
+ struct pppol2tp_session *ps;
/* Error if tunnel socket is not prepped */
if (!tunnel->sock) {
@@ -852,10 +924,14 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
}
pppol2tp_session_init(session);
-
+ ps = l2tp_session_priv(session);
+ mutex_lock(&ps->sk_lock);
error = l2tp_session_register(session, tunnel);
- if (error < 0)
+ if (error < 0) {
+ mutex_unlock(&ps->sk_lock);
goto err_sess;
+ }
+ mutex_unlock(&ps->sk_lock);
return 0;
@@ -972,7 +1048,7 @@ static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr,
*usockaddr_len = len;
error = 0;
- sock_put(sk);
+ l2tp_session_dec_refcount(session);
end:
return error;
}
@@ -1243,7 +1319,7 @@ static int pppol2tp_ioctl(struct socket *sock, unsigned int cmd,
err = pppol2tp_session_ioctl(session, cmd, arg);
end_put_sess:
- sock_put(sk);
+ l2tp_session_dec_refcount(session);
end:
return err;
}
@@ -1394,7 +1470,7 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname,
err = pppol2tp_session_setsockopt(sk, session, optname, val);
}
- sock_put(sk);
+ l2tp_session_dec_refcount(session);
end:
return err;
}
@@ -1526,7 +1602,7 @@ static int pppol2tp_getsockopt(struct socket *sock, int level, int optname,
err = 0;
end_put_sess:
- sock_put(sk);
+ l2tp_session_dec_refcount(session);
end:
return err;
}