diff mbox series

[v3,mptcp-next] mptcp: retransmit ADD_ADDR when timeout

Message ID 25ffeeb7631ebe2285b2449b076d2a424523c721.1600258235.git.geliangtang@gmail.com
State Superseded, archived
Headers show
Series [v3,mptcp-next] mptcp: retransmit ADD_ADDR when timeout | expand

Commit Message

Geliang Tang Sept. 16, 2020, 12:17 p.m. UTC
This patch implemented the retransmition of ADD_ADDR when no ADD_ADDR echo
is received. It added a timer with the announced address. When timeout
occurs, ADD_ADDR will be retransmitted.

Suggested-by: Mat Martineau <mathew.j.martineau@linux.intel.com>
Suggested-by: Paolo Abeni <pabeni@redhat.com>
Signed-off-by: Geliang Tang <geliangtang@gmail.com>
---
v3:
 - use timer.
 - drop bh_lock_sock.
 - move __sock_put at the end of mptcp_pm_add_timer.

v2:
 - Use delayed_work instead of timer.
 - This patch depends on my another patch named 'Squash-to: "mptcp: remove
   addr and subflow in PM netlink"'.
---
 net/mptcp/options.c    |  1 +
 net/mptcp/pm_netlink.c | 71 +++++++++++++++++++++++++++++++++++++-----
 net/mptcp/protocol.h   |  1 +
 3 files changed, 66 insertions(+), 7 deletions(-)

Comments

