diff mbox series

[2/3] Improve example client

Message ID 20211026141445.85452-2-johannes.schrimpf@blueye.no
State Accepted
Headers show
Series [1/3] Rename swupdate-client.py to swupdate_client.py | expand

Commit Message

Johannes Schrimpf Oct. 26, 2021, 2:14 p.m. UTC
- Start websocket client at the beginning of the update, so all messages are shown
- Call requests.post in executor to not block the asyncio event loop
- Use argparse to handle command line arguments
- Handle timeout in both the upload task and the websocket task
- Use flake8/black to lint and format code
- Add support for logging, so external programs can register their own logger
- Handle json parse error in version info message

Signed-off-by: Johannes Schrimpf <johannes.schrimpf@blueye.no>
---
 examples/client/swupdate_client.py | 158 +++++++++++++++++++++--------
 1 file changed, 113 insertions(+), 45 deletions(-)
diff mbox series

Patch

diff --git a/examples/client/swupdate_client.py b/examples/client/swupdate_client.py
index ae31ec9..fc0c1d4 100755
--- a/examples/client/swupdate_client.py
+++ b/examples/client/swupdate_client.py
@@ -6,68 +6,136 @@ 
 
 import asyncio
 import json
-import os
 import requests
 import websockets
+import logging
+import string
+import argparse
 import sys
 
 
 class SWUpdater:
-    "" " Python helper class for SWUpdate " ""
-
-    url_upload = 'http://{}:{}/upload'
-    url_status = 'ws://{}:{}/ws'
-
-    def __init__ (self, path_image, host_name, port):
-        self.__image = path_image
-        self.__host_name = host_name
-        self.__port = port
-
-
-    async def wait_update_finished(self, timeout = 300):
-        print ("Wait update finished")
-        async def get_finish_messages ():
-            async with websockets.connect(self.url_status.format(self.__host_name, self.__port)) as websocket:
+    """Python helper class for SWUpdate"""
+
+    url_upload = "http://{}:{}/upload"
+    url_status = "ws://{}:{}/ws"
+
+    def __init__(self, path_image, host_name, port=8080, logger=None):
+        self._image = path_image
+        self._host_name = host_name
+        self._port = port
+        if logger is not None:
+            self._logger = logger
+        else:
+            logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+            self._logger = logging.getLogger("swupdate")
+
+    async def wait_update_finished(self):
+        self._logger.info("Waiting for messages on websocket connection")
+        try:
+            async with websockets.connect(
+                self.url_status.format(self._host_name, self._port)
+            ) as websocket:
                 while True:
-                    message = await websocket.recv()
-                    data = json.loads(message)
-
-                    if data ["type"] != "message":
+                    try:
+                        message = await websocket.recv()
+                        message = "".join(
+                            filter(lambda x: x in set(string.printable), message)
+                        )
+
+                    except Exception as err:
+                        self._logger.warning(err)
                         continue
 
-                    print (data["text"])
-                    if data ["text"] == "SWUPDATE successful !":
-                        return
+                    try:
+                        data = json.loads(message)
+                    except json.decoder.JSONDecodeError:
+                        # As of 2021.04, the version info message contains invalid json
+                        self._logger.warning(f"json parse error: {message}")
+                        continue
 
-        await asyncio.wait_for(get_finish_messages(), timeout = timeout)
+                    if data["type"] != "message":
+                        continue
 
-    def update (self, timeout = 300):
-        print ("Start uploading image...")
-        print (self.url_upload.format(self.__host_name, self.__port))
+                    self._logger.info(data["text"])
+                    if "SWUPDATE successful" in data["text"]:
+                        return True
+                    if "Installation failed" in data["text"]:
+                        return False
+
+        except Exception as err:
+            self._logger.error(err)
+            return False
+
+    def sync_upload(self, swu_file, timeout):
+        return requests.post(
+            self.url_upload.format(self._host_name, self._port),
+            files={"file": swu_file},
+            timeout=timeout,
+        )
+
+    async def upload(self, timeout):
+        self._logger.info("Start uploading image...")
         try:
-            response = requests.post(self.url_upload.format(self.__host_name, self.__port), files = { 'file':open (self.__image, 'rb') })
+            with open(self._image, "rb") as swu_file:
+                loop = asyncio.get_event_loop()
+                response = await loop.run_in_executor(
+                    None, self.sync_upload, swu_file, timeout
+                )
 
             if response.status_code != 200:
-                raise Exception ("Cannot upload software image: {}".  format (response.status_code))
-
-            print ("Software image uploaded successfully. Wait for installation to be finished...\n")
-            asyncio.sleep(10)
-            asyncio.get_event_loop().run_until_complete(self.wait_update_finished(timeout = timeout))
-
+                self._logger.error(
+                    "Cannot upload software image: {}".format(response.status_code)
+                )
+                return False
+
+            self._logger.info(
+                "Software image uploaded successfully."
+                "Wait for installation to be finished..."
+            )
+            return True
         except ValueError:
-            print("No connection to host, exit")
+            self._logger.info("No connection to host, exit")
+        except FileNotFoundError:
+            self._logger.info("swu file not found")
+        except requests.exceptions.ConnectionError as e:
+            self._logger.info("Connection Error:\n%s" % str(e))
+        return False
+
+    async def start_tasks(self, timeout):
+        ws_task = asyncio.create_task(self.wait_update_finished())
+        upload_task = asyncio.create_task(self.upload(timeout))
+
+        if not await upload_task:
+            self._logger.info("Cancelling websocket task")
+            ws_task.cancel()
+            return False
 
+        try:
+            result = await asyncio.wait_for(ws_task, timeout=timeout)
+        except asyncio.TimeoutError:
+            self._logger.info("timeout!")
+            return False
 
-if __name__ == "__main__":
-    sys.path.append (os.getcwd ())
+        return result
 
-    if len (sys.argv) == 3:
-        port = "8080"
-    elif len (sys.argv) == 4:
-        port = sys.argv[3]
-    else:
-       print ("Usage: swupdate.py <path to image> <hostname> [port]")
-       exit (1)
+    def update(self, timeout=300):
+        return asyncio.run(self.start_tasks(timeout))
 
 
-    SWUpdater (sys.argv[1], sys.argv[2], port).update ()
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("swu_file", help="Path to swu image")
+    parser.add_argument("host_name", help="Host name")
+    parser.add_argument("port", help="Port", type=int, default=8080, nargs="?")
+    parser.add_argument(
+        "--timeout",
+        help="Timeout for the whole swupdate process",
+        type=int,
+        default=300,
+        nargs="?",
+    )
+
+    args = parser.parse_args()
+    updater = SWUpdater(args.swu_file, args.host_name, args.port)
+    updater.update(timeout=args.timeout)