diff mbox series

[3/3] ipsec: Add ESP over TCP encapsulation support

Message ID E1eZcnm-0002VW-Ms@gondolin.hengli.com.au
State Awaiting Upstream, archived
Delegated to: David Miller
Headers show
Series [1/3] skbuff: Avoid sleeping in skb_send_sock_locked | expand

Commit Message

Herbert Xu Jan. 11, 2018, 1:21 p.m. UTC
This patch adds support for ESP over TCP encapsulation per RFC8229.

Most of the input processing is done in the TCP stack and not in
this patch, which is similar to UDP encapsulation.

On the output side, there are two potential levels of indirection.
Firstly all packets are fed through a tasklet in order to avoid
TCP socket lock recursion.  They're then processed directly if
the TCP socket is not owned by user-space.  If it is owned then
we'll place the packet in a queue (tp->encap_out) for processing
when the socket lock is released.

The first outbound packet will trigger a socket lockup for a
matching TCP socket.  If the TCP connection drops we will repeat
the lookup as needed.  The TCP socket is cached in the xfrm state
and is read using RCU.

Note that unlike normal IPsec packets, once we hit a TCP xfrm
state, the xfrm stack is short-circuited and its journey will
continue through the TCP stack, after which a new IPsec lookup
will be done.  This is different from how UDP encapsulation is
done.  This means that if you're doing nested IPsec then you
will need to construct the policies with this in mind.  That is,
start with a new policy whenever TCP encapsulation is done.

Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
---

 include/net/xfrm.h    |    7 +
 net/ipv4/esp4.c       |  208 ++++++++++++++++++++++++++++++++++++++++++++++++--
 net/xfrm/xfrm_input.c |   21 +++--
 net/xfrm/xfrm_state.c |    3 
 4 files changed, 228 insertions(+), 11 deletions(-)
diff mbox series

Patch

