diff mbox series

[stable,4.4,v2,05/11] ip: use rb trees for IP frag queue.

Message ID 1548384524-174152-6-git-send-email-maowenan@huawei.com
State Not Applicable
Delegated to: David Miller
Headers show
Series fix FragmentSmack in stable branch (CVE-2018-5391) | expand

Commit Message

maowenan Jan. 25, 2019, 2:48 a.m. UTC
From: Peter Oskolkov <posk@google.com>

[ Upstream commit fa0f527358bd900ef92f925878ed6bfbd51305cc ]

Similar to TCP OOO RX queue, it makes sense to use rb trees to store
IP fragments, so that OOO fragments are inserted faster.

Tested:

- a follow-up patch contains a rather comprehensive ip defrag
  self-test (functional)
- ran neper `udp_stream -c -H <host> -F 100 -l 300 -T 20`:
    netstat --statistics
    Ip:
        282078937 total packets received
        0 forwarded
        0 incoming packets discarded
        946760 incoming packets delivered
        18743456 requests sent out
        101 fragments dropped after timeout
        282077129 reassemblies required
        944952 packets reassembled ok
        262734239 packet reassembles failed
   (The numbers/stats above are somewhat better re:
    reassemblies vs a kernel without this patchset. More
    comprehensive performance testing TBD).

Reported-by: Jann Horn <jannh@google.com>
Reported-by: Juha-Matti Tilli <juha-matti.tilli@iki.fi>
Suggested-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: Peter Oskolkov <posk@google.com>
Signed-off-by: Eric Dumazet <edumazet@google.com>
Cc: Florian Westphal <fw@strlen.de>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Mao Wenan <maowenan@huawei.com>
---
 include/linux/skbuff.h                  |   2 +-
 include/net/inet_frag.h                 |   3 +-
 net/ipv4/inet_fragment.c                |  16 ++-
 net/ipv4/ip_fragment.c                  | 190 ++++++++++++++++++--------------
 net/ipv6/netfilter/nf_conntrack_reasm.c |   1 +
 net/ipv6/reassembly.c                   |   1 +
 6 files changed, 121 insertions(+), 92 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/skbuff.h b/include/linux/skbuff.h
index 8bfefdd..5c73c79 100644
--- a/include/linux/skbuff.h
+++ b/include/linux/skbuff.h
@@ -556,7 +556,7 @@  struct sk_buff {
 				struct skb_mstamp skb_mstamp;
 			};
 		};
-		struct rb_node	rbnode; /* used in netem & tcp stack */
+		struct rb_node	rbnode; /* used in netem, ip4 defrag, and tcp stack */
 	};
 	struct sock		*sk;
 	struct net_device	*dev;
