diff mbox series

[1/2,iptables] xtables: use libnftnl batch API

Message ID 20180526170713.5044-1-pablo@netfilter.org
State Accepted
Delegated to: Pablo Neira
Headers show
Series [1/2,iptables] xtables: use libnftnl batch API | expand

Commit Message

Pablo Neira Ayuso May 26, 2018, 5:07 p.m. UTC
Use existing batching API from library, the existing code relies on an
earlier implementation.

Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
---
 iptables/nft.c | 201 ++++++++++++++++++++++++++++-----------------------------
 iptables/nft.h |   3 +-
 2 files changed, 100 insertions(+), 104 deletions(-)
diff mbox series

Patch

diff --git a/iptables/nft.c b/iptables/nft.c
index 240e77bbab74..02e8ce9f4212 100644
--- a/iptables/nft.c
+++ b/iptables/nft.c
@@ -45,6 +45,7 @@ 
 #include <libnftnl/expr.h>
 #include <libnftnl/set.h>
 #include <libnftnl/udata.h>
+#include <libnftnl/batch.h>
 
 #include <netinet/in.h>	/* inet_ntoa */
 #include <arpa/inet.h>
@@ -81,13 +82,7 @@  int mnl_talk(struct nft_handle *h, struct nlmsghdr *nlh,
 	return 0;
 }
 
-static LIST_HEAD(batch_page_list);
-static int batch_num_pages;
-
-struct batch_page {
-	struct list_head	head;
-	struct mnl_nlmsg_batch	*batch;
-};
+#define NFT_NLMSG_MAXSIZE (UINT16_MAX + getpagesize())
 
 /* selected batch page is 256 Kbytes long to load ruleset of
  * half a million rules without hitting -EMSGSIZE due to large
@@ -95,44 +90,83 @@  struct batch_page {
  */
 #define BATCH_PAGE_SIZE getpagesize() * 32
 
-static struct mnl_nlmsg_batch *mnl_nftnl_batch_alloc(void)
+static struct nftnl_batch *mnl_batch_init(void)
 {
-	static char *buf;
+	struct nftnl_batch *batch;
 
-	/* libmnl needs higher buffer to handle batch overflows */
-	buf = malloc(BATCH_PAGE_SIZE + getpagesize());
-	if (buf == NULL)
+	batch = nftnl_batch_alloc(BATCH_PAGE_SIZE, NFT_NLMSG_MAXSIZE);
+	if (batch == NULL)
 		return NULL;
 
-	return mnl_nlmsg_batch_start(buf, BATCH_PAGE_SIZE);
+	return batch;
 }
 
-static struct mnl_nlmsg_batch *
-mnl_nftnl_batch_page_add(struct mnl_nlmsg_batch *batch)
+static void mnl_nft_batch_continue(struct nftnl_batch *batch)
 {
-	struct batch_page *batch_page;
+	assert(nftnl_batch_update(batch) >= 0);
+}
 
-	batch_page = malloc(sizeof(struct batch_page));
-	if (batch_page == NULL)
-		return NULL;
+static uint32_t mnl_batch_begin(struct nftnl_batch *batch, uint32_t seqnum)
+{
+	nftnl_batch_begin(nftnl_batch_buffer(batch), seqnum);
+	mnl_nft_batch_continue(batch);
+
+	return seqnum;
+}
+
+static void mnl_batch_end(struct nftnl_batch *batch, uint32_t seqnum)
+{
+	nftnl_batch_end(nftnl_batch_buffer(batch), seqnum);
+	mnl_nft_batch_continue(batch);
+}
+
+static bool mnl_batch_ready(struct nftnl_batch *batch)
+{
+	/* Check if the batch only contains the initial and trailing batch
+	 * messages. In that case, the batch is empty.
+	 */
+	return nftnl_batch_buffer_len(batch) !=
+	       (NLMSG_HDRLEN + sizeof(struct nfgenmsg)) * 2;
+}
+
+static void mnl_batch_reset(struct nftnl_batch *batch)
+{
+	nftnl_batch_free(batch);
+}
+
+struct mnl_err {
+	struct list_head	head;
+	int			err;
+	uint32_t		seqnum;
+};
+
+static void mnl_err_list_node_add(struct list_head *err_list, int error,
+				  int seqnum)
+{
+	struct mnl_err *err = malloc(sizeof(struct mnl_err));
 
-	batch_page->batch = batch;
-	list_add_tail(&batch_page->head, &batch_page_list);
-	batch_num_pages++;
+	err->seqnum = seqnum;
+	err->err = error;
+	list_add_tail(&err->head, err_list);
+}
 
