diff mbox

[10/14] netfilter: ipset: Introduce RCU locking in the list type

Message ID 1417373825-3734-11-git-send-email-kadlec@blackhole.kfki.hu
State Changes Requested
Delegated to: Pablo Neira
Headers show

Commit Message

Jozsef Kadlecsik Nov. 30, 2014, 6:57 p.m. UTC
Signed-off-by: Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
---
 net/netfilter/ipset/ip_set_list_set.c | 386 ++++++++++++++++------------------
 1 file changed, 182 insertions(+), 204 deletions(-)

Comments

Pablo Neira Ayuso Dec. 2, 2014, 6:35 p.m. UTC | #1
On Sun, Nov 30, 2014 at 07:57:01PM +0100, Jozsef Kadlecsik wrote:
> Signed-off-by: Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
> ---
>  net/netfilter/ipset/ip_set_list_set.c | 386 ++++++++++++++++------------------
>  1 file changed, 182 insertions(+), 204 deletions(-)
> 
> diff --git a/net/netfilter/ipset/ip_set_list_set.c b/net/netfilter/ipset/ip_set_list_set.c
> index f8f6828..323115a 100644
> --- a/net/netfilter/ipset/ip_set_list_set.c
> +++ b/net/netfilter/ipset/ip_set_list_set.c
> @@ -9,6 +9,7 @@
>  
>  #include <linux/module.h>
>  #include <linux/ip.h>
> +#include <linux/rculist.h>
>  #include <linux/skbuff.h>
>  #include <linux/errno.h>
>  
> @@ -27,6 +28,8 @@ MODULE_ALIAS("ip_set_list:set");
>  
>  /* Member elements  */
>  struct set_elem {
> +	struct rcu_head rcu;
> +	struct list_head list;

I think rcu_barrier() in the module removal path is missing to make
sure call_rcu() is called before the module is gone.

>  	ip_set_id_t id;
>  };
>  
> @@ -41,12 +44,9 @@ struct list_set {
>  	u32 size;		/* size of set list array */
>  	struct timer_list gc;	/* garbage collection */
>  	struct net *net;	/* namespace */
> -	struct set_elem members[0]; /* the set members */
> +	struct list_head members; /* the set members */
>  };
>  
> -#define list_set_elem(set, map, id)	\
> -	(struct set_elem *)((void *)(map)->members + (id) * (set)->dsize)
> -
>  static int
>  list_set_ktest(struct ip_set *set, const struct sk_buff *skb,
>  	       const struct xt_action_param *par,
> @@ -54,17 +54,14 @@ list_set_ktest(struct ip_set *set, const struct sk_buff *skb,
>  {
>  	struct list_set *map = set->data;
>  	struct set_elem *e;
> -	u32 i, cmdflags = opt->cmdflags;
> +	u32 cmdflags = opt->cmdflags;
>  	int ret;
>  
>  	/* Don't lookup sub-counters at all */
>  	opt->cmdflags &= ~IPSET_FLAG_MATCH_COUNTERS;
>  	if (opt->cmdflags & IPSET_FLAG_SKIP_SUBCOUNTER_UPDATE)
>  		opt->cmdflags &= ~IPSET_FLAG_SKIP_COUNTER_UPDATE;
> -	for (i = 0; i < map->size; i++) {
> -		e = list_set_elem(set, map, i);
> -		if (e->id == IPSET_INVALID_ID)
> -			return 0;
> +	list_for_each_entry_rcu(e, &map->members, list) {
>  		if (SET_WITH_TIMEOUT(set) &&
>  		    ip_set_timeout_expired(ext_timeout(e, set)))
>  			continue;
> @@ -91,13 +88,9 @@ list_set_kadd(struct ip_set *set, const struct sk_buff *skb,
>  {
>  	struct list_set *map = set->data;
>  	struct set_elem *e;
> -	u32 i;
>  	int ret;
>  
> -	for (i = 0; i < map->size; i++) {
> -		e = list_set_elem(set, map, i);
> -		if (e->id == IPSET_INVALID_ID)
> -			return 0;
> +	list_for_each_entry_rcu(e, &map->members, list) {

From net/netfilter/ipset/ip_set_core.c I can see this kadd() will be
called under spin_lock_bh(), so you can just use
list_for_each_entry(). The _rcu() variant protects the reader side,
but this code is only invoked from the writer side (no changes are
guaranteed to happen there).

>  		if (SET_WITH_TIMEOUT(set) &&
>  		    ip_set_timeout_expired(ext_timeout(e, set)))
>  			continue;
> @@ -115,13 +108,9 @@ list_set_kdel(struct ip_set *set, const struct sk_buff *skb,
>  {
>  	struct list_set *map = set->data;
>  	struct set_elem *e;
> -	u32 i;
>  	int ret;
>  
> -	for (i = 0; i < map->size; i++) {
> -		e = list_set_elem(set, map, i);
> -		if (e->id == IPSET_INVALID_ID)
> -			return 0;
> +	list_for_each_entry_rcu(e, &map->members, list) {
>  		if (SET_WITH_TIMEOUT(set) &&
>  		    ip_set_timeout_expired(ext_timeout(e, set)))
>  			continue;
> @@ -138,110 +127,65 @@ list_set_kadt(struct ip_set *set, const struct sk_buff *skb,
>  	      enum ipset_adt adt, struct ip_set_adt_opt *opt)
>  {
>  	struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set);
> +	int ret = -EINVAL;
>  
> +	rcu_read_lock();
>  	switch (adt) {
>  	case IPSET_TEST:
> -		return list_set_ktest(set, skb, par, opt, &ext);
> +		ret = list_set_ktest(set, skb, par, opt, &ext);
> +		break;
>  	case IPSET_ADD:
> -		return list_set_kadd(set, skb, par, opt, &ext);
> +		ret = list_set_kadd(set, skb, par, opt, &ext);
> +		break;
>  	case IPSET_DEL:
> -		return list_set_kdel(set, skb, par, opt, &ext);
> +		ret = list_set_kdel(set, skb, par, opt, &ext);
> +		break;
>  	default:
>  		break;
>  	}
> -	return -EINVAL;
> -}
> +	rcu_read_unlock();
>  
> -static bool
> -id_eq(const struct ip_set *set, u32 i, ip_set_id_t id)
> -{
> -	const struct list_set *map = set->data;
> -	const struct set_elem *e;
> -
> -	if (i >= map->size)
> -		return 0;
> -
> -	e = list_set_elem(set, map, i);
> -	return !!(e->id == id &&
> -		 !(SET_WITH_TIMEOUT(set) &&
> -		   ip_set_timeout_expired(ext_timeout(e, set))));
> +	return ret;
>  }
>  
> -static int
> -list_set_add(struct ip_set *set, u32 i, struct set_adt_elem *d,
> -	     const struct ip_set_ext *ext)
> -{
> -	struct list_set *map = set->data;
> -	struct set_elem *e = list_set_elem(set, map, i);
> -
> -	if (e->id != IPSET_INVALID_ID) {
> -		if (i == map->size - 1) {
> -			/* Last element replaced: e.g. add new,before,last */
> -			ip_set_put_byindex(map->net, e->id);
> -			ip_set_ext_destroy(set, e);
> -		} else {
> -			struct set_elem *x = list_set_elem(set, map,
> -							   map->size - 1);
> -
> -			/* Last element pushed off */
> -			if (x->id != IPSET_INVALID_ID) {
> -				ip_set_put_byindex(map->net, x->id);
> -				ip_set_ext_destroy(set, x);
> -			}
> -			memmove(list_set_elem(set, map, i + 1), e,
> -				set->dsize * (map->size - (i + 1)));
> -			/* Extensions must be initialized to zero */
> -			memset(e, 0, set->dsize);
> -		}
> -	}
> -
> -	e->id = d->id;
> -	if (SET_WITH_TIMEOUT(set))
> -		ip_set_timeout_set(ext_timeout(e, set), ext->timeout);
> -	if (SET_WITH_COUNTER(set))
> -		ip_set_init_counter(ext_counter(e, set), ext);
> -	if (SET_WITH_COMMENT(set))
> -		ip_set_init_comment(ext_comment(e, set), ext);
> -	if (SET_WITH_SKBINFO(set))
> -		ip_set_init_skbinfo(ext_skbinfo(e, set), ext);
> -	return 0;
> -}
> +/* Userspace interfaces: we are protected by the nfnl mutex */
>  
> -static int
> -list_set_del(struct ip_set *set, u32 i)
> +static void
> +__list_set_del(struct ip_set *set, struct set_elem *e)
>  {
>  	struct list_set *map = set->data;
> -	struct set_elem *e = list_set_elem(set, map, i);
>  
>  	ip_set_put_byindex(map->net, e->id);
> +	/* We may call it, because we don't have a to be destroyed
> +	 * extension which is used by the kernel.
> +	 */
>  	ip_set_ext_destroy(set, e);
> +	kfree_rcu(e, rcu);
> +}
>  
> -	if (i < map->size - 1)
> -		memmove(e, list_set_elem(set, map, i + 1),
> -			set->dsize * (map->size - (i + 1)));
> +static inline void
> +list_set_del(struct ip_set *set, struct set_elem *e)
> +{
> +	list_del_rcu(&e->list);
> +	__list_set_del(set, e);
> +}
>  
> -	/* Last element */
> -	e = list_set_elem(set, map, map->size - 1);
> -	e->id = IPSET_INVALID_ID;
> -	return 0;
> +static inline void
> +list_set_replace(struct ip_set *set, struct set_elem *e, struct set_elem *old)
> +{
> +	list_replace_rcu(&old->list, &e->list);
> +	__list_set_del(set, old);
>  }
>  
>  static void
>  set_cleanup_entries(struct ip_set *set)
>  {
>  	struct list_set *map = set->data;
> -	struct set_elem *e;
> -	u32 i = 0;
> +	struct set_elem *e, *n;
>  
> -	while (i < map->size) {
> -		e = list_set_elem(set, map, i);
> -		if (e->id != IPSET_INVALID_ID &&
> -		    ip_set_timeout_expired(ext_timeout(e, set)))
> -			list_set_del(set, i);
> -			/* Check element moved to position i in next loop */
> -		else
> -			i++;
> -	}
> +	list_for_each_entry_safe(e, n, &map->members, list)
> +		if (ip_set_timeout_expired(ext_timeout(e, set)))
> +			list_set_del(set, e);
>  }
>  
>  static int
> @@ -250,31 +194,45 @@ list_set_utest(struct ip_set *set, void *value, const struct ip_set_ext *ext,
>  {
>  	struct list_set *map = set->data;
>  	struct set_adt_elem *d = value;
> -	struct set_elem *e;
> -	u32 i;
> +	struct set_elem *e, *next, *prev = NULL;
>  	int ret;
>  
> -	for (i = 0; i < map->size; i++) {
> -		e = list_set_elem(set, map, i);
> -		if (e->id == IPSET_INVALID_ID)
> -			return 0;
> -		else if (SET_WITH_TIMEOUT(set) &&
> -			 ip_set_timeout_expired(ext_timeout(e, set)))
> +	list_for_each_entry(e, &map->members, list) {
> +		if (SET_WITH_TIMEOUT(set) &&
> +		    ip_set_timeout_expired(ext_timeout(e, set)))
>  			continue;
> -		else if (e->id != d->id)
> +		else if (e->id != d->id) {
> +			prev = e;
>  			continue;
> +		}
>  
>  		if (d->before == 0)
> -			return 1;
> -		else if (d->before > 0)
> -			ret = id_eq(set, i + 1, d->refid);
> -		else
> -			ret = i > 0 && id_eq(set, i - 1, d->refid);
> +			ret = 1;
> +		else if (d->before > 0) {
> +			next = list_next_entry(e, list);
> +			ret = !list_is_last(&e->list, &map->members) &&
> +			      next->id == d->refid;
> +		} else
> +			ret = prev != NULL && prev->id == d->refid;
>  		return ret;
>  	}
>  	return 0;
>  }
>  
> +static void
> +list_set_init_extensions(struct ip_set *set, const struct ip_set_ext *ext,
> +			 struct set_elem *e)
> +{
> +	if (SET_WITH_COUNTER(set))
> +		ip_set_init_counter(ext_counter(e, set), ext);
> +	if (SET_WITH_COMMENT(set))
> +		ip_set_init_comment(ext_comment(e, set), ext);
> +	if (SET_WITH_SKBINFO(set))
> +		ip_set_init_skbinfo(ext_skbinfo(e, set), ext);
> +	/* Update timeout last */
> +	if (SET_WITH_TIMEOUT(set))
> +		ip_set_timeout_set(ext_timeout(e, set), ext->timeout);
> +}
>  
>  static int
>  list_set_uadd(struct ip_set *set, void *value, const struct ip_set_ext *ext,
> @@ -282,60 +240,82 @@ list_set_uadd(struct ip_set *set, void *value, const struct ip_set_ext *ext,
>  {
>  	struct list_set *map = set->data;
>  	struct set_adt_elem *d = value;
> -	struct set_elem *e;
> +	struct set_elem *e, *n, *prev, *next;
>  	bool flag_exist = flags & IPSET_FLAG_EXIST;
> -	u32 i, ret = 0;
>  
>  	if (SET_WITH_TIMEOUT(set))
>  		set_cleanup_entries(set);
>  
> -	/* Check already added element */
> -	for (i = 0; i < map->size; i++) {
> -		e = list_set_elem(set, map, i);
> -		if (e->id == IPSET_INVALID_ID)
> -			goto insert;
> -		else if (e->id != d->id)
> +	/* Find where to add the new entry */
> +	n = prev = next = NULL;
> +	list_for_each_entry(e, &map->members, list) {
> +		if (SET_WITH_TIMEOUT(set) &&
> +		    ip_set_timeout_expired(ext_timeout(e, set)))
>  			continue;
> -
> -		if ((d->before > 1 && !id_eq(set, i + 1, d->refid)) ||
> -		    (d->before < 0 &&
> -		     (i == 0 || !id_eq(set, i - 1, d->refid))))
> -			/* Before/after doesn't match */
> +		else if (d->id == e->id)
> +			n = e;
> +		else if (d->before == 0 || e->id != d->refid)
> +			continue;
> +		else if (d->before > 0)
> +			next = e;
> +		else
> +			prev = e;
> +	}
> +	/* Re-add already existing element */
> +	if (n) {
> +		if ((d->before > 0 && !next) ||
> +		    (d->before < 0 && !prev))
>  			return -IPSET_ERR_REF_EXIST;
>  		if (!flag_exist)
> -			/* Can't re-add */
>  			return -IPSET_ERR_EXIST;
>  		/* Update extensions */
> -		ip_set_ext_destroy(set, e);
> +		ip_set_ext_destroy(set, n);
> +		list_set_init_extensions(set, ext, n);
>  
> -		if (SET_WITH_TIMEOUT(set))
> -			ip_set_timeout_set(ext_timeout(e, set), ext->timeout);
> -		if (SET_WITH_COUNTER(set))
> -			ip_set_init_counter(ext_counter(e, set), ext);
> -		if (SET_WITH_COMMENT(set))
> -			ip_set_init_comment(ext_comment(e, set), ext);
> -		if (SET_WITH_SKBINFO(set))
> -			ip_set_init_skbinfo(ext_skbinfo(e, set), ext);
>  		/* Set is already added to the list */
>  		ip_set_put_byindex(map->net, d->id);
>  		return 0;
>  	}
> -insert:
> -	ret = -IPSET_ERR_LIST_FULL;
> -	for (i = 0; i < map->size && ret == -IPSET_ERR_LIST_FULL; i++) {
> -		e = list_set_elem(set, map, i);
> -		if (e->id == IPSET_INVALID_ID)
> -			ret = d->before != 0 ? -IPSET_ERR_REF_EXIST
> -				: list_set_add(set, i, d, ext);
> -		else if (e->id != d->refid)
> -			continue;
> -		else if (d->before > 0)
> -			ret = list_set_add(set, i, d, ext);
> -		else if (i + 1 < map->size)
> -			ret = list_set_add(set, i + 1, d, ext);
> +	/* Add new entry */
> +	if (d->before == 0) {
> +		/* Append  */
> +		n = list_empty(&map->members) ? NULL :
> +		    list_last_entry(&map->members, struct set_elem, list);
> +	} else if (d->before > 0) {
> +		/* Insert after next element */
> +		if (!list_is_last(&next->list, &map->members))
> +			n = list_next_entry(next, list);
> +	} else {
> +		/* Insert before prev element */
> +		if (prev->list.prev != &map->members)
> +			n = list_prev_entry(prev, list);
>  	}
> -
> -	return ret;
> +	/* Can we replace a timed out entry? */
> +	if (n != NULL &&
> +	    !(SET_WITH_TIMEOUT(set) &&
> +	      ip_set_timeout_expired(ext_timeout(n, set))))
> +		n =  NULL;
> +
> +	e = kzalloc(set->dsize, GFP_KERNEL);
> +	if (!e)
> +		return -ENOMEM;
> +	e->id = d->id;
> +	INIT_LIST_HEAD(&e->list);
> +	list_set_init_extensions(set, ext, e);
> +	if (n)
> +		list_set_replace(set, e, n);
> +	else if (next)
> +		list_add_tail_rcu(&e->list, &next->list);
> +	else if (prev)
> +		list_add_rcu(&e->list, &prev->list);
> +	else
> +		list_add_tail_rcu(&e->list, &map->members);
> +	spin_unlock_bh(&set->lock);
> +
> +	synchronize_rcu_bh();

I suspect you don't need this. What is your intention here?

> +
> +	spin_lock_bh(&set->lock);
> +	return 0;
>  }
>  
>  static int
> @@ -344,32 +324,30 @@ list_set_udel(struct ip_set *set, void *value, const struct ip_set_ext *ext,
>  {
>  	struct list_set *map = set->data;
>  	struct set_adt_elem *d = value;
> -	struct set_elem *e;
> -	u32 i;
> -
> -	for (i = 0; i < map->size; i++) {
> -		e = list_set_elem(set, map, i);
> -		if (e->id == IPSET_INVALID_ID)
> -			return d->before != 0 ? -IPSET_ERR_REF_EXIST
> -					      : -IPSET_ERR_EXIST;
> -		else if (SET_WITH_TIMEOUT(set) &&
> -			 ip_set_timeout_expired(ext_timeout(e, set)))
> +	struct set_elem *e, *next, *prev = NULL;
> +
> +	list_for_each_entry(e, &map->members, list) {
> +		if (SET_WITH_TIMEOUT(set) &&
> +		    ip_set_timeout_expired(ext_timeout(e, set)))
>  			continue;
> -		else if (e->id != d->id)
> +		else if (e->id != d->id) {
> +			prev = e;
>  			continue;
> +		}
>  
> -		if (d->before == 0)
> -			return list_set_del(set, i);
> -		else if (d->before > 0) {
> -			if (!id_eq(set, i + 1, d->refid))
> +		if (d->before > 0) {
> +			next = list_next_entry(e, list);
> +			if (list_is_last(&e->list, &map->members) ||
> +			    next->id != d->refid)
>  				return -IPSET_ERR_REF_EXIST;
> -			return list_set_del(set, i);
> -		} else if (i == 0 || !id_eq(set, i - 1, d->refid))
> -			return -IPSET_ERR_REF_EXIST;
> -		else
> -			return list_set_del(set, i);
> +		} else if (d->before < 0) {
> +			if (prev == NULL || prev->id != d->refid)
> +				return -IPSET_ERR_REF_EXIST;
> +		}
> +		list_set_del(set, e);
> +		return 0;
>  	}
> -	return -IPSET_ERR_EXIST;
> +	return d->before != 0 ? -IPSET_ERR_REF_EXIST : -IPSET_ERR_EXIST;
>  }
>  
>  static int
> @@ -410,6 +388,7 @@ list_set_uadt(struct ip_set *set, struct nlattr *tb[],
>  
>  	if (tb[IPSET_ATTR_CADT_FLAGS]) {
>  		u32 f = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
> +
>  		e.before = f & IPSET_FLAG_BEFORE;
>  	}
>  
> @@ -447,27 +426,26 @@ static void
>  list_set_flush(struct ip_set *set)
>  {
>  	struct list_set *map = set->data;
> -	struct set_elem *e;
> -	u32 i;
> -
> -	for (i = 0; i < map->size; i++) {
> -		e = list_set_elem(set, map, i);
> -		if (e->id != IPSET_INVALID_ID) {
> -			ip_set_put_byindex(map->net, e->id);
> -			ip_set_ext_destroy(set, e);
> -			e->id = IPSET_INVALID_ID;
> -		}
> -	}
> +	struct set_elem *e, *n;
> +
> +	list_for_each_entry_safe(e, n, &map->members, list)
> +		list_set_del(set, e);
>  }
>  
>  static void
>  list_set_destroy(struct ip_set *set)
>  {
>  	struct list_set *map = set->data;
> +	struct set_elem *e, *n;
>  
>  	if (SET_WITH_TIMEOUT(set))
>  		del_timer_sync(&map->gc);
> -	list_set_flush(set);
> +	list_for_each_entry_safe(e, n, &map->members, list) {
> +		list_del(&e->list);
> +		ip_set_put_byindex(map->net, e->id);
> +		ip_set_ext_destroy(set, e);
> +		kfree(e);
> +	}
>  	kfree(map);
>  
>  	set->data = NULL;
> @@ -478,6 +456,11 @@ list_set_head(struct ip_set *set, struct sk_buff *skb)
>  {
>  	const struct list_set *map = set->data;
>  	struct nlattr *nested;
> +	struct set_elem *e;
> +	u32 n = 0;
> +
> +	list_for_each_entry(e, &map->members, list)
> +		n++;
>  
>  	nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
>  	if (!nested)
> @@ -485,7 +468,7 @@ list_set_head(struct ip_set *set, struct sk_buff *skb)
>  	if (nla_put_net32(skb, IPSET_ATTR_SIZE, htonl(map->size)) ||
>  	    nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref - 1)) ||
>  	    nla_put_net32(skb, IPSET_ATTR_MEMSIZE,
> -			  htonl(sizeof(*map) + map->size * set->dsize)))
> +			  htonl(sizeof(*map) + n * set->dsize)))
>  		goto nla_put_failure;
>  	if (unlikely(ip_set_put_flags(skb, set)))
>  		goto nla_put_failure;
> @@ -502,18 +485,20 @@ list_set_list(const struct ip_set *set,
>  {
>  	const struct list_set *map = set->data;
>  	struct nlattr *atd, *nested;
> -	u32 i, first = cb->args[IPSET_CB_ARG0];
> -	const struct set_elem *e;
> +	u32 i = 0, first = cb->args[IPSET_CB_ARG0];
> +	struct set_elem *e;
>  
>  	atd = ipset_nest_start(skb, IPSET_ATTR_ADT);
>  	if (!atd)
>  		return -EMSGSIZE;
> -	for (; cb->args[IPSET_CB_ARG0] < map->size;
> -	     cb->args[IPSET_CB_ARG0]++) {
> -		i = cb->args[IPSET_CB_ARG0];
> -		e = list_set_elem(set, map, i);
> -		if (e->id == IPSET_INVALID_ID)
> -			goto finish;
> +	list_for_each_entry(e, &map->members, list) {
> +		if (i == first)
> +			break;
> +		i++;
> +	}
> +
> +	list_for_each_entry_from(e, &map->members, list) {
> +		i++;
>  		if (SET_WITH_TIMEOUT(set) &&
>  		    ip_set_timeout_expired(ext_timeout(e, set)))
>  			continue;
> @@ -532,7 +517,7 @@ list_set_list(const struct ip_set *set,
>  			goto nla_put_failure;
>  		ipset_nest_end(skb, nested);
>  	}
> -finish:
> +
>  	ipset_nest_end(skb, atd);
>  	/* Set listing finished */
>  	cb->args[IPSET_CB_ARG0] = 0;
> @@ -544,6 +529,7 @@ nla_put_failure:
>  		cb->args[IPSET_CB_ARG0] = 0;
>  		return -EMSGSIZE;
>  	}
> +	cb->args[IPSET_CB_ARG0] = i - 1;
>  	ipset_nest_end(skb, atd);
>  	return 0;
>  }
> @@ -580,9 +566,9 @@ list_set_gc(unsigned long ul_set)
>  	struct ip_set *set = (struct ip_set *) ul_set;
>  	struct list_set *map = set->data;
>  
> -	write_lock_bh(&set->lock);
> +	spin_lock_bh(&set->lock);
>  	set_cleanup_entries(set);
> -	write_unlock_bh(&set->lock);
> +	spin_unlock_bh(&set->lock);
>  
>  	map->gc.expires = jiffies + IPSET_GC_PERIOD(set->timeout) * HZ;
>  	add_timer(&map->gc);
> @@ -606,24 +592,16 @@ static bool
>  init_list_set(struct net *net, struct ip_set *set, u32 size)
>  {
>  	struct list_set *map;
> -	struct set_elem *e;
> -	u32 i;
>  
> -	map = kzalloc(sizeof(*map) +
> -		      min_t(u32, size, IP_SET_LIST_MAX_SIZE) * set->dsize,
> -		      GFP_KERNEL);
> +	map = kzalloc(sizeof(*map), GFP_KERNEL);
>  	if (!map)
>  		return false;
>  
>  	map->size = size;
>  	map->net = net;
> +	INIT_LIST_HEAD(&map->members);
>  	set->data = map;
>  
> -	for (i = 0; i < size; i++) {
> -		e = list_set_elem(set, map, i);
> -		e->id = IPSET_INVALID_ID;
> -	}
> -
>  	return true;
>  }
>  
> -- 
> 1.8.5.1
> 
--
To unsubscribe from this list: send the line "unsubscribe netfilter-devel" in
the body of a message to majordomo@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html
Pablo Neira Ayuso Dec. 2, 2014, 6:52 p.m. UTC | #2
On Tue, Dec 02, 2014 at 07:35:39PM +0100, Pablo Neira Ayuso wrote:
> On Sun, Nov 30, 2014 at 07:57:01PM +0100, Jozsef Kadlecsik wrote:
> > Signed-off-by: Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
> > ---
> >  net/netfilter/ipset/ip_set_list_set.c | 386 ++++++++++++++++------------------
> >  1 file changed, 182 insertions(+), 204 deletions(-)
> > 
> > diff --git a/net/netfilter/ipset/ip_set_list_set.c b/net/netfilter/ipset/ip_set_list_set.c
> > index f8f6828..323115a 100644
> > --- a/net/netfilter/ipset/ip_set_list_set.c
> > +++ b/net/netfilter/ipset/ip_set_list_set.c
> > @@ -9,6 +9,7 @@
> >  
> >  #include <linux/module.h>
> >  #include <linux/ip.h>
> > +#include <linux/rculist.h>
> >  #include <linux/skbuff.h>
> >  #include <linux/errno.h>
> >  
> > @@ -27,6 +28,8 @@ MODULE_ALIAS("ip_set_list:set");
> >  
> >  /* Member elements  */
> >  struct set_elem {
> > +	struct rcu_head rcu;
> > +	struct list_head list;
> 
> I think rcu_barrier() in the module removal path is missing to make
> sure call_rcu() is called before the module is gone.

I mean, we make sure rcu softirq runs to release these objects before
the module is gone.
--
To unsubscribe from this list: send the line "unsubscribe netfilter-devel" in
the body of a message to majordomo@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html
Jozsef Kadlecsik Dec. 3, 2014, 11:17 a.m. UTC | #3
On Tue, 2 Dec 2014, Pablo Neira Ayuso wrote:

> On Sun, Nov 30, 2014 at 07:57:01PM +0100, Jozsef Kadlecsik wrote:
> > Signed-off-by: Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
> > ---
> >  net/netfilter/ipset/ip_set_list_set.c | 386 ++++++++++++++++------------------
> >  1 file changed, 182 insertions(+), 204 deletions(-)
> > 
> > diff --git a/net/netfilter/ipset/ip_set_list_set.c b/net/netfilter/ipset/ip_set_list_set.c
> > index f8f6828..323115a 100644
> > --- a/net/netfilter/ipset/ip_set_list_set.c
> > +++ b/net/netfilter/ipset/ip_set_list_set.c
> > @@ -9,6 +9,7 @@
> >  
> >  #include <linux/module.h>
> >  #include <linux/ip.h>
> > +#include <linux/rculist.h>
> >  #include <linux/skbuff.h>
> >  #include <linux/errno.h>
> >  
> > @@ -27,6 +28,8 @@ MODULE_ALIAS("ip_set_list:set");
> >  
> >  /* Member elements  */
> >  struct set_elem {
> > +	struct rcu_head rcu;
> > +	struct list_head list;
> 
> I think rcu_barrier() in the module removal path is missing to make
> sure call_rcu() is called before the module is gone.

The module can be removed only when there isn't a single set of the given 
type. That means there are no elements to be removed by kfree_rcu(). 
Therefore I think rcu_barrier() is not required in the module removal 
path.

> > @@ -91,13 +88,9 @@ list_set_kadd(struct ip_set *set, const struct sk_buff *skb,
> >  {
> >  	struct list_set *map = set->data;
> >  	struct set_elem *e;
> > -	u32 i;
> >  	int ret;
> >  
> > -	for (i = 0; i < map->size; i++) {
> > -		e = list_set_elem(set, map, i);
> > -		if (e->id == IPSET_INVALID_ID)
> > -			return 0;
> > +	list_for_each_entry_rcu(e, &map->members, list) {
> 
> >From net/netfilter/ipset/ip_set_core.c I can see this kadd() will be
> called under spin_lock_bh(), so you can just use
> list_for_each_entry(). The _rcu() variant protects the reader side,
> but this code is only invoked from the writer side (no changes are
> guaranteed to happen there).

Yes, you are right! I'll correct it.

> >  static int
> >  list_set_uadd(struct ip_set *set, void *value, const struct ip_set_ext *ext,
> > @@ -282,60 +240,82 @@ list_set_uadd(struct ip_set *set, void *value, const struct ip_set_ext *ext,
> >  {
> >  	struct list_set *map = set->data;
> >  	struct set_adt_elem *d = value;
> > -	struct set_elem *e;
> > +	struct set_elem *e, *n, *prev, *next;
> >  	bool flag_exist = flags & IPSET_FLAG_EXIST;
> > -	u32 i, ret = 0;
> >  
> >  	if (SET_WITH_TIMEOUT(set))
> >  		set_cleanup_entries(set);
> >  
> > -	/* Check already added element */
> > -	for (i = 0; i < map->size; i++) {
> > -		e = list_set_elem(set, map, i);
> > -		if (e->id == IPSET_INVALID_ID)
> > -			goto insert;
> > -		else if (e->id != d->id)
> > +	/* Find where to add the new entry */
> > +	n = prev = next = NULL;
> > +	list_for_each_entry(e, &map->members, list) {
> > +		if (SET_WITH_TIMEOUT(set) &&
> > +		    ip_set_timeout_expired(ext_timeout(e, set)))
> >  			continue;
> > -
> > -		if ((d->before > 1 && !id_eq(set, i + 1, d->refid)) ||
> > -		    (d->before < 0 &&
> > -		     (i == 0 || !id_eq(set, i - 1, d->refid))))
> > -			/* Before/after doesn't match */
> > +		else if (d->id == e->id)
> > +			n = e;
> > +		else if (d->before == 0 || e->id != d->refid)
> > +			continue;
> > +		else if (d->before > 0)
> > +			next = e;
> > +		else
> > +			prev = e;
> > +	}
> > +	/* Re-add already existing element */
> > +	if (n) {
> > +		if ((d->before > 0 && !next) ||
> > +		    (d->before < 0 && !prev))
> >  			return -IPSET_ERR_REF_EXIST;
> >  		if (!flag_exist)
> > -			/* Can't re-add */
> >  			return -IPSET_ERR_EXIST;
> >  		/* Update extensions */
> > -		ip_set_ext_destroy(set, e);
> > +		ip_set_ext_destroy(set, n);
> > +		list_set_init_extensions(set, ext, n);
> >  
> > -		if (SET_WITH_TIMEOUT(set))
> > -			ip_set_timeout_set(ext_timeout(e, set), ext->timeout);
> > -		if (SET_WITH_COUNTER(set))
> > -			ip_set_init_counter(ext_counter(e, set), ext);
> > -		if (SET_WITH_COMMENT(set))
> > -			ip_set_init_comment(ext_comment(e, set), ext);
> > -		if (SET_WITH_SKBINFO(set))
> > -			ip_set_init_skbinfo(ext_skbinfo(e, set), ext);
> >  		/* Set is already added to the list */
> >  		ip_set_put_byindex(map->net, d->id);
> >  		return 0;
> >  	}
> > -insert:
> > -	ret = -IPSET_ERR_LIST_FULL;
> > -	for (i = 0; i < map->size && ret == -IPSET_ERR_LIST_FULL; i++) {
> > -		e = list_set_elem(set, map, i);
> > -		if (e->id == IPSET_INVALID_ID)
> > -			ret = d->before != 0 ? -IPSET_ERR_REF_EXIST
> > -				: list_set_add(set, i, d, ext);
> > -		else if (e->id != d->refid)
> > -			continue;
> > -		else if (d->before > 0)
> > -			ret = list_set_add(set, i, d, ext);
> > -		else if (i + 1 < map->size)
> > -			ret = list_set_add(set, i + 1, d, ext);
> > +	/* Add new entry */
> > +	if (d->before == 0) {
> > +		/* Append  */
> > +		n = list_empty(&map->members) ? NULL :
> > +		    list_last_entry(&map->members, struct set_elem, list);
> > +	} else if (d->before > 0) {
> > +		/* Insert after next element */
> > +		if (!list_is_last(&next->list, &map->members))
> > +			n = list_next_entry(next, list);
> > +	} else {
> > +		/* Insert before prev element */
> > +		if (prev->list.prev != &map->members)
> > +			n = list_prev_entry(prev, list);
> >  	}
> > -
> > -	return ret;
> > +	/* Can we replace a timed out entry? */
> > +	if (n != NULL &&
> > +	    !(SET_WITH_TIMEOUT(set) &&
> > +	      ip_set_timeout_expired(ext_timeout(n, set))))
> > +		n =  NULL;
> > +
> > +	e = kzalloc(set->dsize, GFP_KERNEL);
> > +	if (!e)
> > +		return -ENOMEM;
> > +	e->id = d->id;
> > +	INIT_LIST_HEAD(&e->list);
> > +	list_set_init_extensions(set, ext, e);
> > +	if (n)
> > +		list_set_replace(set, e, n);
> > +	else if (next)
> > +		list_add_tail_rcu(&e->list, &next->list);
> > +	else if (prev)
> > +		list_add_rcu(&e->list, &prev->list);
> > +	else
> > +		list_add_tail_rcu(&e->list, &map->members);
> > +	spin_unlock_bh(&set->lock);
> > +
> > +	synchronize_rcu_bh();
> 
> I suspect you don't need this. What is your intention here?

Here the userspace adds/deletes/replaces an element in the list type of 
set and in the meantime the kernel module can traverse the same linked 
list. In the replace case we remove and delete the old entry, therefore 
the call to synchronize_rcu_bh(). That could be called from a condition 
then, to express the case.

Best regards,
Jozsef
-
E-mail  : kadlec@blackhole.kfki.hu, kadlecsik.jozsef@wigner.mta.hu
PGP key : http://www.kfki.hu/~kadlec/pgp_public_key.txt
Address : Wigner Research Centre for Physics, Hungarian Academy of Sciences
          H-1525 Budapest 114, POB. 49, Hungary
--
To unsubscribe from this list: send the line "unsubscribe netfilter-devel" in
the body of a message to majordomo@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html
Pablo Neira Ayuso Dec. 3, 2014, 11:36 a.m. UTC | #4
On Wed, Dec 03, 2014 at 12:17:36PM +0100, Jozsef Kadlecsik wrote:
> On Tue, 2 Dec 2014, Pablo Neira Ayuso wrote:
> 
> > On Sun, Nov 30, 2014 at 07:57:01PM +0100, Jozsef Kadlecsik wrote:
> > > Signed-off-by: Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
> > > ---
> > >  net/netfilter/ipset/ip_set_list_set.c | 386 ++++++++++++++++------------------
> > >  1 file changed, 182 insertions(+), 204 deletions(-)
> > > 
> > > diff --git a/net/netfilter/ipset/ip_set_list_set.c b/net/netfilter/ipset/ip_set_list_set.c
> > > index f8f6828..323115a 100644
> > > --- a/net/netfilter/ipset/ip_set_list_set.c
> > > +++ b/net/netfilter/ipset/ip_set_list_set.c
> > > @@ -9,6 +9,7 @@
> > >  
> > >  #include <linux/module.h>
> > >  #include <linux/ip.h>
> > > +#include <linux/rculist.h>
> > >  #include <linux/skbuff.h>
> > >  #include <linux/errno.h>
> > >  
> > > @@ -27,6 +28,8 @@ MODULE_ALIAS("ip_set_list:set");
> > >  
> > >  /* Member elements  */
> > >  struct set_elem {
> > > +	struct rcu_head rcu;
> > > +	struct list_head list;
> > 
> > I think rcu_barrier() in the module removal path is missing to make
> > sure call_rcu() is called before the module is gone.
> 
> The module can be removed only when there isn't a single set of the given 
> type. That means there are no elements to be removed by kfree_rcu(). 
> Therefore I think rcu_barrier() is not required in the module removal 
> path.

I think this can race with the rcu callback execution. See this:

https://www.kernel.org/doc/Documentation/RCU/rcubarrier.txt

specifically: Unloading Modules That Use call_rcu()

[...]
> > > +	/* Can we replace a timed out entry? */
> > > +	if (n != NULL &&
> > > +	    !(SET_WITH_TIMEOUT(set) &&
> > > +	      ip_set_timeout_expired(ext_timeout(n, set))))
> > > +		n =  NULL;
> > > +
> > > +	e = kzalloc(set->dsize, GFP_KERNEL);
> > > +	if (!e)
> > > +		return -ENOMEM;
> > > +	e->id = d->id;
> > > +	INIT_LIST_HEAD(&e->list);
> > > +	list_set_init_extensions(set, ext, e);
> > > +	if (n)
> > > +		list_set_replace(set, e, n);
> > > +	else if (next)
> > > +		list_add_tail_rcu(&e->list, &next->list);
> > > +	else if (prev)
> > > +		list_add_rcu(&e->list, &prev->list);
> > > +	else
> > > +		list_add_tail_rcu(&e->list, &map->members);
> > > +	spin_unlock_bh(&set->lock);
> > > +
> > > +	synchronize_rcu_bh();
> > 
> > I suspect you don't need this. What is your intention here?
> 
> Here the userspace adds/deletes/replaces an element in the list type of 
> set and in the meantime the kernel module can traverse the same linked 
> list. In the replace case we remove and delete the old entry, therefore 
> the call to synchronize_rcu_bh(). That could be called from a condition 
> then, to express the case.

But you're releasing objects via kfree_rcu(), right? Then you don't
need to wait for the rcu grace state.
--
To unsubscribe from this list: send the line "unsubscribe netfilter-devel" in
the body of a message to majordomo@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html
diff mbox

Patch

diff --git a/net/netfilter/ipset/ip_set_list_set.c b/net/netfilter/ipset/ip_set_list_set.c
index f8f6828..323115a 100644
--- a/net/netfilter/ipset/ip_set_list_set.c
+++ b/net/netfilter/ipset/ip_set_list_set.c
@@ -9,6 +9,7 @@ 
 
 #include <linux/module.h>
 #include <linux/ip.h>
+#include <linux/rculist.h>
 #include <linux/skbuff.h>
 #include <linux/errno.h>
 
@@ -27,6 +28,8 @@  MODULE_ALIAS("ip_set_list:set");
 
 /* Member elements  */
 struct set_elem {
+	struct rcu_head rcu;
+	struct list_head list;
 	ip_set_id_t id;
 };
 
@@ -41,12 +44,9 @@  struct list_set {
 	u32 size;		/* size of set list array */
 	struct timer_list gc;	/* garbage collection */
 	struct net *net;	/* namespace */
-	struct set_elem members[0]; /* the set members */
+	struct list_head members; /* the set members */
 };
 
-#define list_set_elem(set, map, id)	\
-	(struct set_elem *)((void *)(map)->members + (id) * (set)->dsize)
-
 static int
 list_set_ktest(struct ip_set *set, const struct sk_buff *skb,
 	       const struct xt_action_param *par,
@@ -54,17 +54,14 @@  list_set_ktest(struct ip_set *set, const struct sk_buff *skb,
 {
 	struct list_set *map = set->data;
 	struct set_elem *e;
-	u32 i, cmdflags = opt->cmdflags;
+	u32 cmdflags = opt->cmdflags;
 	int ret;
 
 	/* Don't lookup sub-counters at all */
 	opt->cmdflags &= ~IPSET_FLAG_MATCH_COUNTERS;
 	if (opt->cmdflags & IPSET_FLAG_SKIP_SUBCOUNTER_UPDATE)
 		opt->cmdflags &= ~IPSET_FLAG_SKIP_COUNTER_UPDATE;
-	for (i = 0; i < map->size; i++) {
-		e = list_set_elem(set, map, i);
-		if (e->id == IPSET_INVALID_ID)
-			return 0;
+	list_for_each_entry_rcu(e, &map->members, list) {
 		if (SET_WITH_TIMEOUT(set) &&
 		    ip_set_timeout_expired(ext_timeout(e, set)))
 			continue;
@@ -91,13 +88,9 @@  list_set_kadd(struct ip_set *set, const struct sk_buff *skb,
 {
 	struct list_set *map = set->data;
 	struct set_elem *e;
-	u32 i;
 	int ret;
 
-	for (i = 0; i < map->size; i++) {
-		e = list_set_elem(set, map, i);
-		if (e->id == IPSET_INVALID_ID)
-			return 0;
+	list_for_each_entry_rcu(e, &map->members, list) {
 		if (SET_WITH_TIMEOUT(set) &&
 		    ip_set_timeout_expired(ext_timeout(e, set)))
 			continue;
@@ -115,13 +108,9 @@  list_set_kdel(struct ip_set *set, const struct sk_buff *skb,
 {
 	struct list_set *map = set->data;
 	struct set_elem *e;
-	u32 i;
 	int ret;
 
-	for (i = 0; i < map->size; i++) {
-		e = list_set_elem(set, map, i);
-		if (e->id == IPSET_INVALID_ID)
-			return 0;
+	list_for_each_entry_rcu(e, &map->members, list) {
 		if (SET_WITH_TIMEOUT(set) &&
 		    ip_set_timeout_expired(ext_timeout(e, set)))
 			continue;
@@ -138,110 +127,65 @@  list_set_kadt(struct ip_set *set, const struct sk_buff *skb,
 	      enum ipset_adt adt, struct ip_set_adt_opt *opt)
 {
 	struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set);
+	int ret = -EINVAL;
 
+	rcu_read_lock();
 	switch (adt) {
 	case IPSET_TEST:
-		return list_set_ktest(set, skb, par, opt, &ext);
+		ret = list_set_ktest(set, skb, par, opt, &ext);
+		break;
 	case IPSET_ADD:
-		return list_set_kadd(set, skb, par, opt, &ext);
+		ret = list_set_kadd(set, skb, par, opt, &ext);
+		break;
 	case IPSET_DEL:
-		return list_set_kdel(set, skb, par, opt, &ext);
+		ret = list_set_kdel(set, skb, par, opt, &ext);
+		break;
 	default:
 		break;
 	}
-	return -EINVAL;
-}
+	rcu_read_unlock();
 
-static bool
-id_eq(const struct ip_set *set, u32 i, ip_set_id_t id)
-{
-	const struct list_set *map = set->data;
-	const struct set_elem *e;
-
-	if (i >= map->size)
-		return 0;
-
-	e = list_set_elem(set, map, i);
-	return !!(e->id == id &&
-		 !(SET_WITH_TIMEOUT(set) &&
-		   ip_set_timeout_expired(ext_timeout(e, set))));
+	return ret;
 }
 
-static int
-list_set_add(struct ip_set *set, u32 i, struct set_adt_elem *d,
-	     const struct ip_set_ext *ext)
-{
-	struct list_set *map = set->data;
-	struct set_elem *e = list_set_elem(set, map, i);
-
-	if (e->id != IPSET_INVALID_ID) {
-		if (i == map->size - 1) {
-			/* Last element replaced: e.g. add new,before,last */
-			ip_set_put_byindex(map->net, e->id);
-			ip_set_ext_destroy(set, e);
-		} else {
-			struct set_elem *x = list_set_elem(set, map,
-							   map->size - 1);
-
-			/* Last element pushed off */
-			if (x->id != IPSET_INVALID_ID) {
-				ip_set_put_byindex(map->net, x->id);
-				ip_set_ext_destroy(set, x);
-			}
-			memmove(list_set_elem(set, map, i + 1), e,
-				set->dsize * (map->size - (i + 1)));
-			/* Extensions must be initialized to zero */
-			memset(e, 0, set->dsize);
-		}
-	}
-
-	e->id = d->id;
-	if (SET_WITH_TIMEOUT(set))
-		ip_set_timeout_set(ext_timeout(e, set), ext->timeout);
-	if (SET_WITH_COUNTER(set))
-		ip_set_init_counter(ext_counter(e, set), ext);
-	if (SET_WITH_COMMENT(set))
-		ip_set_init_comment(ext_comment(e, set), ext);
-	if (SET_WITH_SKBINFO(set))
-		ip_set_init_skbinfo(ext_skbinfo(e, set), ext);
-	return 0;
-}
+/* Userspace interfaces: we are protected by the nfnl mutex */
 
-static int
-list_set_del(struct ip_set *set, u32 i)
+static void
+__list_set_del(struct ip_set *set, struct set_elem *e)
 {
 	struct list_set *map = set->data;
-	struct set_elem *e = list_set_elem(set, map, i);
 
 	ip_set_put_byindex(map->net, e->id);
+	/* We may call it, because we don't have a to be destroyed
+	 * extension which is used by the kernel.
+	 */
 	ip_set_ext_destroy(set, e);
+	kfree_rcu(e, rcu);
+}
 
-	if (i < map->size - 1)
-		memmove(e, list_set_elem(set, map, i + 1),
-			set->dsize * (map->size - (i + 1)));
+static inline void
+list_set_del(struct ip_set *set, struct set_elem *e)
+{
+	list_del_rcu(&e->list);
+	__list_set_del(set, e);
+}
 
-	/* Last element */
-	e = list_set_elem(set, map, map->size - 1);
-	e->id = IPSET_INVALID_ID;
-	return 0;
+static inline void
+list_set_replace(struct ip_set *set, struct set_elem *e, struct set_elem *old)
+{
+	list_replace_rcu(&old->list, &e->list);
+	__list_set_del(set, old);
 }
 
 static void
 set_cleanup_entries(struct ip_set *set)
 {
 	struct list_set *map = set->data;
-	struct set_elem *e;
-	u32 i = 0;
+	struct set_elem *e, *n;
 
-	while (i < map->size) {
-		e = list_set_elem(set, map, i);
-		if (e->id != IPSET_INVALID_ID &&
-		    ip_set_timeout_expired(ext_timeout(e, set)))
-			list_set_del(set, i);
-			/* Check element moved to position i in next loop */
-		else
-			i++;
-	}
+	list_for_each_entry_safe(e, n, &map->members, list)
+		if (ip_set_timeout_expired(ext_timeout(e, set)))
+			list_set_del(set, e);
 }
 
 static int
@@ -250,31 +194,45 @@  list_set_utest(struct ip_set *set, void *value, const struct ip_set_ext *ext,
 {
 	struct list_set *map = set->data;
 	struct set_adt_elem *d = value;
-	struct set_elem *e;
-	u32 i;
+	struct set_elem *e, *next, *prev = NULL;
 	int ret;
 
-	for (i = 0; i < map->size; i++) {
-		e = list_set_elem(set, map, i);
-		if (e->id == IPSET_INVALID_ID)
-			return 0;
-		else if (SET_WITH_TIMEOUT(set) &&
-			 ip_set_timeout_expired(ext_timeout(e, set)))
+	list_for_each_entry(e, &map->members, list) {
+		if (SET_WITH_TIMEOUT(set) &&
+		    ip_set_timeout_expired(ext_timeout(e, set)))
 			continue;
-		else if (e->id != d->id)
+		else if (e->id != d->id) {
+			prev = e;
 			continue;
+		}
 
 		if (d->before == 0)
-			return 1;
-		else if (d->before > 0)
-			ret = id_eq(set, i + 1, d->refid);
-		else
-			ret = i > 0 && id_eq(set, i - 1, d->refid);
+			ret = 1;
+		else if (d->before > 0) {
+			next = list_next_entry(e, list);
+			ret = !list_is_last(&e->list, &map->members) &&
+			      next->id == d->refid;
+		} else
+			ret = prev != NULL && prev->id == d->refid;
 		return ret;
 	}
 	return 0;
 }
 
+static void
+list_set_init_extensions(struct ip_set *set, const struct ip_set_ext *ext,
+			 struct set_elem *e)
+{
+	if (SET_WITH_COUNTER(set))
+		ip_set_init_counter(ext_counter(e, set), ext);
+	if (SET_WITH_COMMENT(set))
+		ip_set_init_comment(ext_comment(e, set), ext);
+	if (SET_WITH_SKBINFO(set))
+		ip_set_init_skbinfo(ext_skbinfo(e, set), ext);
+	/* Update timeout last */
+	if (SET_WITH_TIMEOUT(set))
+		ip_set_timeout_set(ext_timeout(e, set), ext->timeout);
+}
 
 static int
 list_set_uadd(struct ip_set *set, void *value, const struct ip_set_ext *ext,
@@ -282,60 +240,82 @@  list_set_uadd(struct ip_set *set, void *value, const struct ip_set_ext *ext,
 {
 	struct list_set *map = set->data;
 	struct set_adt_elem *d = value;
-	struct set_elem *e;
+	struct set_elem *e, *n, *prev, *next;
 	bool flag_exist = flags & IPSET_FLAG_EXIST;
-	u32 i, ret = 0;
 
 	if (SET_WITH_TIMEOUT(set))
 		set_cleanup_entries(set);
 
-	/* Check already added element */
-	for (i = 0; i < map->size; i++) {
-		e = list_set_elem(set, map, i);
-		if (e->id == IPSET_INVALID_ID)
-			goto insert;
-		else if (e->id != d->id)
+	/* Find where to add the new entry */
+	n = prev = next = NULL;
+	list_for_each_entry(e, &map->members, list) {
+		if (SET_WITH_TIMEOUT(set) &&
+		    ip_set_timeout_expired(ext_timeout(e, set)))
 			continue;
-
-		if ((d->before > 1 && !id_eq(set, i + 1, d->refid)) ||
-		    (d->before < 0 &&
-		     (i == 0 || !id_eq(set, i - 1, d->refid))))
-			/* Before/after doesn't match */
+		else if (d->id == e->id)
+			n = e;
+		else if (d->before == 0 || e->id != d->refid)
+			continue;
+		else if (d->before > 0)
+			next = e;
+		else
+			prev = e;
+	}
+	/* Re-add already existing element */
+	if (n) {
+		if ((d->before > 0 && !next) ||
+		    (d->before < 0 && !prev))
 			return -IPSET_ERR_REF_EXIST;
 		if (!flag_exist)
-			/* Can't re-add */
 			return -IPSET_ERR_EXIST;
 		/* Update extensions */
-		ip_set_ext_destroy(set, e);
+		ip_set_ext_destroy(set, n);
+		list_set_init_extensions(set, ext, n);
 
-		if (SET_WITH_TIMEOUT(set))
-			ip_set_timeout_set(ext_timeout(e, set), ext->timeout);
-		if (SET_WITH_COUNTER(set))
-			ip_set_init_counter(ext_counter(e, set), ext);
-		if (SET_WITH_COMMENT(set))
-			ip_set_init_comment(ext_comment(e, set), ext);
-		if (SET_WITH_SKBINFO(set))
-			ip_set_init_skbinfo(ext_skbinfo(e, set), ext);
 		/* Set is already added to the list */
 		ip_set_put_byindex(map->net, d->id);
 		return 0;
 	}
-insert:
-	ret = -IPSET_ERR_LIST_FULL;
-	for (i = 0; i < map->size && ret == -IPSET_ERR_LIST_FULL; i++) {
-		e = list_set_elem(set, map, i);
-		if (e->id == IPSET_INVALID_ID)
-			ret = d->before != 0 ? -IPSET_ERR_REF_EXIST
-				: list_set_add(set, i, d, ext);
-		else if (e->id != d->refid)
-			continue;
-		else if (d->before > 0)
-			ret = list_set_add(set, i, d, ext);
-		else if (i + 1 < map->size)
-			ret = list_set_add(set, i + 1, d, ext);
+	/* Add new entry */
+	if (d->before == 0) {
+		/* Append  */
+		n = list_empty(&map->members) ? NULL :
+		    list_last_entry(&map->members, struct set_elem, list);
+	} else if (d->before > 0) {
+		/* Insert after next element */
+		if (!list_is_last(&next->list, &map->members))
+			n = list_next_entry(next, list);
+	} else {
+		/* Insert before prev element */
+		if (prev->list.prev != &map->members)
+			n = list_prev_entry(prev, list);
 	}
-
-	return ret;
+	/* Can we replace a timed out entry? */
+	if (n != NULL &&
+	    !(SET_WITH_TIMEOUT(set) &&
+	      ip_set_timeout_expired(ext_timeout(n, set))))
+		n =  NULL;
+
+	e = kzalloc(set->dsize, GFP_KERNEL);
+	if (!e)
+		return -ENOMEM;
+	e->id = d->id;
+	INIT_LIST_HEAD(&e->list);
+	list_set_init_extensions(set, ext, e);
+	if (n)
+		list_set_replace(set, e, n);
+	else if (next)
+		list_add_tail_rcu(&e->list, &next->list);
+	else if (prev)
+		list_add_rcu(&e->list, &prev->list);
+	else
+		list_add_tail_rcu(&e->list, &map->members);
+	spin_unlock_bh(&set->lock);
+
+	synchronize_rcu_bh();
+
+	spin_lock_bh(&set->lock);
+	return 0;
 }
 
 static int
@@ -344,32 +324,30 @@  list_set_udel(struct ip_set *set, void *value, const struct ip_set_ext *ext,
 {
 	struct list_set *map = set->data;
 	struct set_adt_elem *d = value;
-	struct set_elem *e;
-	u32 i;
-
-	for (i = 0; i < map->size; i++) {
-		e = list_set_elem(set, map, i);
-		if (e->id == IPSET_INVALID_ID)
-			return d->before != 0 ? -IPSET_ERR_REF_EXIST
-					      : -IPSET_ERR_EXIST;
-		else if (SET_WITH_TIMEOUT(set) &&
-			 ip_set_timeout_expired(ext_timeout(e, set)))
+	struct set_elem *e, *next, *prev = NULL;
+
+	list_for_each_entry(e, &map->members, list) {
+		if (SET_WITH_TIMEOUT(set) &&
+		    ip_set_timeout_expired(ext_timeout(e, set)))
 			continue;
-		else if (e->id != d->id)
+		else if (e->id != d->id) {
+			prev = e;
 			continue;
+		}
 
-		if (d->before == 0)
-			return list_set_del(set, i);
-		else if (d->before > 0) {
-			if (!id_eq(set, i + 1, d->refid))
+		if (d->before > 0) {
+			next = list_next_entry(e, list);
+			if (list_is_last(&e->list, &map->members) ||
+			    next->id != d->refid)
 				return -IPSET_ERR_REF_EXIST;
-			return list_set_del(set, i);
-		} else if (i == 0 || !id_eq(set, i - 1, d->refid))
-			return -IPSET_ERR_REF_EXIST;
-		else
-			return list_set_del(set, i);
+		} else if (d->before < 0) {
+			if (prev == NULL || prev->id != d->refid)
+				return -IPSET_ERR_REF_EXIST;
+		}
+		list_set_del(set, e);
+		return 0;
 	}
-	return -IPSET_ERR_EXIST;
+	return d->before != 0 ? -IPSET_ERR_REF_EXIST : -IPSET_ERR_EXIST;
 }
 
 static int
@@ -410,6 +388,7 @@  list_set_uadt(struct ip_set *set, struct nlattr *tb[],
 
 	if (tb[IPSET_ATTR_CADT_FLAGS]) {
 		u32 f = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
+
 		e.before = f & IPSET_FLAG_BEFORE;
 	}
 
@@ -447,27 +426,26 @@  static void
 list_set_flush(struct ip_set *set)
 {
 	struct list_set *map = set->data;
-	struct set_elem *e;
-	u32 i;
-
-	for (i = 0; i < map->size; i++) {
-		e = list_set_elem(set, map, i);
-		if (e->id != IPSET_INVALID_ID) {
-			ip_set_put_byindex(map->net, e->id);
-			ip_set_ext_destroy(set, e);
-			e->id = IPSET_INVALID_ID;
-		}
-	}
+	struct set_elem *e, *n;
+
+	list_for_each_entry_safe(e, n, &map->members, list)
+		list_set_del(set, e);
 }
 
 static void
 list_set_destroy(struct ip_set *set)
 {
 	struct list_set *map = set->data;
+	struct set_elem *e, *n;
 
 	if (SET_WITH_TIMEOUT(set))
 		del_timer_sync(&map->gc);
-	list_set_flush(set);
+	list_for_each_entry_safe(e, n, &map->members, list) {
+		list_del(&e->list);
+		ip_set_put_byindex(map->net, e->id);
+		ip_set_ext_destroy(set, e);
+		kfree(e);
+	}
 	kfree(map);
 
 	set->data = NULL;
@@ -478,6 +456,11 @@  list_set_head(struct ip_set *set, struct sk_buff *skb)
 {
 	const struct list_set *map = set->data;
 	struct nlattr *nested;
+	struct set_elem *e;
+	u32 n = 0;
+
+	list_for_each_entry(e, &map->members, list)
+		n++;
 
 	nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
 	if (!nested)
@@ -485,7 +468,7 @@  list_set_head(struct ip_set *set, struct sk_buff *skb)
 	if (nla_put_net32(skb, IPSET_ATTR_SIZE, htonl(map->size)) ||
 	    nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref - 1)) ||
 	    nla_put_net32(skb, IPSET_ATTR_MEMSIZE,
-			  htonl(sizeof(*map) + map->size * set->dsize)))
+			  htonl(sizeof(*map) + n * set->dsize)))
 		goto nla_put_failure;
 	if (unlikely(ip_set_put_flags(skb, set)))
 		goto nla_put_failure;
@@ -502,18 +485,20 @@  list_set_list(const struct ip_set *set,
 {
 	const struct list_set *map = set->data;
 	struct nlattr *atd, *nested;
-	u32 i, first = cb->args[IPSET_CB_ARG0];
-	const struct set_elem *e;
+	u32 i = 0, first = cb->args[IPSET_CB_ARG0];
+	struct set_elem *e;
 
 	atd = ipset_nest_start(skb, IPSET_ATTR_ADT);
 	if (!atd)
 		return -EMSGSIZE;
-	for (; cb->args[IPSET_CB_ARG0] < map->size;
-	     cb->args[IPSET_CB_ARG0]++) {
-		i = cb->args[IPSET_CB_ARG0];
-		e = list_set_elem(set, map, i);
-		if (e->id == IPSET_INVALID_ID)
-			goto finish;
+	list_for_each_entry(e, &map->members, list) {
+		if (i == first)
+			break;
+		i++;
+	}
+
+	list_for_each_entry_from(e, &map->members, list) {
+		i++;
 		if (SET_WITH_TIMEOUT(set) &&
 		    ip_set_timeout_expired(ext_timeout(e, set)))
 			continue;
@@ -532,7 +517,7 @@  list_set_list(const struct ip_set *set,
 			goto nla_put_failure;
 		ipset_nest_end(skb, nested);
 	}
-finish:
+
 	ipset_nest_end(skb, atd);
 	/* Set listing finished */
 	cb->args[IPSET_CB_ARG0] = 0;
@@ -544,6 +529,7 @@  nla_put_failure:
 		cb->args[IPSET_CB_ARG0] = 0;
 		return -EMSGSIZE;
 	}
+	cb->args[IPSET_CB_ARG0] = i - 1;
 	ipset_nest_end(skb, atd);
 	return 0;
 }
@@ -580,9 +566,9 @@  list_set_gc(unsigned long ul_set)
 	struct ip_set *set = (struct ip_set *) ul_set;
 	struct list_set *map = set->data;
 
-	write_lock_bh(&set->lock);
+	spin_lock_bh(&set->lock);
 	set_cleanup_entries(set);
-	write_unlock_bh(&set->lock);
+	spin_unlock_bh(&set->lock);
 
 	map->gc.expires = jiffies + IPSET_GC_PERIOD(set->timeout) * HZ;
 	add_timer(&map->gc);
@@ -606,24 +592,16 @@  static bool
 init_list_set(struct net *net, struct ip_set *set, u32 size)
 {
 	struct list_set *map;
-	struct set_elem *e;
-	u32 i;
 
-	map = kzalloc(sizeof(*map) +
-		      min_t(u32, size, IP_SET_LIST_MAX_SIZE) * set->dsize,
-		      GFP_KERNEL);
+	map = kzalloc(sizeof(*map), GFP_KERNEL);
 	if (!map)
 		return false;
 
 	map->size = size;
 	map->net = net;
+	INIT_LIST_HEAD(&map->members);
 	set->data = map;
 
-	for (i = 0; i < size; i++) {
-		e = list_set_elem(set, map, i);
-		e->id = IPSET_INVALID_ID;
-	}
-
 	return true;
 }