diff mbox series

[v2,09/24] python/aqmp: add AsyncProtocol.accept() method

Message ID 20210717003253.457418-10-jsnow@redhat.com
State New
Headers show
Series python: introduce Asynchronous QMP package | expand

Commit Message

John Snow July 17, 2021, 12:32 a.m. UTC
It's a little messier than connect, because it wasn't designed to accept
*precisely one* connection. Such is life.

Signed-off-by: John Snow <jsnow@redhat.com>
---
 python/qemu/aqmp/protocol.py | 89 ++++++++++++++++++++++++++++++++++--
 1 file changed, 85 insertions(+), 4 deletions(-)
diff mbox series

Patch

diff --git a/python/qemu/aqmp/protocol.py b/python/qemu/aqmp/protocol.py
index f9295546cda..99b9614ba94 100644
--- a/python/qemu/aqmp/protocol.py
+++ b/python/qemu/aqmp/protocol.py
@@ -245,6 +245,24 @@  async def runstate_changed(self) -> Runstate:
         await self._runstate_event.wait()
         return self.runstate
 
+    @upper_half
+    @require(Runstate.IDLE)
+    async def accept(self, address: Union[str, Tuple[str, int]],
+                     ssl: Optional[SSLContext] = None) -> None:
+        """
+        Accept a connection and begin processing message queues.
+
+        If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
+
+        :param address:
+            Address to listen to; UNIX socket path or TCP address/port.
+        :param ssl: SSL context to use, if any.
+
+        :raise StateError: When the `Runstate` is not `IDLE`.
+        :raise ConnectError: If a connection could not be accepted.
+        """
+        await self._new_session(address, ssl, accept=True)
+
     @upper_half
     @require(Runstate.IDLE)
     async def connect(self, address: Union[str, Tuple[str, int]],
@@ -309,7 +327,8 @@  def _set_state(self, state: Runstate) -> None:
     @upper_half
     async def _new_session(self,
                            address: Union[str, Tuple[str, int]],
-                           ssl: Optional[SSLContext] = None) -> None:
+                           ssl: Optional[SSLContext] = None,
+                           accept: bool = False) -> None:
         """
         Establish a new connection and initialize the session.
 
@@ -318,9 +337,10 @@  async def _new_session(self,
         to be set back to `IDLE`.
 
         :param address:
-            Address to connect to;
+            Address to connect to/listen on;
             UNIX socket path or TCP address/port.
         :param ssl: SSL context to use, if any.
+        :param accept: Accept a connection instead of connecting when `True`.
 
         :raise ConnectError:
             When a connection or session cannot be established.
@@ -334,7 +354,7 @@  async def _new_session(self,
 
         try:
             phase = "connection"
-            await self._establish_connection(address, ssl)
+            await self._establish_connection(address, ssl, accept)
 
             phase = "session"
             await self._establish_session()
@@ -368,6 +388,7 @@  async def _establish_connection(
             self,
             address: Union[str, Tuple[str, int]],
             ssl: Optional[SSLContext] = None,
+            accept: bool = False
     ) -> None:
         """
         Establish a new connection.
@@ -376,6 +397,7 @@  async def _establish_connection(
             Address to connect to/listen on;
             UNIX socket path or TCP address/port.
         :param ssl: SSL context to use, if any.
+        :param accept: Accept a connection instead of connecting when `True`.
         """
         assert self.runstate == Runstate.IDLE
         self._set_state(Runstate.CONNECTING)
@@ -385,7 +407,66 @@  async def _establish_connection(
         # otherwise yield.
         await asyncio.sleep(0)
 
-        await self._do_connect(address, ssl)
+        if accept:
+            await self._do_accept(address, ssl)
+        else:
+            await self._do_connect(address, ssl)
+
+    @upper_half
+    async def _do_accept(self, address: Union[str, Tuple[str, int]],
+                         ssl: Optional[SSLContext] = None) -> None:
+        """
+        Acting as the transport server, accept a single connection.
+
+        :param address:
+            Address to listen on; UNIX socket path or TCP address/port.
+        :param ssl: SSL context to use, if any.
+
+        :raise OSError: For stream-related errors.
+        """
+        self.logger.debug("Awaiting connection on %s ...", address)
+        connected = asyncio.Event()
+        server: Optional[asyncio.AbstractServer] = None
+
+        async def _client_connected_cb(reader: asyncio.StreamReader,
+                                       writer: asyncio.StreamWriter) -> None:
+            """Used to accept a single incoming connection, see below."""
+            nonlocal server
+            nonlocal connected
+
+            # A connection has been accepted; stop listening for new ones.
+            assert server is not None
+            server.close()
+            await server.wait_closed()
+            server = None
+
+            # Register this client as being connected
+            self._reader, self._writer = (reader, writer)
+
+            # Signal back: We've accepted a client!
+            connected.set()
+
+        if isinstance(address, tuple):
+            coro = asyncio.start_server(
+                _client_connected_cb,
+                host=address[0],
+                port=address[1],
+                ssl=ssl,
+                backlog=1,
+            )
+        else:
+            coro = asyncio.start_unix_server(
+                _client_connected_cb,
+                path=address,
+                ssl=ssl,
+                backlog=1,
+            )
+
+        server = await coro     # Starts listening
+        await connected.wait()  # Waits for the callback to fire (and finish)
+        assert server is None
+
+        self.logger.debug("Connection accepted.")
 
     @upper_half
     async def _do_connect(self, address: Union[str, Tuple[str, int]],