diff mbox series

[RFC,12/15] SUNRPC: Add FSM machinery to handle RPC_AUTH_TLS on reconnect

Message ID 165030074924.5246.5399913437403260546.stgit@oracle-102.nfsv4.dev
State New
Headers show
Series Prototype implementation of RPC-with-TLS | expand

Commit Message

Chuck Lever April 18, 2022, 4:52 p.m. UTC
Try STARTTLS with the RPC server peer as soon as a transport
connection is established.

Signed-off-by: Chuck Lever <chuck.lever@oracle.com>
---
 include/linux/sunrpc/clnt.h  |    1 -
 include/linux/sunrpc/sched.h |    1 +
 net/sunrpc/clnt.c            |   59 +++++++++++++++++++++++++++++++++++++++---
 3 files changed, 56 insertions(+), 5 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/sunrpc/clnt.h b/include/linux/sunrpc/clnt.h
index 15fd84e4c321..e10a19d136ca 100644
--- a/include/linux/sunrpc/clnt.h
+++ b/include/linux/sunrpc/clnt.h
@@ -209,7 +209,6 @@  int		rpc_call_sync(struct rpc_clnt *clnt,
 			      unsigned int flags);
 struct rpc_task *rpc_call_null(struct rpc_clnt *clnt, struct rpc_cred *cred,
 			       int flags);
-void		rpc_starttls_async(struct rpc_task *task);
 int		rpc_restart_call_prepare(struct rpc_task *);
 int		rpc_restart_call(struct rpc_task *);
 void		rpc_setbufsize(struct rpc_clnt *, unsigned int, unsigned int);
diff --git a/include/linux/sunrpc/sched.h b/include/linux/sunrpc/sched.h
index f8c09638fa69..0d1ae89a2339 100644
--- a/include/linux/sunrpc/sched.h
+++ b/include/linux/sunrpc/sched.h
@@ -139,6 +139,7 @@  struct rpc_task_setup {
 #define RPC_IS_ASYNC(t)		((t)->tk_flags & RPC_TASK_ASYNC)
 #define RPC_IS_SWAPPER(t)	((t)->tk_flags & RPC_TASK_SWAPPER)
 #define RPC_IS_CORK(t)		((t)->tk_flags & RPC_TASK_CORK)
+#define RPC_IS_TLSPROBE(t)	((t)->tk_flags & RPC_TASK_TLSCRED)
 #define RPC_IS_SOFT(t)		((t)->tk_flags & (RPC_TASK_SOFT|RPC_TASK_TIMEOUT))
 #define RPC_IS_SOFTCONN(t)	((t)->tk_flags & RPC_TASK_SOFTCONN)
 #define RPC_WAS_SENT(t)		((t)->tk_flags & RPC_TASK_SENT)
diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c
index e9a6622dba68..0506971410f7 100644
--- a/net/sunrpc/clnt.c
+++ b/net/sunrpc/clnt.c
@@ -70,6 +70,8 @@  static void	call_refresh(struct rpc_task *task);
 static void	call_refreshresult(struct rpc_task *task);
 static void	call_connect(struct rpc_task *task);
 static void	call_connect_status(struct rpc_task *task);
+static void	call_start_tls(struct rpc_task *task);
+static void	call_tls_status(struct rpc_task *task);
 
 static int	rpc_encode_header(struct rpc_task *task,
 				  struct xdr_stream *xdr);
@@ -77,6 +79,7 @@  static int	rpc_decode_header(struct rpc_task *task,
 				  struct xdr_stream *xdr);
 static int	rpc_ping(struct rpc_clnt *clnt);
 static int	rpc_starttls_sync(struct rpc_clnt *clnt);
+static void	rpc_starttls_async(struct rpc_task *task);
 static void	rpc_check_timeout(struct rpc_task *task);
 
 static void rpc_register_client(struct rpc_clnt *clnt)
@@ -2163,7 +2166,7 @@  call_connect_status(struct rpc_task *task)
 	rpc_call_rpcerror(task, status);
 	return;
 out_next:
-	task->tk_action = call_transmit;
+	task->tk_action = call_start_tls;
 	return;
 out_retry:
 	/* Check for timeouts before looping back to call_bind */
@@ -2171,6 +2174,53 @@  call_connect_status(struct rpc_task *task)
 	rpc_check_timeout(task);
 }
 
+static void
+call_start_tls(struct rpc_task *task)
+{
+	struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt;
+	struct rpc_clnt *clnt = task->tk_client;
+
+	task->tk_action = call_transmit;
+	if (RPC_IS_TLSPROBE(task))
+		return;
+
+	switch (clnt->cl_xprtsec_policy) {
+	case RPC_XPRTSEC_TLS:
+	case RPC_XPRTSEC_MTLS:
+		if (xprt->ops->tls_handshake_async) {
+			task->tk_action = call_tls_status;
+			rpc_starttls_async(task);
+		}
+		break;
+	default:
+		break;
+	}
+}
+
+static void
+call_tls_status(struct rpc_task *task)
+{
+	struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt;
+	struct rpc_clnt *clnt = task->tk_client;
+
+	task->tk_action = call_transmit;
+	if (!task->tk_status)
+		return;
+
+	xprt_force_disconnect(xprt);
+
+	switch (clnt->cl_xprtsec_policy) {
+	case RPC_XPRTSEC_TLS:
+	case RPC_XPRTSEC_MTLS:
+		rpc_delay(task, 5*HZ /* arbitrary */);
+		break;
+	default:
+		task->tk_action = call_bind;
+	}
+
+	rpc_check_timeout(task);
+}
+
 /*
  * 5.	Transmit the RPC request, and wait for reply
  */
@@ -2355,7 +2405,7 @@  call_status(struct rpc_task *task)
 	struct rpc_clnt	*clnt = task->tk_client;
 	int		status;
 
-	if (!task->tk_msg.rpc_proc->p_proc)
+	if (!task->tk_msg.rpc_proc->p_proc && !RPC_IS_TLSPROBE(task))
 		trace_xprt_ping(task->tk_xprt, task->tk_status);
 
 	status = task->tk_status;
@@ -2663,6 +2713,8 @@  rpc_decode_header(struct rpc_task *task, struct xdr_stream *xdr)
 
 out_msg_denied:
 	error = -EACCES;
+	if (RPC_IS_TLSPROBE(task))
+		goto out_err;
 	p = xdr_inline_decode(xdr, sizeof(*p));
 	if (!p)
 		goto out_unparsable;
@@ -2865,7 +2917,7 @@  static const struct rpc_call_ops rpc_ops_probe_tls = {
  * @task: an RPC task waiting for a TLS session
  *
  */
-void rpc_starttls_async(struct rpc_task *task)
+static void rpc_starttls_async(struct rpc_task *task)
 {
 	struct rpc_xprt *xprt = xprt_get(task->tk_xprt);
 
@@ -2885,7 +2937,6 @@  void rpc_starttls_async(struct rpc_task *task)
 		     RPC_TASK_TLSCRED | RPC_TASK_SWAPPER | RPC_TASK_CORK,
 		     &rpc_ops_probe_tls, xprt));
 }
-EXPORT_SYMBOL_GPL(rpc_starttls_async);
 
 struct rpc_cb_add_xprt_calldata {
 	struct rpc_xprt_switch *xps;