@@ -132,7 +132,7 @@ l2tp_session_id_hash_2(struct l2tp_net *pn, u32 session_id)
/* Lookup a session by id in the global session list
*/
-static struct l2tp_session *l2tp_session_find_2(struct net *net, u32 session_id)
+static struct l2tp_session *l2tp_session_find_2(struct net *net, u32 session_id, int ref)
{
struct l2tp_net *pn = l2tp_pernet(net);
struct hlist_head *session_list =
@@ -143,6 +143,8 @@ static struct l2tp_session *l2tp_session_find_2(struct net *net, u32 session_id)
rcu_read_lock_bh();
hlist_for_each_entry_rcu(session, walk, session_list, global_hlist) {
if (session->session_id == session_id) {
+ if (ref)
+ l2tp_session_inc_refcount(session);
rcu_read_unlock_bh();
return session;
}
@@ -166,7 +168,7 @@ l2tp_session_id_hash(struct l2tp_tunnel *tunnel, u32 session_id)
/* Lookup a session by id
*/
-struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunnel, u32 session_id)
+struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunnel, u32 session_id, int ref)
{
struct hlist_head *session_list;
struct l2tp_session *session;
@@ -177,12 +179,14 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
* tunnel.
*/
if (tunnel == NULL)
- return l2tp_session_find_2(net, session_id);
+ return l2tp_session_find_2(net, session_id, ref);
session_list = l2tp_session_id_hash(tunnel, session_id);
read_lock_bh(&tunnel->hlist_lock);
hlist_for_each_entry(session, walk, session_list, hlist) {
if (session->session_id == session_id) {
+ if (ref)
+ l2tp_session_inc_refcount(session);
read_unlock_bh(&tunnel->hlist_lock);
return session;
}
@@ -193,7 +197,7 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
}
EXPORT_SYMBOL_GPL(l2tp_session_find);
-struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
+struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth, int ref)
{
int hash;
struct hlist_node *walk;
@@ -204,6 +208,8 @@ struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
for (hash = 0; hash < L2TP_HASH_SIZE; hash++) {
hlist_for_each_entry(session, walk, &tunnel->session_hlist[hash], hlist) {
if (++count > nth) {
+ if (ref)
+ l2tp_session_inc_refcount(session);
read_unlock_bh(&tunnel->hlist_lock);
return session;
}
@@ -244,7 +250,7 @@ EXPORT_SYMBOL_GPL(l2tp_session_find_by_ifname);
/* Lookup a tunnel by id
*/
-struct l2tp_tunnel *l2tp_tunnel_find(struct net *net, u32 tunnel_id)
+struct l2tp_tunnel *l2tp_tunnel_find(struct net *net, u32 tunnel_id, int ref)
{
struct l2tp_tunnel *tunnel;
struct l2tp_net *pn = l2tp_pernet(net);
@@ -252,6 +258,8 @@ struct l2tp_tunnel *l2tp_tunnel_find(struct net *net, u32 tunnel_id)
rcu_read_lock_bh();
list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
if (tunnel->tunnel_id == tunnel_id) {
+ if (ref)
+ l2tp_tunnel_inc_refcount(tunnel);
rcu_read_unlock_bh();
return tunnel;
}
@@ -500,11 +508,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
int offset;
u32 ns, nr;
- /* The ref count is increased since we now hold a pointer to
- * the session. Take care to decrement the refcnt when exiting
- * this function from now on...
- */
- l2tp_session_inc_refcount(session);
if (session->ref)
(*session->ref)(session);
@@ -785,7 +788,7 @@ int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
}
/* Find the session context */
- session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id);
+ session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id, 1);
if (!session || !session->recv_skb) {
/* Not found? Pass to userspace to deal with */
PRINTK(tunnel->debug, L2TP_MSG_DATA, KERN_INFO,
@@ -221,10 +221,10 @@ out:
return tunnel;
}
-extern struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunnel, u32 session_id);
-extern struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth);
+extern struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunnel, u32 session_id, int ref);
+extern struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth, int ref);
extern struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname);
-extern struct l2tp_tunnel *l2tp_tunnel_find(struct net *net, u32 tunnel_id);
+extern struct l2tp_tunnel *l2tp_tunnel_find(struct net *net, u32 tunnel_id, int ref);
extern struct l2tp_tunnel *l2tp_tunnel_find_nth(struct net *net, int nth);
extern int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 peer_tunnel_id, struct l2tp_tunnel_cfg *cfg, struct l2tp_tunnel **tunnelp);
@@ -51,7 +51,7 @@ static void l2tp_dfs_next_tunnel(struct l2tp_dfs_seq_data *pd)
static void l2tp_dfs_next_session(struct l2tp_dfs_seq_data *pd)
{
- pd->session = l2tp_session_find_nth(pd->tunnel, pd->session_idx);
+ pd->session = l2tp_session_find_nth(pd->tunnel, pd->session_idx, 0);
pd->session_idx++;
if (pd->session == NULL) {
@@ -194,13 +194,13 @@ static int l2tp_eth_create(struct net *net, u32 tunnel_id, u32 session_id, u32 p
int rc;
struct l2tp_eth_net *pn;
- tunnel = l2tp_tunnel_find(net, tunnel_id);
+ tunnel = l2tp_tunnel_find(net, tunnel_id, 0);
if (!tunnel) {
rc = -ENODEV;
goto out;
}
- session = l2tp_session_find(net, tunnel, session_id);
+ session = l2tp_session_find(net, tunnel, session_id, 0);
if (session) {
rc = -EEXIST;
goto out;
@@ -126,7 +126,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
u32 session_id;
u32 tunnel_id;
unsigned char *ptr, *optr;
- struct l2tp_session *session;
+ struct l2tp_session *session = NULL;
struct l2tp_tunnel *tunnel = NULL;
int length;
int offset;
@@ -150,7 +150,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
}
/* Ok, this is a data packet. Lookup the session. */
- session = l2tp_session_find(&init_net, NULL, session_id);
+ session = l2tp_session_find(&init_net, NULL, session_id, 1);
if (session == NULL)
goto discard;
@@ -187,7 +187,7 @@ pass_up:
goto discard;
tunnel_id = ntohl(*(__be32 *) &skb->data[4]);
- tunnel = l2tp_tunnel_find(&init_net, tunnel_id);
+ tunnel = l2tp_tunnel_find(&init_net, tunnel_id, 0);
if (tunnel != NULL)
sk = tunnel->sock;
else {
@@ -214,6 +214,8 @@ discard_put:
sock_put(sk);
discard:
+ if (session)
+ l2tp_session_dec_refcount(session);
kfree_skb(skb);
return 0;
}
@@ -40,7 +40,7 @@ static struct genl_family l2tp_nl_family = {
/* Accessed under genl lock */
static const struct l2tp_nl_cmd_ops *l2tp_nl_cmd_ops[__L2TP_PWTYPE_MAX];
-static struct l2tp_session *l2tp_nl_session_find(struct genl_info *info)
+static struct l2tp_session *l2tp_nl_session_find(struct genl_info *info, int ref)
{
u32 tunnel_id;
u32 session_id;
@@ -56,9 +56,9 @@ static struct l2tp_session *l2tp_nl_session_find(struct genl_info *info)
(info->attrs[L2TP_ATTR_CONN_ID])) {
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]);
- tunnel = l2tp_tunnel_find(net, tunnel_id);
+ tunnel = l2tp_tunnel_find(net, tunnel_id, 0);
if (tunnel)
- session = l2tp_session_find(net, tunnel, session_id);
+ session = l2tp_session_find(net, tunnel, session_id, ref);
}
return session;
@@ -148,7 +148,7 @@ static int l2tp_nl_cmd_tunnel_create(struct sk_buff *skb, struct genl_info *info
if (info->attrs[L2TP_ATTR_DEBUG])
cfg.debug = nla_get_u32(info->attrs[L2TP_ATTR_DEBUG]);
- tunnel = l2tp_tunnel_find(net, tunnel_id);
+ tunnel = l2tp_tunnel_find(net, tunnel_id, 0);
if (tunnel != NULL) {
ret = -EEXIST;
goto out;
@@ -180,7 +180,7 @@ static int l2tp_nl_cmd_tunnel_delete(struct sk_buff *skb, struct genl_info *info
}
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
- tunnel = l2tp_tunnel_find(net, tunnel_id);
+ tunnel = l2tp_tunnel_find(net, tunnel_id, 0);
if (tunnel == NULL) {
ret = -ENODEV;
goto out;
@@ -205,7 +205,7 @@ static int l2tp_nl_cmd_tunnel_modify(struct sk_buff *skb, struct genl_info *info
}
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
- tunnel = l2tp_tunnel_find(net, tunnel_id);
+ tunnel = l2tp_tunnel_find(net, tunnel_id, 0);
if (tunnel == NULL) {
ret = -ENODEV;
goto out;
@@ -292,7 +292,7 @@ static int l2tp_nl_cmd_tunnel_get(struct sk_buff *skb, struct genl_info *info)
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
- tunnel = l2tp_tunnel_find(net, tunnel_id);
+ tunnel = l2tp_tunnel_find(net, tunnel_id, 0);
if (tunnel == NULL) {
ret = -ENODEV;
goto out;
@@ -359,7 +359,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
goto out;
}
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
- tunnel = l2tp_tunnel_find(net, tunnel_id);
+ tunnel = l2tp_tunnel_find(net, tunnel_id, 0);
if (!tunnel) {
ret = -ENODEV;
goto out;
@@ -370,7 +370,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
goto out;
}
session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]);
- session = l2tp_session_find(net, tunnel, session_id);
+ session = l2tp_session_find(net, tunnel, session_id, 0);
if (session) {
ret = -EEXIST;
goto out;
@@ -495,7 +495,7 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf
struct l2tp_session *session;
u16 pw_type;
- session = l2tp_nl_session_find(info);
+ session = l2tp_nl_session_find(info, 0);
if (session == NULL) {
ret = -ENODEV;
goto out;
@@ -515,7 +515,7 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf
int ret = 0;
struct l2tp_session *session;
- session = l2tp_nl_session_find(info);
+ session = l2tp_nl_session_find(info, 1);
if (session == NULL) {
ret = -ENODEV;
goto out;
@@ -615,7 +615,7 @@ static int l2tp_nl_cmd_session_get(struct sk_buff *skb, struct genl_info *info)
struct sk_buff *msg;
int ret;
- session = l2tp_nl_session_find(info);
+ session = l2tp_nl_session_find(info, 0);
if (session == NULL) {
ret = -ENODEV;
goto out;
@@ -656,7 +656,7 @@ static int l2tp_nl_cmd_session_dump(struct sk_buff *skb, struct netlink_callback
goto out;
}
- session = l2tp_session_find_nth(tunnel, si);
+ session = l2tp_session_find_nth(tunnel, si, 0);
if (session == NULL) {
ti++;
tunnel = NULL;