Paolo Abeni Sept. 16, 2020, 2:03 p.m. UTC | #1
On Wed, 2020-09-16 at 20:17 +0800, Geliang Tang wrote:
[...] 
> +static void mptcp_pm_add_timer(struct timer_list *timer)
> +{
> +	struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer);
> +	struct mptcp_sock *msk = entry->sock;
> +	struct sock *sk = (struct sock *)msk;
> +
> +	pr_debug("msk=%p\n", msk);
> +
> +	if (!msk)
> +		return;
> +
> +	spin_lock_bh(&msk->pm.lock);
> +
> +	if (entry->addr.id > 0) {
> +		pr_debug("retransmit ADD_ADDR id=%d\n", entry->addr.id);
> +		mptcp_pm_announce_addr(msk, &entry->addr, false);
> +		entry->retrans_times++;
> +	}
> +
> +	if (entry->retrans_times < ADD_ADDR_RETRANS_MAX)
> +		sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX);
> +
> +	spin_unlock_bh(&msk->pm.lock);
> +
> +	__sock_put(sk);
> +}
> +
> +void mptcp_pm_del_add_timer(struct mptcp_sock *msk, struct mptcp_addr_info *addr)
> +{
> +	struct mptcp_pm_add_entry *entry;
> +
> +	spin_lock_bh(&msk->pm.lock);
> +	list_for_each_entry(entry, &msk->pm.anno_list, list) {
> +		if (addresses_equal(&entry->addr, addr, false))
> +			sk_stop_timer((struct sock *)msk, &entry->add_timer);
> +	}
> +	spin_unlock_bh(&msk->pm.lock);

I fear this is still racy:

when sk_stop_timer() is called, mptcp_pm_add_timer() could be running,
and waiting for the pm spinlock.

After the above loop is completed, mptcp_pm_add_timer() may re-schedule 
itself.

> +}
> +
>  static bool lookup_anno_list_by_saddr(struct mptcp_sock *msk,
>  				      struct mptcp_addr_info *addr)
>  {
> -	struct mptcp_pm_addr_entry *entry;
> +	struct mptcp_pm_add_entry *entry;
>  
>  	list_for_each_entry(entry, &msk->pm.anno_list, list) {
>  		if (addresses_equal(&entry->addr, addr, false))
> @@ -194,28 +242,36 @@ static bool lookup_anno_list_by_saddr(struct mptcp_sock *msk,
>  static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
>  				     struct mptcp_pm_addr_entry *entry)
>  {
> -	struct mptcp_pm_addr_entry *clone = NULL;
> +	struct mptcp_pm_add_entry *add_entry = NULL;
>  
>  	if (lookup_anno_list_by_saddr(msk, &entry->addr))
>  		return false;
>  
> -	clone = kmemdup(entry, sizeof(*entry), GFP_ATOMIC);
> -	if (!clone)
> +	add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC);
> +	if (!add_entry)
>  		return false;
>  
> -	list_add(&clone->list, &msk->pm.anno_list);
> +	list_add(&add_entry->list, &msk->pm.anno_list);
> +
> +	add_entry->addr = entry->addr;
> +	add_entry->sock = msk;
> +	add_entry->retrans_times = 0;
> +
> +	timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0);
> +	sk_reset_timer((struct sock *)msk, &add_entry->add_timer, jiffies + TCP_RTO_MAX);
>  
>  	return true;
>  }
>  
>  void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
>  {
> -	struct mptcp_pm_addr_entry *entry, *tmp;
> +	struct mptcp_pm_add_entry *entry, *tmp;
>  
>  	pr_debug("msk=%p\n", msk);
>  
>  	spin_lock_bh(&msk->pm.lock);
>  	list_for_each_entry_safe(entry, tmp, &msk->pm.anno_list, list) {
> +		sk_stop_timer((struct sock *)msk, &entry->add_timer);
>  		list_del(&entry->list);
>  		kfree(entry);

Like the above, but additionally mptcp_pm_add_timer() can hit use-
after-free on the anno list entry, if kfree is executed
before mptcp_pm_add_timer() completes.

Sorry for not noticing the above before.

Note that TCP timers do not have the above issues, because the timers
structs are embedded into the sock, and the timer keep a reference to
the sock itself.

I think the a solution could be somethink alike:

---
	// in mptcp_pm_free_anno_list()
	LIST_HEAD(to_free);

	spin_lock_bh(&msk->pm.lock);
  	list_for_each_entry_safe(entry, tmp, &msk->pm.anno_list, list) {
		/* avoid timer to rescheuling 
		 * as an alternative, here we could simply splicing
		 * 'anno_list' into 'to_free'
		 * and check msk sk_state in mptcp_pm_add_timer()
		 * avoiding re-scheduling if  TCP_CLOSE
		 */
		entry->retrans_times = ADD_ADDR_RETRANS_MAX;
		list_del(&entry->list);
		list_add(&entry->list, &to_free);
	}
	spin_unlock_bh(&msk->pm.lock);
	
  	list_for_each_entry_safe(entry, tmp, &to_free, list) {
		if (del_timer_sync(&entry->add_timer))
			__sock_put(msk);
		kfree(entry);
	}

void mptcp_pm_del_add_timer(struct mptcp_sock *msk, struct mptcp_addr_info *addr)
{
	struct mptcp_pm_add_entry *entry;

	spin_lock_bh(&msk->pm.lock);
	entry = lookup_anno_list_by_saddr(msk, addr);
	if (entry)
		entry->retrans_times = ADD_ADDR_RETRANS_MAX;
	spin_unlock_bh(&msk->pm.lock);
	
	if (entry && del_timer_sync(&entry->add_timer))
		__sock_put(msk);
}
---
beware: completely untested!!! Likely some bugs present && something
simpler is possible.

Anyway this looks less complex than the work-based solution (IMHO).
Other opinons more than welcome!

Cheers,

Paolo
diff mbox series

Patch

diff --git a/net/mptcp/options.c b/net/mptcp/options.c
index 171039cbe9c4..14a290fae767 100644
--- a/net/mptcp/options.c
+++ b/net/mptcp/options.c
@@ -893,6 +893,7 @@  void mptcp_incoming_options(struct sock *sk, struct sk_buff *skb,
 			mptcp_pm_add_addr_received(msk, &addr);
 			MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_ADDADDR);
 		} else {
+			mptcp_pm_del_add_timer(msk, &addr);
 			MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_ECHOADD);
 		}
 		mp_opt.add_addr = 0;
diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c
index 8780d618a05a..b06070d81b6c 100644
--- a/net/mptcp/pm_netlink.c
+++ b/net/mptcp/pm_netlink.c
@@ -28,6 +28,14 @@  struct mptcp_pm_addr_entry {
 	struct rcu_head		rcu;
 };
 
+struct mptcp_pm_add_entry {
+	struct list_head	list;
+	struct mptcp_addr_info	addr;
+	struct timer_list	add_timer;
+	struct mptcp_sock	*sock;
+	u8			retrans_times;
+};
+
 struct pm_nl_pernet {
 	/* protects pernet updates */
 	spinlock_t		lock;
@@ -41,6 +49,7 @@  struct pm_nl_pernet {
 };
 
 #define MPTCP_PM_ADDR_MAX	8
+#define ADD_ADDR_RETRANS_MAX	3
 
 static bool addresses_equal(const struct mptcp_addr_info *a,
 			    struct mptcp_addr_info *b, bool use_port)
@@ -178,10 +187,49 @@  static void check_work_pending(struct mptcp_sock *msk)
 		WRITE_ONCE(msk->pm.work_pending, false);
 }
 
