diff mbox series

[swugenerator,2/3] Refactor main.py and add tests

Message ID 20230127202639.505959-3-colin.mcallister@garmin.com
State Accepted
Headers show
Series Refactor main and add initial tests | expand

Commit Message

Colin McAllister Jan. 27, 2023, 8:26 p.m. UTC
From: Colin McAllister <colinmca242@gmail.com>

Refactored main to make the main function call into a parse_args func.
This fuction moves all of the argument parsing into parser.parse_args()
to make the funtion much clearer. Code to validate the arguments were
either converted to lambdas or moved into functions. If arguments are
invalid, instead of exiting with error, either argparse or the helper
functions generate an exception, this lines up better with Python best
practices, where the stack trace can be used to debug the program.

In order to properly guarantee functionality for main.py, tests have
been added for most functions in main.py.

Signed-off-by: Colin McAllister <colinmca242@gmail.com>
---
 swugenerator/__about__.py |   2 +-
 swugenerator/main.py      | 317 +++++++++++++++++++++++---------------
 tests/test_main.py        | 246 +++++++++++++++++++++++++++++
 3 files changed, 444 insertions(+), 121 deletions(-)
 create mode 100644 tests/test_main.py
diff mbox series

Patch

diff --git a/swugenerator/__about__.py b/swugenerator/__about__.py
index 9bbfc9f..84037ad 100644
--- a/swugenerator/__about__.py
+++ b/swugenerator/__about__.py
@@ -14,7 +14,7 @@  __all__ = [
 ]
 
 __title__ = "SWUGenerator"
-__summary__ = "Generator SWU Packages for SWUpdate"
+__summary__ = "SWU Package Generator for SWUpdate"
 __uri__ = "https://github.com/sbabic/swugenerator"
 
 __version__ = "0.2"
diff --git a/swugenerator/main.py b/swugenerator/main.py
index 102e589..4531865 100644
--- a/swugenerator/main.py
+++ b/swugenerator/main.py
@@ -3,14 +3,15 @@ 
 # Copyright (C) 2022 Stefano Babic
 #
 # SPDX-License-Identifier: GPLv3
+# pylint: disable=C0114
 
 import argparse
-import codecs
 import logging
 import os
 import sys
 import textwrap
 from pathlib import Path
+from typing import List, Optional, Tuple, Union
 
 import libconf
 
@@ -18,27 +19,179 @@  from swugenerator import __about__, generator
 from swugenerator.swu_sign import SWUSignCMS, SWUSignCustom, SWUSignPKCS11, SWUSignRSA
 
 
-def extract_keys(keyfile):
+class InvalidKeyFile(ValueError):
+    """Raised when a key file is invalid"""
+
+
+class InvalidSigningOption(ValueError):
+    """Raised when an invalid signing option is passed via command line"""
+
+
+def extract_keys(keyfile: str) -> Tuple[Optional[str], Optional[str]]:
+    """Extracts encryption key and initialization vector (IV)
+
+    TODO Ensure alignment of key file with description in docs:
+    (https://sbabic.github.io/swupdate/encrypted_images.html#building-an-encrypted-swu-image)
+
+    Args:
+        keyfile (str): Path to file containing key and IV
+
+    Raises:
+        InvalidKeyFile: If key file cannot be read
+
+    Returns:
+        Tuple[Optional[str], Optional[str]]: Key and IV if found.
+            Tuple will contain None for either value if missing
+    """
     try:
-        with open(keyfile, "r") as f:
-            lines = f.readlines()
-    except IOError:
-        logging.fatal("Failed to open file with keys %s", keyfile)
-        sys.exit(1)
+        with open(keyfile, "r", encoding="utf-8") as keyfile_fd:
+            lines = keyfile_fd.readlines()
+    except IOError as error:
+        raise InvalidKeyFile(f"Failed to open key file {keyfile}") from error
 
-    key, iv = None, None
+    enc_key, init_vec = None, None
     for line in lines:
