diff mbox series

[bpf-next,07/16] bpf: sockmap, add msg_cork_bytes() helper

Message ID 20180305195132.6612.89749.stgit@john-Precision-Tower-5810
State Changes Requested, archived
Delegated to: BPF Maintainers
Headers show
Series bpf,sockmap: sendmsg/sendfile ULP | expand

Commit Message

John Fastabend March 5, 2018, 7:51 p.m. UTC
In the case where we need a specific number of bytes before a
verdict can be assigned, even if the data spans multiple sendmsg
or sendfile calls. The BPF program may use msg_apply_bytes().

The extreme case is a user can call sendmsg repeatedly with
1-byte msg segments. Obviously, this is bad for performance but
is still valid. If the BPF program needs N bytes to validate
a header it can use msg_cork_bytes to specify N bytes and the
BPF program will not be called again until N bytes have been
accumulated.

Signed-off-by: John Fastabend <john.fastabend@gmail.com>
---
 include/linux/filter.h   |    2 
 include/uapi/linux/bpf.h |    3 
 kernel/bpf/sockmap.c     |  334 ++++++++++++++++++++++++++++++++++++++++------
 net/core/filter.c        |   16 ++
 4 files changed, 310 insertions(+), 45 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/filter.h b/include/linux/filter.h
index 805a566..6058a1b 100644
--- a/include/linux/filter.h
+++ b/include/linux/filter.h
@@ -511,6 +511,8 @@  struct sk_msg_buff {
 	void *data;
 	void *data_end;
 	int apply_bytes;
+	int cork_bytes;
+	int sg_copybreak;
 	int sg_start;
 	int sg_curr;
 	int sg_end;
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index e50c61f..cfcc002 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -770,7 +770,8 @@  enum bpf_attach_type {
 	FN(override_return),		\
 	FN(sock_ops_cb_flags_set),	\
 	FN(msg_redirect_map),		\
-	FN(msg_apply_bytes),
+	FN(msg_apply_bytes),		\
+	FN(msg_cork_bytes),
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
  * function eBPF program intends to call
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
index 98c6a3b..f637a83 100644
--- a/kernel/bpf/sockmap.c
+++ b/kernel/bpf/sockmap.c
@@ -78,8 +78,10 @@  struct smap_psock {
 	/* datapath variables for tx_msg ULP */
 	struct sock *sk_redir;
 	int apply_bytes;
+	int cork_bytes;
 	int sg_size;
 	int eval;
+	struct sk_msg_buff *cork;
 
 	struct strparser strp;
 	struct bpf_prog *bpf_tx_msg;
@@ -140,22 +142,30 @@  static int bpf_tcp_init(struct sock *sk)
 	return 0;
 }
 
+static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
+
 static void bpf_tcp_release(struct sock *sk)
 {
 	struct smap_psock *psock;
 
 	rcu_read_lock();
 	psock = smap_psock_sk(sk);
+	if (unlikely(!psock))
+		goto out;
 
-	if (likely(psock)) {
-		sk->sk_prot = psock->sk_proto;
-		psock->sk_proto = NULL;
+	if (psock->cork) {
+		free_start_sg(psock->sock, psock->cork);
+		kfree(psock->cork);
+		psock->cork = NULL;
 	}
+
+	sk->sk_prot = psock->sk_proto;
+	psock->sk_proto = NULL;
+out:
 	rcu_read_unlock();
 }
 
-static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
-
 static void bpf_tcp_close(struct sock *sk, long timeout)
 {
 	void (*close_fun)(struct sock *sk, long timeout);
@@ -211,14 +221,25 @@  static int memcopy_from_iter(struct sock *sk,
 			     struct iov_iter *from, int bytes)
 {
 	struct scatterlist *sg = md->sg_data;
-	int i = md->sg_curr, rc = 0;
+	int i = md->sg_curr, rc = -ENOSPC;
 
 	do {
 		int copy;
 		char *to;
 
-		copy = sg[i].length;
-		to = sg_virt(&sg[i]);
+		if (md->sg_copybreak >= sg[i].length) {
+			md->sg_copybreak = 0;
+
+			if (++i == MAX_SKB_FRAGS)
+				i = 0;
+
+			if (i == md->sg_end)
+				break;
+		}
+
+		copy = sg[i].length - md->sg_copybreak;
+		to = sg_virt(&sg[i]) + md->sg_copybreak;
+		md->sg_copybreak += copy;
 
 		if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
 			rc = copy_from_iter_nocache(to, copy, from);
@@ -234,6 +255,7 @@  static int memcopy_from_iter(struct sock *sk,
 		if (!bytes)
 			break;
 
+		md->sg_copybreak = 0;
 		if (++i == MAX_SKB_FRAGS)
 			i = 0;
 	} while (i != md->sg_end);
@@ -328,6 +350,33 @@  static void return_mem_sg(struct sock *sk, int bytes,  struct sk_msg_buff *md)
 	} while (i != md->sg_end);
 }
 
+static void free_bytes_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
+{
+	struct scatterlist *sg = md->sg_data;
+	int i = md->sg_start, free;
+
+	while (bytes && sg[i].length) {
+		free = sg[i].length;
+		if (bytes < free) {
+			sg[i].length -= bytes;
+			sg[i].offset += bytes;
+			sk_mem_uncharge(sk, bytes);
+			break;
+		}
+
+		sk_mem_uncharge(sk, sg[i].length);
+		put_page(sg_page(&sg[i]));
+		bytes -= sg[i].length;
+		sg[i].length = 0;
+		sg[i].page_link = 0;
+		sg[i].offset = 0;
+		i++;
+
+		if (i == MAX_SKB_FRAGS)
+			i = 0;
+	}
+}
+
 static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
 {
 	struct scatterlist *sg = md->sg_data;
@@ -510,6 +559,9 @@  static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 	timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 
 	while (msg_data_left(msg)) {
+		bool cork = false, enospc = false;
+		struct sk_msg_buff *m;
+
 		if (sk->sk_err) {
 			err = sk->sk_err;
 			goto out_err;
@@ -519,32 +571,76 @@  static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 		if (!sk_stream_memory_free(sk))
 			goto wait_for_sndbuf;
 
-		md.sg_curr = md.sg_end;
-		err = sk_alloc_sg(sk, copy, sg,
-				  md.sg_start, &md.sg_end, &sg_copy,
-				  md.sg_end);
+		m = psock->cork_bytes ? psock->cork : &md;
+		m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
+		err = sk_alloc_sg(sk, copy, m->sg_data,
+				  m->sg_start, &m->sg_end, &sg_copy,
+				  m->sg_end - 1);
 		if (err) {
 			if (err != -ENOSPC)
 				goto wait_for_memory;
+			enospc = true;
 			copy = sg_copy;
 		}
 
-		err = memcopy_from_iter(sk, &md, &msg->msg_iter, copy);
+		err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
 		if (err < 0) {
-			free_curr_sg(sk, &md);
+			free_curr_sg(sk, m);
 			goto out_err;
 		}
 
 		psock->sg_size += copy;
 		copied += copy;
 		sg_copy = 0;
+
+		/* When bytes are being corked skip running BPF program and
+		 * applying verdict unless there is no more buffer space. In
+		 * the ENOSPC case simply run BPF prorgram with currently
+		 * accumulated data. We don't have much choice at this point
+		 * we could try extending the page frags or chaining complex
+		 * frags but even in these cases _eventually_ we will hit an
+		 * OOM scenario. More complex recovery schemes may be
+		 * implemented in the future, but BPF programs must handle
+		 * the case where apply_cork requests are not honored. The
+		 * canonical method to verify this is to check data length.
+		 */
+		if (psock->cork_bytes) {
+			if (copy > psock->cork_bytes)
+				psock->cork_bytes = 0;
+			else
+				psock->cork_bytes -= copy;
+
+			if (psock->cork_bytes && !enospc)
+				goto out_cork;
+
+			/* All cork bytes accounted for re-run filter */
+			psock->eval = __SK_NONE;
+			psock->cork_bytes = 0;
+		}
 more_data:
 		/* If msg is larger than MAX_SKB_FRAGS we can send multiple
 		 * scatterlists per msg. However BPF decisions apply to the
 		 * entire msg.
 		 */
 		if (psock->eval == __SK_NONE)
-			psock->eval = smap_do_tx_msg(sk, psock, &md);
+			psock->eval = smap_do_tx_msg(sk, psock, m);
+
+		if (m->cork_bytes &&
+		    m->cork_bytes > psock->sg_size && !enospc) {
+			psock->cork_bytes = m->cork_bytes - psock->sg_size;
+			if (!psock->cork) {
+				psock->cork = kcalloc(1,
+						sizeof(struct sk_msg_buff),
+						GFP_ATOMIC | __GFP_NOWARN);
+
+				if (!psock->cork) {
+					err = -ENOMEM;
+					goto out_err;
+				}
+			}
+			memcpy(psock->cork, m, sizeof(*m));
+			goto out_cork;
+		}
 
 		send = psock->sg_size;
 		if (psock->apply_bytes && psock->apply_bytes < send)
@@ -552,9 +648,9 @@  static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 
 		switch (psock->eval) {
 		case __SK_PASS:
-			err = bpf_tcp_push(sk, send, &md, flags, true);
+			err = bpf_tcp_push(sk, send, m, flags, true);
 			if (unlikely(err)) {
-				copied -= free_start_sg(sk, &md);
+				copied -= free_start_sg(sk, m);
 				goto out_err;
 			}
 
@@ -576,13 +672,23 @@  static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 					psock->apply_bytes -= send;
 			}
 
-			return_mem_sg(sk, send, &md);
+			if (psock->cork) {
+				cork = true;
+				psock->cork = NULL;
+			}
+
+			return_mem_sg(sk, send, m);
 			release_sock(sk);
 
 			err = bpf_tcp_sendmsg_do_redirect(redir, send,
-							  &md, flags);
+							  m, flags);
 			lock_sock(sk);
 
+			if (cork) {
+				free_start_sg(sk, m);
+				kfree(m);
+				m = NULL;
+			}
 			if (unlikely(err)) {
 				copied -= err;
 				goto out_redir;
@@ -592,21 +698,23 @@  static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 			break;
 		case __SK_DROP:
 		default:
-			copied -= free_start_sg(sk, &md);
-
+			free_bytes_sg(sk, send, m);
 			if (psock->apply_bytes) {
 				if (psock->apply_bytes < send)
 					psock->apply_bytes = 0;
 				else
 					psock->apply_bytes -= send;
 			}
-			psock->sg_size -= copied;
+			copied -= send;
+			psock->sg_size -= send;
 			err = -EACCES;
 			break;
 		}
 
 		bpf_md_init(psock);
-		if (sg[md.sg_start].page_link && sg[md.sg_start].length)
+		if (m &&
+		    m->sg_data[m->sg_start].page_link &&
+		    m->sg_data[m->sg_start].length)
 			goto more_data;
 		continue;
 wait_for_sndbuf:
@@ -623,6 +731,47 @@  static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 	release_sock(sk);
 	smap_release_sock(psock, sk);
 	return copied ? copied : err;
+out_cork:
+	release_sock(sk);
+	smap_release_sock(psock, sk);
+	return copied;
+}
+
+static int bpf_tcp_sendpage_sg_locked(struct sock *sk,
+				      struct sk_msg_buff *m,
+				      int send,
+				      int flags)
+{
+	int copied = 0;
+
+	do {
+		struct scatterlist *sg = &m->sg_data[m->sg_start];
+		struct page *p = sg_page(sg);
+		int off = sg->offset;
+		int len = sg->length;
+		int err;
+
+		if (len > send)
+			len = send;
+
+		err = tcp_sendpage_locked(sk, p, off, len, flags);
+		if (err < 0)
+			break;
+
+		sg->length -= len;
+		sg->offset += len;
+		copied += len;
+		send -= len;
+		if (!sg->length) {
+			sg->page_link = 0;
+			put_page(p);
+			m->sg_start++;
+			if (m->sg_start == MAX_SKB_FRAGS)
+				m->sg_start = 0;
+		}
+	} while (send && m->sg_start != m->sg_end);
+
+	return copied;
 }
 
 static int bpf_tcp_sendpage_do_redirect(struct sock *sk,
@@ -644,7 +793,10 @@  static int bpf_tcp_sendpage_do_redirect(struct sock *sk,
 	rcu_read_unlock();
 
 	lock_sock(sk);
-	rc = tcp_sendpage_locked(sk, page, offset, size, flags);
+	if (md)
+		rc = bpf_tcp_sendpage_sg_locked(sk, md, size, flags);
+	else
+		rc = tcp_sendpage_locked(sk, page, offset, size, flags);
 	release_sock(sk);
 
 	smap_release_sock(psock, sk);
@@ -657,10 +809,10 @@  static int bpf_tcp_sendpage_do_redirect(struct sock *sk,
 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
 			    int offset, size_t size, int flags)
 {
-	struct sk_msg_buff md = {0};
+	struct sk_msg_buff md = {0}, *m = NULL;
+	bool cork = false, enospc = false;
 	struct smap_psock *psock;
-	int send, total = 0, rc = __SK_NONE;
-	int orig_size = size;
+	int send, total = 0, rc;
 	struct bpf_prog *prog;
 	struct sock *redir;
 
@@ -686,19 +838,90 @@  static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
 	preempt_enable();
 
 	lock_sock(sk);
+
+	psock->sg_size += size;
+do_cork:
+	if (psock->cork_bytes) {
+		struct scatterlist *sg;
+
+		m = psock->cork;
+		sg = &m->sg_data[m->sg_end];
+		sg_set_page(sg, page, send, offset);
+		get_page(page);
+		sk_mem_charge(sk, send);
+		m->sg_end++;
+		cork = true;
+
+		if (send > psock->cork_bytes)
+			psock->cork_bytes = 0;
+		else
+			psock->cork_bytes -= send;
+
+		if (m->sg_end == MAX_SKB_FRAGS)
+			m->sg_end = 0;
+
+		if (m->sg_end == m->sg_start) {
+			enospc = true;
+			psock->cork_bytes = 0;
+		}
+
+		if (!psock->cork_bytes)
+			psock->eval = __SK_NONE;
+
+		if (!enospc && psock->cork_bytes) {
+			total = send;
+			goto out_err;
+		}
+	}
 more_sendpage_data:
 	if (psock->eval == __SK_NONE)
 		psock->eval = smap_do_tx_msg(sk, psock, &md);
 
+	if (md.cork_bytes && !enospc && md.cork_bytes > psock->sg_size) {
+		psock->cork_bytes = md.cork_bytes;
+		if (!psock->cork) {
+			psock->cork = kzalloc(sizeof(struct sk_msg_buff),
+					GFP_ATOMIC | __GFP_NOWARN);
+
+			if (!psock->cork) {
+				psock->sg_size -= size;
+				total = -ENOMEM;
+				goto out_err;
+			}
+		}
+
+		if (!cork) {
+			send = psock->sg_size;
+			goto do_cork;
+		}
+	}
+
+	send = psock->sg_size;
 	if (psock->apply_bytes && psock->apply_bytes < send)
 		send = psock->apply_bytes;
 
-	switch (rc) {
+	switch (psock->eval) {
 	case __SK_PASS:
-		rc = tcp_sendpage_locked(sk, page, offset, send, flags);
-		if (rc < 0) {
-			total = total ? : rc;
-			goto out_err;
+		/* When data is corked once cork bytes limit is reached
+		 * we may send more data then the current sendfile call
+		 * is expecting. To handle this we have to fixup return
+		 * codes. However, if there is an error there is nothing
+		 * to do but continue. We can not go back in time and
+		 * give errors to data we have already consumed.
+		 */
+		if (m) {
+			rc = bpf_tcp_sendpage_sg_locked(sk, m, send, flags);
+			if (rc < 0) {
+				total = total ? : rc;
+				goto out_err;
+			}
+			sk_mem_uncharge(sk, rc);
+		} else {
+			rc = tcp_sendpage_locked(sk, page, offset, send, flags);
+			if (rc < 0) {
+				total = total ? : rc;
+				goto out_err;
+			}
 		}
 
 		if (psock->apply_bytes) {
@@ -711,7 +934,7 @@  static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
 		total += rc;
 		psock->sg_size -= rc;
 		offset += rc;
-		size -= rc;
+		send -= rc;
 		break;
 	case __SK_REDIRECT:
 		redir = psock->sk_redir;
@@ -728,12 +951,30 @@  static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
 		/* sock lock dropped must not dereference psock below */
 		rc = bpf_tcp_sendpage_do_redirect(redir,
 						  page, offset, send,
-						  flags, &md);
+						  flags, m);
 		lock_sock(sk);
-		if (rc > 0) {
-			offset += rc;
-			psock->sg_size -= rc;
-			send -= rc;
+		if (m) {
+			int free = free_start_sg(sk, m);
+
+			if (rc > 0) {
+				sk_mem_uncharge(sk, rc);
+				free = rc + free;
+			}
+			psock->sg_size -= free;
+			psock->cork_bytes = 0;
+			send = 0;
+			if (psock->apply_bytes) {
+				if (psock->apply_bytes > free)
+					psock->apply_bytes -= free;
+				else
+					psock->apply_bytes = 0;
+			}
+		} else {
+			if (rc > 0) {
+				offset += rc;
+				psock->sg_size -= rc;
+				send -= rc;
+			}
 		}
 
 		if ((total && rc > 0) || (!total && rc < 0))
@@ -741,7 +982,8 @@  static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
 		break;
 	case __SK_DROP:
 	default:
-		return_mem_sg(sk, send, &md);
+		if (m)
+			free_bytes_sg(sk, send, m);
 		if (psock->apply_bytes) {
 			if (psock->apply_bytes > send)
 				psock->apply_bytes -= send;
@@ -749,18 +991,17 @@  static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
 				psock->apply_bytes -= 0;
 		}
 		psock->sg_size -= send;
-		size -= send;
-		total += send;
-		rc = -EACCES;
+		total = total ? : -EACCES;
+		goto out_err;
 	}
 
 	bpf_md_init(psock);
-	if (size)
+	if (psock->sg_size)
 		goto more_sendpage_data;
 out_err:
 	release_sock(sk);
 	smap_release_sock(psock, sk);
-	return total <= orig_size ? total : orig_size;
+	return total <= size ? total : size;
 }
 
 static void bpf_tcp_msg_add(struct smap_psock *psock,
@@ -1077,6 +1318,11 @@  static void smap_gc_work(struct work_struct *w)
 	if (psock->bpf_tx_msg)
 		bpf_prog_put(psock->bpf_tx_msg);
 
+	if (psock->cork) {
+		free_start_sg(psock->sock, psock->cork);
+		kfree(psock->cork);
+	}
+
 	list_for_each_entry_safe(e, tmp, &psock->maps, list) {
 		list_del(&e->list);
 		kfree(e);
diff --git a/net/core/filter.c b/net/core/filter.c
index df2a8f4..2c73af0 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -1942,6 +1942,20 @@  struct sock *do_msg_redirect_map(struct sk_msg_buff *msg)
 	.arg2_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg_buff *, msg, u64, bytes)
+{
+	msg->cork_bytes = bytes;
+	return 0;
+}
+
+static const struct bpf_func_proto bpf_msg_cork_bytes_proto = {
+	.func           = bpf_msg_cork_bytes,
+	.gpl_only       = false,
+	.ret_type       = RET_INTEGER,
+	.arg1_type	= ARG_PTR_TO_CTX,
+	.arg2_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
 {
 	return task_get_classid(skb);
@@ -3650,6 +3664,8 @@  static const struct bpf_func_proto *sk_msg_func_proto(enum bpf_func_id func_id)
 		return &bpf_msg_redirect_map_proto;
 	case BPF_FUNC_msg_apply_bytes:
 		return &bpf_msg_apply_bytes_proto;
+	case BPF_FUNC_msg_cork_bytes:
+		return &bpf_msg_cork_bytes_proto;
 	default:
 		return bpf_base_func_proto(func_id);
 	}