diff --git a/include/net/inet_frag.h b/include/net/inet_frag.h
index 09472b8..861d24c 100644
--- a/include/net/inet_frag.h
+++ b/include/net/inet_frag.h
@@ -45,7 +45,8 @@  struct inet_frag_queue {
 	struct timer_list	timer;
 	struct hlist_node	list;
 	atomic_t		refcnt;
-	struct sk_buff		*fragments;
+	struct sk_buff		*fragments;  /* Used in IPv6. */
+	struct rb_root		rb_fragments; /* Used in IPv4. */
 	struct sk_buff		*fragments_tail;
 	ktime_t			stamp;
 	int			len;
diff --git a/net/ipv4/inet_fragment.c b/net/ipv4/inet_fragment.c
index b2001b2..2b3a926 100644
--- a/net/ipv4/inet_fragment.c
+++ b/net/ipv4/inet_fragment.c
@@ -306,12 +306,16 @@  void inet_frag_destroy(struct inet_frag_queue *q, struct inet_frags *f)
 	/* Release all fragment data. */
 	fp = q->fragments;
 	nf = q->net;
-	while (fp) {
-		struct sk_buff *xp = fp->next;
-
-		sum_truesize += fp->truesize;
-		frag_kfree_skb(nf, f, fp);
-		fp = xp;
+	if (fp) {
+		do {
+			struct sk_buff *xp = fp->next;
+
+			sum_truesize += fp->truesize;
+			kfree_skb(fp);
+			fp = xp;
+		} while (fp);
+	} else {
+		sum_truesize = skb_rbtree_purge(&q->rb_fragments);
 	}
 	sum = sum_truesize + f->qsize;
 
diff --git a/net/ipv4/ip_fragment.c b/net/ipv4/ip_fragment.c
index 264f382..e820eb9 100644
--- a/net/ipv4/ip_fragment.c
+++ b/net/ipv4/ip_fragment.c
@@ -194,7 +194,7 @@  static bool frag_expire_skip_icmp(u32 user)
  */
 static void ip_expire(unsigned long arg)
 {
-	struct sk_buff *clone, *head;
+	struct sk_buff *head = NULL;
 	const struct iphdr *iph;
 	struct net *net;
 	struct ipq *qp;
@@ -211,14 +211,31 @@  static void ip_expire(unsigned long arg)
 
 	ipq_kill(qp);
 	IP_INC_STATS_BH(net, IPSTATS_MIB_REASMFAILS);
-
-	head = qp->q.fragments;
-
 	IP_INC_STATS_BH(net, IPSTATS_MIB_REASMTIMEOUT);
 
-	if (!(qp->q.flags & INET_FRAG_FIRST_IN) || !head)
+	if (!qp->q.flags & INET_FRAG_FIRST_IN)
 		goto out;
 
+	/* sk_buff::dev and sk_buff::rbnode are unionized. So we
+	 * pull the head out of the tree in order to be able to
+	 * deal with head->dev.
+	 */
+	if (qp->q.fragments) {
+		head = qp->q.fragments;
+		qp->q.fragments = head->next;
+	} else {
+		head = skb_rb_first(&qp->q.rb_fragments);
+		if (!head)
+			goto out;
+		rb_erase(&head->rbnode, &qp->q.rb_fragments);
+		memset(&head->rbnode, 0, sizeof(head->rbnode));
+		barrier();
+	}
+	if (head == qp->q.fragments_tail)
+		qp->q.fragments_tail = NULL;
+
+	sub_frag_mem_limit(qp->q.net, head->truesize);
+
 	head->dev = dev_get_by_index_rcu(net, qp->iif);
 	if (!head->dev)
 		goto out;
@@ -237,20 +254,17 @@  static void ip_expire(unsigned long arg)
 	    (skb_rtable(head)->rt_type != RTN_LOCAL))
 		goto out;
 
-	clone = skb_clone(head, GFP_ATOMIC);
-
 	/* Send an ICMP "Fragment Reassembly Timeout" message. */
-	if (clone) {
-		spin_unlock(&qp->q.lock);
-		icmp_send(clone, ICMP_TIME_EXCEEDED,
-			  ICMP_EXC_FRAGTIME, 0);
-		consume_skb(clone);
-		goto out_rcu_unlock;
-	}
+	spin_unlock(&qp->q.lock);
+	icmp_send(head, ICMP_TIME_EXCEEDED, ICMP_EXC_FRAGTIME, 0);
+	goto out_rcu_unlock;
+
 out:
 	spin_unlock(&qp->q.lock);
 out_rcu_unlock:
 	rcu_read_unlock();
+	if (head)
+		kfree_skb(head);
 	ipq_put(qp);
 }
 
@@ -294,7 +308,7 @@  static int ip_frag_too_far(struct ipq *qp)
 	end = atomic_inc_return(&peer->rid);
 	qp->rid = end;
 
-	rc = qp->q.fragments && (end - start) > max;
+	rc = qp->q.fragments_tail && (end - start) > max;
 
 	if (rc) {
 		struct net *net;
@@ -308,7 +322,6 @@  static int ip_frag_too_far(struct ipq *qp)
 
 static int ip_frag_reinit(struct ipq *qp)
 {
-	struct sk_buff *fp;
 	unsigned int sum_truesize = 0;
 
 	if (!mod_timer(&qp->q.timer, jiffies + qp->q.net->timeout)) {
@@ -316,20 +329,14 @@  static int ip_frag_reinit(struct ipq *qp)
 		return -ETIMEDOUT;
 	}
 
-	fp = qp->q.fragments;
-	do {
-		struct sk_buff *xp = fp->next;
-
-		sum_truesize += fp->truesize;
-		kfree_skb(fp);
-		fp = xp;
-	} while (fp);
+	sum_truesize = skb_rbtree_purge(&qp->q.rb_fragments);
 	sub_frag_mem_limit(qp->q.net, sum_truesize);
 
 	qp->q.flags = 0;
 	qp->q.len = 0;
 	qp->q.meat = 0;
 	qp->q.fragments = NULL;
+	qp->q.rb_fragments = RB_ROOT;
 	qp->q.fragments_tail = NULL;
 	qp->iif = 0;
 	qp->ecn = 0;
@@ -341,7 +348,8 @@  static int ip_frag_reinit(struct ipq *qp)
 static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb)
 {
 	struct net *net = container_of(qp->q.net, struct net, ipv4.frags);
-	struct sk_buff *prev, *next;
+	struct rb_node **rbn, *parent;
+	struct sk_buff *skb1;
 	struct net_device *dev;
 	unsigned int fragsize;
 	int flags, offset;
@@ -404,56 +412,60 @@  static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb)
 	if (err)
 		goto err;
 
-	/* Find out which fragments are in front and at the back of us
-	 * in the chain of fragments so far.  We must know where to put
-	 * this fragment, right?
-	 */
-	prev = qp->q.fragments_tail;
-	if (!prev || FRAG_CB(prev)->offset < offset) {
-		next = NULL;
-		goto found;
-	}
-	prev = NULL;
-	for (next = qp->q.fragments; next != NULL; next = next->next) {
-		if (FRAG_CB(next)->offset >= offset)
-			break;	/* bingo! */
-		prev = next;
-	}
+	/* Note : skb->rbnode and skb->dev share the same location. */
+	dev = skb->dev;
+	/* Makes sure compiler wont do silly aliasing games */
+	barrier();
 
-found:
 	/* RFC5722, Section 4, amended by Errata ID : 3089
 	 *                          When reassembling an IPv6 datagram, if
 	 *   one or more its constituent fragments is determined to be an
 	 *   overlapping fragment, the entire datagram (and any constituent
 	 *   fragments) MUST be silently discarded.
 	 *
-	 * We do the same here for IPv4.
+	 * We do the same here for IPv4 (and increment an snmp counter).
 	 */
-	/* Is there an overlap with the previous fragment? */
-	if (prev &&
-	    (FRAG_CB(prev)->offset + prev->len) > offset)
-		goto discard_qp;
-
-	/* Is there an overlap with the next fragment? */
-	if (next && FRAG_CB(next)->offset < end)
-		goto discard_qp;
-
-	FRAG_CB(skb)->offset = offset;
 
-	/* Insert this fragment in the chain of fragments. */
-	skb->next = next;
-	if (!next)
+	/* Find out where to put this fragment.  */
+	skb1 = qp->q.fragments_tail;
+	if (!skb1) {
+		/* This is the first fragment we've received. */
+		rb_link_node(&skb->rbnode, NULL, &qp->q.rb_fragments.rb_node);
 		qp->q.fragments_tail = skb;
-	if (prev)
-		prev->next = skb;
-	else
-		qp->q.fragments = skb;
+	} else if ((FRAG_CB(skb1)->offset + skb1->len) < end) {
+		/* This is the common/special case: skb goes to the end. */
+		/* Detect and discard overlaps. */
+		if (offset < (FRAG_CB(skb1)->offset + skb1->len))
+			goto discard_qp;
+		/* Insert after skb1. */
+		rb_link_node(&skb->rbnode, &skb1->rbnode, &skb1->rbnode.rb_right);
+		qp->q.fragments_tail = skb;
+	} else {
+		/* Binary search. Note that skb can become the first fragment, but
+		 * not the last (covered above). */
+		rbn = &qp->q.rb_fragments.rb_node;
+		do {
+			parent = *rbn;
+			skb1 = rb_to_skb(parent);
+			if (end <= FRAG_CB(skb1)->offset)
+				rbn = &parent->rb_left;
+			else if (offset >= FRAG_CB(skb1)->offset + skb1->len)
+				rbn = &parent->rb_right;
+			else /* Found an overlap with skb1. */
+				goto discard_qp;
+		} while (*rbn);
+		/* Here we have parent properly set, and rbn pointing to
+		 * one of its NULL left/right children. Insert skb. */
+		rb_link_node(&skb->rbnode, parent, rbn);
+	}
+	rb_insert_color(&skb->rbnode, &qp->q.rb_fragments);
 
-	dev = skb->dev;
 	if (dev) {
 		qp->iif = dev->ifindex;
 		skb->dev = NULL;
 	}
+	FRAG_CB(skb)->offset = offset;
+
 	qp->q.stamp = skb->tstamp;
 	qp->q.meat += skb->len;
 	qp->ecn |= ecn;
@@ -475,7 +487,7 @@  found:
 		unsigned long orefdst = skb->_skb_refdst;
 
 		skb->_skb_refdst = 0UL;
-		err = ip_frag_reasm(qp, prev, dev);
+		err = ip_frag_reasm(qp, skb, dev);
 		skb->_skb_refdst = orefdst;
 		return err;
 	}
@@ -492,15 +504,15 @@  err:
 	return err;
 }
 