-        if "key" in line:
-            key = line.split("=")[1].rstrip("\n")
-        if "iv" in line:
-            iv = line.split("=")[1].rstrip("\n")
-    return key, iv
+        key, value = (
+            line.rstrip("\n").split("=") if len(line.split("=")) == 2 else (None, None)
+        )
+        if key == "key":
+            enc_key = value
+        if key == "iv":
+            init_vec = value
+    return enc_key, init_vec
+
+
+def parse_config_file(config_file_arg: str) -> dict:
+    """Parses configuration file to store key value pairs
+
+    Args:
+        config_file_arg (str): Path to config file
+
+    Returns:
+        dict: Key value pairs parsed from configuration file
+    """
+    config_vars = {}
+    with open(config_file_arg, "r", encoding="utf-8") as config_fd:
+        config = libconf.load(config_fd)
+        for key, keydict in config.items():
+            if key == "variables":
+                for varname, varvalue in keydict.items():
+                    logging.debug("VAR = %s VAL = %s", varname, varvalue)
+                    config_vars[varname] = varvalue
+    return config_vars
+
+
+def parse_signing_option(
+    sign_arg: str,
+) -> Union[SWUSignCMS, SWUSignRSA, SWUSignPKCS11, SWUSignCustom]:
+    """Parses signgning option passed by user. Valid options can be found below.
+
+    CMS,<private key>,<certificate used to sign>,<file with password>
+    CMS,<private key>,<certificate used to sign>
+    RSA,<private key>,<file with password>
+    RSA,<private key>
+    PKCS11,<pin>
+    CUSTOM,<custom command>
+
+    Args:
+        sign_arg (str): argument passed by user
+
+    Raises:
+        InvalidSigningOption: If option passed by user is invalid
+
+    Returns:
+        Union[SWUSignCMS, SWUSignRSA, SWUSignPKCS11, SWUSignCustom]: Signing option to use
+    """
+    sign_parms = sign_arg.split(",")
+    cmd = sign_parms[0]
+    if cmd == "CMS":
+        if len(sign_parms) not in (3, 4) or not all(sign_parms):
+            raise InvalidSigningOption(
+                "CMS requires private key, certificate, and an optional password file"
+            )
+        # Format : CMS,<private key>,<certificate used to sign>,<file with password>
+        if len(sign_parms) == 4:
+            return SWUSignCMS(sign_parms[1], sign_parms[2], sign_parms[3])
+        # Format : CMS,<private key>,<certificate used to sign>
+        return SWUSignCMS(sign_parms[1], sign_parms[2], None)
+    if cmd == "RSA":
+        if len(sign_parms) not in (2, 3) or not all(sign_parms):
+            raise InvalidSigningOption(
+                "RSA requires private key and an optional password file"
+            )
+        # Format : RSA,<private key>,<file with password>
+        if len(sign_parms) == 3:
+            return SWUSignRSA(sign_parms[1], sign_parms[2])
+        # Format : RSA,<private key>
+        return SWUSignRSA(sign_parms[1], None)
+    if cmd == "PKCS11":
+        # Format : PKCS11,<pin>
+        if len(sign_parms) != 2 or not all(sign_parms):
+            raise InvalidSigningOption("PKCS11 requires URI")
+        return SWUSignPKCS11(sign_parms[1])
+    if cmd == "CUSTOM":
+        # Format : CUSTOM,<custom command>
+        if len(sign_parms) != 2 or not all(sign_parms):
+            raise InvalidSigningOption("CUSTOM requires custom command")
+        return SWUSignCustom(sign_parms[1])
+    raise InvalidSigningOption("Unknown signing command")
+
+
+def set_log_level(arg: str) -> str:
+    """Sets log level
 
+    This is meant to be used with Argparse's type param for
+    add_argument, but this allows log level to be set when
+    argparse parses commands.
 
-def main() -> None:
-    """swugenerator main entry point."""
+    Args:
+        arg (str): Log level in string form (All caps)
 
-    # ArgumentParser {{{
+    Returns:
+        str: Returns arg to be parsed by argparse.
+            Argparse will then make sure arg is in choices
+            and return an error to the user otherwise.
+    """
+    if arg in ("DEBUG", "INFO", "ERROR", "CRITICAL"):
+        logging.basicConfig(level=logging.getLevelName(arg))
+    return arg
+
+
+def create_swu(args: argparse.Namespace) -> None:
+    """Creates SWU archive from arguments passed to SWUGenerate
+
+    Args:
+        args (argparse.Namespace): Parsed arguments to generate SWU file with
+    """
+    # Extract key and iv from encryption_key_file (Will default to '(None, None)')
+    key, init_vec = args.encryption_key_file
+
+    # Add current working directory to search path
+    args.artifactory.append(Path(os.getcwd()))
+
+    swu = generator.SWUGenerator(
+        args.sw_description,
+        args.swu_file,
+        args.config,
+        args.artifactory,
+        args.sign,
+        key,
+        init_vec,
+        args.encrypt_swdesc,
+        args.no_compress,
+        args.no_encrypt,
+        args.no_ivt,
+    )
+    swu.process()
+    swu.close()
+
+
+def parse_args(args: List[str]) -> None:
+    """Sets up arguments for swugenerator and parses commandline args
+
+    Args:
+        args (List[str]): Command line arguments
+    """
     parser = argparse.ArgumentParser(
         prog=__about__.__title__,
         description=__about__.__summary__ + " " + __about__.__version__,
@@ -49,39 +202,36 @@  def main() -> None:
     parser.add_argument(
         "-K",
         "--encryption-key-file",
+        default=(None, None),
+        type=extract_keys,
         help="<key,iv> : AES Key to encrypt artifacts",
     )
 
     parser.add_argument(
         "-n",
         "--no-compress",
-        action="store_const",
-        const=True,
-        default=False,
+        action="store_true",
         help="Do not compress files",
     )
 
     parser.add_argument(
         "-e",
         "--no-encrypt",
-        action="store_const",
-        const=True,
-        default=False,
+        action="store_true",
         help="Do not encrypt files",
     )
 
     parser.add_argument(
         "-x",
         "--no-ivt",
-        action="store_const",
-        const=True,
-        default=False,
+        action="store_true",
         help="Do not generate IV when encrypting",
     )
 
     parser.add_argument(
         "-k",
         "--sign",
+        type=parse_signing_option,
         help=textwrap.dedent(
             """\
             RSA key or certificate to sign the SWU
@@ -97,6 +247,7 @@  def main() -> None:
         "-s",
         "--sw-description",
         required=True,
+        type=lambda p: Path(p).resolve(),
         help="sw-description template",
     )
 
@@ -112,6 +263,8 @@  def main() -> None:
     parser.add_argument(
         "-a",
         "--artifactory",
+        default=[],
+        type=lambda paths: [Path(p).resolve() for p in paths.split(",")],
         help="list of directories where artifacts are searched",
     )
 
@@ -119,119 +272,43 @@  def main() -> None:
         "-o",
         "--swu-file",
         required=True,
+        type=Path,
         help="SWU output file",
     )
 
     parser.add_argument(
         "-c",
         "--config",
+        default={},
+        type=parse_config_file,
         help="configuration file",
     )
 
     parser.add_argument(
         "-l",
         "--loglevel",
+        choices=["DEBUG", "INFO", "ERROR", "CRITICAL"],
+        default="WARNING",
+        type=set_log_level,
         help="set log level, default is WARNING",
     )
-    parser.add_argument(
-        "command", metavar="command", default=[], help="command to be executed (create)"
+
+    subparsers = parser.add_subparsers(
+        title="command", help="command to be executed", required=True
     )
+    create_subparser = subparsers.add_parser("create", help="creates a SWU file")
+    create_subparser.set_defaults(func=create_swu)
 
-    args = parser.parse_args()
+    args = parser.parse_args(args)
+    args.func(args)
 
-    if args.loglevel:
-        if args.loglevel == "DEBUG":
-            logging.basicConfig(level=logging.DEBUG)
-        if args.loglevel == "INFO":
-            logging.basicConfig(level=logging.INFO)
-        if args.loglevel == "ERROR":
-            logging.basicConfig(level=logging.ERROR)
-        if args.loglevel == "CRITICAL":
-            logging.basicConfig(level=logging.CRITICAL)
 
-    # Read configuration file if any
-    config_vars = {}
-    if args.config and args.config != "":
-        logging.info("Reading configuration file %s", args.config)
-
-        with codecs.open(args.config, "r", "utf-8") as f:
-            config = libconf.load(f)
-            for key, keydict in config.items():
-                if key == "variables":
-                    for varname, varvalue in keydict.items():
-                        logging.debug("VAR = %s VAL = %s", varname, varvalue)
-                        config_vars[varname] = varvalue
-            f.close()
-
-    # Signing
-    sign_option = None
-    if args.sign:
-        sign_parms = args.sign.split(",")
-        cmd = sign_parms[0]
-        if cmd == "CMS":
-            if len(sign_parms) < 3:
-                logging.critical("CMS requires private key and certificate")
-                sys.exit(1)
-            # Format : CMS,<private key>,<certificate used to sign>,<file with password>
-            if len(sign_parms) == 4:
-                sign_option = SWUSignCMS(sign_parms[1], sign_parms[2], sign_parms[3])
-            # Format : CMS,<private key>,<certificate used to sign>
-            else:
-                sign_option = SWUSignCMS(sign_parms[1], sign_parms[2], None)
-        if cmd == "RSA":
-            if len(sign_parms) < 2:
-                logging.critical("RSA requires private key")
-                sys.exit(1)
-            # Format : RSA,<private key>,<file with password>
-            if len(sign_parms) == 3:
-                sign_option = SWUSignRSA(sign_parms[1], sign_parms[2])
-            # Format : RSA,<private key>
-            else:
-                sign_option = SWUSignRSA(sign_parms[1], None)
-        if cmd == "PKCS11":
-            # Format : PKCS11,<pin>>
-            if len(sign_parms) < 2:
-                logging.critical("PKCS11 requires URI")
-                sys.exit(1)
-            sign_option = SWUSignPKCS11(sign_parms[1])
-        if cmd == "CUSTOM":
-            # Format : PKCS11,<custom command>>
-            if len(sign_parms) < 2:
-                logging.critical("PKCS11 requires URI")
-                sys.exit(1)
-            sign_option = SWUSignCustom(sign_parms[1])
-
-    key = None
-    iv = None
-    if args.encryption_key_file:
-        key, iv = extract_keys(args.encryption_key_file)
-
-    artidirs = []
-    artidirs.append(os.getcwd())
-    if args.artifactory:
-        dirs = args.artifactory.split(",")
-        for directory in dirs:
-            deploy = Path(directory).resolve()
-            artidirs.append(deploy)
-
-    if args.command == "create":
-        swu = generator.SWUGenerator(
-            args.sw_description,
-            args.swu_file,
-            config_vars,
-            artidirs,
-            sign_option,
-            key,
-            iv,
-            args.encrypt_swdesc,
-            args.no_compress,
-            args.no_encrypt,
-            args.no_ivt,
-        )
-        swu.process()
-        swu.close()
-    else:
-        parser.error("no suitable command found: (create)")
+def main():
+    """Main entry point for SWUGenerator"""
+    # Arg parsing is in separate function
+    # to allow argument parsing to be easily
+    # tested with pytest
+    parse_args(sys.argv[1:])
 
 
 if __name__ == "__main__":
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 0000000..4fd97d0
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,246 @@ 
+# pylint: disable=C0114,C0116,W0621
+import argparse
+import libconf
+import pytest
+
+from swugenerator import main
+from swugenerator.swu_sign import SWUSignCMS, SWUSignCustom, SWUSignPKCS11, SWUSignRSA
+
+VALID_KEY = "390ad54490a4a5f53722291023c19e08ffb5c4677a59e958c96ffa6e641df040"
+VALID_IV = "d5d601bacfe13100b149177318ebc7a4"
+VALID_KEY_FILE = "valid_key.txt"
+INVALID_KEY_FILE = "invalid_key.txt"
+
+
+@pytest.fixture(scope="session")
+def test_dir(tmp_path_factory):
+    """Creates a directory to test"""
+    test_space = tmp_path_factory.mktemp("archive")
+    return test_space
+
+
+#### Key file parsing tests ####
+@pytest.fixture(scope="session")
+def valid_key_file(test_dir):
+    key_file = test_dir / VALID_KEY_FILE
+    with key_file.open("w") as key_file_fd:
+        key_file_fd.write(f"key={VALID_KEY}\niv={VALID_IV}")
+    return key_file
+
+
+@pytest.fixture(scope="session")
+def invalid_key_file(test_dir):
+    # Create invalid key file where only the key can be parsed
+    key_file = test_dir / INVALID_KEY_FILE
+    with key_file.open("w") as key_file_fd:
+        key_file_fd.write(f"key foo\nkey={VALID_KEY}\nkey bar\niv\n{VALID_IV}")
+    return key_file
+
+
+def test_extract_keys_returns_valid_tuple_from_valid_file(valid_key_file):
+    assert main.extract_keys(str(valid_key_file)) == (VALID_KEY, VALID_IV)
+
+
+def test_extract_keys_returns_none_from_key_file_thats_invalid(invalid_key_file):
+    assert main.extract_keys(str(invalid_key_file)) == (VALID_KEY, None)
+
+
+def test_extract_keys_returns_exception_from_key_file_that_dne():
+    with pytest.raises(main.InvalidKeyFile):
+        main.extract_keys("foo/bar/baz.txt")
+
+
+#### Config file parsing tests ####
+VALID_CONFIG = {"foo": 1, "bar": "test"}
+VALID_CONFIG_FILE = "valid.cfg"
+INVALID_CONFIG_FILE = "invalid.cfg"
+
+
+@pytest.fixture(scope="session")
+def valid_config_file(test_dir):
+    config_file = test_dir / VALID_CONFIG_FILE
+    with config_file.open("w") as config_file_fd:
+        config_file_fd.write(libconf.dumps({"variables": VALID_CONFIG}))
+    return config_file
+
+
+@pytest.fixture(scope="session")
+def invalid_config_file(test_dir):
+    config_file = test_dir / VALID_CONFIG_FILE
+    with config_file.open("w") as config_file_fd:
+        config_file_fd.write("{" + libconf.dumps({"variables": VALID_CONFIG}))
+    return config_file
+
+
+def test_valid_config_file_is_properly_parsed(valid_config_file):
+    assert main.parse_config_file(str(valid_config_file)) == VALID_CONFIG
+
+
+def test_invalid_config_file_throws_exception(invalid_config_file):
+    with pytest.raises(libconf.ConfigParseError):
+        main.parse_config_file(str(invalid_config_file))
+
+
+def test_missing_config_file_throws_exception():
+    with pytest.raises(FileNotFoundError):
+        main.parse_config_file("foo/bar/baz.txt")
+
+
+#### Signing option parsing tests ####
+SIGNING_TEST_PARAMETERS = [
+    ("CMS,foo,bar,baz", SWUSignCMS("foo", "bar", "baz")),
+    ("CMS,foo,bar", SWUSignCMS("foo", "bar", None)),
+    ("RSA,foo,bar", SWUSignRSA("foo", "bar")),
+    ("RSA,foo", SWUSignRSA("foo", None)),
+    ("PKCS11,foo", SWUSignPKCS11("foo")),
+    ("CUSTOM,foo", SWUSignCustom("foo")),
+]
+
+
+@pytest.mark.parametrize("arg,expected", SIGNING_TEST_PARAMETERS)
+def test_valid_siging_params_parsed_to_correct_signing_obj(arg, expected):
+    signing_option = main.parse_signing_option(arg)
+    assert type(signing_option) == type(expected)
+    assert signing_option.type == expected.type
+    assert signing_option.key == expected.key
+    assert signing_option.cert == expected.cert
+    assert signing_option.passin == expected.passin
+
+
+INVALID_SIGNING_TEST_PARAMETERS = [
+    ("CMS", "CMS requires private key, certificate, and an optional password file"),
+    ("CMS,", "CMS requires private key, certificate, and an optional password file"),
+    ("CMS,,", "CMS requires private key, certificate, and an optional password file"),
+    ("CMS,,,", "CMS requires private key, certificate, and an optional password file"),
+    ("CMS,,,,", "CMS requires private key, certificate, and an optional password file"),
+    (
+        "CMS,,foo,",
+        "CMS requires private key, certificate, and an optional password file",
+    ),
+    ("CMS,foo", "CMS requires private key, certificate, and an optional password file"),
+    (
+        "CMS,foo,bar,baz,jaz",
+        "CMS requires private key, certificate, and an optional password file",
+    ),
+    ("RSA,foo,bar,baz", "RSA requires private key and an optional password file"),
+    ("PKCS11", "PKCS11 requires URI"),
+    ("PKCS11,", "PKCS11 requires URI"),
+    ("PKCS11,,", "PKCS11 requires URI"),
+    ("PKCS11,foo,", "PKCS11 requires URI"),
+    ("CUSTOM", "CUSTOM requires custom command"),
+    ("CUSTOM,", "CUSTOM requires custom command"),
+    ("CUSTOM,,", "CUSTOM requires custom command"),
+    ("CUSTOM,foo,", "CUSTOM requires custom command"),
+    ("FOO", "Unknown signing command"),
+    ("FOO,bar,baz", "Unknown signing command"),
+]
+
+
+@pytest.mark.parametrize("arg,exception_msg", INVALID_SIGNING_TEST_PARAMETERS)
+def test_invalid_signing_params_throws_exception_with_correct_msg(arg, exception_msg):
+    with pytest.raises(main.InvalidSigningOption) as error:
+        main.parse_signing_option(arg)
+    assert exception_msg in error.value.args
+
+
+#### Parse args tests ####
+@pytest.fixture
+def mock_main_funcs(monkeypatch):
+    def mock_create_swu(*_):
+        return True
+
+    def mock_extract_keys(*_):
+        return "foo", "bar"
+
+    def mock_parse_signing_option(*_):
+        return SWUSignCMS("foo", "bar", "baz")
+
+    def mock_parse_config_file(*_):
+        return {}
+
+    def mock_argparse_error(*_):
+        raise Exception
+
+    monkeypatch.setattr(main, "create_swu", mock_create_swu)
+    monkeypatch.setattr(main, "extract_keys", mock_extract_keys)
+    monkeypatch.setattr(main, "parse_signing_option", mock_parse_signing_option)
+    monkeypatch.setattr(main, "parse_config_file", mock_parse_config_file)
+    monkeypatch.setattr(argparse.ArgumentParser, "exit", mock_argparse_error)
+
+
+VALID_COMMANDS = [
+    (["-s", "sw-description", "-o", "test.swu", "create"]),
+    (["--sw-description", "sw-description", "--swu-file", "test.swu", "create"]),
+    (
+        [
+            "-K",
+            "key.txt",
+            "-n",
+            "-e",
+            "-x",
+            "-k",
+            "CUSTOM,foo",
+            "-s",
+            "sw-description",
+            "-t",
+            "-a",
+            ".,..",
+            "-o",
+            "test.swu",
+            "-c",
+            "test.cfg",
+            "-l",
+            "DEBUG",
+            "create",
+        ]
+    ),
+    (
+        [
+            "--encryption-key-file",
+            "key.txt",
+            "--no-compress",
+            "--no-encrypt",
+            "--no-ivt",
+            "--sign",
+            "CUSTOM,foo",
+            "--sw-description",
+            "sw-description",
+            "--encrypt-swdesc",
+            "--artifactory",
+            ".,..",
+            "--swu-file",
+            "test.swu",
+            "--config",
+            "test.cfg",
+            "--loglevel",
+            "DEBUG",
+            "create",
+        ]
+    ),
+]
+
+
+@pytest.mark.parametrize("args", VALID_COMMANDS)
+def test_parsing_valid_args_doesnt_throw(args, mock_main_funcs):
+    main.parse_args(args)
+
+
+INVALID_COMMANDS = [
+    (["-s", "sw-description", "-o", "test.swu"]),
+    (["-s", "-o", "test.swu", "create"]),
+    (["-s", "sw-description", "-o", "create"]),
+    (["-s", "-o", "create"]),
+    (["-s", "-o"]),
+    (["-K", "-s", "sw-description", "-o", "test.swu", "create"]),
+    (["-k", "-s", "sw-description", "-o", "test.swu", "create"]),
+    (["-a", "sw-description", "-o", "test.swu", "create"]),
+    (["-c", "sw-description", "-o", "test.swu", "create"]),
+    (["-l", "sw-description", "-o", "test.swu", "create"]),
+    (["-l", "baz", "sw-description", "-o", "test.swu", "create"]),
+]
+
+
+@pytest.mark.parametrize("args", INVALID_COMMANDS)
+def test_parsing_invalid_args_does_throw_argparse_exception(args, mock_main_funcs):
+    with pytest.raises(Exception):
+        main.parse_args(args)