diff --git a/include/net/xfrm.h b/include/net/xfrm.h
index ae35991..3694536 100644
--- a/include/net/xfrm.h
+++ b/include/net/xfrm.h
@@ -180,6 +180,7 @@  struct xfrm_state {
 
 	/* Data for encapsulator */
 	struct xfrm_encap_tmpl	*encap;
+	struct sock __rcu	*encap_sk;
 
 	/* Data for care-of address */
 	xfrm_address_t	*coaddr;
@@ -210,6 +211,9 @@  struct xfrm_state {
 	u32			replay_maxage;
 	u32			replay_maxdiff;
 
+	/* Copy of encap_type from encap to avoid locking. */
+	u16			encap_type;
+
 	/* Replay detection notification timer */
 	struct timer_list	rtimer;
 
@@ -1570,6 +1574,9 @@  struct xfrmk_spdinfo {
 int xfrm_prepare_input(struct xfrm_state *x, struct sk_buff *skb);
 int xfrm_input(struct sk_buff *skb, int nexthdr, __be32 spi, int encap_type);
 int xfrm_input_resume(struct sk_buff *skb, int nexthdr);
+int xfrm_trans_queue_net(struct net *net, struct sk_buff *skb,
+			 int (*finish)(struct net *, struct sock *,
+				       struct sk_buff *));
 int xfrm_trans_queue(struct sk_buff *skb,
 		     int (*finish)(struct net *, struct sock *,
 				   struct sk_buff *));
diff --git a/net/ipv4/esp4.c b/net/ipv4/esp4.c
index 61fe6e4..0544e4e 100644
--- a/net/ipv4/esp4.c
+++ b/net/ipv4/esp4.c
@@ -9,13 +9,16 @@ 
 #include <net/esp.h>
 #include <linux/scatterlist.h>
 #include <linux/kernel.h>
+#include <linux/netdevice.h>
 #include <linux/pfkeyv2.h>
+#include <linux/rcupdate.h>
 #include <linux/rtnetlink.h>
 #include <linux/slab.h>
 #include <linux/spinlock.h>
 #include <linux/in6.h>
 #include <net/icmp.h>
 #include <net/protocol.h>
+#include <net/tcp.h>
 #include <net/udp.h>
 
 #include <linux/highmem.h>
@@ -30,6 +33,11 @@  struct esp_output_extra {
 	u32 esphoff;
 };
 
+struct esp_tcp_sk {
+	struct sock *sk;
+	struct rcu_head rcu;
+};
+
 #define ESP_SKB_CB(__skb) ((struct esp_skb_cb *)&((__skb)->cb[0]))
 
 static u32 esp4_get_mtu(struct xfrm_state *x, int mtu);
@@ -118,6 +126,143 @@  static void esp_ssg_unref(struct xfrm_state *x, void *tmp)
 			put_page(sg_page(sg));
 }
 
+static void esp_free_tcp_sk(struct rcu_head *head)
+{
+	struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu);
+
+	sock_put(esk->sk);
+	kfree(esk);
+}
+
+static struct sock *esp_find_tcp_sk(struct xfrm_state *x)
+{
+	struct xfrm_encap_tmpl *encap = x->encap;
+	struct esp_tcp_sk *esk;
+	__be16 sport, dport;
+	struct sock *nsk;
+	struct sock *sk;
+
+	sk = rcu_dereference(x->encap_sk);
+	if (sk && sk->sk_state == TCP_ESTABLISHED)
+		return sk;
+
+	spin_lock_bh(&x->lock);
+	sport = encap->encap_sport;
+	dport = encap->encap_dport;
+	nsk = rcu_dereference_protected(x->encap_sk,
+					lockdep_is_held(&x->lock));
+	if (sk && sk == nsk) {
+		esk = kmalloc(sizeof(*esk), GFP_ATOMIC);
+		if (!esk) {
+			spin_unlock_bh(&x->lock);
+			return ERR_PTR(-ENOMEM);
+		}
+		RCU_INIT_POINTER(x->encap_sk, NULL);
+		esk->sk = sk;
+		call_rcu(&esk->rcu, esp_free_tcp_sk);
+	}
+	spin_unlock_bh(&x->lock);
+
+	/* XXX We don't support bound_dev_if. */
+	sk = inet_lookup_established(xs_net(x), &tcp_hashinfo, x->id.daddr.a4,
+				     dport, x->props.saddr.a4, sport, 0);
+
+	if (!sk)
+		return ERR_PTR(-ENOENT);
+
+	if (!tcp_sk(sk)->encap) {
+		sock_put(sk);
+		return ERR_PTR(-EINVAL);
+	}
+
+	spin_lock_bh(&x->lock);
+	nsk = rcu_dereference_protected(x->encap_sk,
+					lockdep_is_held(&x->lock));
+	if (encap->encap_sport != sport ||
+	    encap->encap_dport != dport) {
+		sock_put(sk);
+		sk = nsk ?: ERR_PTR(-EREMCHG);
+	} else if (sk == nsk)
+		sock_put(sk);
+	else
+		rcu_assign_pointer(x->encap_sk, sk);
+	spin_unlock_bh(&x->lock);
+
+	return sk;
+}
+
+static int esp_output_tcp_encap2(struct xfrm_state *x, struct sk_buff *skb)
+{
+	struct tcp_sock *tp;
+	struct sock *sk;
+	int err;
+
+	rcu_read_lock();
+
+	sk = esp_find_tcp_sk(x);
+	err = PTR_ERR(sk);
+	if (IS_ERR(sk))
+		goto out;
+
+	err = -ENOBUFS;
+	bh_lock_sock(sk);
+	if (sock_owned_by_user(sk)) {
+		tp = tcp_sk(sk);
+		if (skb_queue_len(&tp->encap_out) >= netdev_max_backlog)
+			goto unlock_sock;
+
+		__skb_queue_tail(&tp->encap_out, skb);
+		set_bit(TCP_ESP_DEFERRED, &sk->sk_tsq_flags);
+
+		err = 0;
+		goto unlock_sock;
+	}
+
+	err = tcp_encap_output(sk, skb);
+
+unlock_sock:
+	bh_unlock_sock(sk);
+
+out:
+	rcu_read_unlock();
+
+	return err;
+}
+
+static int esp_output_tcp_encap_cb(struct net *net, struct sock *sk,
+				   struct sk_buff *skb)
+{
+	struct dst_entry *dst = skb_dst(skb);
+	struct xfrm_state *x = dst->xfrm;
+	int err;
+
+	err = esp_output_tcp_encap2(x, skb);
+
+	if (err)
+		xfrm_output_resume(skb, err);
+
+	return 0;
+}
+
+static int esp_output_tcp_encap(struct xfrm_state *x, struct sk_buff *skb)
+{
+	int err;
+
+	if (x->encap_type != TCP_ENCAP_ESPINTCP)
+		return 0;
+
+	/* Batch packets in interrupt mode to prevent TCP encap nesting. */
+	preempt_disable();
+	err = xfrm_trans_queue_net(xs_net(x), skb, esp_output_tcp_encap_cb);
+	preempt_enable();
+
+	/* EINPROGRESS just happens to do the right thing.  It
+	 * actually means that the skb has been consumed and
+	 * isn't coming back.
+	 */
+	return err ?: -EINPROGRESS;
+}
+
 static void esp_output_done(struct crypto_async_request *base, int err)
 {
 	struct sk_buff *skb = base->data;
@@ -128,6 +273,13 @@  static void esp_output_done(struct crypto_async_request *base, int err)
 	tmp = ESP_SKB_CB(skb)->tmp;
 	esp_ssg_unref(x, tmp);
 	kfree(tmp);
+
+	if (!err) {
+		err = esp_output_tcp_encap(x, skb);
+		if (err == -EINPROGRESS)
+			return;
+	}
+
 	xfrm_output_resume(skb, err);
 }
 
@@ -205,7 +357,8 @@  static void esp_output_fill_trailer(u8 *tail, int tfclen, int plen, __u8 proto)
 	tail[plen - 1] = proto;
 }
 