-	return mnl_nftnl_batch_alloc();
+static void mnl_err_list_free(struct mnl_err *err)
+{
+	list_del(&err->head);
+	free(err);
 }
 
 static int nlbuffsiz;
 
-static void mnl_nft_set_sndbuffer(const struct mnl_socket *nl)
+static void mnl_set_sndbuffer(const struct mnl_socket *nl,
+			      struct nftnl_batch *batch)
 {
 	int newbuffsiz;
 
-	if (batch_num_pages * BATCH_PAGE_SIZE <= nlbuffsiz)
+	if (nftnl_batch_iovec_len(batch) * BATCH_PAGE_SIZE <= nlbuffsiz)
 		return;
 
-	newbuffsiz = batch_num_pages * BATCH_PAGE_SIZE;
+	newbuffsiz = nftnl_batch_iovec_len(batch) * BATCH_PAGE_SIZE;
 
 	/* Rise sender buffer length to avoid hitting -EMSGSIZE */
 	if (setsockopt(mnl_socket_get_fd(nl), SOL_SOCKET, SO_SNDBUFFORCE,
@@ -142,58 +176,33 @@  static void mnl_nft_set_sndbuffer(const struct mnl_socket *nl)
 	nlbuffsiz = newbuffsiz;
 }
 
-static void mnl_nftnl_batch_reset(void)
-{
-	struct batch_page *batch_page, *next;
-
-	list_for_each_entry_safe(batch_page, next, &batch_page_list, head) {
-		list_del(&batch_page->head);
-		free(mnl_nlmsg_batch_head(batch_page->batch));
-		mnl_nlmsg_batch_stop(batch_page->batch);
-		free(batch_page);
-		batch_num_pages--;
-	}
-}
-
-static ssize_t mnl_nft_socket_sendmsg(const struct mnl_socket *nl)
+static ssize_t mnl_nft_socket_sendmsg(const struct mnl_socket *nf_sock,
+				      struct nftnl_batch *batch)
 {
 	static const struct sockaddr_nl snl = {
 		.nl_family = AF_NETLINK
 	};
-	struct iovec iov[batch_num_pages];
+	uint32_t iov_len = nftnl_batch_iovec_len(batch);
+	struct iovec iov[iov_len];
 	struct msghdr msg = {
 		.msg_name	= (struct sockaddr *) &snl,
 		.msg_namelen	= sizeof(snl),
 		.msg_iov	= iov,
-		.msg_iovlen	= batch_num_pages,
+		.msg_iovlen	= iov_len,
 	};
-	struct batch_page *batch_page;
-	int i = 0, ret;
-
-	mnl_nft_set_sndbuffer(nl);
-
-	list_for_each_entry(batch_page, &batch_page_list, head) {
-		iov[i].iov_base = mnl_nlmsg_batch_head(batch_page->batch);
-		iov[i].iov_len = mnl_nlmsg_batch_size(batch_page->batch);
-		i++;
-#ifdef NL_DEBUG
-		mnl_nlmsg_fprintf(stdout,
-				  mnl_nlmsg_batch_head(batch_page->batch),
-				  mnl_nlmsg_batch_size(batch_page->batch),
-				  sizeof(struct nfgenmsg));
-#endif
-	}
 
-	ret = sendmsg(mnl_socket_get_fd(nl), &msg, 0);
-	mnl_nftnl_batch_reset();
+	mnl_set_sndbuffer(nf_sock, batch);
+	nftnl_batch_iovec(batch, iov, iov_len);
 
-	return ret;
+	return sendmsg(mnl_socket_get_fd(nf_sock), &msg, 0);
 }
 
-static int mnl_nftnl_batch_talk(struct nft_handle *h)
+static int mnl_batch_talk(const struct mnl_socket *nf_sock,
+			  struct nftnl_batch *batch, struct list_head *err_list)
 {
-	int ret, fd = mnl_socket_get_fd(h->nl);
-	char rcv_buf[16536];
+	const struct mnl_socket *nl = nf_sock;
+	int ret, fd = mnl_socket_get_fd(nl), portid = mnl_socket_get_portid(nl);
+	char rcv_buf[MNL_SOCKET_BUFFER_SIZE];
 	fd_set readfds;
 	struct timeval tv = {
 		.tv_sec		= 0,
@@ -201,7 +210,7 @@  static int mnl_nftnl_batch_talk(struct nft_handle *h)
 	};
 	int err = 0;
 
-	ret = mnl_nft_socket_sendmsg(h->nl);
+	ret = mnl_nft_socket_sendmsg(nf_sock, batch);
 	if (ret == -1)
 		return -1;
 
@@ -214,16 +223,18 @@  static int mnl_nftnl_batch_talk(struct nft_handle *h)
 		return -1;
 
 	while (ret > 0 && FD_ISSET(fd, &readfds)) {
-		ret = mnl_socket_recvfrom(h->nl, rcv_buf, sizeof(rcv_buf));
+		struct nlmsghdr *nlh = (struct nlmsghdr *)rcv_buf;
+
+		ret = mnl_socket_recvfrom(nl, rcv_buf, sizeof(rcv_buf));
 		if (ret == -1)
 			return -1;
 
-		ret = mnl_cb_run(rcv_buf, ret, 0, h->portid, NULL, NULL);
-		/* Annotate first error and continue, make sure we get all
-		 * acknoledgments.
-		 */
-		if (!err && ret == -1)
-			err = errno;
+		ret = mnl_cb_run(rcv_buf, ret, 0, portid, NULL, NULL);
+		/* Continue on error, make sure we get all acknowledgments */
+		if (ret == -1) {
+			mnl_err_list_node_add(err_list, errno, nlh->nlmsg_seq);
+			err = -1;
+		}
 
 		ret = select(fd+1, &readfds, NULL, NULL, &tv);
 		if (ret == -1)
@@ -232,22 +243,7 @@  static int mnl_nftnl_batch_talk(struct nft_handle *h)
 		FD_ZERO(&readfds);
 		FD_SET(fd, &readfds);
 	}
-	errno = err;
-	return err ? -1 : 0;
-}
-
-static void mnl_nftnl_batch_begin(struct mnl_nlmsg_batch *batch, uint32_t seq)
-{
-	nftnl_batch_begin(mnl_nlmsg_batch_current(batch), seq);
-	if (!mnl_nlmsg_batch_next(batch))
-		mnl_nftnl_batch_page_add(batch);
-}
-
-static void mnl_nftnl_batch_end(struct mnl_nlmsg_batch *batch, uint32_t seq)
-{
-	nftnl_batch_end(mnl_nlmsg_batch_current(batch), seq);
-	if (!mnl_nlmsg_batch_next(batch))
-		mnl_nftnl_batch_page_add(batch);
+	return err;
 }
 
 enum obj_update_type {
@@ -693,8 +689,7 @@  int nft_init(struct nft_handle *h, struct builtin_table *t)
 	h->tables = t;
 
 	INIT_LIST_HEAD(&h->obj_list);
-
-	h->batch = mnl_nftnl_batch_alloc();
+	INIT_LIST_HEAD(&h->err_list);
 
 	return 0;
 }
@@ -712,8 +707,6 @@  void nft_fini(struct nft_handle *h)
 {
 	flush_rule_cache(h);
 	mnl_socket_close(h->nl);
-	free(mnl_nlmsg_batch_head(h->batch));
-	mnl_nlmsg_batch_stop(h->batch);
 }
 
 static void nft_chain_print_debug(struct nftnl_chain *c, struct nlmsghdr *nlh)
@@ -2270,7 +2263,7 @@  static void nft_compat_table_batch_add(struct nft_handle *h, uint16_t type,
 {
 	struct nlmsghdr *nlh;
 
-	nlh = nftnl_table_nlmsg_build_hdr(mnl_nlmsg_batch_current(h->batch),
+	nlh = nftnl_table_nlmsg_build_hdr(nftnl_batch_buffer(h->batch),
 					type, h->family, flags, seq);
 	nftnl_table_nlmsg_build_payload(nlh, table);
 	nftnl_table_free(table);
@@ -2282,7 +2275,7 @@  static void nft_compat_chain_batch_add(struct nft_handle *h, uint16_t type,
 {
 	struct nlmsghdr *nlh;
 
-	nlh = nftnl_chain_nlmsg_build_hdr(mnl_nlmsg_batch_current(h->batch),
+	nlh = nftnl_chain_nlmsg_build_hdr(nftnl_batch_buffer(h->batch),
 					type, h->family, flags, seq);
 	nftnl_chain_nlmsg_build_payload(nlh, chain);
 	nft_chain_print_debug(chain, nlh);
@@ -2295,7 +2288,7 @@  static void nft_compat_rule_batch_add(struct nft_handle *h, uint16_t type,
 {
 	struct nlmsghdr *nlh;
 
-	nlh = nftnl_rule_nlmsg_build_hdr(mnl_nlmsg_batch_current(h->batch),
+	nlh = nftnl_rule_nlmsg_build_hdr(nftnl_batch_buffer(h->batch),
 				       type, h->family, flags, seq);
 	nftnl_rule_nlmsg_build_payload(nlh, rule);
 	nft_rule_print_debug(rule, nlh);
@@ -2305,10 +2298,13 @@  static void nft_compat_rule_batch_add(struct nft_handle *h, uint16_t type,
 static int nft_action(struct nft_handle *h, int action)
 {
 	struct obj_update *n, *tmp;
+	struct mnl_err *err, *ne;
 	uint32_t seq = 1;
 	int ret = 0;
 
-	mnl_nftnl_batch_begin(h->batch, seq++);
+	h->batch = mnl_batch_init();
+
+	mnl_batch_begin(h->batch, seq++);
 
 	list_for_each_entry_safe(n, tmp, &h->obj_list, head) {
 		switch (n->type) {
@@ -2378,24 +2374,23 @@  static int nft_action(struct nft_handle *h, int action)
 		list_del(&n->head);
 		free(n);
 
-		if (!mnl_nlmsg_batch_next(h->batch))
-			h->batch = mnl_nftnl_batch_page_add(h->batch);
+		mnl_nft_batch_continue(h->batch);
 	}
 
 	switch (action) {
 	case NFT_COMPAT_COMMIT:
-		mnl_nftnl_batch_end(h->batch, seq++);
+		mnl_batch_end(h->batch, seq++);
 		break;
 	case NFT_COMPAT_ABORT:
 		break;
 	}
 
-	if (!mnl_nlmsg_batch_is_empty(h->batch))
-		h->batch = mnl_nftnl_batch_page_add(h->batch);
+	ret = mnl_batch_talk(h->nl, h->batch, &h->err_list);
 
-	ret = mnl_nftnl_batch_talk(h);
+	list_for_each_entry_safe(err, ne, &h->err_list, head)
+		mnl_err_list_free(err);
 
-	mnl_nlmsg_batch_reset(h->batch);
+	mnl_batch_reset(h->batch);
 
 	return ret == 0 ? 1 : 0;
 }
diff --git a/iptables/nft.h b/iptables/nft.h
index 0c4beb998de8..91a90355b5a5 100644
--- a/iptables/nft.h
+++ b/iptables/nft.h
@@ -32,7 +32,8 @@  struct nft_handle {
 	uint32_t		seq;
 	struct list_head	obj_list;
 	int			obj_list_num;
-	struct mnl_nlmsg_batch	*batch;
+	struct nftnl_batch	*batch;
+	struct list_head	err_list;
 	struct nft_family_ops	*ops;
 	struct builtin_table	*tables;
 	struct nftnl_rule_list	*rule_cache;