-
 /* Build a new IP datagram from all its fragments. */
-
-static int ip_frag_reasm(struct ipq *qp, struct sk_buff *prev,
+static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb,
 			 struct net_device *dev)
 {
 	struct net *net = container_of(qp->q.net, struct net, ipv4.frags);
 	struct iphdr *iph;
-	struct sk_buff *fp, *head = qp->q.fragments;
+	struct sk_buff *fp, *head = skb_rb_first(&qp->q.rb_fragments);
+	struct sk_buff **nextp; /* To build frag_list. */
+	struct rb_node *rbn;
 	int len;
 	int ihlen;
 	int err;
@@ -514,25 +526,21 @@  static int ip_frag_reasm(struct ipq *qp, struct sk_buff *prev,
 		goto out_fail;
 	}
 	/* Make the one we just received the head. */
-	if (prev) {
-		head = prev->next;
-		fp = skb_clone(head, GFP_ATOMIC);
+	if (head != skb) {
+		fp = skb_clone(skb, GFP_ATOMIC);
 		if (!fp)
 			goto out_nomem;
 
-		fp->next = head->next;
-		if (!fp->next)
+		rb_replace_node(&skb->rbnode, &fp->rbnode, &qp->q.rb_fragments);
+		if (qp->q.fragments_tail == skb)
 			qp->q.fragments_tail = fp;
-		prev->next = fp;
-
-		skb_morph(head, qp->q.fragments);
-		head->next = qp->q.fragments->next;
-
-		consume_skb(qp->q.fragments);
-		qp->q.fragments = head;
+		skb_morph(skb, head);
+		rb_replace_node(&head->rbnode, &skb->rbnode,
+				&qp->q.rb_fragments);
+		consume_skb(head);
+		head = skb;
 	}
 
-	WARN_ON(!head);
 	WARN_ON(FRAG_CB(head)->offset != 0);
 
 	/* Allocate a new buffer for the datagram. */
@@ -557,24 +565,35 @@  static int ip_frag_reasm(struct ipq *qp, struct sk_buff *prev,
 		clone = alloc_skb(0, GFP_ATOMIC);
 		if (!clone)
 			goto out_nomem;
-		clone->next = head->next;
-		head->next = clone;
 		skb_shinfo(clone)->frag_list = skb_shinfo(head)->frag_list;
 		skb_frag_list_init(head);
 		for (i = 0; i < skb_shinfo(head)->nr_frags; i++)
 			plen += skb_frag_size(&skb_shinfo(head)->frags[i]);
 		clone->len = clone->data_len = head->data_len - plen;
-		head->data_len -= clone->len;
-		head->len -= clone->len;
+		skb->truesize += clone->truesize;
 		clone->csum = 0;
 		clone->ip_summed = head->ip_summed;
 		add_frag_mem_limit(qp->q.net, clone->truesize);
+		skb_shinfo(head)->frag_list = clone;
+		nextp = &clone->next;
+	} else {
+		nextp = &skb_shinfo(head)->frag_list;
 	}
 
-	skb_shinfo(head)->frag_list = head->next;
 	skb_push(head, head->data - skb_network_header(head));
 
-	for (fp=head->next; fp; fp = fp->next) {
+	/* Traverse the tree in order, to build frag_list. */
+	rbn = rb_next(&head->rbnode);
+	rb_erase(&head->rbnode, &qp->q.rb_fragments);
+	while (rbn) {
+		struct rb_node *rbnext = rb_next(rbn);
+		fp = rb_to_skb(rbn);
+		rb_erase(rbn, &qp->q.rb_fragments);
+		rbn = rbnext;
+		*nextp = fp;
+		nextp = &fp->next;
+		fp->prev = NULL;
+		memset(&fp->rbnode, 0, sizeof(fp->rbnode));
 		head->data_len += fp->len;
 		head->len += fp->len;
 		if (head->ip_summed != fp->ip_summed)
@@ -585,7 +604,9 @@  static int ip_frag_reasm(struct ipq *qp, struct sk_buff *prev,
 	}
 	sub_frag_mem_limit(qp->q.net, head->truesize);
 
+	*nextp = NULL;
 	head->next = NULL;
+	head->prev = NULL;
 	head->dev = dev;
 	head->tstamp = qp->q.stamp;
 	IPCB(head)->frag_max_size = max(qp->max_df_size, qp->q.max_size);
@@ -613,6 +634,7 @@  static int ip_frag_reasm(struct ipq *qp, struct sk_buff *prev,
 
 	IP_INC_STATS_BH(net, IPSTATS_MIB_REASMOKS);
 	qp->q.fragments = NULL;
+	qp->q.rb_fragments = RB_ROOT;
 	qp->q.fragments_tail = NULL;
 	return 0;
 
diff --git a/net/ipv6/netfilter/nf_conntrack_reasm.c b/net/ipv6/netfilter/nf_conntrack_reasm.c
index 5a9ae56..9cd8863 100644
--- a/net/ipv6/netfilter/nf_conntrack_reasm.c
+++ b/net/ipv6/netfilter/nf_conntrack_reasm.c
@@ -472,6 +472,7 @@  nf_ct_frag6_reasm(struct frag_queue *fq, struct net_device *dev)
 					  head->csum);
 
 	fq->q.fragments = NULL;
+	fq->q.rb_fragments = RB_ROOT;
 	fq->q.fragments_tail = NULL;
 
 	/* all original skbs are linked into the NFCT_FRAG6_CB(head).orig */
diff --git a/net/ipv6/reassembly.c b/net/ipv6/reassembly.c
index ee4789b..adc7512 100644
--- a/net/ipv6/reassembly.c
+++ b/net/ipv6/reassembly.c
@@ -499,6 +499,7 @@  static int ip6_frag_reasm(struct frag_queue *fq, struct sk_buff *prev,
 	IP6_INC_STATS_BH(net, __in6_dev_get(dev), IPSTATS_MIB_REASMOKS);
 	rcu_read_unlock();
 	fq->q.fragments = NULL;
+	fq->q.rb_fragments = RB_ROOT;
 	fq->q.fragments_tail = NULL;
 	return 1;