diff mbox series

[PULL,1/5] python/qmp: allow sockets to be passed to connect()

Message ID 20230531204338.1656158-2-jsnow@redhat.com
State New
Headers show
Series [PULL,1/5] python/qmp: allow sockets to be passed to connect() | expand

Commit Message

John Snow May 31, 2023, 8:43 p.m. UTC
Allow existing sockets to be passed to connect(). The changes are pretty
minimal, and this allows for far greater flexibility in setting up
communications with an endpoint.

Signed-off-by: John Snow <jsnow@redhat.com>
Message-id: 20230517163406.2593480-2-jsnow@redhat.com
Signed-off-by: John Snow <jsnow@redhat.com>
---
 python/qemu/qmp/protocol.py | 21 +++++++++++++++------
 1 file changed, 15 insertions(+), 6 deletions(-)
diff mbox series

Patch

diff --git a/python/qemu/qmp/protocol.py b/python/qemu/qmp/protocol.py
index 22e60298d2..d534db4631 100644
--- a/python/qemu/qmp/protocol.py
+++ b/python/qemu/qmp/protocol.py
@@ -370,7 +370,7 @@  async def accept(self) -> None:
 
     @upper_half
     @require(Runstate.IDLE)
-    async def connect(self, address: SocketAddrT,
+    async def connect(self, address: Union[SocketAddrT, socket.socket],
                       ssl: Optional[SSLContext] = None) -> None:
         """
         Connect to the server and begin processing message queues.
@@ -615,7 +615,7 @@  async def _do_accept(self) -> None:
         self.logger.debug("Connection accepted.")
 
     @upper_half
-    async def _do_connect(self, address: SocketAddrT,
+    async def _do_connect(self, address: Union[SocketAddrT, socket.socket],
                           ssl: Optional[SSLContext] = None) -> None:
         """
         Acting as the transport client, initiate a connection to a server.
@@ -634,9 +634,17 @@  async def _do_connect(self, address: SocketAddrT,
         # otherwise yield.
         await asyncio.sleep(0)
 
-        self.logger.debug("Connecting to %s ...", address)
-
-        if isinstance(address, tuple):
+        if isinstance(address, socket.socket):
+            self.logger.debug("Connecting with existing socket: "
+                              "fd=%d, family=%r, type=%r",
+                              address.fileno(), address.family, address.type)
+            connect = asyncio.open_connection(
+                limit=self._limit,
+                ssl=ssl,
+                sock=address,
+            )
+        elif isinstance(address, tuple):
+            self.logger.debug("Connecting to %s ...", address)
             connect = asyncio.open_connection(
                 address[0],
                 address[1],
@@ -644,13 +652,14 @@  async def _do_connect(self, address: SocketAddrT,
                 limit=self._limit,
             )
         else:
+            self.logger.debug("Connecting to file://%s ...", address)
             connect = asyncio.open_unix_connection(
                 path=address,
                 ssl=ssl,
                 limit=self._limit,
             )
+
         self._reader, self._writer = await connect
-
         self.logger.debug("Connected.")
 
     @upper_half