diff mbox series

[RFC,bpf-next,3/5] bpf, sockmap: Don't let child socket inherit psock or its ops on copy

Message ID 20191022113730.29303-4-jakub@cloudflare.com
State RFC
Delegated to: BPF Maintainers
Headers show
Series Extend SOCKMAP to store listening sockets | expand

Commit Message

Jakub Sitnicki Oct. 22, 2019, 11:37 a.m. UTC
New sockets cloned from listening sockets that are in a sockmap must not
inherit the psock that has the link to the sockmap. Otherwise child sockets
unintentionally share the sockmap entry with the listening socket, which
leads to double-free on socket close.

Prevent it by overloading the accept callback. In it we restore the
protocol and write buffer callbacks and clear the pointer to psock.

Signed-off-by: Jakub Sitnicki <jakub@cloudflare.com>
---
 net/ipv4/tcp_bpf.c | 30 ++++++++++++++++++++++++++++++
 1 file changed, 30 insertions(+)
diff mbox series

Patch

diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 8a56e09cfb0e..5838aaba4ce0 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -582,6 +582,35 @@  static void tcp_bpf_close(struct sock *sk, long timeout)
 	saved_close(sk, timeout);
 }
 
+static struct sock *tcp_bpf_accept(struct sock *sk, int flags, int *err,
+				   bool kern)
+{
+	void (*saved_write_space)(struct sock *sk);
+	struct proto *saved_proto;
+	struct sk_psock *psock;
+	struct sock *child;
+
+	rcu_read_lock();
+	psock = sk_psock(sk);
+	if (unlikely(!psock)) {
+		rcu_read_unlock();
+		return sk->sk_prot->accept(sk, flags, err, kern);
+	}
+	saved_proto = psock->sk_proto;
+	saved_write_space = psock->saved_write_space;
+	rcu_read_unlock();
+
+	child = saved_proto->accept(sk, flags, err, kern);
+	if (!child)
+		return NULL;
+
+	/* Child must not inherit psock or its ops. */
+	rcu_assign_sk_user_data(child, NULL);
+	child->sk_prot = saved_proto;
+	child->sk_write_space = saved_write_space;
+	return child;
+}
+
 enum {
 	TCP_BPF_IPV4,
 	TCP_BPF_IPV6,
@@ -606,6 +635,7 @@  static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
 	prot[TCP_BPF_BASE].close		= tcp_bpf_close;
 	prot[TCP_BPF_BASE].recvmsg		= tcp_bpf_recvmsg;
 	prot[TCP_BPF_BASE].stream_memory_read	= tcp_bpf_stream_read;
+	prot[TCP_BPF_BASE].accept		= tcp_bpf_accept;
 
 	prot[TCP_BPF_TX]			= prot[TCP_BPF_BASE];
 	prot[TCP_BPF_TX].sendmsg		= tcp_bpf_sendmsg;