-static void esp_output_udp_encap(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
+static void esp_output_encap(struct xfrm_state *x, struct sk_buff *skb,
+			     struct esp_info *esp)
 {
 	int encap_type;
 	struct udphdr *uh;
@@ -213,6 +366,9 @@  static void esp_output_udp_encap(struct xfrm_state *x, struct sk_buff *skb, stru
 	__be16 sport, dport;
 	struct xfrm_encap_tmpl *encap = x->encap;
 	struct ip_esp_hdr *esph = esp->esph;
+	unsigned len;
+
+	len = skb->len + esp->tailen - skb_transport_offset(skb);
 
 	spin_lock_bh(&x->lock);
 	sport = encap->encap_sport;
@@ -220,6 +376,14 @@  static void esp_output_udp_encap(struct xfrm_state *x, struct sk_buff *skb, stru
 	encap_type = encap->encap_type;
 	spin_unlock_bh(&x->lock);
 
+	if (encap_type == TCP_ENCAP_ESPINTCP) {
+		__be16 *lenp = (void *)esph;
+
+		*lenp = htons(len);
+		esph = (struct ip_esp_hdr *)(lenp + 1);
+		goto out;
+	}
+
 	uh = (struct udphdr *)esph;
 	uh->source = sport;
 	uh->dest = dport;
@@ -240,6 +404,8 @@  static void esp_output_udp_encap(struct xfrm_state *x, struct sk_buff *skb, stru
 	}
 
 	*skb_mac_header(skb) = IPPROTO_UDP;
+
+out:
 	esp->esph = esph;
 }
 
@@ -253,9 +419,8 @@  int esp_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *
 	struct sk_buff *trailer;
 	int tailen = esp->tailen;
 
-	/* this is non-NULL only with UDP Encapsulation */
 	if (x->encap)
-		esp_output_udp_encap(x, skb, esp);
+		esp_output_encap(x, skb, esp);
 
 	if (!skb_cloned(skb)) {
 		if (tailen <= skb_tailroom(skb)) {
@@ -447,7 +612,7 @@  int esp_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *
 error_free:
 	kfree(tmp);
 error:
-	return err;
+	return err ?: esp_output_tcp_encap(x, skb);
 }
 EXPORT_SYMBOL_GPL(esp_output_tail);
 
@@ -570,7 +735,19 @@  int esp_input_done2(struct sk_buff *skb, int err)
 
 	if (x->encap) {
 		struct xfrm_encap_tmpl *encap = x->encap;
+		struct tcphdr *th = (void *)(skb_network_header(skb) + ihl);
 		struct udphdr *uh = (void *)(skb_network_header(skb) + ihl);
+		__be16 source;
+
+		switch (x->encap_type) {
+		case TCP_ENCAP_ESPINTCP:
+			source = th->source;
+			break;
+
+		default:
+			source = uh->source;
+			break;
+		}
 
 		/*
 		 * 1) if the NAT-T peer's IP or port changed then
@@ -579,11 +756,11 @@  int esp_input_done2(struct sk_buff *skb, int err)
 		 *    SRC ports.
 		 */
 		if (iph->saddr != x->props.saddr.a4 ||
-		    uh->source != encap->encap_sport) {
+		    source != encap->encap_sport) {
 			xfrm_address_t ipaddr;
 
 			ipaddr.a4 = iph->saddr;
-			km_new_mapping(x, &ipaddr, uh->source);
+			km_new_mapping(x, &ipaddr, source);
 
 			/* XXX: perhaps add an extra
 			 * policy check here, to see
@@ -762,6 +939,7 @@  static u32 esp4_get_mtu(struct xfrm_state *x, int mtu)
 	struct crypto_aead *aead = x->data;
 	u32 blksize = ALIGN(crypto_aead_blocksize(aead), 4);
 	unsigned int net_adj;
+	unsigned int props;
 
 	switch (x->props.mode) {
 	case XFRM_MODE_TRANSPORT:
@@ -775,6 +953,20 @@  static u32 esp4_get_mtu(struct xfrm_state *x, int mtu)
 		BUG();
 	}
 
+	props = x->props.header_len;
+
+	if (x->encap_type == TCP_ENCAP_ESPINTCP) {
+		struct sock *sk;
+
+		rcu_read_lock();
+
+		sk = esp_find_tcp_sk(x);
+		if (!IS_ERR(sk))
+			mtu = tcp_current_mss(sk) + sizeof(struct iphdr);
+
+		rcu_read_unlock();
+	}
+
 	return ((mtu - x->props.header_len - crypto_aead_authsize(aead) -
 		 net_adj) & ~(blksize - 1)) + net_adj - 2;
 }
@@ -979,6 +1171,8 @@  static int esp_init_state(struct xfrm_state *x)
 	if (x->encap) {
 		struct xfrm_encap_tmpl *encap = x->encap;
 
+		x->encap_type = encap->encap_type;
+
 		switch (encap->encap_type) {
 		default:
 			err = -EINVAL;
@@ -989,6 +1183,8 @@  static int esp_init_state(struct xfrm_state *x)
 		case UDP_ENCAP_ESPINUDP_NON_IKE:
 			x->props.header_len += sizeof(struct udphdr) + 2 * sizeof(u32);
 			break;
+		case TCP_ENCAP_ESPINTCP:
+			x->props.header_len += 2;
 		}
 	}
 
diff --git a/net/xfrm/xfrm_input.c b/net/xfrm/xfrm_input.c
index 444fa37..1eb0bba 100644
--- a/net/xfrm/xfrm_input.c
+++ b/net/xfrm/xfrm_input.c
@@ -27,6 +27,7 @@  struct xfrm_trans_tasklet {
 
 struct xfrm_trans_cb {
 	int (*finish)(struct net *net, struct sock *sk, struct sk_buff *skb);
+	struct net *net;
 };
 
 #define XFRM_TRANS_SKB_CB(__skb) ((struct xfrm_trans_cb *)&((__skb)->cb[0]))
@@ -493,12 +494,13 @@  static void xfrm_trans_reinject(unsigned long data)
 	skb_queue_splice_init(&trans->queue, &queue);
 
 	while ((skb = __skb_dequeue(&queue)))
-		XFRM_TRANS_SKB_CB(skb)->finish(dev_net(skb->dev), NULL, skb);
+		XFRM_TRANS_SKB_CB(skb)->finish(XFRM_TRANS_SKB_CB(skb)->net,
+					       NULL, skb);
 }
 
-int xfrm_trans_queue(struct sk_buff *skb,
-		     int (*finish)(struct net *, struct sock *,
-				   struct sk_buff *))
+int xfrm_trans_queue_net(struct net *net, struct sk_buff *skb,
+			 int (*finish)(struct net *, struct sock *,
+				       struct sk_buff *))
 {
 	struct xfrm_trans_tasklet *trans;
 
@@ -508,10 +510,19 @@  int xfrm_trans_queue(struct sk_buff *skb,
 		return -ENOBUFS;
 
 	XFRM_TRANS_SKB_CB(skb)->finish = finish;
-	skb_queue_tail(&trans->queue, skb);
+	XFRM_TRANS_SKB_CB(skb)->net = net;
+	__skb_queue_tail(&trans->queue, skb);
 	tasklet_schedule(&trans->tasklet);
 	return 0;
 }
+EXPORT_SYMBOL(xfrm_trans_queue_net);
+
+int xfrm_trans_queue(struct sk_buff *skb,
+		     int (*finish)(struct net *, struct sock *,
+				   struct sk_buff *))
+{
+	return xfrm_trans_queue_net(dev_net(skb->dev), skb, finish);
+}
 EXPORT_SYMBOL(xfrm_trans_queue);
 
 void __init xfrm_input_init(void)
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c
index 065d896..7b01d24 100644
--- a/net/xfrm/xfrm_state.c
+++ b/net/xfrm/xfrm_state.c
@@ -617,6 +617,9 @@  int __xfrm_state_delete(struct xfrm_state *x)
 		net->xfrm.state_num--;
 		spin_unlock(&net->xfrm.xfrm_state_lock);
 
+		if (x->encap_sk)
+			sock_put(rcu_dereference_raw(x->encap_sk));
+
 		xfrm_dev_state_delete(x);
 
 		/* All xfrm_state objects are created by xfrm_state_alloc.