diff mbox series

[v3,2/3] python/qmp/legacy: make QEMUMonitorProtocol accept a socket

Message ID 20230111080101.969151-3-marcandre.lureau@redhat.com
State New
Headers show
Series python/qemu/machine: fix potential hang in QMP accept | expand

Commit Message

Marc-André Lureau Jan. 11, 2023, 8:01 a.m. UTC
From: Marc-André Lureau <marcandre.lureau@redhat.com>

Teach QEMUMonitorProtocol to accept an exisiting socket.

Signed-off-by: Marc-André Lureau <marcandre.lureau@redhat.com>
Reviewed-by: Daniel P. Berrangé <berrange@redhat.com>
---
 python/qemu/qmp/legacy.py | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)
diff mbox series

Patch

diff --git a/python/qemu/qmp/legacy.py b/python/qemu/qmp/legacy.py
index 1951754455..8b09ee7dbb 100644
--- a/python/qemu/qmp/legacy.py
+++ b/python/qemu/qmp/legacy.py
@@ -22,6 +22,7 @@ 
 #
 
 import asyncio
+import socket
 from types import TracebackType
 from typing import (
     Any,
@@ -69,22 +70,32 @@  class QEMUMonitorProtocol:
 
     :param address:  QEMU address, can be either a unix socket path (string)
                      or a tuple in the form ( address, port ) for a TCP
-                     connection
+                     connection or None
+    :param sock:     a socket or None
     :param server:   Act as the socket server. (See 'accept')
     :param nickname: Optional nickname used for logging.
     """
 
-    def __init__(self, address: SocketAddrT,
+    def __init__(self,
+                 address: Optional[SocketAddrT] = None,
+                 sock: Optional[socket.socket] = None,
                  server: bool = False,
                  nickname: Optional[str] = None):
 
+        assert address or sock
         self._qmp = QMPClient(nickname)
         self._aloop = asyncio.get_event_loop()
         self._address = address
+        self._sock = sock
         self._timeout: Optional[float] = None
 
         if server:
-            self._sync(self._qmp.start_server(self._address))
+            if sock:
+                assert self._sock is not None
+                self._sync(self._qmp.open_with_socket(self._sock))
+            else:
+                assert self._address is not None
+                self._sync(self._qmp.start_server(self._address))
 
     _T = TypeVar('_T')
 
@@ -139,6 +150,7 @@  def connect(self, negotiate: bool = True) -> Optional[QMPMessage]:
         :return: QMP greeting dict, or None if negotiate is false
         :raise ConnectError: on connection errors
         """
+        assert self._address is not None
         self._qmp.await_greeting = negotiate
         self._qmp.negotiate = negotiate