+static void mptcp_pm_add_timer(struct timer_list *timer)
+{
+	struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer);
+	struct mptcp_sock *msk = entry->sock;
+	struct sock *sk = (struct sock *)msk;
+
+	pr_debug("msk=%p\n", msk);
+
+	if (!msk)
+		return;
+
+	spin_lock_bh(&msk->pm.lock);
+
+	if (entry->addr.id > 0) {
+		pr_debug("retransmit ADD_ADDR id=%d\n", entry->addr.id);
+		mptcp_pm_announce_addr(msk, &entry->addr, false);
+		entry->retrans_times++;
+	}
+
+	if (entry->retrans_times < ADD_ADDR_RETRANS_MAX)
+		sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX);
+
+	spin_unlock_bh(&msk->pm.lock);
+
+	__sock_put(sk);
+}
+
+void mptcp_pm_del_add_timer(struct mptcp_sock *msk, struct mptcp_addr_info *addr)
+{
+	struct mptcp_pm_add_entry *entry;
+
+	spin_lock_bh(&msk->pm.lock);
+	list_for_each_entry(entry, &msk->pm.anno_list, list) {
+		if (addresses_equal(&entry->addr, addr, false))
+			sk_stop_timer((struct sock *)msk, &entry->add_timer);
+	}
+	spin_unlock_bh(&msk->pm.lock);
+}
+
 static bool lookup_anno_list_by_saddr(struct mptcp_sock *msk,
 				      struct mptcp_addr_info *addr)
 {
-	struct mptcp_pm_addr_entry *entry;
+	struct mptcp_pm_add_entry *entry;
 
 	list_for_each_entry(entry, &msk->pm.anno_list, list) {
 		if (addresses_equal(&entry->addr, addr, false))
@@ -194,28 +242,36 @@  static bool lookup_anno_list_by_saddr(struct mptcp_sock *msk,
 static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
 				     struct mptcp_pm_addr_entry *entry)
 {
-	struct mptcp_pm_addr_entry *clone = NULL;
+	struct mptcp_pm_add_entry *add_entry = NULL;
 
 	if (lookup_anno_list_by_saddr(msk, &entry->addr))
 		return false;
 
-	clone = kmemdup(entry, sizeof(*entry), GFP_ATOMIC);
-	if (!clone)
+	add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC);
+	if (!add_entry)
 		return false;
 
-	list_add(&clone->list, &msk->pm.anno_list);
+	list_add(&add_entry->list, &msk->pm.anno_list);
+
+	add_entry->addr = entry->addr;
+	add_entry->sock = msk;
+	add_entry->retrans_times = 0;
+
+	timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0);
+	sk_reset_timer((struct sock *)msk, &add_entry->add_timer, jiffies + TCP_RTO_MAX);
 
 	return true;
 }
 
 void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
 {
-	struct mptcp_pm_addr_entry *entry, *tmp;
+	struct mptcp_pm_add_entry *entry, *tmp;
 
 	pr_debug("msk=%p\n", msk);
 
 	spin_lock_bh(&msk->pm.lock);
 	list_for_each_entry_safe(entry, tmp, &msk->pm.anno_list, list) {
+		sk_stop_timer((struct sock *)msk, &entry->add_timer);
 		list_del(&entry->list);
 		kfree(entry);
 	}
@@ -654,10 +710,11 @@  __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
 				      struct mptcp_addr_info *addr)
 {
-	struct mptcp_pm_addr_entry *entry, *tmp;
+	struct mptcp_pm_add_entry *entry, *tmp;
 
 	list_for_each_entry_safe(entry, tmp, &msk->pm.anno_list, list) {
 		if (addresses_equal(&entry->addr, addr, false)) {
+			sk_stop_timer((struct sock *)msk, &entry->add_timer);
 			list_del(&entry->list);
 			kfree(entry);
 			return true;
diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h
index db1e5de2fee7..031ae106746d 100644
--- a/net/mptcp/protocol.h
+++ b/net/mptcp/protocol.h
@@ -444,6 +444,7 @@  void mptcp_pm_add_addr_received(struct mptcp_sock *msk,
 				const struct mptcp_addr_info *addr);
 void mptcp_pm_rm_addr_received(struct mptcp_sock *msk, u8 rm_id);
 void mptcp_pm_free_anno_list(struct mptcp_sock *msk);
+void mptcp_pm_del_add_timer(struct mptcp_sock *msk, struct mptcp_addr_info *addr);
 
 int mptcp_pm_announce_addr(struct mptcp_sock *msk,
 			   const struct mptcp_addr_info *addr,