From f758d94fba753e1543b0eff4f2814b5d5d432cf0 Mon Sep 17 00:00:00 2001 From: Felix Fontein Date: Sun, 11 May 2025 18:00:11 +0200 Subject: [PATCH] Add type hints and type checking (#885) * Enable basic type checking. * Fix first errors. * Add changelog fragment. * Add types to module_utils and plugin_utils (without module backends). * Add typing hints for acme_* modules. * Add typing to X.509 certificate modules, and add more helpers. * Add typing to remaining module backends. * Add typing for action, filter, and lookup plugins. * Bump ansible-core 2.19 beta requirement for typing. * Add more typing definitions. * Add typing to some unit tests. --- antsibull-nox.toml | 9 +- changelogs/fragments/refactoring.yml | 1 + changelogs/fragments/relative-timestamps.yml | 5 + plugins/action/openssl_privatekey_pipe.py | 37 +- plugins/filter/gpg_fingerprint.py | 6 +- plugins/filter/openssl_csr_info.py | 8 +- plugins/filter/openssl_privatekey_info.py | 14 +- plugins/filter/openssl_publickey_info.py | 6 +- plugins/filter/parse_serial.py | 6 +- plugins/filter/split_pem.py | 9 +- plugins/filter/to_serial.py | 6 +- plugins/filter/x509_certificate_info.py | 8 +- plugins/filter/x509_crl_info.py | 19 +- plugins/lookup/gpg_fingerprint.py | 16 +- plugins/module_utils/acme/account.py | 76 ++-- plugins/module_utils/acme/acme.py | 205 +++++++--- .../module_utils/acme/backend_cryptography.py | 142 ++++--- .../module_utils/acme/backend_openssl_cli.py | 138 +++++-- plugins/module_utils/acme/backends.py | 94 +++-- plugins/module_utils/acme/certificate.py | 132 ++++-- plugins/module_utils/acme/certificates.py | 48 ++- plugins/module_utils/acme/challenges.py | 129 ++++-- plugins/module_utils/acme/errors.py | 38 +- plugins/module_utils/acme/io.py | 11 +- plugins/module_utils/acme/orders.py | 83 ++-- plugins/module_utils/acme/utils.py | 36 +- plugins/module_utils/argspec.py | 68 +++- plugins/module_utils/crypto/_asn1.py | 26 +- plugins/module_utils/crypto/_obj2txt.py | 2 +- plugins/module_utils/crypto/_objects.py | 6 +- .../module_utils/crypto/cryptography_crl.py | 59 ++- .../crypto/cryptography_support.py | 311 ++++++++++++--- plugins/module_utils/crypto/math.py | 34 +- .../crypto/module_backends/certificate.py | 146 ++++--- .../module_backends/certificate_acme.py | 61 +-- .../module_backends/certificate_entrust.py | 59 ++- .../module_backends/certificate_info.py | 142 ++++--- .../module_backends/certificate_ownca.py | 128 ++++-- .../module_backends/certificate_selfsigned.py | 90 +++-- .../crypto/module_backends/crl_info.py | 35 +- .../crypto/module_backends/csr.py | 252 +++++++----- .../crypto/module_backends/csr_info.py | 118 ++++-- .../crypto/module_backends/privatekey.py | 205 ++++++---- .../module_backends/privatekey_convert.py | 70 ++-- .../crypto/module_backends/privatekey_info.py | 146 ++++--- .../crypto/module_backends/publickey_info.py | 88 ++-- plugins/module_utils/crypto/openssh.py | 3 + plugins/module_utils/crypto/pem.py | 21 +- plugins/module_utils/crypto/support.py | 205 +++++++--- plugins/module_utils/cryptography_dep.py | 31 +- plugins/module_utils/ecs/api.py | 9 +- plugins/module_utils/gnupg/cli.py | 10 +- plugins/module_utils/io.py | 20 +- .../module_utils/openssh/backends/common.py | 187 ++++++--- .../openssh/backends/keypair_backend.py | 142 ++++--- plugins/module_utils/openssh/certificate.py | 375 +++++++++++------- plugins/module_utils/openssh/cryptography.py | 215 ++++++---- plugins/module_utils/openssh/utils.py | 81 ++-- plugins/module_utils/serial.py | 18 +- plugins/module_utils/time.py | 61 ++- plugins/modules/acme_account.py | 48 ++- plugins/modules/acme_account_info.py | 48 ++- plugins/modules/acme_ari_info.py | 8 +- plugins/modules/acme_certificate.py | 67 +++- .../acme_certificate_deactivate_authz.py | 4 +- .../modules/acme_certificate_order_create.py | 4 +- .../acme_certificate_order_finalize.py | 11 +- .../modules/acme_certificate_order_info.py | 11 +- .../acme_certificate_order_validate.py | 28 +- .../modules/acme_certificate_renewal_info.py | 7 +- plugins/modules/acme_certificate_revoke.py | 24 +- plugins/modules/acme_challenge_cert_helper.py | 32 +- plugins/modules/acme_inspect.py | 18 +- plugins/modules/certificate_complete_chain.py | 60 ++- plugins/modules/crypto_info.py | 21 +- plugins/modules/ecs_certificate.py | 15 +- plugins/modules/ecs_domain.py | 11 +- plugins/modules/get_certificate.py | 49 +-- plugins/modules/luks_device.py | 319 ++++++++------- plugins/modules/openssh_cert.py | 111 ++++-- plugins/modules/openssh_keypair.py | 4 +- plugins/modules/openssl_csr.py | 22 +- plugins/modules/openssl_csr_info.py | 13 +- plugins/modules/openssl_csr_pipe.py | 19 +- plugins/modules/openssl_dhparam.py | 58 +-- plugins/modules/openssl_pkcs12.py | 172 +++++--- plugins/modules/openssl_privatekey.py | 26 +- plugins/modules/openssl_privatekey_convert.py | 24 +- plugins/modules/openssl_privatekey_info.py | 9 +- plugins/modules/openssl_publickey.py | 64 +-- plugins/modules/openssl_publickey_info.py | 7 +- plugins/modules/openssl_signature.py | 30 +- plugins/modules/openssl_signature_info.py | 30 +- plugins/modules/x509_certificate.py | 43 +- plugins/modules/x509_certificate_convert.py | 45 ++- plugins/modules/x509_certificate_info.py | 22 +- plugins/modules/x509_certificate_pipe.py | 29 +- plugins/modules/x509_crl.py | 227 +++++++---- plugins/modules/x509_crl_info.py | 22 +- plugins/plugin_utils/action_module.py | 93 ++--- plugins/plugin_utils/filter_module.py | 13 +- plugins/plugin_utils/gnupg.py | 15 +- tests/nox-config-flake8.ini | 2 +- tests/nox-config-mypy.ini | 19 + tests/sanity/ignore-2.17.txt | 25 ++ tests/sanity/ignore-2.18.txt | 17 + tests/sanity/ignore-2.19.txt | 10 + .../plugins/module_utils/acme/backend_data.py | 177 +++++---- .../acme/test_backend_cryptography.py | 40 +- .../acme/test_backend_openssl_cli.py | 42 +- .../module_utils/acme/test_challenges.py | 23 +- .../plugins/module_utils/acme/test_errors.py | 28 +- .../unit/plugins/module_utils/acme/test_io.py | 7 +- .../plugins/module_utils/acme/test_orders.py | 4 +- .../plugins/module_utils/acme/test_utils.py | 27 +- .../plugins/module_utils/crypto/test_asn1.py | 16 +- .../crypto/test_cryptography_support.py | 42 +- .../plugins/module_utils/crypto/test_math.py | 12 +- .../plugins/module_utils/crypto/test_pem.py | 13 +- .../module_utils/openssh/test_certificate.py | 62 +-- .../module_utils/openssh/test_cryptography.py | 57 ++- .../module_utils/openssh/test_utils.py | 50 +-- tests/unit/plugins/module_utils/test_time.py | 170 +++++--- .../unit/plugins/modules/test_luks_device.py | 216 ++++++---- 124 files changed, 4986 insertions(+), 2662 deletions(-) create mode 100644 changelogs/fragments/relative-timestamps.yml create mode 100644 tests/nox-config-mypy.ini diff --git a/antsibull-nox.toml b/antsibull-nox.toml index df87fdc6..6fc33245 100644 --- a/antsibull-nox.toml +++ b/antsibull-nox.toml @@ -18,7 +18,14 @@ run_yamllint = true yamllint_config = ".yamllint" yamllint_config_plugins = ".yamllint-docs" yamllint_config_plugins_examples = ".yamllint-examples" -run_mypy = false +run_mypy = true +mypy_ansible_core_package = "ansible-core>=2.19.0b3" +mypy_config = "tests/nox-config-mypy.ini" +mypy_extra_deps = [ + "cryptography", + "types-mock", + "types-PyYAML", +] [sessions.docs_check] validate_collection_refs="all" diff --git a/changelogs/fragments/refactoring.yml b/changelogs/fragments/refactoring.yml index a0651118..ef11dc2c 100644 --- a/changelogs/fragments/refactoring.yml +++ b/changelogs/fragments/refactoring.yml @@ -5,3 +5,4 @@ minor_changes: - "Python code modernization: remove Python 3 specific code (https://github.com/ansible-collections/community.crypto/pull/877)." - "Python code modernization: avoid unnecessary string conversion (https://github.com/ansible-collections/community.crypto/pull/880)." - "Python code modernization: avoid using ``six`` (https://github.com/ansible-collections/community.crypto/pull/884)." + - "Python code modernization: add type hints and type checking (https://github.com/ansible-collections/community.crypto/pull/885)." diff --git a/changelogs/fragments/relative-timestamps.yml b/changelogs/fragments/relative-timestamps.yml new file mode 100644 index 00000000..b5e6a716 --- /dev/null +++ b/changelogs/fragments/relative-timestamps.yml @@ -0,0 +1,5 @@ +breaking_changes: + - "The validation for relative timestamps is now more strict. A string starting with ``+`` or ``-`` must be valid, + otherwise validation will fail. In the past such strings were often silently ignored, and in many cases the code + which triggered the validation was not able to handle no result + (https://github.com/ansible-collections/community.crypto/pull/885)." diff --git a/plugins/action/openssl_privatekey_pipe.py b/plugins/action/openssl_privatekey_pipe.py index 1d84d50e..eee36e3d 100644 --- a/plugins/action/openssl_privatekey_pipe.py +++ b/plugins/action/openssl_privatekey_pipe.py @@ -5,6 +5,7 @@ from __future__ import annotations import base64 +import typing as t from ansible.module_utils.common.text.converters import to_bytes from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -19,25 +20,41 @@ from ansible_collections.community.crypto.plugins.plugin_utils.action_module imp ) +if t.TYPE_CHECKING: + from ansible_collections.community.crypto.plugins.module_utils.argspec import ( + ArgumentSpec, + ) + from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey import ( + PrivateKeyBackend, + ) + from ansible_collections.community.crypto.plugins.plugin_utils.action_module import ( + AnsibleActionModule, + ) + + class PrivateKeyModule: - def __init__(self, module, module_backend): + def __init__( + self, module: AnsibleActionModule, module_backend: PrivateKeyBackend + ) -> None: self.module = module self.module_backend = module_backend self.check_mode = module.check_mode self.changed = False - self.return_current_key = module.params["return_current_key"] + self.return_current_key: bool = module.params["return_current_key"] - if module.params["content"] is not None: - if module.params["content_base64"]: + content: str | None = module.params["content"] + content_base64: bool = module.params["content_base64"] + if content is not None: + if content_base64: try: - data = base64.b64decode(module.params["content"]) + data = base64.b64decode(content) except Exception as e: module.fail_json(msg=f"Cannot decode Base64 encoded data: {e}") else: - data = to_bytes(module.params["content"]) + data = to_bytes(content) module_backend.set_existing(data) - def generate(self, module): + def generate(self, module: AnsibleActionModule) -> None: """Generate a keypair.""" if self.module_backend.needs_regeneration(): @@ -53,7 +70,7 @@ class PrivateKeyModule: self.privatekey_bytes = privatekey_data self.changed = True - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" result = self.module_backend.dump( include_key=self.changed or self.return_current_key @@ -64,7 +81,7 @@ class PrivateKeyModule: class ActionModule(ActionModuleBase): @staticmethod - def setup_module(): + def setup_module() -> tuple[ArgumentSpec, dict[str, t.Any]]: argument_spec = get_privatekey_argument_spec() argument_spec.argument_spec.update( dict( @@ -78,7 +95,7 @@ class ActionModule(ActionModuleBase): ) @staticmethod - def run_module(module): + def run_module(module: AnsibleActionModule) -> None: module_backend = select_backend(module=module) try: diff --git a/plugins/filter/gpg_fingerprint.py b/plugins/filter/gpg_fingerprint.py index 2bc02d4b..9d7aefcb 100644 --- a/plugins/filter/gpg_fingerprint.py +++ b/plugins/filter/gpg_fingerprint.py @@ -39,6 +39,8 @@ _value: type: string """ +import typing as t + from ansible.errors import AnsibleFilterError from ansible.module_utils.common.text.converters import to_bytes from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import ( @@ -50,7 +52,7 @@ from ansible_collections.community.crypto.plugins.plugin_utils.gnupg import ( ) -def gpg_fingerprint(input): +def gpg_fingerprint(input: str | bytes) -> str: if not isinstance(input, (str, bytes)): raise AnsibleFilterError( f"The input for the community.crypto.gpg_fingerprint filter must be a string; got {type(input)} instead" @@ -65,7 +67,7 @@ def gpg_fingerprint(input): class FilterModule: """Ansible jinja2 filters""" - def filters(self): + def filters(self) -> dict[str, t.Callable]: return { "gpg_fingerprint": gpg_fingerprint, } diff --git a/plugins/filter/openssl_csr_info.py b/plugins/filter/openssl_csr_info.py index d0b4b21d..7abc0028 100644 --- a/plugins/filter/openssl_csr_info.py +++ b/plugins/filter/openssl_csr_info.py @@ -274,6 +274,8 @@ _value: sample: 12345 """ +import typing as t + from ansible.errors import AnsibleFilterError from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -287,7 +289,9 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp ) -def openssl_csr_info_filter(data, name_encoding="ignore"): +def openssl_csr_info_filter( + data: str | bytes, name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore" +) -> dict[str, t.Any]: """Extract information from X.509 PEM certificate.""" if not isinstance(data, (str, bytes)): raise AnsibleFilterError( @@ -313,7 +317,7 @@ def openssl_csr_info_filter(data, name_encoding="ignore"): class FilterModule: """Ansible jinja2 filters""" - def filters(self): + def filters(self) -> dict[str, t.Callable]: return { "openssl_csr_info": openssl_csr_info_filter, } diff --git a/plugins/filter/openssl_privatekey_info.py b/plugins/filter/openssl_privatekey_info.py index da6d8ce1..f7597c44 100644 --- a/plugins/filter/openssl_privatekey_info.py +++ b/plugins/filter/openssl_privatekey_info.py @@ -146,8 +146,10 @@ _value: type: dict """ +import typing as t + from ansible.errors import AnsibleFilterError -from ansible.module_utils.common.text.converters import to_bytes +from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( OpenSSLObjectError, ) @@ -161,8 +163,10 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp def openssl_privatekey_info_filter( - data, passphrase=None, return_private_key_data=False -): + data: str | bytes, + passphrase: str | bytes | None = None, + return_private_key_data: bool = False, +) -> dict[str, t.Any]: """Extract information from X.509 PEM certificate.""" if not isinstance(data, (str, bytes)): raise AnsibleFilterError( @@ -182,7 +186,7 @@ def openssl_privatekey_info_filter( result = get_privatekey_info( module, content=to_bytes(data), - passphrase=passphrase, + passphrase=to_text(passphrase) if passphrase is not None else None, return_private_key_data=return_private_key_data, ) result.pop("can_parse_key", None) @@ -197,7 +201,7 @@ def openssl_privatekey_info_filter( class FilterModule: """Ansible jinja2 filters""" - def filters(self): + def filters(self) -> dict[str, t.Callable]: return { "openssl_privatekey_info": openssl_privatekey_info_filter, } diff --git a/plugins/filter/openssl_publickey_info.py b/plugins/filter/openssl_publickey_info.py index 037ec7b3..940d96ad 100644 --- a/plugins/filter/openssl_publickey_info.py +++ b/plugins/filter/openssl_publickey_info.py @@ -123,6 +123,8 @@ _value: returned: When RV(_value.type=DSA) or RV(_value.type=ECC) """ +import typing as t + from ansible.errors import AnsibleFilterError from ansible.module_utils.common.text.converters import to_bytes from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -137,7 +139,7 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp ) -def openssl_publickey_info_filter(data): +def openssl_publickey_info_filter(data: str | bytes) -> dict[str, t.Any]: """Extract information from OpenSSL PEM public key.""" if not isinstance(data, (str, bytes)): raise AnsibleFilterError( @@ -156,7 +158,7 @@ def openssl_publickey_info_filter(data): class FilterModule: """Ansible jinja2 filters""" - def filters(self): + def filters(self) -> dict[str, t.Callable]: return { "openssl_publickey_info": openssl_publickey_info_filter, } diff --git a/plugins/filter/parse_serial.py b/plugins/filter/parse_serial.py index 6bfc4f72..80d5afd7 100644 --- a/plugins/filter/parse_serial.py +++ b/plugins/filter/parse_serial.py @@ -39,6 +39,8 @@ _value: type: int """ +import typing as t + from ansible.errors import AnsibleFilterError from ansible.module_utils.common.text.converters import to_native from ansible_collections.community.crypto.plugins.module_utils.serial import ( @@ -46,7 +48,7 @@ from ansible_collections.community.crypto.plugins.module_utils.serial import ( ) -def parse_serial_filter(input): +def parse_serial_filter(input: str | bytes) -> int: if not isinstance(input, (str, bytes)): raise AnsibleFilterError( f"The input for the community.crypto.parse_serial filter must be a string; got {type(input)} instead" @@ -60,7 +62,7 @@ def parse_serial_filter(input): class FilterModule: """Ansible jinja2 filters""" - def filters(self): + def filters(self) -> dict[str, t.Callable]: return { "parse_serial": parse_serial_filter, } diff --git a/plugins/filter/split_pem.py b/plugins/filter/split_pem.py index 4112e988..446db6af 100644 --- a/plugins/filter/split_pem.py +++ b/plugins/filter/split_pem.py @@ -38,6 +38,8 @@ _value: elements: string """ +import typing as t + from ansible.errors import AnsibleFilterError from ansible.module_utils.common.text.converters import to_text from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import ( @@ -45,21 +47,20 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import ) -def split_pem_filter(data): +def split_pem_filter(data: str | bytes) -> list[str]: """Split PEM file.""" if not isinstance(data, (str, bytes)): raise AnsibleFilterError( f"The community.crypto.split_pem input must be a text type, not {type(data)}" ) - data = to_text(data) - return split_pem_list(data) + return split_pem_list(to_text(data)) class FilterModule: """Ansible jinja2 filters""" - def filters(self): + def filters(self) -> dict[str, t.Callable]: return { "split_pem": split_pem_filter, } diff --git a/plugins/filter/to_serial.py b/plugins/filter/to_serial.py index 1f5b00fb..4549212d 100644 --- a/plugins/filter/to_serial.py +++ b/plugins/filter/to_serial.py @@ -39,11 +39,13 @@ _value: type: string """ +import typing as t + from ansible.errors import AnsibleFilterError from ansible_collections.community.crypto.plugins.module_utils.serial import to_serial -def to_serial_filter(input): +def to_serial_filter(input: int) -> str: if not isinstance(input, int): raise AnsibleFilterError( f"The input for the community.crypto.to_serial filter must be an integer; got {type(input)} instead" @@ -61,7 +63,7 @@ def to_serial_filter(input): class FilterModule: """Ansible jinja2 filters""" - def filters(self): + def filters(self) -> dict[str, t.Callable]: return { "to_serial": to_serial_filter, } diff --git a/plugins/filter/x509_certificate_info.py b/plugins/filter/x509_certificate_info.py index 82a3757d..dbc0a24a 100644 --- a/plugins/filter/x509_certificate_info.py +++ b/plugins/filter/x509_certificate_info.py @@ -308,6 +308,8 @@ _value: type: str """ +import typing as t + from ansible.errors import AnsibleFilterError from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -321,7 +323,9 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp ) -def x509_certificate_info_filter(data, name_encoding="ignore"): +def x509_certificate_info_filter( + data: str | bytes, name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore" +) -> dict[str, t.Any]: """Extract information from X.509 PEM certificate.""" if not isinstance(data, (str, bytes)): raise AnsibleFilterError( @@ -347,7 +351,7 @@ def x509_certificate_info_filter(data, name_encoding="ignore"): class FilterModule: """Ansible jinja2 filters""" - def filters(self): + def filters(self) -> dict[str, t.Callable]: return { "x509_certificate_info": x509_certificate_info_filter, } diff --git a/plugins/filter/x509_crl_info.py b/plugins/filter/x509_crl_info.py index 40960d16..fedeee4e 100644 --- a/plugins/filter/x509_crl_info.py +++ b/plugins/filter/x509_crl_info.py @@ -155,6 +155,7 @@ _value: import base64 import binascii +import typing as t from ansible.errors import AnsibleFilterError from ansible.module_utils.common.text.converters import to_bytes, to_native @@ -172,7 +173,11 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp ) -def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates=True): +def x509_crl_info_filter( + data: str | bytes, + name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore", + list_revoked_certificates: bool = True, +) -> dict[str, t.Any]: """Extract information from X.509 PEM certificate.""" if not isinstance(data, (str, bytes)): raise AnsibleFilterError( @@ -192,17 +197,19 @@ def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates f'The name_encoding option must be one of the values "ignore", "idna", or "unicode", not "{name_encoding}"' ) - data = to_bytes(data) - if not identify_pem_format(data): + data_bytes = to_bytes(data) + if not identify_pem_format(data_bytes): try: - data = base64.b64decode(to_native(data)) + data_bytes = base64.b64decode(to_native(data_bytes)) except (binascii.Error, TypeError, ValueError, UnicodeEncodeError): pass module = FilterModuleMock({"name_encoding": name_encoding}) try: return get_crl_info( - module, content=data, list_revoked_certificates=list_revoked_certificates + module, + content=data_bytes, + list_revoked_certificates=list_revoked_certificates, ) except OpenSSLObjectError as exc: raise AnsibleFilterError(str(exc)) @@ -211,7 +218,7 @@ def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates class FilterModule: """Ansible jinja2 filters""" - def filters(self): + def filters(self) -> dict[str, t.Callable]: return { "x509_crl_info": x509_crl_info_filter, } diff --git a/plugins/lookup/gpg_fingerprint.py b/plugins/lookup/gpg_fingerprint.py index 8393d3f9..8e439d7f 100644 --- a/plugins/lookup/gpg_fingerprint.py +++ b/plugins/lookup/gpg_fingerprint.py @@ -42,7 +42,11 @@ _value: elements: string """ +import os +import typing as t + from ansible.errors import AnsibleLookupError +from ansible.module_utils.common.text.converters import to_native from ansible.plugins.lookup import LookupBase from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import ( GPGError, @@ -54,14 +58,20 @@ from ansible_collections.community.crypto.plugins.plugin_utils.gnupg import ( class LookupModule(LookupBase): - def run(self, terms, variables=None, **kwargs): + def run(self, terms: list[t.Any], variables=None, **kwargs) -> list[str]: self.set_options(direct=kwargs) + if self._loader is None: + raise AssertionError("Contract violation: self._loader is None") try: gpg = PluginGPGRunner(cwd=self._loader.get_basedir()) result = [] - for path in terms: - result.append(get_fingerprint_from_file(gpg, path)) + for i, path in enumerate(terms): + if not isinstance(path, (str, bytes, os.PathLike)): + raise AnsibleLookupError( + f"Lookup parameter #{i} should be string or a path object, but got {type(path)}" + ) + result.append(get_fingerprint_from_file(gpg, to_native(path))) return result except GPGError as exc: raise AnsibleLookupError(str(exc)) diff --git a/plugins/module_utils/acme/account.py b/plugins/module_utils/acme/account.py index cc8af4b1..b2830e27 100644 --- a/plugins/module_utils/acme/account.py +++ b/plugins/module_utils/acme/account.py @@ -5,6 +5,8 @@ from __future__ import annotations +import typing as t + from ansible.module_utils.common._collections_compat import Mapping from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( ACMEProtocolException, @@ -12,26 +14,29 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor ) +if t.TYPE_CHECKING: + from .acme import ACMEClient + + class ACMEAccount: """ ACME account object. Allows to create new accounts, check for existence of accounts, retrieve account data. """ - def __init__(self, client): + def __init__(self, client: ACMEClient) -> None: # Set to true to enable logging of all signed requests - self._debug = False + self._debug: bool = False self.client = client def _new_reg( self, - contact=None, - agreement=None, - terms_agreed=False, - allow_creation=True, - external_account_binding=None, - ): + contact: list[str] | None = None, + terms_agreed: bool = False, + allow_creation: bool = True, + external_account_binding: dict[str, t.Any] | None = None, + ) -> tuple[bool, dict[str, t.Any] | None]: """ Registers a new ACME account. Returns a pair ``(created, data)``. Here, ``created`` is ``True`` if the account was created and @@ -63,7 +68,7 @@ class ACMEAccount: return created, data # An account does not yet exist. Try to create one next. - new_reg = {"contact": contact} + new_reg: dict[str, t.Any] = {"contact": contact} if not allow_creation: # https://tools.ietf.org/html/rfc8555#section-7.3.1 new_reg["onlyReturnExisting"] = True @@ -99,7 +104,7 @@ class ACMEAccount: self.client.module, msg="Invalid account creation reply from ACME server", info=info, - content=result, + content_json=result, ) if info["status"] == 201: @@ -152,7 +157,7 @@ class ACMEAccount: content_json=result, ) - def get_account_data(self): + def get_account_data(self) -> dict[str, t.Any] | None: """ Retrieve account information. Can only be called when the account URI is already known (such as after calling setup_account). @@ -161,7 +166,7 @@ class ACMEAccount: if self.client.account_uri is None: raise ModuleFailException("Account URI unknown") # try POST-as-GET first (draft-15 or newer) - data = None + data: dict[str, t.Any] | None = None result, info = self.client.send_signed_request( self.client.account_uri, data, fail_on_error=False ) @@ -180,7 +185,7 @@ class ACMEAccount: self.client.module, msg="Invalid account data retrieved from ACME server", info=info, - content=result, + content_json=result, ) if ( info["status"] in (400, 403) @@ -203,15 +208,34 @@ class ACMEAccount: ) return result + @t.overload def setup_account( self, - contact=None, - agreement=None, - terms_agreed=False, - allow_creation=True, - remove_account_uri_if_not_exists=False, - external_account_binding=None, - ): + contact: list[str] | None = None, + terms_agreed: bool = False, + allow_creation: t.Literal[True] = True, + remove_account_uri_if_not_exists: bool = False, + external_account_binding: dict[str, t.Any] | None = None, + ) -> tuple[bool, dict[str, t.Any]]: ... + + @t.overload + def setup_account( + self, + contact: list[str] | None = None, + terms_agreed: bool = False, + allow_creation: bool = True, + remove_account_uri_if_not_exists: bool = False, + external_account_binding: dict[str, t.Any] | None = None, + ) -> tuple[bool, dict[str, t.Any] | None]: ... + + def setup_account( + self, + contact: list[str] | None = None, + terms_agreed: bool = False, + allow_creation: bool = True, + remove_account_uri_if_not_exists: bool = False, + external_account_binding: dict[str, t.Any] | None = None, + ) -> tuple[bool, dict[str, t.Any] | None]: """ Detect or create an account on the ACME server. For ACME v1, as the only way (without knowing an account URI) to test if an @@ -253,7 +277,6 @@ class ACMEAccount: else: created, account_data = self._new_reg( contact, - agreement=agreement, terms_agreed=terms_agreed, allow_creation=allow_creation and not self.client.module.check_mode, external_account_binding=external_account_binding, @@ -267,7 +290,9 @@ class ACMEAccount: account_data = {"contact": contact or []} return created, account_data - def update_account(self, account_data, contact=None): + def update_account( + self, account_data: dict[str, t.Any], contact: list[str] | None = None + ) -> tuple[bool, dict[str, t.Any]]: """ Update an account on the ACME server. Check mode is fully respected. @@ -280,8 +305,11 @@ class ACMEAccount: https://tools.ietf.org/html/rfc8555#section-7.3.2 """ + if self.client.account_uri is None: + raise ModuleFailException("Cannot update account without account URI") + # Create request - update_request = {} + update_request: dict[str, t.Any] = {} if contact is not None and account_data.get("contact", []) != contact: update_request["contact"] = list(contact) @@ -302,7 +330,7 @@ class ACMEAccount: self.client.module, msg="Invalid account updating reply from ACME server", info=info, - content=account_data, + content_json=account_data, ) return True, account_data diff --git a/plugins/module_utils/acme/acme.py b/plugins/module_utils/acme/acme.py index b5bdba25..c14cdddb 100644 --- a/plugins/module_utils/acme/acme.py +++ b/plugins/module_utils/acme/acme.py @@ -10,6 +10,7 @@ import datetime import json import locale import time +import typing as t from ansible.module_utils.basic import missing_required_lib from ansible.module_utils.common.text.converters import to_bytes @@ -41,13 +42,24 @@ from ansible_collections.community.crypto.plugins.module_utils.argspec import ( ) +if t.TYPE_CHECKING: + import os + + from ansible.module_utils.basic import AnsibleModule + + from .account import ACMEAccount + from .backends import CertificateInformation, CryptoBackend + + # -1 usually means connection problems RETRY_STATUS_CODES = (-1, 408, 429, 503) RETRY_COUNT = 10 -def _decode_retry(module, response, info, retry_count): +def _decode_retry( + module: AnsibleModule, response: t.Any, info: dict[str, t.Any], retry_count: int +) -> bool: if info["status"] not in RETRY_STATUS_CODES: return False @@ -61,7 +73,8 @@ def _decode_retry(module, response, info, retry_count): # 429 and 503 should have a Retry-After header (https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After) try: - retry_after = min(max(1, int(info.get("retry-after"))), 60) + # TODO: use utils.parse_retry_after() + retry_after = min(max(1, int(info.get("retry-after", "10"))), 60) except (TypeError, ValueError): retry_after = 10 module.log( @@ -73,13 +86,13 @@ def _decode_retry(module, response, info, retry_count): def _assert_fetch_url_success( - module, - response, - info, - allow_redirect=False, - allow_client_error=True, - allow_server_error=True, -): + module: AnsibleModule, + response: t.Any, + info: dict[str, t.Any], + allow_redirect: bool = False, + allow_client_error: bool = True, + allow_server_error: bool = True, +) -> None: if info["status"] < 0: raise NetworkException(msg=f"Failure downloading {info['url']}, {info['msg']}") @@ -91,7 +104,9 @@ def _assert_fetch_url_success( raise ACMEProtocolException(module, info=info, response=response) -def _is_failed(info, expected_status_codes=None): +def _is_failed( + info: dict[str, t.Any], expected_status_codes: t.Iterable[int] | None = None +) -> bool: if info["status"] < 200 or info["status"] >= 400: return True if ( @@ -111,12 +126,12 @@ class ACMEDirectory: https://tools.ietf.org/html/rfc8555#section-7.1.1 """ - def __init__(self, module, account): + def __init__(self, module: AnsibleModule, client: ACMEClient) -> None: self.module = module self.directory_root = module.params["acme_directory"] self.version = module.params["acme_version"] - self.directory, dummy = account.get_request(self.directory_root, get_only=True) + self.directory, dummy = client.get_request(self.directory_root, get_only=True) self.request_timeout = module.params["request_timeout"] @@ -131,16 +146,16 @@ class ACMEDirectory: if "meta" not in self.directory: self.directory["meta"] = {} - def __getitem__(self, key): + def __getitem__(self, key: str) -> t.Any: return self.directory[key] - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return key in self.directory - def get(self, key, default_value=None): + def get(self, key: str, default_value: t.Any = None) -> t.Any: return self.directory.get(key, default_value) - def get_nonce(self, resource=None): + def get_nonce(self, resource: str | None = None) -> str: url = self.directory["newNonce"] if resource is not None: url = resource @@ -170,7 +185,7 @@ class ACMEDirectory: ) retry_count += 1 - def has_renewal_info_endpoint(self): + def has_renewal_info_endpoint(self) -> bool: return "renewalInfo" in self.directory @@ -180,7 +195,7 @@ class ACMEClient: ACME server. """ - def __init__(self, module, backend): + def __init__(self, module: AnsibleModule, backend: CryptoBackend) -> None: # Set to true to enable logging of all signed requests self._debug = False @@ -221,16 +236,22 @@ class ACMEClient: self.directory = ACMEDirectory(module, self) - def set_account_uri(self, uri): + def set_account_uri(self, uri: str) -> None: """ Set account URI. For ACME v2, it needs to be used to sending signed requests. """ self.account_uri = uri - self.account_jws_header.pop("jwk") - self.account_jws_header["kid"] = self.account_uri + if self.account_jws_header: + self.account_jws_header.pop("jwk", None) + self.account_jws_header["kid"] = self.account_uri - def parse_key(self, key_file=None, key_content=None, passphrase=None): + def parse_key( + self, + key_file: str | os.PathLike | None = None, + key_content: str | None = None, + passphrase: str | None = None, + ) -> dict[str, t.Any]: """ Parses an RSA or Elliptic Curve key file in PEM format and returns key_data. In case of an error, raises KeyParsingError. @@ -239,7 +260,13 @@ class ACMEClient: raise AssertionError("One of key_file and key_content must be specified!") return self.backend.parse_key(key_file, key_content, passphrase=passphrase) - def sign_request(self, protected, payload, key_data, encode_payload=True): + def sign_request( + self, + protected: dict[str, t.Any], + payload: str | dict[str, t.Any] | None, + key_data: dict[str, t.Any], + encode_payload: bool = True, + ) -> dict[str, t.Any]: """ Signs an ACME request. """ @@ -260,7 +287,7 @@ class ACMEClient: return self.backend.sign(payload64, protected64, key_data) - def _log(self, msg, data=None): + def _log(self, msg: str, data: t.Any = None) -> None: """ Write arguments to acme.log when logging is enabled. """ @@ -275,18 +302,49 @@ class ACMEClient: ) ) + @t.overload def send_signed_request( self, - url, - payload, - key_data=None, - jws_header=None, - parse_json_result=True, - encode_payload=True, - fail_on_error=True, - error_msg=None, - expected_status_codes=None, - ): + url: str, + payload: str | dict[str, t.Any] | None, + *, + key_data: dict[str, t.Any] | None = None, + jws_header: dict[str, t.Any] | None = None, + parse_json_result: t.Literal[True] = True, + encode_payload: bool = True, + fail_on_error: bool = True, + error_msg: str | None = None, + expected_status_codes: t.Iterable[int] | None = None, + ) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ... + + @t.overload + def send_signed_request( + self, + url: str, + payload: str | dict[str, t.Any] | None, + *, + key_data: dict[str, t.Any] | None = None, + jws_header: dict[str, t.Any] | None = None, + parse_json_result: t.Literal[False], + encode_payload: bool = True, + fail_on_error: bool = True, + error_msg: str | None = None, + expected_status_codes: t.Iterable[int] | None = None, + ) -> tuple[bytes, dict[str, t.Any]]: ... + + def send_signed_request( + self, + url: str, + payload: str | dict[str, t.Any] | None, + *, + key_data: dict[str, t.Any] | None = None, + jws_header: dict[str, t.Any] | None = None, + parse_json_result: bool = True, + encode_payload: bool = True, + fail_on_error: bool = True, + error_msg: str | None = None, + expected_status_codes: t.Iterable[int] | None = None, + ) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]: """ Sends a JWS signed HTTP POST request to the ACME server and returns the response as dictionary (if parse_json_result is True) or in raw form @@ -297,7 +355,11 @@ class ACMEClient: (https://tools.ietf.org/html/rfc8555#section-6.3) """ key_data = key_data or self.account_key_data + if key_data is None: + raise ModuleFailException("Missing key data") jws_header = jws_header or self.account_jws_header + if jws_header is None: + raise ModuleFailException("Missing JWS header") failed_tries = 0 while True: protected = copy.deepcopy(jws_header) @@ -382,16 +444,43 @@ class ACMEClient: ) return result, info + @t.overload def get_request( self, - uri, - parse_json_result=True, - headers=None, - get_only=False, - fail_on_error=True, - error_msg=None, - expected_status_codes=None, - ): + uri: str, + *, + parse_json_result: t.Literal[True] = True, + headers: dict[str, str] | None = None, + get_only: bool = False, + fail_on_error: bool = True, + error_msg: str | None = None, + expected_status_codes: t.Iterable[int] | None = None, + ) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ... + + @t.overload + def get_request( + self, + uri: str, + *, + parse_json_result: t.Literal[False], + headers: dict[str, str] | None = None, + get_only: bool = False, + fail_on_error: bool = True, + error_msg: str | None = None, + expected_status_codes: t.Iterable[int] | None = None, + ) -> tuple[bytes, dict[str, t.Any]]: ... + + def get_request( + self, + uri: str, + *, + parse_json_result: bool = True, + headers: dict[str, str] | None = None, + get_only: bool = False, + fail_on_error: bool = True, + error_msg: str | None = None, + expected_status_codes: t.Iterable[int] | None = None, + ) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]: """ Perform a GET-like request. Will try POST-as-GET for ACMEv2, with fallback to GET if server replies with a status code of 405. @@ -436,6 +525,7 @@ class ACMEClient: # Process result parsed_json_result = False + result: dict[str, t.Any] | bytes if parse_json_result: result = {} if content: @@ -445,7 +535,7 @@ class ACMEClient: parsed_json_result = True except ValueError: raise NetworkException( - f"Failed to parse the ACME response: {uri} {content}" + f"Failed to parse the ACME response: {uri} {content!r}" ) else: result = content @@ -460,19 +550,21 @@ class ACMEClient: msg=error_msg, info=info, content=content, - content_json=result if parsed_json_result else None, + content_json=( + t.cast(dict[str, t.Any], result) if parsed_json_result else None + ), ) return result, info def get_renewal_info( self, - cert_id=None, - cert_info=None, - cert_filename=None, - cert_content=None, - include_retry_after=False, - retry_after_relative_with_timezone=True, - ): + cert_id: str | None = None, + cert_info: CertificateInformation | None = None, + cert_filename: str | os.PathLike | None = None, + cert_content: str | bytes | None = None, + include_retry_after: bool = False, + retry_after_relative_with_timezone: bool = True, + ) -> dict[str, t.Any]: if not self.directory.has_renewal_info_endpoint(): raise ModuleFailException( "The ACME endpoint does not support ACME Renewal Information retrieval" @@ -504,10 +596,10 @@ class ACMEClient: def create_default_argspec( - with_account=True, - require_account_key=True, - with_certificate=False, -): + with_account: bool = True, + require_account_key: bool = True, + with_certificate: bool = False, +) -> ArgumentSpec: """ Provides default argument spec for the options documented in the acme doc fragment. """ @@ -544,7 +636,7 @@ def create_default_argspec( return result -def create_backend(module, needs_acme_v2=True): +def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoBackend: backend = module.params["select_crypto_backend"] # Backend autodetect @@ -552,6 +644,7 @@ def create_backend(module, needs_acme_v2=True): backend = "cryptography" if HAS_CURRENT_CRYPTOGRAPHY else "openssl" # Create backend object + module_backend: CryptoBackend if backend == "cryptography": if CRYPTOGRAPHY_ERROR is not None: # Either we could not import cryptography at all, or there was an unexpected error diff --git a/plugins/module_utils/acme/backend_cryptography.py b/plugins/module_utils/acme/backend_cryptography.py index 426cdd78..f9127af9 100644 --- a/plugins/module_utils/acme/backend_cryptography.py +++ b/plugins/module_utils/acme/backend_cryptography.py @@ -9,6 +9,7 @@ import base64 import binascii import os import traceback +import typing as t from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible_collections.community.crypto.plugins.module_utils.acme.backends import ( @@ -75,10 +76,19 @@ else: CRYPTOGRAPHY_MINIMAL_VERSION ) +if t.TYPE_CHECKING: + import datetime + + from ansible.module_utils.basic import AnsibleModule + + from .certificates import CertificateChain, Criterium + class CryptographyChainMatcher(ChainMatcher): @staticmethod - def _parse_key_identifier(key_identifier, name, criterium_idx, module): + def _parse_key_identifier( + key_identifier: str | None, name: str, criterium_idx: int, module: AnsibleModule + ) -> bytes | None: if key_identifier: try: return binascii.unhexlify(key_identifier.replace(":", "")) @@ -94,11 +104,11 @@ class CryptographyChainMatcher(ChainMatcher): ) return None - def __init__(self, criterium, module): + def __init__(self, criterium: Criterium, module: AnsibleModule) -> None: self.criterium = criterium self.test_certificates = criterium.test_certificates - self.subject = [] - self.issuer = [] + self.subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = [] + self.issuer: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = [] if criterium.subject: self.subject = [ (cryptography_name_to_oid(k), to_native(v)) @@ -121,8 +131,13 @@ class CryptographyChainMatcher(ChainMatcher): criterium.index, module, ) + self.module = module - def _match_subject(self, x509_subject, match_subject): + def _match_subject( + self, + x509_subject: cryptography.x509.Name, + match_subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]], + ) -> bool: for oid, value in match_subject: found = False for attribute in x509_subject: @@ -133,7 +148,7 @@ class CryptographyChainMatcher(ChainMatcher): return False return True - def match(self, certificate): + def match(self, certificate: CertificateChain) -> bool: """ Check whether an alternate chain matches the specified criterium. """ @@ -152,19 +167,22 @@ class CryptographyChainMatcher(ChainMatcher): matches = False if self.subject_key_identifier: try: - ext = x509.extensions.get_extension_for_class( + ext_ski = x509.extensions.get_extension_for_class( cryptography.x509.SubjectKeyIdentifier ) - if self.subject_key_identifier != ext.value.digest: + if self.subject_key_identifier != ext_ski.value.digest: matches = False except cryptography.x509.ExtensionNotFound: matches = False if self.authority_key_identifier: try: - ext = x509.extensions.get_extension_for_class( + ext_aki = x509.extensions.get_extension_for_class( cryptography.x509.AuthorityKeyIdentifier ) - if self.authority_key_identifier != ext.value.key_identifier: + if ( + self.authority_key_identifier + != ext_aki.value.key_identifier + ): matches = False except cryptography.x509.ExtensionNotFound: matches = False @@ -176,59 +194,68 @@ class CryptographyChainMatcher(ChainMatcher): class CryptographyBackend(CryptoBackend): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(CryptographyBackend, self).__init__( module, with_timezone=CRYPTOGRAPHY_TIMEZONE ) - def parse_key(self, key_file=None, key_content=None, passphrase=None): + def parse_key( + self, + key_file: str | os.PathLike | None = None, + key_content: str | None = None, + passphrase: str | None = None, + ) -> dict[str, t.Any]: """ Parses an RSA or Elliptic Curve key file in PEM format and returns key_data. Raises KeyParsingError in case of errors. """ # If key_content is not given, read key_file if key_content is None: - key_content = read_file(key_file) + if key_file is None: + raise KeyParsingError( + "one of key_file and key_content must be specified" + ) + b_key_content = read_file(key_file) else: - key_content = to_bytes(key_content) + b_key_content = to_bytes(key_content) # Parse key try: key = cryptography.hazmat.primitives.serialization.load_pem_private_key( - key_content, + b_key_content, password=to_bytes(passphrase) if passphrase is not None else None, ) except Exception as e: raise KeyParsingError(f"error while loading key: {e}") if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey): - pk = key.public_key().public_numbers() + rsa_pk = key.public_key().public_numbers() return { "key_obj": key, "type": "rsa", "alg": "RS256", "jwk": { "kty": "RSA", - "e": nopad_b64(convert_int_to_bytes(pk.e)), - "n": nopad_b64(convert_int_to_bytes(pk.n)), + "e": nopad_b64(convert_int_to_bytes(rsa_pk.e)), + "n": nopad_b64(convert_int_to_bytes(rsa_pk.n)), }, "hash": "sha256", } elif isinstance( key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey ): - pk = key.public_key().public_numbers() - if pk.curve.name == "secp256r1": + ec_pk = key.public_key().public_numbers() + if ec_pk.curve.name == "secp256r1": bits = 256 alg = "ES256" hashalg = "sha256" point_size = 32 curve = "P-256" - elif pk.curve.name == "secp384r1": + elif ec_pk.curve.name == "secp384r1": bits = 384 alg = "ES384" hashalg = "sha384" point_size = 48 curve = "P-384" - elif pk.curve.name == "secp521r1": + elif ec_pk.curve.name == "secp521r1": # Not yet supported on Let's Encrypt side, see # https://github.com/letsencrypt/boulder/issues/2217 bits = 521 @@ -237,7 +264,7 @@ class CryptographyBackend(CryptoBackend): point_size = 66 curve = "P-521" else: - raise KeyParsingError(f"unknown elliptic curve: {pk.curve.name}") + raise KeyParsingError(f"unknown elliptic curve: {ec_pk.curve.name}") num_bytes = (bits + 7) // 8 return { "key_obj": key, @@ -246,8 +273,8 @@ class CryptographyBackend(CryptoBackend): "jwk": { "kty": "EC", "crv": curve, - "x": nopad_b64(convert_int_to_bytes(pk.x, count=num_bytes)), - "y": nopad_b64(convert_int_to_bytes(pk.y, count=num_bytes)), + "x": nopad_b64(convert_int_to_bytes(ec_pk.x, count=num_bytes)), + "y": nopad_b64(convert_int_to_bytes(ec_pk.y, count=num_bytes)), }, "hash": hashalg, "point_size": point_size, @@ -255,8 +282,11 @@ class CryptographyBackend(CryptoBackend): else: raise KeyParsingError(f'unknown key type "{type(key)}"') - def sign(self, payload64, protected64, key_data): + def sign( + self, payload64: str, protected64: str, key_data: dict[str, t.Any] + ) -> dict[str, t.Any]: sign_payload = f"{protected64}.{payload64}".encode("utf8") + hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm] if "mac_obj" in key_data: mac = key_data["mac_obj"]() mac.update(sign_payload) @@ -292,8 +322,9 @@ class CryptographyBackend(CryptoBackend): "signature": nopad_b64(signature), } - def create_mac_key(self, alg, key): + def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]: """Create a MAC key.""" + hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm] if alg == "HS256": hashalg = cryptography.hazmat.primitives.hashes.SHA256 hashbytes = 32 @@ -324,7 +355,11 @@ class CryptographyBackend(CryptoBackend): }, } - def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None): + def get_ordered_csr_identifiers( + self, + csr_filename: str | os.PathLike | None = None, + csr_content: str | bytes | None = None, + ) -> list[tuple[str, str]]: """ Return a list of requested identifiers (CN and SANs) for the CSR. Each identifier is a pair (type, identifier), where type is either @@ -334,15 +369,19 @@ class CryptographyBackend(CryptoBackend): as the first element in the result. """ if csr_content is None: - csr_content = read_file(csr_filename) + if csr_filename is None: + raise BackendException( + "One of csr_content and csr_filename has to be provided" + ) + b_csr_content = read_file(csr_filename) else: - csr_content = to_bytes(csr_content) - csr = cryptography.x509.load_pem_x509_csr(csr_content) + b_csr_content = to_bytes(csr_content) + csr = cryptography.x509.load_pem_x509_csr(b_csr_content) identifiers = set() result = [] - def add_identifier(identifier): + def add_identifier(identifier: tuple[str, str]) -> None: if identifier in identifiers: return identifiers.add(identifier) @@ -350,7 +389,7 @@ class CryptographyBackend(CryptoBackend): for sub in csr.subject: if sub.oid == cryptography.x509.oid.NameOID.COMMON_NAME: - add_identifier(("dns", sub.value)) + add_identifier(("dns", t.cast(str, sub.value))) for extension in csr.extensions: if ( extension.oid @@ -367,7 +406,11 @@ class CryptographyBackend(CryptoBackend): ) return result - def get_csr_identifiers(self, csr_filename=None, csr_content=None): + def get_csr_identifiers( + self, + csr_filename: str | os.PathLike | None = None, + csr_content: str | bytes | bytes | None = None, + ) -> set[tuple[str, str]]: """ Return a set of requested identifiers (CN and SANs) for the CSR. Each identifier is a pair (type, identifier), where type is either @@ -379,7 +422,12 @@ class CryptographyBackend(CryptoBackend): ) ) - def get_cert_days(self, cert_filename=None, cert_content=None, now=None): + def get_cert_days( + self, + cert_filename: str | os.PathLike | None = None, + cert_content: str | bytes | None = None, + now: datetime.datetime | None = None, + ) -> int: """ Return the days the certificate in cert_filename remains valid and -1 if the file was not found. If cert_filename contains more than one @@ -398,10 +446,10 @@ class CryptographyBackend(CryptoBackend): return -1 # Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf. - cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "") + b_cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "") try: - cert = cryptography.x509.load_pem_x509_certificate(cert_content) + cert = cryptography.x509.load_pem_x509_certificate(b_cert_content) except Exception as e: if cert_filename is None: raise BackendException(f"Cannot parse certificate: {e}") @@ -413,13 +461,17 @@ class CryptographyBackend(CryptoBackend): now = add_or_remove_timezone(now, with_timezone=CRYPTOGRAPHY_TIMEZONE) return (get_not_valid_after(cert) - now).days - def create_chain_matcher(self, criterium): + def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher: """ Given a Criterium object, creates a ChainMatcher object. """ return CryptographyChainMatcher(criterium, self.module) - def get_cert_information(self, cert_filename=None, cert_content=None): + def get_cert_information( + self, + cert_filename: str | os.PathLike | None = None, + cert_content: str | bytes | None = None, + ) -> CertificateInformation: """ Return some information on a X.509 certificate as a CertificateInformation object. """ @@ -429,10 +481,10 @@ class CryptographyBackend(CryptoBackend): cert_content = to_bytes(cert_content) # Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf. - cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "") + b_cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "") try: - cert = cryptography.x509.load_pem_x509_certificate(cert_content) + cert = cryptography.x509.load_pem_x509_certificate(b_cert_content) except Exception as e: if cert_filename is None: raise BackendException(f"Cannot parse certificate: {e}") @@ -440,19 +492,19 @@ class CryptographyBackend(CryptoBackend): ski = None try: - ext = cert.extensions.get_extension_for_class( + ext_ski = cert.extensions.get_extension_for_class( cryptography.x509.SubjectKeyIdentifier ) - ski = ext.value.digest + ski = ext_ski.value.digest except cryptography.x509.ExtensionNotFound: pass aki = None try: - ext = cert.extensions.get_extension_for_class( + ext_aki = cert.extensions.get_extension_for_class( cryptography.x509.AuthorityKeyIdentifier ) - aki = ext.value.key_identifier + aki = ext_aki.value.key_identifier except cryptography.x509.ExtensionNotFound: pass diff --git a/plugins/module_utils/acme/backend_openssl_cli.py b/plugins/module_utils/acme/backend_openssl_cli.py index 3efcffee..139f8afd 100644 --- a/plugins/module_utils/acme/backend_openssl_cli.py +++ b/plugins/module_utils/acme/backend_openssl_cli.py @@ -13,6 +13,7 @@ import os import re import tempfile import traceback +import typing as t from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible_collections.community.crypto.plugins.module_utils.acme.backends import ( @@ -34,12 +35,23 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + from .certificates import Criterium + + _OPENSSL_ENVIRONMENT_UPDATE = dict(LANG="C", LC_ALL="C", LC_MESSAGES="C", LC_CTYPE="C") -def _extract_date(out_text, name, cert_filename_suffix=""): +def _extract_date( + out_text: str, name: str, cert_filename_suffix: str = "" +) -> datetime.datetime: + matcher = re.search(rf"\s+{name}\s*:\s+(.*)", out_text) + if matcher is None: + raise BackendException(f"No '{name}' date found{cert_filename_suffix}") + date_str = matcher.group(1) try: - date_str = re.search(rf"\s+{name}\s*:\s+(.*)", out_text).group(1) # For some reason Python's strptime() does not return any timezone information, # even though the information is there and a supported timezone for all supported # Python implementations (GMT). So we have to modify the datetime object by @@ -47,19 +59,40 @@ def _extract_date(out_text, name, cert_filename_suffix=""): return ensure_utc_timezone( datetime.datetime.strptime(date_str, "%b %d %H:%M:%S %Y %Z") ) - except AttributeError: - raise BackendException(f"No '{name}' date found{cert_filename_suffix}") except ValueError as exc: raise BackendException( f"Failed to parse '{name}' date{cert_filename_suffix}: {exc}" ) -def _decode_octets(octets_text): +def _decode_octets(octets_text: str) -> bytes: return binascii.unhexlify(re.sub(r"(\s|:)", "", octets_text).encode("utf-8")) -def _extract_octets(out_text, name, required=True, potential_prefixes=None): +@t.overload +def _extract_octets( + out_text: str, + name: str, + required: t.Literal[False], + potential_prefixes: t.Iterable[str] | None = None, +) -> bytes | None: ... + + +@t.overload +def _extract_octets( + out_text: str, + name: str, + required: t.Literal[True], + potential_prefixes: t.Iterable[str] | None = None, +) -> bytes: ... + + +def _extract_octets( + out_text: str, + name: str, + required: bool = True, + potential_prefixes: t.Iterable[str] | None = None, +) -> bytes | None: part = ( f"(?:{'|'.join(re.escape(pp) for pp in potential_prefixes)})" if potential_prefixes @@ -75,13 +108,20 @@ def _extract_octets(out_text, name, required=True, potential_prefixes=None): class OpenSSLCLIBackend(CryptoBackend): - def __init__(self, module, openssl_binary=None): + def __init__( + self, module: AnsibleModule, openssl_binary: str | None = None + ) -> None: super(OpenSSLCLIBackend, self).__init__(module, with_timezone=True) if openssl_binary is None: openssl_binary = module.get_bin_path("openssl", True) self.openssl_binary = openssl_binary - def parse_key(self, key_file=None, key_content=None, passphrase=None): + def parse_key( + self, + key_file: str | os.PathLike | None = None, + key_content: str | None = None, + passphrase: str | None = None, + ) -> dict[str, t.Any]: """ Parses an RSA or Elliptic Curve key file in PEM format and returns key_data. Raises KeyParsingError in case of errors. @@ -90,6 +130,10 @@ class OpenSSLCLIBackend(CryptoBackend): raise KeyParsingError("openssl backend does not support key passphrases") # If key_file is not given, but key_content, write that to a temporary file if key_file is None: + if key_content is None: + raise KeyParsingError( + "one of key_file and key_content must be specified" + ) fd, tmpsrc = tempfile.mkstemp() self.module.add_cleanup_file(tmpsrc) # Ansible will delete the file on exit f = os.fdopen(fd, "wb") @@ -108,8 +152,8 @@ class OpenSSLCLIBackend(CryptoBackend): f.close() # Parse key account_key_type = None - with open(key_file, "rt") as f: - for line in f: + with open(key_file, "rt") as fi: + for line in fi: m = re.match( r"^\s*-{5,}BEGIN\s+(EC|RSA)\s+PRIVATE\s+KEY-{5,}\s*$", line ) @@ -129,38 +173,44 @@ class OpenSSLCLIBackend(CryptoBackend): self.openssl_binary, account_key_type, "-in", - key_file, + str(key_file), "-noout", "-text", ] - rc, out, err = self.module.run_command( + rc, out, stderr = self.module.run_command( openssl_keydump_cmd, check_rc=False, environ_update=_OPENSSL_ENVIRONMENT_UPDATE, ) if rc != 0: raise BackendException( - f"Error while running {' '.join(openssl_keydump_cmd)}: {err}" + f"Error while running {' '.join(openssl_keydump_cmd)}: {stderr}" ) out_text = to_text(out, errors="surrogate_or_strict") if account_key_type == "rsa": - pub_hex = re.search( + matcher = re.search( r"modulus:\n\s+00:([a-f0-9\:\s]+?)\npublicExponent", out_text, re.MULTILINE | re.DOTALL, - ).group(1) + ) + if matcher is None: + raise KeyParsingError("cannot parse RSA key: modulus not found") + pub_hex = matcher.group(1) - pub_exp = re.search( + matcher = re.search( r"\npublicExponent: ([0-9]+)", out_text, re.MULTILINE | re.DOTALL - ).group(1) + ) + if matcher is None: + raise KeyParsingError("cannot parse RSA key: public exponent not found") + pub_exp = matcher.group(1) pub_exp = f"{int(pub_exp):x}" if len(pub_exp) % 2: pub_exp = f"0{pub_exp}" return { - "key_file": key_file, + "key_file": str(key_file), "type": "rsa", "alg": "RS256", "jwk": { @@ -223,8 +273,13 @@ class OpenSSLCLIBackend(CryptoBackend): "hash": hashalg, "point_size": point_size, } + raise KeyParsingError( + f"Internal error: unexpected account_key_type = {account_key_type!r}" + ) - def sign(self, payload64, protected64, key_data): + def sign( + self, payload64: str, protected64: str, key_data: dict[str, t.Any] + ) -> dict[str, t.Any]: sign_payload = f"{protected64}.{payload64}".encode("utf8") if key_data["type"] == "hmac": hex_key = ( @@ -284,7 +339,7 @@ class OpenSSLCLIBackend(CryptoBackend): "signature": nopad_b64(to_bytes(out)), } - def create_mac_key(self, alg, key): + def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]: """Create a MAC key.""" if alg == "HS256": hashalg = "sha256" @@ -315,14 +370,18 @@ class OpenSSLCLIBackend(CryptoBackend): } @staticmethod - def _normalize_ip(ip): + def _normalize_ip(ip: str) -> str: try: - return ipaddress.ip_address(to_text(ip)).compressed + return ipaddress.ip_address(ip).compressed except ValueError: # We do not want to error out on something IPAddress() cannot parse return ip - def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None): + def get_ordered_csr_identifiers( + self, + csr_filename: str | os.PathLike | None = None, + csr_content: str | bytes | None = None, + ) -> list[tuple[str, str]]: """ Return a list of requested identifiers (CN and SANs) for the CSR. Each identifier is a pair (type, identifier), where type is either @@ -335,13 +394,13 @@ class OpenSSLCLIBackend(CryptoBackend): data = None if csr_content is not None: filename = "/dev/stdin" - data = csr_content.encode("utf-8") + data = to_bytes(csr_content) openssl_csr_cmd = [ self.openssl_binary, "req", "-in", - filename, + str(filename), "-noout", "-text", ] @@ -360,7 +419,7 @@ class OpenSSLCLIBackend(CryptoBackend): identifiers = set() result = [] - def add_identifier(identifier): + def add_identifier(identifier: tuple[str, str]) -> None: if identifier in identifiers: return identifiers.add(identifier) @@ -389,7 +448,11 @@ class OpenSSLCLIBackend(CryptoBackend): raise BackendException(f'Found unsupported SAN identifier "{san}"') return result - def get_csr_identifiers(self, csr_filename=None, csr_content=None): + def get_csr_identifiers( + self, + csr_filename: str | os.PathLike | None = None, + csr_content: str | bytes | None = None, + ) -> set[tuple[str, str]]: """ Return a set of requested identifiers (CN and SANs) for the CSR. Each identifier is a pair (type, identifier), where type is either @@ -401,7 +464,12 @@ class OpenSSLCLIBackend(CryptoBackend): ) ) - def get_cert_days(self, cert_filename=None, cert_content=None, now=None): + def get_cert_days( + self, + cert_filename: str | os.PathLike | None = None, + cert_content: str | bytes | None = None, + now: datetime.datetime | None = None, + ) -> int: """ Return the days the certificate in cert_filename remains valid and -1 if the file was not found. If cert_filename contains more than one @@ -413,7 +481,7 @@ class OpenSSLCLIBackend(CryptoBackend): data = None if cert_content is not None: filename = "/dev/stdin" - data = cert_content.encode("utf-8") + data = to_bytes(cert_content) cert_filename_suffix = "" elif cert_filename is not None: if not os.path.exists(cert_filename): @@ -426,7 +494,7 @@ class OpenSSLCLIBackend(CryptoBackend): self.openssl_binary, "x509", "-in", - filename, + str(filename), "-noout", "-text", ] @@ -452,7 +520,7 @@ class OpenSSLCLIBackend(CryptoBackend): now = ensure_utc_timezone(now) return (not_after - now).days - def create_chain_matcher(self, criterium): + def create_chain_matcher(self, criterium: Criterium) -> t.NoReturn: """ Given a Criterium object, creates a ChainMatcher object. """ @@ -460,7 +528,11 @@ class OpenSSLCLIBackend(CryptoBackend): 'Alternate chain matching can only be used with the "cryptography" backend.' ) - def get_cert_information(self, cert_filename=None, cert_content=None): + def get_cert_information( + self, + cert_filename: str | os.PathLike | None = None, + cert_content: str | bytes | None = None, + ) -> CertificateInformation: """ Return some information on a X.509 certificate as a CertificateInformation object. """ @@ -477,7 +549,7 @@ class OpenSSLCLIBackend(CryptoBackend): self.openssl_binary, "x509", "-in", - filename, + str(filename), "-noout", "-text", ] diff --git a/plugins/module_utils/acme/backends.py b/plugins/module_utils/acme/backends.py index 395662a5..2876c74f 100644 --- a/plugins/module_utils/acme/backends.py +++ b/plugins/module_utils/acme/backends.py @@ -8,7 +8,7 @@ from __future__ import annotations import abc import datetime import re -from collections import namedtuple +import typing as t from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( BackendException, @@ -27,16 +27,20 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( ) -CertificateInformation = namedtuple( - "CertificateInformation", - ( - "not_valid_after", - "not_valid_before", - "serial_number", - "subject_key_identifier", - "authority_key_identifier", - ), -) +if t.TYPE_CHECKING: + import os + + from ansible.module_utils.basic import AnsibleModule + + from .certificates import ChainMatcher, Criterium + + +class CertificateInformation(t.NamedTuple): + not_valid_after: datetime.datetime + not_valid_before: datetime.datetime + serial_number: int + subject_key_identifier: bytes | None + authority_key_identifier: bytes | None _FRACTIONAL_MATCHER = re.compile( @@ -44,7 +48,7 @@ _FRACTIONAL_MATCHER = re.compile( ) -def _reduce_fractional_digits(timestamp_str): +def _reduce_fractional_digits(timestamp_str: str) -> str: """ Given a RFC 3339 timestamp that includes too many digits for the fractional seconds part, reduces these to at most 6. """ @@ -60,7 +64,7 @@ def _reduce_fractional_digits(timestamp_str): return f"{timestamp}{fractional}{timezone}" -def _parse_acme_timestamp(timestamp_str, with_timezone): +def _parse_acme_timestamp(timestamp_str: str, with_timezone: bool) -> datetime.datetime: """ Parses a RFC 3339 timestamp. """ @@ -86,34 +90,42 @@ def _parse_acme_timestamp(timestamp_str, with_timezone): class CryptoBackend(metaclass=abc.ABCMeta): - def __init__(self, module, with_timezone=False): + def __init__(self, module: AnsibleModule, with_timezone: bool = False) -> None: self.module = module self._with_timezone = with_timezone - def get_now(self): + def get_now(self) -> datetime.datetime: return get_now_datetime(with_timezone=self._with_timezone) - def parse_acme_timestamp(self, timestamp_str): + def parse_acme_timestamp(self, timestamp_str: str) -> datetime.datetime: # RFC 3339 (https://www.rfc-editor.org/info/rfc3339) return _parse_acme_timestamp(timestamp_str, with_timezone=self._with_timezone) - def parse_module_parameter(self, value, name): + def parse_module_parameter(self, value: str, name: str) -> datetime.datetime: try: - return get_relative_time_option( + result = get_relative_time_option( value, name, with_timezone=self._with_timezone ) + if result is None: + raise BackendException(f"Invalid value for {name}: {value!r}") + return result except OpenSSLObjectError as exc: raise BackendException(str(exc)) - def interpolate_timestamp(self, timestamp_start, timestamp_end, percentage): + def interpolate_timestamp( + self, + timestamp_start: datetime.datetime, + timestamp_end: datetime.datetime, + percentage: float, + ) -> datetime.datetime: start = get_epoch_seconds(timestamp_start) end = get_epoch_seconds(timestamp_end) return from_epoch_seconds( start + percentage * (end - start), with_timezone=self._with_timezone ) - def get_utc_datetime(self, *args, **kwargs): - kwargs_ext = dict(kwargs) + def get_utc_datetime(self, *args, **kwargs) -> datetime.datetime: + kwargs_ext: dict[str, t.Any] = dict(kwargs) if self._with_timezone and ("tzinfo" not in kwargs_ext and len(args) < 8): kwargs_ext["tzinfo"] = UTC result = datetime.datetime(*args, **kwargs_ext) @@ -122,22 +134,33 @@ class CryptoBackend(metaclass=abc.ABCMeta): return result @abc.abstractmethod - def parse_key(self, key_file=None, key_content=None, passphrase=None): + def parse_key( + self, + key_file: str | os.PathLike | None = None, + key_content: str | None = None, + passphrase: str | None = None, + ) -> dict[str, t.Any]: """ Parses an RSA or Elliptic Curve key file in PEM format and returns key_data. Raises KeyParsingError in case of errors. """ @abc.abstractmethod - def sign(self, payload64, protected64, key_data): + def sign( + self, payload64: str, protected64: str, key_data: dict[str, t.Any] + ) -> dict[str, t.Any]: pass @abc.abstractmethod - def create_mac_key(self, alg, key): + def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]: """Create a MAC key.""" @abc.abstractmethod - def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None): + def get_ordered_csr_identifiers( + self, + csr_filename: str | os.PathLike | None = None, + csr_content: str | bytes | None = None, + ) -> list[tuple[str, str]]: """ Return a list of requested identifiers (CN and SANs) for the CSR. Each identifier is a pair (type, identifier), where type is either @@ -148,7 +171,11 @@ class CryptoBackend(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def get_csr_identifiers(self, csr_filename=None, csr_content=None): + def get_csr_identifiers( + self, + csr_filename: str | os.PathLike | None = None, + csr_content: str | bytes | None = None, + ) -> set[tuple[str, str]]: """ Return a set of requested identifiers (CN and SANs) for the CSR. Each identifier is a pair (type, identifier), where type is either @@ -156,7 +183,12 @@ class CryptoBackend(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def get_cert_days(self, cert_filename=None, cert_content=None, now=None): + def get_cert_days( + self, + cert_filename: str | os.PathLike | None = None, + cert_content: str | bytes | None = None, + now: datetime.datetime | None = None, + ) -> int: """ Return the days the certificate in cert_filename remains valid and -1 if the file was not found. If cert_filename contains more than one @@ -166,13 +198,17 @@ class CryptoBackend(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def create_chain_matcher(self, criterium): + def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher: """ Given a Criterium object, creates a ChainMatcher object. """ @abc.abstractmethod - def get_cert_information(self, cert_filename=None, cert_content=None): + def get_cert_information( + self, + cert_filename: str | os.PathLike | None = None, + cert_content: str | bytes | None = None, + ) -> CertificateInformation: """ Return some information on a X.509 certificate as a CertificateInformation object. """ diff --git a/plugins/module_utils/acme/certificate.py b/plugins/module_utils/acme/certificate.py index 894b53b6..0e0075b8 100644 --- a/plugins/module_utils/acme/certificate.py +++ b/plugins/module_utils/acme/certificate.py @@ -5,6 +5,7 @@ from __future__ import annotations import os +import typing as t from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( ACMEAccount, @@ -30,6 +31,14 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + from .backends import CryptoBackend + from .certificates import ChainMatcher + from .challenges import Challenge + + class ACMECertificateClient: """ ACME v2 client class. Uses an ACME account object and a CSR to @@ -37,7 +46,13 @@ class ACMECertificateClient: certificates. """ - def __init__(self, module, backend, client=None, account=None): + def __init__( + self, + module: AnsibleModule, + backend: CryptoBackend, + client: ACMEClient | None = None, + account: ACMEAccount | None = None, + ) -> None: self.module = module self.version = module.params["acme_version"] self.csr = module.params.get("csr") @@ -66,13 +81,17 @@ class ACMECertificateClient: # Extract list of identifiers from CSR if self.csr is not None or self.csr_content is not None: - self.identifiers = self.client.backend.get_ordered_csr_identifiers( - csr_filename=self.csr, csr_content=self.csr_content + self.identifiers: list[tuple[str, str]] | None = ( + self.client.backend.get_ordered_csr_identifiers( + csr_filename=self.csr, csr_content=self.csr_content + ) ) else: self.identifiers = None - def parse_select_chain(self, select_chain): + def parse_select_chain( + self, select_chain: list[dict[str, t.Any]] | None + ) -> list[ChainMatcher]: select_chain_matcher = [] if select_chain: for criterium_idx, criterium in enumerate(select_chain): @@ -88,14 +107,16 @@ class ACMECertificateClient: ) return select_chain_matcher - def load_order(self): + def load_order(self) -> Order: if not self.order_uri: raise ModuleFailException("The order URI has not been provided") order = Order.from_url(self.client, self.order_uri) order.load_authorizations(self.client) return order - def create_order(self, replaces_cert_id=None, profile=None): + def create_order( + self, replaces_cert_id: str | None = None, profile: str | None = None + ) -> Order: """ Create a new order. """ @@ -114,31 +135,31 @@ class ACMECertificateClient: order.load_authorizations(self.client) return order - def get_challenges_data(self, order): + def get_challenges_data( + self, order: Order + ) -> tuple[list[dict[str, t.Any]], dict[str, list[str]]]: """ Get challenge details. Return a tuple of generic challenge details, and specialized DNS challenge details. """ - # Get general challenge data - data = [] + data: list[dict[str, t.Any]] = [] + data_dns: dict[str, list[str]] = {} + dns_challenge_type = "dns-01" for authz in order.authorizations.values(): # Skip valid authentications: their challenges are already valid # and do not need to be returned if authz.status == "valid": continue + challenge_data = authz.get_challenge_data(self.client) data.append( dict( identifier=authz.identifier, identifier_type=authz.identifier_type, - challenges=authz.get_challenge_data(self.client), + challenges=challenge_data, ) ) - # Get DNS challenge data - data_dns = {} - dns_challenge_type = "dns-01" - for entry in data: - dns_challenge = entry["challenges"].get(dns_challenge_type) + dns_challenge = challenge_data.get(dns_challenge_type) if dns_challenge: values = data_dns.get(dns_challenge["record"]) if values is None: @@ -147,7 +168,7 @@ class ACMECertificateClient: values.append(dns_challenge["resource_value"]) return data, data_dns - def check_that_authorizations_can_be_used(self, order): + def check_that_authorizations_can_be_used(self, order: Order) -> None: bad_authzs = [] for authz in order.authorizations.values(): if authz.status not in ("valid", "pending"): @@ -155,27 +176,32 @@ class ACMECertificateClient: f"{authz.combined_identifier} (status={authz.status!r})" ) if bad_authzs: - bad_authzs = ", ".join(sorted(bad_authzs)) + bad_authzs_str = ", ".join(sorted(bad_authzs)) raise ModuleFailException( "Some of the authorizations for the order are in a bad state, so the order" - f" can no longer be satisfied: {bad_authzs}", + f" can no longer be satisfied: {bad_authzs_str}", ) - def collect_invalid_authzs(self, order): + def collect_invalid_authzs(self, order: Order) -> list[Authorization]: return [ authz for authz in order.authorizations.values() if authz.status == "invalid" ] - def collect_pending_authzs(self, order): + def collect_pending_authzs(self, order: Order) -> list[Authorization]: return [ authz for authz in order.authorizations.values() if authz.status == "pending" ] - def call_validate(self, pending_authzs, get_challenge, wait=True): + def call_validate( + self, + pending_authzs: list[Authorization], + get_challenge: t.Callable[[Authorization], str], + wait: bool = True, + ) -> list[tuple[Authorization, str, Challenge | None]]: authzs_with_challenges_to_wait_for = [] for authz in pending_authzs: challenge_type = get_challenge(authz) @@ -185,10 +211,12 @@ class ACMECertificateClient: ) return authzs_with_challenges_to_wait_for - def wait_for_validation(self, authzs_to_wait_for): + def wait_for_validation(self, authzs_to_wait_for: list[Authorization]) -> None: wait_for_validation(authzs_to_wait_for, self.client) - def _download_alternate_chains(self, cert): + def _download_alternate_chains( + self, cert: CertificateChain + ) -> list[CertificateChain]: alternate_chains = [] for alternate in cert.alternates: try: @@ -206,13 +234,30 @@ class ACMECertificateClient: ) return alternate_chains - def download_certificate(self, order, download_all_chains=True): + @t.overload + def download_certificate( + self, order: Order, *, download_all_chains: t.Literal[True] = True + ) -> tuple[CertificateChain, list[CertificateChain]]: ... + + @t.overload + def download_certificate( + self, order: Order, *, download_all_chains: t.Literal[False] + ) -> tuple[CertificateChain, None]: ... + + @t.overload + def download_certificate( + self, order: Order, *, download_all_chains: bool = True + ) -> tuple[CertificateChain, list[CertificateChain] | None]: ... + + def download_certificate( + self, order: Order, *, download_all_chains: bool = True + ) -> tuple[CertificateChain, list[CertificateChain] | None]: """ Download certificate from a valid oder. """ if order.status != "valid": raise ModuleFailException( - f"The order must be valid, but has state {order.state!r}!" + f"The order must be valid, but has state {order.status!r}!" ) if not order.certificate_uri: @@ -232,7 +277,24 @@ class ACMECertificateClient: return cert, alternate_chains - def get_certificate(self, order, download_all_chains=True): + @t.overload + def get_certificate( + self, order: Order, *, download_all_chains: t.Literal[True] = True + ) -> tuple[CertificateChain, list[CertificateChain] | None]: ... + + @t.overload + def get_certificate( + self, order: Order, *, download_all_chains: t.Literal[False] + ) -> tuple[CertificateChain, list[CertificateChain] | None]: ... + + @t.overload + def get_certificate( + self, order: Order, *, download_all_chains: bool = True + ) -> tuple[CertificateChain, list[CertificateChain] | None]: ... + + def get_certificate( + self, order: Order, *, download_all_chains: bool = True + ) -> tuple[CertificateChain, list[CertificateChain] | None]: """ Request a new certificate and downloads it, and optionally all certificate chains. First verifies whether all authorizations are valid; if not, aborts with an error. @@ -250,7 +312,11 @@ class ACMECertificateClient: return self.download_certificate(order, download_all_chains=download_all_chains) - def find_matching_chain(self, chains, select_chain_matcher): + def find_matching_chain( + self, + chains: list[CertificateChain], + select_chain_matcher: t.Iterable[ChainMatcher], + ) -> CertificateChain | None: for criterium_idx, matcher in enumerate(select_chain_matcher): for chain in chains: if matcher.match(chain): @@ -261,9 +327,15 @@ class ACMECertificateClient: return None def write_cert_chain( - self, cert, cert_dest=None, fullchain_dest=None, chain_dest=None - ): + self, + cert: CertificateChain, + cert_dest: str | os.PathLike | None = None, + fullchain_dest: str | os.PathLike | None = None, + chain_dest: str | os.PathLike | None = None, + ) -> bool: changed = False + if cert.cert is None: + raise ValueError("Certificate is not present") if cert_dest and write_file(self.module, cert_dest, cert.cert.encode("utf8")): changed = True @@ -282,7 +354,7 @@ class ACMECertificateClient: return changed - def deactivate_authzs(self, order): + def deactivate_authzs(self, order: Order) -> None: """ Deactivates all valid authz's. Does not raise exceptions. https://community.letsencrypt.org/t/authorization-deactivation/19860/2 diff --git a/plugins/module_utils/acme/certificates.py b/plugins/module_utils/acme/certificates.py index 600db93c..204569e8 100644 --- a/plugins/module_utils/acme/certificates.py +++ b/plugins/module_utils/acme/certificates.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import typing as t from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( ModuleFailException, @@ -19,20 +20,29 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import ) +if t.TYPE_CHECKING: + from .acme import ACMEClient + + +_CertificateChain = t.TypeVar("_CertificateChain", bound="CertificateChain") + + class CertificateChain: """ Download and parse the certificate chain. https://tools.ietf.org/html/rfc8555#section-7.4.2 """ - def __init__(self, url): + def __init__(self, url: str): self.url = url - self.cert = None - self.chain = [] - self.alternates = [] + self.cert: str | None = None + self.chain: list[str] = [] + self.alternates: list[str] = [] @classmethod - def download(cls, client, url): + def download( + cls: t.Type[_CertificateChain], client: ACMEClient, url: str + ) -> _CertificateChain: content, info = client.get_request( url, parse_json_result=False, @@ -43,7 +53,7 @@ class CertificateChain: "application/pem-certificate-chain" ): raise ModuleFailException( - f"Cannot download certificate chain from {url}, as content type is not application/pem-certificate-chain: {content} (headers: {info})" + f"Cannot download certificate chain from {url}, as content type is not application/pem-certificate-chain: {content!r} (headers: {info})" ) result = cls(url) @@ -60,12 +70,12 @@ class CertificateChain: if result.cert is None: raise ModuleFailException( - f"Failed to parse certificate chain download from {url}: {content} (headers: {info})" + f"Failed to parse certificate chain download from {url}: {content!r} (headers: {info})" ) return result - def _process_links(self, client, link, relation): + def _process_links(self, client: ACMEClient, link: str, relation: str) -> None: if relation == "up": # Process link-up headers if there was no chain in reply if not self.chain: @@ -77,7 +87,9 @@ class CertificateChain: elif relation == "alternate": self.alternates.append(link) - def to_json(self): + def to_json(self) -> dict[str, bytes]: + if self.cert is None: + raise ValueError("Has no certificate") cert = self.cert.encode("utf8") chain = ("\n".join(self.chain)).encode("utf8") return { @@ -88,18 +100,22 @@ class CertificateChain: class Criterium: - def __init__(self, criterium, index=None): + def __init__(self, criterium: dict[str, t.Any], index: int): self.index = index - self.test_certificates = criterium["test_certificates"] - self.subject = criterium["subject"] - self.issuer = criterium["issuer"] - self.subject_key_identifier = criterium["subject_key_identifier"] - self.authority_key_identifier = criterium["authority_key_identifier"] + self.test_certificates: t.Literal["first", "last", "all"] = criterium[ + "test_certificates" + ] + self.subject: dict[str, t.Any] | None = criterium["subject"] + self.issuer: dict[str, t.Any] | None = criterium["issuer"] + self.subject_key_identifier: str | None = criterium["subject_key_identifier"] + self.authority_key_identifier: str | None = criterium[ + "authority_key_identifier" + ] class ChainMatcher(metaclass=abc.ABCMeta): @abc.abstractmethod - def match(self, certificate): + def match(self, certificate: CertificateChain) -> bool: """ Check whether a certificate chain (CertificateChain instance) matches. """ diff --git a/plugins/module_utils/acme/challenges.py b/plugins/module_utils/acme/challenges.py index 35ff5d84..114f4681 100644 --- a/plugins/module_utils/acme/challenges.py +++ b/plugins/module_utils/acme/challenges.py @@ -11,6 +11,7 @@ import ipaddress import json import re import time +import typing as t from ansible.module_utils.common.text.converters import to_bytes from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( @@ -23,7 +24,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import ) -def create_key_authorization(client, token): +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + from .acme import ACMEClient + + +def create_key_authorization(client: ACMEClient, token: str) -> str: """ Returns the key authorization for the given token https://tools.ietf.org/html/rfc8555#section-8.1 @@ -35,41 +42,49 @@ def create_key_authorization(client, token): return f"{token}.{thumbprint}" -def combine_identifier(identifier_type, identifier): +def combine_identifier(identifier_type: str, identifier: str) -> str: return f"{identifier_type}:{identifier}" -def normalize_combined_identifier(identifier): +def normalize_combined_identifier(identifier: str) -> str: identifier_type, identifier = split_identifier(identifier) # Normalize DNS names and IPs identifier = identifier.lower() return combine_identifier(identifier_type, identifier) -def split_identifier(identifier): +def split_identifier(identifier: str) -> tuple[str, str]: parts = identifier.split(":", 1) if len(parts) != 2: raise ModuleFailException( f'Identifier "{identifier}" is not of the form :' ) - return parts + return parts[0], parts[1] + + +_Challenge = t.TypeVar("_Challenge", bound="Challenge") class Challenge: - def __init__(self, data, url): + def __init__(self, data: dict[str, t.Any], url: str) -> None: self.data = data - self.type = data["type"] + self.type: str = data["type"] self.url = url - self.status = data["status"] - self.token = data.get("token") + self.status: str = data["status"] + self.token: str | None = data.get("token") @classmethod - def from_json(cls, client, data, url=None): + def from_json( + cls: t.Type[_Challenge], + client: ACMEClient, + data: dict[str, t.Any], + url: str | None = None, + ) -> _Challenge: return cls(data, url or data["url"]) - def call_validate(self, client): - challenge_response = {} + def call_validate(self, client: ACMEClient) -> None: + challenge_response: dict[str, t.Any] = {} client.send_signed_request( self.url, challenge_response, @@ -77,10 +92,15 @@ class Challenge: expected_status_codes=[200, 202], ) - def to_json(self): + def to_json(self) -> dict[str, t.Any]: return self.data.copy() - def get_validation_data(self, client, identifier_type, identifier): + def get_validation_data( + self, client: ACMEClient, identifier_type: str, identifier: str + ) -> dict[str, t.Any] | None: + if self.token is None: + return None + token = re.sub(r"[^A-Za-z0-9_\-]", "_", self.token) key_authorization = create_key_authorization(client, token) @@ -113,21 +133,33 @@ class Challenge: resource += "." else: resource = identifier - value = base64.b64encode( + b_value = base64.b64encode( hashlib.sha256(to_bytes(key_authorization)).digest() ) return { "resource": resource, "resource_original": combine_identifier(identifier_type, identifier), - "resource_value": value, + "resource_value": b_value, } # Unknown challenge type: ignore return None +_Authorization = t.TypeVar("_Authorization", bound="Authorization") + + class Authorization: - def _setup(self, client, data): + def __init__(self, url: str) -> None: + self.url = url + + self.data: dict[str, t.Any] | None = None + self.challenges: list[Challenge] = [] + self.status: str | None = None + self.identifier_type: str | None = None + self.identifier: str | None = None + + def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None: data["uri"] = self.url self.data = data # While 'challenges' is a required field, apparently not every CA cares @@ -145,29 +177,32 @@ class Authorization: if data.get("wildcard", False): self.identifier = f"*.{self.identifier}" - def __init__(self, url): - self.url = url - - self.data = None - self.challenges = [] - self.status = None - self.identifier_type = None - self.identifier = None - @classmethod - def from_json(cls, client, data, url): + def from_json( + cls: t.Type[_Authorization], + client: ACMEClient, + data: dict[str, t.Any], + url: str, + ) -> _Authorization: result = cls(url) result._setup(client, data) return result @classmethod - def from_url(cls, client, url): + def from_url( + cls: t.Type[_Authorization], client: ACMEClient, url: str + ) -> _Authorization: result = cls(url) result.refresh(client) return result @classmethod - def create(cls, client, identifier_type, identifier): + def create( + cls: t.Type[_Authorization], + client: ACMEClient, + identifier_type: str, + identifier: str, + ) -> _Authorization: """ Create a new authorization for the given identifier. Return the authorization object of the new authorization @@ -194,23 +229,29 @@ class Authorization: return cls.from_json(client, result, info["location"]) @property - def combined_identifier(self): + def combined_identifier(self) -> str: + if self.identifier_type is None or self.identifier is None: + raise ValueError("Data not present") return combine_identifier(self.identifier_type, self.identifier) - def to_json(self): + def to_json(self) -> dict[str, t.Any]: + if self.data is None: + raise ValueError("Data not present") return self.data.copy() - def refresh(self, client): + def refresh(self, client: ACMEClient) -> bool: result, dummy = client.get_request(self.url) changed = self.data != result self._setup(client, result) return changed - def get_challenge_data(self, client): + def get_challenge_data(self, client: ACMEClient) -> dict[str, t.Any]: """ Returns a dict with the data for all proposed (and supported) challenges of the given authorization. """ + if self.identifier_type is None or self.identifier is None: + raise ValueError("Data not present") data = {} for challenge in self.challenges: validation_data = challenge.get_validation_data( @@ -220,7 +261,7 @@ class Authorization: data[challenge.type] = validation_data return data - def raise_error(self, error_msg, module=None): + def raise_error(self, error_msg: str, module: AnsibleModule) -> t.NoReturn: """ Aborts with a specific error for a challenge. """ @@ -246,13 +287,13 @@ class Authorization: ), ) - def find_challenge(self, challenge_type): + def find_challenge(self, challenge_type: str) -> Challenge | None: for challenge in self.challenges: if challenge_type == challenge.type: return challenge return None - def wait_for_validation(self, client, callenge_type): + def wait_for_validation(self, client: ACMEClient, callenge_type: str) -> bool: while True: self.refresh(client) if self.status in ["valid", "invalid", "revoked"]: @@ -264,7 +305,9 @@ class Authorization: return self.status == "valid" - def call_validate(self, client, challenge_type, wait=True): + def call_validate( + self, client: ACMEClient, challenge_type: str, wait: bool = True + ) -> bool: """ Validate the authorization provided in the auth dict. Returns True when the validation was successful and False when it was not. @@ -281,7 +324,7 @@ class Authorization: return self.status == "valid" return self.wait_for_validation(client, challenge_type) - def can_deactivate(self): + def can_deactivate(self) -> bool: """ Deactivates this authorization. https://community.letsencrypt.org/t/authorization-deactivation/19860/2 @@ -289,14 +332,14 @@ class Authorization: """ return self.status in ("valid", "pending") - def deactivate(self, client): + def deactivate(self, client: ACMEClient) -> bool | None: """ Deactivates this authorization. https://community.letsencrypt.org/t/authorization-deactivation/19860/2 https://tools.ietf.org/html/rfc8555#section-7.5.2 """ if not self.can_deactivate(): - return + return None authz_deactivate = {"status": "deactivated"} result, info = client.send_signed_request( self.url, authz_deactivate, fail_on_error=False @@ -307,7 +350,9 @@ class Authorization: return False @classmethod - def deactivate_url(cls, client, url): + def deactivate_url( + cls: t.Type[_Authorization], client: ACMEClient, url: str + ) -> _Authorization: """ Deactivates this authorization. https://community.letsencrypt.org/t/authorization-deactivation/19860/2 @@ -322,7 +367,7 @@ class Authorization: return authz -def wait_for_validation(authzs, client): +def wait_for_validation(authzs: t.Iterable[Authorization], client: ACMEClient) -> None: """ Wait until a list of authz is valid. Fail if at least one of them is invalid or revoked. """ diff --git a/plugins/module_utils/acme/errors.py b/plugins/module_utils/acme/errors.py index dfc7cfac..5899e537 100644 --- a/plugins/module_utils/acme/errors.py +++ b/plugins/module_utils/acme/errors.py @@ -5,19 +5,24 @@ from __future__ import annotations +import typing as t from http.client import responses as http_responses from ansible.module_utils.common.text.converters import to_text -def format_http_status(status_code): +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + +def format_http_status(status_code: int) -> str: expl = http_responses.get(status_code) if not expl: return str(status_code) return f"{status_code} {expl}" -def format_error_problem(problem, subproblem_prefix=""): +def format_error_problem(problem: dict[str, t.Any], subproblem_prefix: str = "") -> str: error_type = problem.get( "type", "about:blank" ) # https://www.rfc-editor.org/rfc/rfc7807#section-3.1 @@ -32,8 +37,10 @@ def format_error_problem(problem, subproblem_prefix=""): msg = f"{msg} Subproblems:" for index, problem in enumerate(subproblems): index_str = f"{subproblem_prefix}{index}" - problem = format_error_problem(problem, subproblem_prefix=f"{index_str}.") - msg = f"{msg}\n({index_str}) {problem}" + problem_str = format_error_problem( + problem, subproblem_prefix=f"{index_str}." + ) + msg = f"{msg}\n({index_str}) {problem_str}" return msg @@ -42,25 +49,25 @@ class ModuleFailException(Exception): If raised, module.fail_json() will be called with the given parameters after cleanup. """ - def __init__(self, msg, **args): + def __init__(self, msg: str, **args: t.Any) -> None: super(ModuleFailException, self).__init__(self, msg) self.msg = msg self.module_fail_args = args - def do_fail(self, module, **arguments): + def do_fail(self, module: AnsibleModule, **arguments) -> t.NoReturn: module.fail_json(msg=self.msg, other=self.module_fail_args, **arguments) class ACMEProtocolException(ModuleFailException): def __init__( self, - module, - msg=None, - info=None, + module: AnsibleModule, + msg: str | None = None, + info: dict[str, t.Any] | None = None, response=None, - content=None, - content_json=None, - extras=None, + content: bytes | None = None, + content_json: dict[str, t.Any] | None = None, + extras: dict[str, t.Any] | None = None, ): # Try to get hold of content, if response is given and content is not provided if content is None and content_json is None and response is not None: @@ -71,7 +78,8 @@ class ACMEProtocolException(ModuleFailException): raise TypeError content = response.read() except (AttributeError, TypeError): - content = info.pop("body", None) + if info is not None: + content = info.pop("body", None) # Make sure that content_json is None or a dictionary if content_json is not None and not isinstance(content_json, dict): @@ -139,8 +147,8 @@ class ACMEProtocolException(ModuleFailException): add_msg = f" The raw result: {to_text(content)}" super(ACMEProtocolException, self).__init__(f"{msg}.{add_msg}", **extras) - self.problem = {} - self.subproblems = [] + self.problem: dict[str, t.Any] = {} + self.subproblems: list[dict[str, t.Any]] = [] self.error_code = error_code self.error_type = error_type for k, v in extras.items(): diff --git a/plugins/module_utils/acme/io.py b/plugins/module_utils/acme/io.py index 687315df..081056fd 100644 --- a/plugins/module_utils/acme/io.py +++ b/plugins/module_utils/acme/io.py @@ -10,22 +10,27 @@ import os import shutil import tempfile import traceback +import typing as t from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( ModuleFailException, ) -def read_file(fn, mode="b"): +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + +def read_file(fn: str | os.PathLike) -> bytes: try: - with open(fn, "r" + mode) as f: + with open(fn, "rb") as f: return f.read() except Exception as e: raise ModuleFailException(f'Error while reading file "{fn}": {e}') # This function was adapted from an earlier version of https://github.com/ansible/ansible/blob/devel/lib/ansible/modules/uri.py -def write_file(module, dest, content): +def write_file(module: AnsibleModule, dest: str | os.PathLike, content: bytes) -> bool: """ Write content to destination file dest, only if the content has changed. diff --git a/plugins/module_utils/acme/orders.py b/plugins/module_utils/acme/orders.py index 6f21170c..904031f2 100644 --- a/plugins/module_utils/acme/orders.py +++ b/plugins/module_utils/acme/orders.py @@ -6,6 +6,7 @@ from __future__ import annotations import time +import typing as t from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import ( Authorization, @@ -13,14 +14,35 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.challenges i ) from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( ACMEProtocolException, + ModuleFailException, ) from ansible_collections.community.crypto.plugins.module_utils.acme.utils import ( nopad_b64, ) +if t.TYPE_CHECKING: + from .acme import ACMEClient + + +_Order = t.TypeVar("_Order", bound="Order") + + class Order: - def _setup(self, client, data): + def __init__(self, url: str) -> None: + self.url = url + + self.data: dict[str, t.Any] | None = None + + self.status = None + self.identifiers: list[tuple[str, str]] = [] + self.replaces_cert_id = None + self.finalize_uri = None + self.certificate_uri = None + self.authorization_uris: list[str] = [] + self.authorizations: dict[str, Authorization] = {} + + def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None: self.data = data self.status = data["status"] @@ -33,33 +55,28 @@ class Order: self.authorization_uris = data["authorizations"] self.authorizations = {} - def __init__(self, url): - self.url = url - - self.data = None - - self.status = None - self.identifiers = [] - self.replaces_cert_id = None - self.finalize_uri = None - self.certificate_uri = None - self.authorization_uris = [] - self.authorizations = {} - @classmethod - def from_json(cls, client, data, url): + def from_json( + cls: t.Type[_Order], client: ACMEClient, data: dict[str, t.Any], url: str + ) -> _Order: result = cls(url) result._setup(client, data) return result @classmethod - def from_url(cls, client, url): + def from_url(cls: t.Type[_Order], client: ACMEClient, url: str) -> _Order: result = cls(url) result.refresh(client) return result @classmethod - def create(cls, client, identifiers, replaces_cert_id=None, profile=None): + def create( + cls: t.Type[_Order], + client: ACMEClient, + identifiers: list[tuple[str, str]], + replaces_cert_id: str | None = None, + profile: str | None = None, + ) -> _Order: """ Start a new certificate order (ACME v2 protocol). https://tools.ietf.org/html/rfc8555#section-7.4 @@ -72,7 +89,7 @@ class Order: "value": identifier, } ) - new_order = {"identifiers": acme_identifiers} + new_order: dict[str, t.Any] = {"identifiers": acme_identifiers} if replaces_cert_id is not None: new_order["replaces"] = replaces_cert_id if profile is not None: @@ -87,15 +104,17 @@ class Order: @classmethod def create_with_error_handling( - cls, - client, - identifiers, - error_strategy="auto", - error_max_retries=3, - replaces_cert_id=None, - profile=None, - message_callback=None, - ): + cls: t.Type[_Order], + client: ACMEClient, + identifiers: list[tuple[str, str]], + error_strategy: t.Literal[ + "auto", "fail", "always", "retry_without_replaces_cert_id" + ] = "auto", + error_max_retries: int = 3, + replaces_cert_id: str | None = None, + profile: str | None = None, + message_callback: t.Callable[[str], None] | None = None, + ) -> _Order: """ error_strategy can be one of the following strings: @@ -140,20 +159,20 @@ class Order: raise - def refresh(self, client): + def refresh(self, client: ACMEClient) -> bool: result, dummy = client.get_request(self.url) changed = self.data != result self._setup(client, result) return changed - def load_authorizations(self, client): + def load_authorizations(self, client: ACMEClient) -> None: for auth_uri in self.authorization_uris: authz = Authorization.from_url(client, auth_uri) self.authorizations[ normalize_combined_identifier(authz.combined_identifier) ] = authz - def wait_for_finalization(self, client): + def wait_for_finalization(self, client: ACMEClient) -> None: while True: self.refresh(client) if self.status in ["valid", "invalid", "pending", "ready"]: @@ -167,12 +186,14 @@ class Order: content_json=self.data, ) - def finalize(self, client, csr_der, wait=True): + def finalize(self, client: ACMEClient, csr_der: bytes, wait: bool = True) -> None: """ Create a new certificate based on the csr. Return the certificate object as dict https://tools.ietf.org/html/rfc8555#section-7.4 """ + if self.finalize_uri is None: + raise ModuleFailException("finalize_uri must be set") new_cert = { "csr": nopad_b64(csr_der), } diff --git a/plugins/module_utils/acme/utils.py b/plugins/module_utils/acme/utils.py index 5b7e39c0..9eaf49a8 100644 --- a/plugins/module_utils/acme/utils.py +++ b/plugins/module_utils/acme/utils.py @@ -7,9 +7,11 @@ from __future__ import annotations import base64 import datetime +import os import re import textwrap import traceback +import typing as t from urllib.parse import unquote from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( @@ -23,11 +25,15 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( ) -def nopad_b64(data): +if t.TYPE_CHECKING: + from .backends import CertificateInformation, CryptoBackend + + +def nopad_b64(data: bytes) -> str: return base64.urlsafe_b64encode(data).decode("utf8").replace("=", "") -def der_to_pem(der_cert): +def der_to_pem(der_cert: bytes) -> str: """ Convert the DER format certificate in der_cert to a PEM format certificate and return it. """ @@ -35,7 +41,9 @@ def der_to_pem(der_cert): return f"-----BEGIN CERTIFICATE-----\n{content}\n-----END CERTIFICATE-----\n" -def pem_to_der(pem_filename=None, pem_content=None): +def pem_to_der( + pem_filename: str | os.PathLike | None = None, pem_content: str | None = None +) -> bytes: """ Load PEM file, or use PEM file's content, and convert to DER. @@ -70,7 +78,9 @@ def pem_to_der(pem_filename=None, pem_content=None): return base64.b64decode("".join(certificate_lines)) -def process_links(info, callback): +def process_links( + info: dict[str, t.Any], callback: t.Callable[[str, str], None] +) -> None: """ Process link header, calls callback for every link header with the URL and relation as options. @@ -82,7 +92,11 @@ def process_links(info, callback): callback(unquote(url), relation) -def parse_retry_after(value, relative_with_timezone=True, now=None): +def parse_retry_after( + value: str, + relative_with_timezone: bool = True, + now: datetime.datetime | None = None, +) -> datetime.datetime: """ Parse the value of a Retry-After header and return a timestamp. @@ -106,12 +120,12 @@ def parse_retry_after(value, relative_with_timezone=True, now=None): def compute_cert_id( - backend, - cert_info=None, - cert_filename=None, - cert_content=None, - none_if_required_information_is_missing=False, -): + backend: CryptoBackend, + cert_info: CertificateInformation | None = None, + cert_filename: str | os.PathLike | None = None, + cert_content: str | bytes | None = None, + none_if_required_information_is_missing: bool = False, +) -> str | None: # Obtain certificate info if not provided if cert_info is None: cert_info = backend.get_cert_information( diff --git a/plugins/module_utils/argspec.py b/plugins/module_utils/argspec.py index ea5d8365..a0a02c8a 100644 --- a/plugins/module_utils/argspec.py +++ b/plugins/module_utils/argspec.py @@ -4,10 +4,15 @@ from __future__ import annotations +import typing as t + from ansible.module_utils.basic import AnsibleModule -def _ensure_list(value): +_T = t.TypeVar("_T") + + +def _ensure_list(value: list[_T] | tuple[_T] | None) -> list[_T]: if value is None: return [] return list(value) @@ -16,13 +21,19 @@ def _ensure_list(value): class ArgumentSpec: def __init__( self, - argument_spec=None, - mutually_exclusive=None, - required_together=None, - required_one_of=None, - required_if=None, - required_by=None, - ): + argument_spec: dict[str, t.Any] | None = None, + mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None, + required_together: list[list[str] | tuple[str, ...]] | None = None, + required_one_of: list[list[str] | tuple[str, ...]] | None = None, + required_if: ( + list[ + tuple[str, t.Any, list[str] | tuple[str, ...]] + | tuple[str, t.Any, list[str] | tuple[str, ...], bool] + ] + | None + ) = None, + required_by: dict[str, tuple[str, ...] | list[str]] | None = None, + ) -> None: self.argument_spec = argument_spec or {} self.mutually_exclusive = _ensure_list(mutually_exclusive) self.required_together = _ensure_list(required_together) @@ -30,17 +41,23 @@ class ArgumentSpec: self.required_if = _ensure_list(required_if) self.required_by = required_by or {} - def update_argspec(self, **kwargs): + def update_argspec(self, **kwargs) -> t.Self: self.argument_spec.update(kwargs) return self def update( self, - mutually_exclusive=None, - required_together=None, - required_one_of=None, - required_if=None, - required_by=None, + mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None, + required_together: list[list[str] | tuple[str, ...]] | None = None, + required_one_of: list[list[str] | tuple[str, ...]] | None = None, + required_if: ( + list[ + tuple[str, t.Any, list[str] | tuple[str, ...]] + | tuple[str, t.Any, list[str] | tuple[str, ...], bool] + ] + | None + ) = None, + required_by: dict[str, tuple[str, ...] | list[str]] | None = None, ): if mutually_exclusive: self.mutually_exclusive.extend(mutually_exclusive) @@ -57,7 +74,7 @@ class ArgumentSpec: self.required_by[k] = v return self - def merge(self, other): + def merge(self, other: t.Self) -> t.Self: self.update_argspec(**other.argument_spec) self.update( mutually_exclusive=other.mutually_exclusive, @@ -68,8 +85,22 @@ class ArgumentSpec: ) return self - def create_ansible_module_helper(self, clazz, args, **kwargs): - return clazz( + def create_ansible_module_helper( + self, clazz: type[_T], args: tuple, **kwargs: t.Any + ) -> _T: + for forbidden_name in ( + "argument_spec", + "mutually_exclusive", + "required_together", + "required_one_of", + "required_if", + "required_by", + ): + if forbidden_name in kwargs: + raise ValueError( + f"You must not provide a {forbidden_name} keyword parameter to create_ansible_module_helper()" + ) + instance = clazz( # type: ignore *args, argument_spec=self.argument_spec, mutually_exclusive=self.mutually_exclusive, @@ -79,8 +110,9 @@ class ArgumentSpec: required_by=self.required_by, **kwargs, ) + return instance - def create_ansible_module(self, **kwargs): + def create_ansible_module(self, **kwargs: t.Any) -> AnsibleModule: return self.create_ansible_module_helper(AnsibleModule, (), **kwargs) diff --git a/plugins/module_utils/crypto/_asn1.py b/plugins/module_utils/crypto/_asn1.py index c0478992..826561e9 100644 --- a/plugins/module_utils/crypto/_asn1.py +++ b/plugins/module_utils/crypto/_asn1.py @@ -4,6 +4,7 @@ from __future__ import annotations +import enum import re from ansible.module_utils.common.text.converters import to_bytes @@ -32,7 +33,7 @@ ASN1_STRING_REGEX = re.compile( ) -class TagClass: +class TagClass(enum.Enum): universal = 0 application = 1 context_specific = 2 @@ -40,11 +41,11 @@ class TagClass: # Universal tag numbers that can be encoded. -class TagNumber: +class TagNumber(enum.Enum): utf8_string = 12 -def _pack_octet_integer(value): +def _pack_octet_integer(value: int) -> bytes: """Packs an integer value into 1 or multiple octets.""" # NOTE: This is *NOT* the same as packing an ASN.1 INTEGER like value. octets = bytearray() @@ -66,7 +67,7 @@ def _pack_octet_integer(value): return bytes(octets) -def serialize_asn1_string_as_der(value): +def serialize_asn1_string_as_der(value: str) -> bytes: """Deserializes an ASN.1 string to a DER encoded byte string.""" asn1_match = ASN1_STRING_REGEX.match(value) if not asn1_match: @@ -92,7 +93,7 @@ def serialize_asn1_string_as_der(value): b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value) if tag_type: - tag_class = { + tag_class_enum = { "U": TagClass.universal, "A": TagClass.application, "P": TagClass.private, @@ -100,13 +101,15 @@ def serialize_asn1_string_as_der(value): }[tag_class] # When adding support for more types this should be looked into further. For now it works with UTF8Strings. - constructed = tag_type == "EXPLICIT" and tag_class != TagClass.universal - b_value = pack_asn1(tag_class, constructed, int(tag_number), b_value) + constructed = tag_type == "EXPLICIT" and tag_class_enum != TagClass.universal + b_value = pack_asn1(tag_class_enum, constructed, int(tag_number), b_value) return b_value -def pack_asn1(tag_class, constructed, tag_number, b_data): +def pack_asn1( + tag_class: TagClass, constructed: bool, tag_number: TagNumber | int, b_data: bytes +) -> bytes: """Pack the value into an ASN.1 data structure. The structure for an ASN.1 element is @@ -115,16 +118,15 @@ def pack_asn1(tag_class, constructed, tag_number, b_data): """ b_asn1_data = bytearray() - if tag_class < 0 or tag_class > 3: - raise ValueError(f"tag_class must be between 0 and 3 not {tag_class}") - # Bit 8 and 7 denotes the class. - identifier_octets = tag_class << 6 + identifier_octets = tag_class.value << 6 # Bit 6 denotes whether the value is primitive or constructed. identifier_octets |= (1 if constructed else 0) << 5 # Bits 5-1 contain the tag number, if it cannot be encoded in these 5 bits # then they are set and another octet(s) is used to denote the tag number. + if isinstance(tag_number, TagNumber): + tag_number = tag_number.value if tag_number < 31: identifier_octets |= tag_number b_asn1_data.append(identifier_octets) diff --git a/plugins/module_utils/crypto/_obj2txt.py b/plugins/module_utils/crypto/_obj2txt.py index e115c606..f1dcf95c 100644 --- a/plugins/module_utils/crypto/_obj2txt.py +++ b/plugins/module_utils/crypto/_obj2txt.py @@ -34,7 +34,7 @@ from __future__ import annotations # cryptography versions! -def obj2txt(openssl_lib, openssl_ffi, obj): +def obj2txt(openssl_lib, openssl_ffi, obj) -> str: # Set to 80 on the recommendation of # https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values # diff --git a/plugins/module_utils/crypto/_objects.py b/plugins/module_utils/crypto/_objects.py index 957ec186..4b0a6a42 100644 --- a/plugins/module_utils/crypto/_objects.py +++ b/plugins/module_utils/crypto/_objects.py @@ -7,9 +7,9 @@ from __future__ import annotations from ._objects_data import OID_MAP -OID_LOOKUP = dict() -NORMALIZE_NAMES = dict() -NORMALIZE_NAMES_SHORT = dict() +OID_LOOKUP: dict[str, str] = dict() +NORMALIZE_NAMES: dict[str, str] = dict() +NORMALIZE_NAMES_SHORT: dict[str, str] = dict() for dotted, names in OID_MAP.items(): for name in names: diff --git a/plugins/module_utils/crypto/cryptography_crl.py b/plugins/module_utils/crypto/cryptography_crl.py index 50fed969..6edfaa71 100644 --- a/plugins/module_utils/crypto/cryptography_crl.py +++ b/plugins/module_utils/crypto/cryptography_crl.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.version import ( LooseVersion as _LooseVersion, ) @@ -21,6 +23,10 @@ from .basic import HAS_CRYPTOGRAPHY from .cryptography_support import CRYPTOGRAPHY_TIMEZONE, cryptography_decode_name +if t.TYPE_CHECKING: + import datetime + + # TODO: once cryptography has a _utc variant of InvalidityDate.invalidity_date, set this # to True and adjust get_invalidity_date() accordingly. # (https://github.com/pyca/cryptography/issues/10818) @@ -55,7 +61,9 @@ else: REVOCATION_REASON_MAP_INVERSE = dict() -def cryptography_decode_revoked_certificate(cert): +def cryptography_decode_revoked_certificate( + cert: x509.RevokedCertificate, +) -> dict[str, t.Any]: result = { "serial_number": cert.serial_number, "revocation_date": get_revocation_date(cert), @@ -67,27 +75,30 @@ def cryptography_decode_revoked_certificate(cert): "invalidity_date_critical": False, } try: - ext = cert.extensions.get_extension_for_class(x509.CertificateIssuer) - result["issuer"] = list(ext.value) - result["issuer_critical"] = ext.critical + ext_ci = cert.extensions.get_extension_for_class(x509.CertificateIssuer) + result["issuer"] = list(ext_ci.value) + result["issuer_critical"] = ext_ci.critical except x509.ExtensionNotFound: pass try: - ext = cert.extensions.get_extension_for_class(x509.CRLReason) - result["reason"] = ext.value.reason - result["reason_critical"] = ext.critical + ext_cr = cert.extensions.get_extension_for_class(x509.CRLReason) + result["reason"] = ext_cr.value.reason + result["reason_critical"] = ext_cr.critical except x509.ExtensionNotFound: pass try: - ext = cert.extensions.get_extension_for_class(x509.InvalidityDate) - result["invalidity_date"] = get_invalidity_date(ext.value) - result["invalidity_date_critical"] = ext.critical + ext_id = cert.extensions.get_extension_for_class(x509.InvalidityDate) + result["invalidity_date"] = get_invalidity_date(ext_id.value) + result["invalidity_date_critical"] = ext_id.critical except x509.ExtensionNotFound: pass return result -def cryptography_dump_revoked(entry, idn_rewrite="ignore"): +def cryptography_dump_revoked( + entry: dict[str, t.Any], + idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore", +) -> dict[str, t.Any]: return { "serial_number": entry["serial_number"], "revocation_date": entry["revocation_date"].strftime(TIMESTAMP_FORMAT), @@ -115,48 +126,56 @@ def cryptography_dump_revoked(entry, idn_rewrite="ignore"): } -def cryptography_get_signature_algorithm_oid_from_crl(crl): +def cryptography_get_signature_algorithm_oid_from_crl( + crl: x509.CertificateRevocationList, +) -> x509.oid.ObjectIdentifier: try: return crl.signature_algorithm_oid except AttributeError: # Older cryptography versions do not have signature_algorithm_oid yet dotted = obj2txt( - crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm + crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm # type: ignore ) return x509.oid.ObjectIdentifier(dotted) -def get_next_update(obj): +def get_next_update(obj: x509.CertificateRevocationList) -> datetime.datetime | None: if CRYPTOGRAPHY_TIMEZONE: return obj.next_update_utc return obj.next_update -def get_last_update(obj): +def get_last_update(obj: x509.CertificateRevocationList) -> datetime.datetime: if CRYPTOGRAPHY_TIMEZONE: return obj.last_update_utc return obj.last_update -def get_revocation_date(obj): +def get_revocation_date(obj: x509.RevokedCertificate) -> datetime.datetime: if CRYPTOGRAPHY_TIMEZONE: return obj.revocation_date_utc return obj.revocation_date -def get_invalidity_date(obj): +def get_invalidity_date(obj: x509.InvalidityDate) -> datetime.datetime: if CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE: return obj.invalidity_date_utc return obj.invalidity_date -def set_next_update(builder, value): +def set_next_update( + builder: x509.CertificateRevocationListBuilder, value: datetime.datetime +) -> x509.CertificateRevocationListBuilder: return builder.next_update(value) -def set_last_update(builder, value): +def set_last_update( + builder: x509.CertificateRevocationListBuilder, value: datetime.datetime +) -> x509.CertificateRevocationListBuilder: return builder.last_update(value) -def set_revocation_date(builder, value): +def set_revocation_date( + builder: x509.RevokedCertificateBuilder, value: datetime.datetime +) -> x509.RevokedCertificateBuilder: return builder.revocation_date(value) diff --git a/plugins/module_utils/crypto/cryptography_support.py b/plugins/module_utils/crypto/cryptography_support.py index 95981dbb..bfbfd92f 100644 --- a/plugins/module_utils/crypto/cryptography_support.py +++ b/plugins/module_utils/crypto/cryptography_support.py @@ -9,6 +9,7 @@ import binascii import ipaddress import re import traceback +import typing as t from urllib.parse import ( ParseResult, urlparse, @@ -40,6 +41,7 @@ except ImportError: pass try: + import cryptography.hazmat.primitives.asymmetric.dh import cryptography.hazmat.primitives.asymmetric.ed448 import cryptography.hazmat.primitives.asymmetric.ed25519 import cryptography.hazmat.primitives.asymmetric.rsa @@ -55,7 +57,7 @@ try: ) except ImportError: # Error handled in the calling module. - _load_pkcs12 = None + _load_pkcs12 = None # type: ignore try: import idna @@ -74,6 +76,50 @@ from .basic import ( ) +if t.TYPE_CHECKING: + import datetime + + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric.dh import DHPrivateKey, DHPublicKey + from cryptography.hazmat.primitives.asymmetric.dsa import ( + DSAPrivateKey, + DSAPublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.ec import ( + EllipticCurvePrivateKey, + EllipticCurvePublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateKey, + RSAPublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPrivateKeyTypes, + CertificateIssuerPublicKeyTypes, + CertificatePublicKeyTypes, + PrivateKeyTypes, + PublicKeyTypes, + ) + from cryptography.hazmat.primitives.serialization.pkcs12 import ( + PKCS12KeyAndCertificates, + ) + + CertificatePrivateKeyTypes = ( + CertificateIssuerPrivateKeyTypes + | cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey + | cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey + ) + PublicKeyTypesWOEdwards = ( + DHPublicKey | DSAPublicKey | EllipticCurvePublicKey | RSAPublicKey + ) + PrivateKeyTypesWOEdwards = ( + DHPrivateKey | DSAPrivateKey | EllipticCurvePrivateKey | RSAPrivateKey + ) +else: + PublicKeyTypesWOEdwards = None + PrivateKeyTypesWOEdwards = None + + CRYPTOGRAPHY_TIMEZONE = False _CRYPTOGRAPHY_36_0_OR_NEWER = False if _HAS_CRYPTOGRAPHY: @@ -88,7 +134,9 @@ if _HAS_CRYPTOGRAPHY: DOTTED_OID = re.compile(r"^\d+(?:\.\d+)+$") -def cryptography_get_extensions_from_cert(cert): +def cryptography_get_extensions_from_cert( + cert: x509.Certificate, +) -> dict[str, dict[str, bool | str]]: result = dict() if _CRYPTOGRAPHY_36_0_OR_NEWER: @@ -105,7 +153,7 @@ def cryptography_get_extensions_from_cert(cert): backend = default_backend() - x509_obj = cert._x509 + x509_obj = cert._x509 # type: ignore # With cryptography 35.0.0, we can no longer use obj2txt. Unfortunately it still does # not allow to get the raw value of an extension, so we have to use this ugly hack: exts = list(cert.extensions) @@ -135,7 +183,9 @@ def cryptography_get_extensions_from_cert(cert): return result -def cryptography_get_extensions_from_csr(csr): +def cryptography_get_extensions_from_csr( + csr: x509.CertificateSigningRequest, +) -> dict[str, dict[str, bool | str]]: result = dict() if _CRYPTOGRAPHY_36_0_OR_NEWER: @@ -153,7 +203,7 @@ def cryptography_get_extensions_from_csr(csr): backend = default_backend() - extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req) + extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req) # type: ignore extensions = backend._ffi.gc( extensions, lambda ext: backend._lib.sk_X509_EXTENSION_pop_free( @@ -175,7 +225,7 @@ def cryptography_get_extensions_from_csr(csr): crit = backend._lib.X509_EXTENSION_get_critical(ext) data = backend._lib.X509_EXTENSION_get_data(ext) backend.openssl_assert(data != backend._ffi.NULL) - der = backend._ffi.buffer(data.data, data.length)[:] + der: bytes = backend._ffi.buffer(data.data, data.length)[:] # type: ignore entry = dict( critical=(crit == 1), value=base64.b64encode(der).decode("ascii"), @@ -193,7 +243,7 @@ def cryptography_get_extensions_from_csr(csr): return result -def cryptography_name_to_oid(name): +def cryptography_name_to_oid(name: str) -> x509.oid.ObjectIdentifier: dotted = OID_LOOKUP.get(name) if dotted is None: if DOTTED_OID.match(name): @@ -202,7 +252,9 @@ def cryptography_name_to_oid(name): return x509.oid.ObjectIdentifier(dotted) -def cryptography_oid_to_name(oid, short=False): +def cryptography_oid_to_name( + oid: x509.oid.ObjectIdentifier, short: bool = False +) -> str: dotted_string = oid.dotted_string names = OID_MAP.get(dotted_string) if names: @@ -217,15 +269,22 @@ def cryptography_oid_to_name(oid, short=False): return NORMALIZE_NAMES.get(name, name) -def _get_hex(bytesstr): +def _get_hex(bytesstr: bytes) -> str: if bytesstr is None: return bytesstr data = binascii.hexlify(bytesstr) - data = to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2))) - return data + return to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2))) -def _parse_hex(bytesstr): +@t.overload +def _parse_hex(bytesstr: bytes | str) -> bytes: ... + + +@t.overload +def _parse_hex(bytesstr: bytes | str | None) -> bytes | None: ... + + +def _parse_hex(bytesstr: bytes | str | None) -> bytes | None: if bytesstr is None: return bytesstr data = "".join( @@ -234,19 +293,20 @@ def _parse_hex(bytesstr): for p in to_text(bytesstr).split(":") ] ) - data = binascii.unhexlify(data) - return data + return binascii.unhexlify(data) DN_COMPONENT_START_RE = re.compile(b"^ *([a-zA-z0-9.]+) *= *") DN_HEX_LETTER = b"0123456789abcdef" -def _int_to_byte(value): +def _int_to_byte(value: int) -> bytes: return bytes((value,)) -def _parse_dn_component(name, sep=b",", decode_remainder=True): +def _parse_dn_component( + name: bytes, sep: bytes = b",", decode_remainder: bool = True +) -> tuple[x509.NameAttribute, bytes]: m = DN_COMPONENT_START_RE.match(name) if not m: raise OpenSSLObjectError(f'cannot start part in "{to_text(name)}"') @@ -305,7 +365,7 @@ def _parse_dn_component(name, sep=b",", decode_remainder=True): return x509.NameAttribute(oid, to_text(b"".join(decoded_name))), name[idx:] -def _parse_dn(name): +def _parse_dn(name: bytes) -> list[x509.NameAttribute]: """ Parse a Distinguished Name. @@ -323,31 +383,33 @@ def _parse_dn(name): attribute, name = _parse_dn_component(name, sep=sep) except OpenSSLObjectError as e: raise OpenSSLObjectError( - f'Error while parsing distinguished name "{to_text(original_name)}": {e}' + f"Error while parsing distinguished name {to_text(original_name)!r}: {e}" ) result.append(attribute) if name: if name[0:1] != sep or len(name) < 2: raise OpenSSLObjectError( - f'Error while parsing distinguished name "{to_text(original_name)}": unexpected end of string' + f"Error while parsing distinguished name {to_text(original_name)!r}: unexpected end of string" ) name = name[1:] return result -def cryptography_parse_relative_distinguished_name(rdn): +def cryptography_parse_relative_distinguished_name( + rdn: list[str | bytes], +) -> cryptography.x509.RelativeDistinguishedName: names = [] for part in rdn: try: names.append(_parse_dn_component(to_bytes(part), decode_remainder=False)[0]) except OpenSSLObjectError as e: raise OpenSSLObjectError( - f'Error while parsing relative distinguished name "{part}": {e}' + f"Error while parsing relative distinguished name {to_text(part)!r}: {e}" ) return cryptography.x509.RelativeDistinguishedName(names) -def _is_ascii(value): +def _is_ascii(value: str) -> bool: """Check whether the Unicode string `value` contains only ASCII characters.""" try: value.encode("ascii") @@ -356,7 +418,7 @@ def _is_ascii(value): return False -def _adjust_idn(value, idn_rewrite): +def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]) -> str: if idn_rewrite == "ignore" or not value: return value if idn_rewrite == "idna" and _is_ascii(value): @@ -399,16 +461,20 @@ def _adjust_idn(value, idn_rewrite): return ".".join(parts) -def _adjust_idn_email(value, idn_rewrite): +def _adjust_idn_email( + value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"] +) -> str: idx = value.find("@") if idx < 0: return value return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite)}" -def _adjust_idn_url(value, idn_rewrite): +def _adjust_idn_url( + value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"] +) -> str: url = urlparse(value) - host = _adjust_idn(url.hostname, idn_rewrite) + host = _adjust_idn(url.hostname, idn_rewrite) if url.hostname else None if url.username is not None and url.password is not None: host = f"{url.username}:{url.password}@{host}" elif url.username is not None: @@ -418,7 +484,7 @@ def _adjust_idn_url(value, idn_rewrite): return urlunparse( ParseResult( scheme=url.scheme, - netloc=host, + netloc=host or "", path=url.path, params=url.params, query=url.query, @@ -427,7 +493,9 @@ def _adjust_idn_url(value, idn_rewrite): ) -def cryptography_get_name(name, what="Subject Alternative Name"): +def cryptography_get_name( + name: str, what: str = "Subject Alternative Name" +) -> x509.GeneralName: """ Given a name string, returns a cryptography x509.GeneralName object. Raises an OpenSSLObjectError if the name is unknown or cannot be parsed. @@ -490,7 +558,7 @@ def cryptography_get_name(name, what="Subject Alternative Name"): ) -def _dn_escape_value(value): +def _dn_escape_value(value: str) -> str: """ Escape Distinguished Name's attribute value. """ @@ -505,7 +573,10 @@ def _dn_escape_value(value): return value -def cryptography_decode_name(name, idn_rewrite="ignore"): +def cryptography_decode_name( + name: x509.GeneralName, + idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore", +) -> str: """ Given a cryptography x509.GeneralName object, returns a string. Raises an OpenSSLObjectError if the name is not supported. @@ -529,7 +600,7 @@ def cryptography_decode_name(name, idn_rewrite="ignore"): # list needs to be reversed, and joined by commas return "dirName:" + ",".join( [ - f"{to_text(cryptography_oid_to_name(attribute.oid, short=True))}={_dn_escape_value(attribute.value)}" + f"{to_text(cryptography_oid_to_name(attribute.oid, short=True))}={_dn_escape_value(to_text(attribute.value))}" for attribute in reversed(list(name.value)) ] ) @@ -540,7 +611,7 @@ def cryptography_decode_name(name, idn_rewrite="ignore"): raise OpenSSLObjectError(f'Cannot decode name "{name}"') -def _cryptography_get_keyusage(usage): +def _cryptography_get_keyusage(usage: str) -> str: """ Given a key usage identifier string, returns the parameter name used by cryptography's x509.KeyUsage(). Raises an OpenSSLObjectError if the identifier is unknown. @@ -566,7 +637,7 @@ def _cryptography_get_keyusage(usage): raise OpenSSLObjectError(f'Unknown key usage "{usage}"') -def cryptography_parse_key_usage_params(usages): +def cryptography_parse_key_usage_params(usages: t.Iterable[str]) -> dict[str, bool]: """ Given a list of key usage identifier strings, returns the parameters for cryptography's x509.KeyUsage(). Raises an OpenSSLObjectError if an identifier is unknown. @@ -587,13 +658,15 @@ def cryptography_parse_key_usage_params(usages): return params -def cryptography_get_basic_constraints(constraints): +def cryptography_get_basic_constraints( + constraints: t.Iterable[str] | None, +) -> tuple[bool, int | None]: """ Given a list of constraints, returns a tuple (ca, path_length). Raises an OpenSSLObjectError if a constraint is unknown or cannot be parsed. """ ca = False - path_length = None + path_length: int | None = None if constraints: for constraint in constraints: if constraint.startswith("CA:"): @@ -618,7 +691,9 @@ def cryptography_get_basic_constraints(constraints): return ca, path_length -def cryptography_key_needs_digest_for_signing(key): +def cryptography_key_needs_digest_for_signing( + key: CertificateIssuerPrivateKeyTypes, +) -> bool: """Tests whether the given private key requires a digest algorithm for signing. Ed25519 and Ed448 keys do not; they need None to be passed as the digest algorithm. @@ -632,19 +707,27 @@ def cryptography_key_needs_digest_for_signing(key): return True -def _compare_public_keys(key1, key2, clazz): +def _compare_public_keys( + key1: PublicKeyTypes, key2: PublicKeyTypes, clazz: type[PublicKeyTypes] +) -> bool | None: a = isinstance(key1, clazz) b = isinstance(key2, clazz) if not (a or b): return None if not a or not b: return False - a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) - b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) - return a == b + a_bytes = key1.public_bytes( + serialization.Encoding.Raw, serialization.PublicFormat.Raw + ) + b_bytes = key2.public_bytes( + serialization.Encoding.Raw, serialization.PublicFormat.Raw + ) + return a_bytes == b_bytes -def cryptography_compare_public_keys(key1, key2): +def cryptography_compare_public_keys( + key1: PublicKeyTypes, key2: PublicKeyTypes +) -> bool: """Tests whether two public keys are the same. Needs special logic for Ed25519 and Ed448 keys, since they do not have public_numbers(). @@ -654,6 +737,13 @@ def cryptography_compare_public_keys(key1, key2): key2, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey, ) + if res is not None: + return res + res = _compare_public_keys( + key1, + key2, + cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey, + ) if res is not None: return res res = _compare_public_keys( @@ -661,10 +751,20 @@ def cryptography_compare_public_keys(key1, key2): ) if res is not None: return res - return key1.public_numbers() == key2.public_numbers() + res = _compare_public_keys( + key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey + ) + if res is not None: + return res + return ( + t.cast(PublicKeyTypesWOEdwards, key1).public_numbers() + == t.cast(PublicKeyTypesWOEdwards, key2).public_numbers() + ) -def _compare_private_keys(key1, key2, clazz): +def _compare_private_keys( + key1: PrivateKeyTypes, key2: PrivateKeyTypes, clazz: type[PrivateKeyTypes] +) -> bool | None: a = isinstance(key1, clazz) b = isinstance(key2, clazz) if not (a or b): @@ -672,20 +772,22 @@ def _compare_private_keys(key1, key2, clazz): if not a or not b: return False encryption_algorithm = cryptography.hazmat.primitives.serialization.NoEncryption() - a = key1.private_bytes( + a_bytes = key1.private_bytes( serialization.Encoding.Raw, serialization.PrivateFormat.Raw, encryption_algorithm=encryption_algorithm, ) - b = key2.private_bytes( + b_bytes = key2.private_bytes( serialization.Encoding.Raw, serialization.PrivateFormat.Raw, encryption_algorithm=encryption_algorithm, ) - return a == b + return a_bytes == b_bytes -def cryptography_compare_private_keys(key1, key2): +def cryptography_compare_private_keys( + key1: PrivateKeyTypes, key2: PrivateKeyTypes +) -> bool: """Tests whether two private keys are the same. Needs special logic for Ed25519, X25519, and Ed448 keys, since they do not have private_numbers(). @@ -714,25 +816,39 @@ def cryptography_compare_private_keys(key1, key2): ) if res is not None: return res - return key1.private_numbers() == key2.private_numbers() + return ( + t.cast(PrivateKeyTypesWOEdwards, key1).private_numbers() + == t.cast(PrivateKeyTypesWOEdwards, key2).private_numbers() + ) -def parse_pkcs12(pkcs12_bytes, passphrase=None): +def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) -> tuple[ + PrivateKeyTypes | None, + x509.Certificate | None, + list[x509.Certificate], + bytes | None, +]: """Returns a tuple (private_key, certificate, additional_certificates, friendly_name).""" + passphrase_bytes = None if passphrase is not None: - passphrase = to_bytes(passphrase) + passphrase_bytes = to_bytes(passphrase) # Main code for cryptography 36.0.0 and forward if _load_pkcs12 is not None: - return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase) + return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase_bytes) if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"): - return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase) + return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase_bytes) - return _parse_pkcs12_legacy(pkcs12_bytes, passphrase) + return _parse_pkcs12_legacy(pkcs12_bytes, passphrase_bytes) -def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None): +def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[ + PrivateKeyTypes | None, + x509.Certificate | None, + list[x509.Certificate], + bytes | None, +]: # Requires cryptography 36.0.0 or newer pkcs12 = _load_pkcs12(pkcs12_bytes, passphrase) additional_certificates = [cert.certificate for cert in pkcs12.additional_certs] @@ -745,7 +861,12 @@ def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None): return private_key, certificate, additional_certificates, friendly_name -def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None): +def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[ + PrivateKeyTypes | None, + x509.Certificate | None, + list[x509.Certificate], + bytes | None, +]: # Backwards compatibility code for cryptography 35.x private_key, certificate, additional_certificates = _load_key_and_certificates( pkcs12_bytes, passphrase @@ -787,7 +908,12 @@ def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None): return private_key, certificate, additional_certificates, friendly_name -def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None): +def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[ + PrivateKeyTypes | None, + x509.Certificate | None, + list[x509.Certificate], + bytes | None, +]: # Backwards compatibility code for cryptography < 35.0.0 private_key, certificate, additional_certificates = _load_key_and_certificates( pkcs12_bytes, passphrase @@ -796,14 +922,19 @@ def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None): friendly_name = None if certificate: # See https://github.com/pyca/cryptography/issues/5760#issuecomment-842687238 - backend = certificate._backend - maybe_name = backend._lib.X509_alias_get0(certificate._x509, backend._ffi.NULL) + backend = certificate._backend # type: ignore + maybe_name = backend._lib.X509_alias_get0(certificate._x509, backend._ffi.NULL) # type: ignore if maybe_name != backend._ffi.NULL: friendly_name = backend._ffi.string(maybe_name) return private_key, certificate, additional_certificates, friendly_name -def cryptography_verify_signature(signature, data, hash_algorithm, signer_public_key): +def cryptography_verify_signature( + signature: bytes, + data: bytes, + hash_algorithm: hashes.HashAlgorithm | None, + signer_public_key: PublicKeyTypes, +) -> bool: """ Check whether the given signature of the given data was signed by the given public key object. """ @@ -812,6 +943,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public signer_public_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey, ): + if hash_algorithm is None: + raise OpenSSLObjectError("Need hash_algorithm for RSA keys") signer_public_key.verify( signature, data, padding.PKCS1v15(), hash_algorithm ) @@ -820,6 +953,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public signer_public_key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey, ): + if hash_algorithm is None: + raise OpenSSLObjectError("Need hash_algorithm for ECC keys") signer_public_key.verify( signature, data, @@ -830,6 +965,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public signer_public_key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey, ): + if hash_algorithm is None: + raise OpenSSLObjectError("Need hash_algorithm for DSA keys") signer_public_key.verify(signature, data, hash_algorithm) return True if isinstance( @@ -851,7 +988,9 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public return False -def cryptography_verify_certificate_signature(certificate, signer_public_key): +def cryptography_verify_certificate_signature( + certificate: x509.Certificate, signer_public_key: PublicKeyTypes +) -> bool: """ Check whether the given X509 certificate object was signed by the given public key object. """ @@ -863,21 +1002,65 @@ def cryptography_verify_certificate_signature(certificate, signer_public_key): ) -def get_not_valid_after(obj): +def get_not_valid_after(obj: x509.Certificate) -> datetime.datetime: if CRYPTOGRAPHY_TIMEZONE: return obj.not_valid_after_utc return obj.not_valid_after -def get_not_valid_before(obj): +def get_not_valid_before(obj: x509.Certificate) -> datetime.datetime: if CRYPTOGRAPHY_TIMEZONE: return obj.not_valid_before_utc return obj.not_valid_before -def set_not_valid_after(builder, value): +def set_not_valid_after( + builder: x509.CertificateBuilder, value: datetime.datetime +) -> x509.CertificateBuilder: return builder.not_valid_after(value) -def set_not_valid_before(builder, value): +def set_not_valid_before( + builder: x509.CertificateBuilder, value: datetime.datetime +) -> x509.CertificateBuilder: return builder.not_valid_before(value) + + +def is_potential_certificate_private_key( + key: PrivateKeyTypes, +) -> t.TypeGuard[CertificatePrivateKeyTypes]: + return not isinstance( + key, cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey + ) + + +def is_potential_certificate_issuer_private_key( + key: PrivateKeyTypes, +) -> t.TypeGuard[CertificateIssuerPrivateKeyTypes]: + return not isinstance( + key, + ( + cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey, + cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey, + cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey, + ), + ) + + +def is_potential_certificate_public_key( + key: PublicKeyTypes, +) -> t.TypeGuard[CertificatePublicKeyTypes]: + return not isinstance(key, DHPublicKey) + + +def is_potential_certificate_issuer_public_key( + key: PublicKeyTypes, +) -> t.TypeGuard[CertificateIssuerPublicKeyTypes]: + return not isinstance( + key, + ( + cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey, + cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey, + cryptography.hazmat.primitives.asymmetric.dh.DHPublicKey, + ), + ) diff --git a/plugins/module_utils/crypto/math.py b/plugins/module_utils/crypto/math.py index 6b3e3434..453a2e0a 100644 --- a/plugins/module_utils/crypto/math.py +++ b/plugins/module_utils/crypto/math.py @@ -5,7 +5,7 @@ from __future__ import annotations -def binary_exp_mod(f, e, m): +def binary_exp_mod(f: int, e: int, m: int) -> int: """Computes f^e mod m in O(log e) multiplications modulo m.""" # Compute len_e = floor(log_2(e)) len_e = -1 @@ -22,14 +22,14 @@ def binary_exp_mod(f, e, m): return result -def simple_gcd(a, b): +def simple_gcd(a: int, b: int) -> int: """Compute GCD of its two inputs.""" while b != 0: a, b = b, a % b return a -def quick_is_not_prime(n): +def quick_is_not_prime(n: int) -> bool: """Does some quick checks to see if we can poke a hole into the primality of n. A result of `False` does **not** mean that the number is prime; it just means @@ -97,7 +97,7 @@ def quick_is_not_prime(n): return False -def count_bytes(no): +def count_bytes(no: int) -> int: """ Given an integer, compute the number of bytes necessary to store its absolute value. """ @@ -107,7 +107,7 @@ def count_bytes(no): return (no.bit_length() + 7) // 8 -def count_bits(no): +def count_bits(no: int) -> int: """ Given an integer, compute the number of bits necessary to store its absolute value. """ @@ -117,19 +117,7 @@ def count_bits(no): return no.bit_length() -def _convert_int_to_bytes(count, no): - return no.to_bytes(count, byteorder="big") - - -def _convert_bytes_to_int(data): - return int.from_bytes(data, byteorder="big", signed=False) - - -def _to_hex(no): - return f"{no:x}" - - -def convert_int_to_bytes(no, count=None): +def convert_int_to_bytes(no: int, count: int | None = None) -> bytes: """ Convert the absolute value of an integer to a byte string in network byte order. @@ -142,10 +130,10 @@ def convert_int_to_bytes(no, count=None): no = abs(no) if count is None: count = count_bytes(no) - return _convert_int_to_bytes(count, no) + return no.to_bytes(count, byteorder="big") -def convert_int_to_hex(no, digits=None): +def convert_int_to_hex(no: int, digits: int | None = None) -> str: """ Convert the absolute value of an integer to a string of hexadecimal digits. @@ -154,14 +142,14 @@ def convert_int_to_hex(no, digits=None): the string will be longer. """ no = abs(no) - value = _to_hex(no) + value = f"{no:x}" if digits is not None and len(value) < digits: value = "0" * (digits - len(value)) + value return value -def convert_bytes_to_int(data): +def convert_bytes_to_int(data: bytes) -> int: """ Convert a byte string to an unsigned integer in network byte order. """ - return _convert_bytes_to_int(data) + return int.from_bytes(data, byteorder="big", signed=False) diff --git a/plugins/module_utils/crypto/module_backends/certificate.py b/plugins/module_utils/crypto/module_backends/certificate.py index e8f6afda..42827edb 100644 --- a/plugins/module_utils/crypto/module_backends/certificate.py +++ b/plugins/module_utils/crypto/module_backends/certificate.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import typing as t from ansible_collections.community.crypto.plugins.module_utils.argspec import ( ArgumentSpec, @@ -24,8 +25,8 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac ) from ansible_collections.community.crypto.plugins.module_utils.crypto.support import ( load_certificate, + load_certificate_privatekey, load_certificate_request, - load_privatekey, ) from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import ( COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION, @@ -33,6 +34,17 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep ) +if t.TYPE_CHECKING: + import datetime + + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPrivateKeyTypes, + ) + + from ..cryptography_support import CertificatePrivateKeyTypes + + MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION try: @@ -47,41 +59,45 @@ class CertificateError(OpenSSLObjectError): class CertificateBackend(metaclass=abc.ABCMeta): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module - self.force = module.params["force"] - self.ignore_timestamps = module.params["ignore_timestamps"] - self.privatekey_path = module.params["privatekey_path"] - self.privatekey_content = module.params["privatekey_content"] - if self.privatekey_content is not None: - self.privatekey_content = self.privatekey_content.encode("utf-8") - self.privatekey_passphrase = module.params["privatekey_passphrase"] - self.csr_path = module.params["csr_path"] - self.csr_content = module.params["csr_content"] - if self.csr_content is not None: - self.csr_content = self.csr_content.encode("utf-8") + self.force: bool = module.params["force"] + self.ignore_timestamps: bool = module.params["ignore_timestamps"] + self.privatekey_path: str | None = module.params["privatekey_path"] + privatekey_content: str | None = module.params["privatekey_content"] + if privatekey_content is not None: + self.privatekey_content: bytes | None = privatekey_content.encode("utf-8") + else: + self.privatekey_content = None + self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"] + self.csr_path: str | None = module.params["csr_path"] + csr_content = module.params["csr_content"] + if csr_content is not None: + self.csr_content: bytes | None = csr_content.encode("utf-8") + else: + self.csr_content = None # The following are default values which make sure check() works as # before if providers do not explicitly change these properties. - self.create_subject_key_identifier = "never_create" - self.create_authority_key_identifier = False + self.create_subject_key_identifier: str = "never_create" + self.create_authority_key_identifier: bool = False - self.privatekey = None - self.csr = None - self.cert = None - self.existing_certificate = None - self.existing_certificate_bytes = None + self.privatekey: CertificatePrivateKeyTypes | None = None + self.csr: x509.CertificateSigningRequest | None = None + self.cert: x509.Certificate | None = None + self.existing_certificate: x509.Certificate | None = None + self.existing_certificate_bytes: bytes | None = None - self.check_csr_subject = True - self.check_csr_extensions = True + self.check_csr_subject: bool = True + self.check_csr_extensions: bool = True self.diff_before = self._get_info(None) self.diff_after = self._get_info(None) - def _get_info(self, data): + def _get_info(self, data: bytes | None) -> dict[str, t.Any]: if data is None: - return dict() + return {} try: result = get_certificate_info( self.module, data, prefer_one_fingerprint=True @@ -92,34 +108,34 @@ class CertificateBackend(metaclass=abc.ABCMeta): return dict(can_parse_certificate=False) @abc.abstractmethod - def generate_certificate(self): + def generate_certificate(self) -> None: """(Re-)Generate certificate.""" pass @abc.abstractmethod - def get_certificate_data(self): + def get_certificate_data(self) -> bytes: """Return bytes for self.cert.""" pass - def set_existing(self, certificate_bytes): + def set_existing(self, certificate_bytes: bytes | None) -> None: """Set existing certificate bytes. None indicates that the key does not exist.""" self.existing_certificate_bytes = certificate_bytes self.diff_after = self.diff_before = self._get_info( self.existing_certificate_bytes ) - def has_existing(self): + def has_existing(self) -> bool: """Query whether an existing certificate is/has been there.""" return self.existing_certificate_bytes is not None - def _ensure_private_key_loaded(self): + def _ensure_private_key_loaded(self) -> None: """Load the provided private key into self.privatekey.""" if self.privatekey is not None: return if self.privatekey_path is None and self.privatekey_content is None: return try: - self.privatekey = load_privatekey( + self.privatekey = load_certificate_privatekey( path=self.privatekey_path, content=self.privatekey_content, passphrase=self.privatekey_passphrase, @@ -127,7 +143,7 @@ class CertificateBackend(metaclass=abc.ABCMeta): except OpenSSLBadPassphraseError as exc: raise CertificateError(exc) - def _ensure_csr_loaded(self): + def _ensure_csr_loaded(self) -> None: """Load the CSR into self.csr.""" if self.csr is not None: return @@ -138,7 +154,7 @@ class CertificateBackend(metaclass=abc.ABCMeta): content=self.csr_content, ) - def _ensure_existing_certificate_loaded(self): + def _ensure_existing_certificate_loaded(self) -> None: """Load the existing certificate into self.existing_certificate.""" if self.existing_certificate is not None: return @@ -149,14 +165,28 @@ class CertificateBackend(metaclass=abc.ABCMeta): content=self.existing_certificate_bytes, ) - def _check_privatekey(self): + def _check_privatekey(self) -> bool: """Check whether provided parameters match, assuming self.existing_certificate and self.privatekey have been populated.""" + if self.existing_certificate is None: + raise AssertionError( + "Contract violation: existing_certificate has not been populated" + ) + if self.privatekey is None: + raise AssertionError( + "Contract violation: privatekey has not been populated" + ) return cryptography_compare_public_keys( self.existing_certificate.public_key(), self.privatekey.public_key() ) - def _check_csr(self): + def _check_csr(self) -> bool: """Check whether provided parameters match, assuming self.existing_certificate and self.csr have been populated.""" + if self.existing_certificate is None: + raise AssertionError( + "Contract violation: existing_certificate has not been populated" + ) + if self.csr is None: + raise AssertionError("Contract violation: csr has not been populated") # Verify that CSR is signed by certificate's private key if not self.csr.is_signature_valid: return False @@ -214,8 +244,14 @@ class CertificateBackend(metaclass=abc.ABCMeta): return False return True - def _check_subject_key_identifier(self): - """Check whether Subject Key Identifier matches, assuming self.existing_certificate has been populated.""" + def _check_subject_key_identifier(self) -> bool: + """Check whether Subject Key Identifier matches, assuming self.existing_certificate and self.csr have been populated.""" + if self.existing_certificate is None: + raise AssertionError( + "Contract violation: existing_certificate has not been populated" + ) + if self.csr is None: + raise AssertionError("Contract violation: csr has not been populated") # Get hold of certificate's SKI try: ext = self.existing_certificate.extensions.get_extension_for_class( @@ -247,7 +283,11 @@ class CertificateBackend(metaclass=abc.ABCMeta): return False return True - def needs_regeneration(self, not_before=None, not_after=None): + def needs_regeneration( + self, + not_before: datetime.datetime | None = None, + not_after: datetime.datetime | None = None, + ) -> bool: """Check whether a regeneration is necessary.""" if self.force or self.existing_certificate_bytes is None: return True @@ -256,6 +296,7 @@ class CertificateBackend(metaclass=abc.ABCMeta): self._ensure_existing_certificate_loaded() except Exception: return True + assert self.existing_certificate is not None # Check whether private key matches self._ensure_private_key_loaded() @@ -285,9 +326,12 @@ class CertificateBackend(metaclass=abc.ABCMeta): return True return False - def dump(self, include_certificate): + def dump(self, include_certificate: bool) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" - result = {"privatekey": self.privatekey_path, "csr": self.csr_path} + result: dict[str, t.Any] = { + "privatekey": self.privatekey_path, + "csr": self.csr_path, + } # Get hold of certificate bytes certificate_bytes = self.existing_certificate_bytes if self.cert is not None: @@ -299,35 +343,33 @@ class CertificateBackend(metaclass=abc.ABCMeta): certificate_bytes.decode("utf-8") if certificate_bytes else None ) - result["diff"] = dict( - before=self.diff_before, - after=self.diff_after, - ) + result["diff"] = { + "before": self.diff_before, + "after": self.diff_after, + } return result class CertificateProvider(metaclass=abc.ABCMeta): @abc.abstractmethod - def validate_module_args(self, module): + def validate_module_args(self, module: AnsibleModule) -> None: """Check module arguments""" @abc.abstractmethod - def needs_version_two_certs(self, module): + def needs_version_two_certs(self, module: AnsibleModule) -> bool: """Whether the provider needs to create a version 2 certificate.""" @abc.abstractmethod - def create_backend(self, module): + def create_backend(self, module: AnsibleModule) -> CertificateBackend: """Create an implementation for a backend. Return value must be instance of CertificateBackend. """ -def select_backend(module, provider): - """ - :type module: AnsibleModule - :type provider: CertificateProvider - """ +def select_backend( + module: AnsibleModule, provider: CertificateProvider +) -> CertificateBackend: provider.validate_module_args(module) assert_required_cryptography_version( @@ -343,7 +385,7 @@ def select_backend(module, provider): return provider.create_backend(module) -def get_certificate_argument_spec(): +def get_certificate_argument_spec() -> ArgumentSpec: return ArgumentSpec( argument_spec=dict( provider=dict( diff --git a/plugins/module_utils/crypto/module_backends/certificate_acme.py b/plugins/module_utils/crypto/module_backends/certificate_acme.py index 3ba150f4..dc49a644 100644 --- a/plugins/module_utils/crypto/module_backends/certificate_acme.py +++ b/plugins/module_utils/crypto/module_backends/certificate_acme.py @@ -8,6 +8,7 @@ from __future__ import annotations import os import tempfile import traceback +import typing as t from ansible.module_utils.common.text.converters import to_bytes from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import ( @@ -17,22 +18,30 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac ) -class AcmeCertificateBackend(CertificateBackend): - def __init__(self, module): - super(AcmeCertificateBackend, self).__init__(module) - self.accountkey_path = module.params["acme_accountkey_path"] - self.challenge_path = module.params["acme_challenge_path"] - self.use_chain = module.params["acme_chain"] - self.acme_directory = module.params["acme_directory"] +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule - if self.csr_content is None and self.csr_path is None: - raise CertificateError( - "csr_path or csr_content is required for ownca provider" - ) - if self.csr_content is None and not os.path.exists(self.csr_path): - raise CertificateError( - f"The certificate signing request file {self.csr_path} does not exist" - ) + from ...argspec import ArgumentSpec + + +class AcmeCertificateBackend(CertificateBackend): + def __init__(self, module: AnsibleModule) -> None: + super(AcmeCertificateBackend, self).__init__(module) + self.accountkey_path: str = module.params["acme_accountkey_path"] + self.challenge_path: str = module.params["acme_challenge_path"] + self.use_chain: bool = module.params["acme_chain"] + self.acme_directory: str = module.params["acme_directory"] + self.cert_bytes: bytes | None = None + + if self.csr_content is None: + if self.csr_path is None: + raise CertificateError( + "csr_path or csr_content is required for ownca provider" + ) + if not os.path.exists(self.csr_path): + raise CertificateError( + f"The certificate signing request file {self.csr_path} does not exist" + ) if not os.path.exists(self.accountkey_path): raise CertificateError( @@ -46,7 +55,7 @@ class AcmeCertificateBackend(CertificateBackend): self.acme_tiny_path = self.module.get_bin_path("acme-tiny", required=True) - def generate_certificate(self): + def generate_certificate(self) -> None: """(Re-)Generate certificate.""" command = [self.acme_tiny_path] @@ -77,22 +86,26 @@ class AcmeCertificateBackend(CertificateBackend): command.extend(["--directory-url", self.acme_directory]) try: - self.cert = to_bytes(self.module.run_command(command, check_rc=True)[1]) + self.cert_bytes = to_bytes( + self.module.run_command(command, check_rc=True)[1] + ) except OSError as exc: raise CertificateError(exc) - def get_certificate_data(self): + def get_certificate_data(self) -> bytes: """Return bytes for self.cert.""" - return self.cert + if self.cert_bytes is None: + raise AssertionError("Contract violation: cert_bytes is None") + return self.cert_bytes - def dump(self, include_certificate): + def dump(self, include_certificate: bool) -> dict[str, t.Any]: result = super(AcmeCertificateBackend, self).dump(include_certificate) result["accountkey"] = self.accountkey_path return result class AcmeCertificateProvider(CertificateProvider): - def validate_module_args(self, module): + def validate_module_args(self, module: AnsibleModule) -> None: if module.params["acme_accountkey_path"] is None: module.fail_json( msg="The acme_accountkey_path option must be specified for the acme provider." @@ -102,14 +115,14 @@ class AcmeCertificateProvider(CertificateProvider): msg="The acme_challenge_path option must be specified for the acme provider." ) - def needs_version_two_certs(self, module): + def needs_version_two_certs(self, module: AnsibleModule) -> bool: return False - def create_backend(self, module): + def create_backend(self, module: AnsibleModule) -> AcmeCertificateBackend: return AcmeCertificateBackend(module) -def add_acme_provider_to_argument_spec(argument_spec): +def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None: argument_spec.argument_spec["provider"]["choices"].append("acme") argument_spec.argument_spec.update( dict( diff --git a/plugins/module_utils/crypto/module_backends/certificate_entrust.py b/plugins/module_utils/crypto/module_backends/certificate_entrust.py index 22779797..48d3ad5e 100644 --- a/plugins/module_utils/crypto/module_backends/certificate_entrust.py +++ b/plugins/module_utils/crypto/module_backends/certificate_entrust.py @@ -7,6 +7,7 @@ from __future__ import annotations import datetime import os +import typing as t from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import ( @@ -32,6 +33,12 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + from ...argspec import ArgumentSpec + + try: from cryptography.x509.oid import NameOID except ImportError: @@ -39,7 +46,7 @@ except ImportError: class EntrustCertificateBackend(CertificateBackend): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(EntrustCertificateBackend, self).__init__(module) self.trackingId = None self.notAfter = get_relative_time_option( @@ -48,16 +55,19 @@ class EntrustCertificateBackend(CertificateBackend): with_timezone=CRYPTOGRAPHY_TIMEZONE, ) - if self.csr_content is None and self.csr_path is None: - raise CertificateError( - "csr_path or csr_content is required for entrust provider" - ) - if self.csr_content is None and not os.path.exists(self.csr_path): - raise CertificateError( - f"The certificate signing request file {self.csr_path} does not exist" - ) + if self.csr_content is None: + if self.csr_path is None: + raise CertificateError( + "csr_path or csr_content is required for entrust provider" + ) + if not os.path.exists(self.csr_path): + raise CertificateError( + f"The certificate signing request file {self.csr_path} does not exist" + ) self._ensure_csr_loaded() + if self.csr is None: + raise CertificateError("CSR not provided") # ECS API defaults to using the validated organization tied to the account. # We want to always force behavior of trying to use the organization provided in the CSR. @@ -93,9 +103,9 @@ class EntrustCertificateBackend(CertificateBackend): ], ) except SessionConfigurationException as e: - module.fail_json(msg=f"Failed to initialize Entrust Provider: {e.message}") + module.fail_json(msg=f"Failed to initialize Entrust Provider: {e}") - def generate_certificate(self): + def generate_certificate(self) -> None: """(Re-)Generate certificate.""" body = {} @@ -104,6 +114,7 @@ class EntrustCertificateBackend(CertificateBackend): # csr_content contains bytes body["csr"] = to_native(self.csr_content) else: + assert self.csr_path is not None with open(self.csr_path, "r") as csr_file: body["csr"] = csr_file.read() @@ -138,11 +149,15 @@ class EntrustCertificateBackend(CertificateBackend): content=self.cert_bytes, ) - def get_certificate_data(self): + def get_certificate_data(self) -> bytes: """Return bytes for self.cert.""" return self.cert_bytes - def needs_regeneration(self): + def needs_regeneration( + self, + not_before: datetime.datetime | None = None, + not_after: datetime.datetime | None = None, + ) -> bool: parent_check = super(EntrustCertificateBackend, self).needs_regeneration() try: @@ -167,12 +182,12 @@ class EntrustCertificateBackend(CertificateBackend): return parent_check - def _get_cert_details(self): - cert_details = {} + def _get_cert_details(self) -> dict[str, t.Any]: + cert_details: dict[str, t.Any] = {} try: self._ensure_existing_certificate_loaded() except Exception: - return + return cert_details if self.existing_certificate: serial_number = f"{self.existing_certificate.serial_number:X}" expiry = get_not_valid_after(self.existing_certificate) @@ -203,17 +218,17 @@ class EntrustCertificateBackend(CertificateBackend): class EntrustCertificateProvider(CertificateProvider): - def validate_module_args(self, module): + def validate_module_args(self, module: AnsibleModule) -> None: pass - def needs_version_two_certs(self, module): + def needs_version_two_certs(self, module: AnsibleModule) -> t.Literal[False]: return False - def create_backend(self, module): + def create_backend(self, module: AnsibleModule) -> EntrustCertificateBackend: return EntrustCertificateBackend(module) -def add_entrust_provider_to_argument_spec(argument_spec): +def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None: argument_spec.argument_spec["provider"]["choices"].append("entrust") argument_spec.argument_spec.update( dict( @@ -248,7 +263,7 @@ def add_entrust_provider_to_argument_spec(argument_spec): ) ) argument_spec.required_if.append( - [ + ( "provider", "entrust", [ @@ -260,5 +275,5 @@ def add_entrust_provider_to_argument_spec(argument_spec): "entrust_api_client_cert_path", "entrust_api_client_cert_key_path", ], - ] + ) ) diff --git a/plugins/module_utils/crypto/module_backends/certificate_info.py b/plugins/module_utils/crypto/module_backends/certificate_info.py index d41ede93..9f745fd9 100644 --- a/plugins/module_utils/crypto/module_backends/certificate_info.py +++ b/plugins/module_utils/crypto/module_backends/certificate_info.py @@ -8,6 +8,7 @@ from __future__ import annotations import abc import binascii +import typing as t from ansible.module_utils.common.text.converters import to_native from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import ( @@ -34,6 +35,19 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( ) +if t.TYPE_CHECKING: + import datetime + + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes + + from ....plugin_utils.action_module import AnsibleActionModule + from ....plugin_utils.filter_module import FilterModuleMock + from ...argspec import ArgumentSpec + + GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock] + + MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION try: @@ -48,93 +62,97 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ" class CertificateInfoRetrieval(metaclass=abc.ABCMeta): - def __init__(self, module, content): + def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None: # content must be a bytes string self.module = module self.content = content @abc.abstractmethod - def _get_der_bytes(self): + def _get_der_bytes(self) -> bytes: pass @abc.abstractmethod - def _get_signature_algorithm(self): + def _get_signature_algorithm(self) -> str: pass @abc.abstractmethod - def _get_subject_ordered(self): + def _get_subject_ordered(self) -> list[list[str]]: pass @abc.abstractmethod - def _get_issuer_ordered(self): + def _get_issuer_ordered(self) -> list[list[str]]: pass @abc.abstractmethod - def _get_version(self): + def _get_version(self) -> int | str: pass @abc.abstractmethod - def _get_key_usage(self): + def _get_key_usage(self) -> tuple[list[str] | None, bool]: pass @abc.abstractmethod - def _get_extended_key_usage(self): + def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]: pass @abc.abstractmethod - def _get_basic_constraints(self): + def _get_basic_constraints(self) -> tuple[list[str] | None, bool]: pass @abc.abstractmethod - def _get_ocsp_must_staple(self): + def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]: pass @abc.abstractmethod - def _get_subject_alt_name(self): + def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]: pass @abc.abstractmethod - def get_not_before(self): + def get_not_before(self) -> datetime.datetime: pass @abc.abstractmethod - def get_not_after(self): + def get_not_after(self) -> datetime.datetime: pass @abc.abstractmethod - def _get_public_key_pem(self): + def _get_public_key_pem(self) -> bytes: pass @abc.abstractmethod - def _get_public_key_object(self): + def _get_public_key_object(self) -> PublicKeyTypes: pass @abc.abstractmethod - def _get_subject_key_identifier(self): + def _get_subject_key_identifier(self) -> bytes | None: pass @abc.abstractmethod - def _get_authority_key_identifier(self): + def _get_authority_key_identifier( + self, + ) -> tuple[bytes | None, list[str] | None, int | None]: pass @abc.abstractmethod - def _get_serial_number(self): + def _get_serial_number(self) -> int: pass @abc.abstractmethod - def _get_all_extensions(self): + def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]: pass @abc.abstractmethod - def _get_ocsp_uri(self): + def _get_ocsp_uri(self) -> str | None: pass @abc.abstractmethod - def _get_issuer_uri(self): + def _get_issuer_uri(self) -> str | None: pass - def get_info(self, prefer_one_fingerprint=False, der_support_enabled=False): - result = dict() + def get_info( + self, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False + ) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} self.cert = load_certificate( None, content=self.content, @@ -194,16 +212,20 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta): self._get_der_bytes(), prefer_one=prefer_one_fingerprint ) - ski = self._get_subject_key_identifier() - if ski is not None: - ski = binascii.hexlify(ski).decode("ascii") + ski_bytes = self._get_subject_key_identifier() + if ski_bytes is not None: + ski = binascii.hexlify(ski_bytes).decode("ascii") ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)]) + else: + ski = None result["subject_key_identifier"] = ski - aki, aci, acsn = self._get_authority_key_identifier() - if aki is not None: - aki = binascii.hexlify(aki).decode("ascii") + aki_bytes, aci, acsn = self._get_authority_key_identifier() + if aki_bytes is not None: + aki = binascii.hexlify(aki_bytes).decode("ascii") aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)]) + else: + aki = None result["authority_key_identifier"] = aki result["authority_cert_issuer"] = aci result["authority_cert_serial_number"] = acsn @@ -219,36 +241,40 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta): class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): """Validate the supplied cert, using the cryptography backend""" - def __init__(self, module, content): + def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None: super(CertificateInfoRetrievalCryptography, self).__init__(module, content) self.name_encoding = module.params.get("name_encoding", "ignore") - def _get_der_bytes(self): + def _get_der_bytes(self) -> bytes: return self.cert.public_bytes(serialization.Encoding.DER) - def _get_signature_algorithm(self): + def _get_signature_algorithm(self) -> str: return cryptography_oid_to_name(self.cert.signature_algorithm_oid) - def _get_subject_ordered(self): - result = [] + def _get_subject_ordered(self) -> list[list[str]]: + result: list[list[str]] = [] for attribute in self.cert.subject: - result.append([cryptography_oid_to_name(attribute.oid), attribute.value]) + result.append( + [cryptography_oid_to_name(attribute.oid), to_native(attribute.value)] + ) return result - def _get_issuer_ordered(self): + def _get_issuer_ordered(self) -> list[list[str]]: result = [] for attribute in self.cert.issuer: - result.append([cryptography_oid_to_name(attribute.oid), attribute.value]) + result.append( + [cryptography_oid_to_name(attribute.oid), to_native(attribute.value)] + ) return result - def _get_version(self): + def _get_version(self) -> int | str: if self.cert.version == x509.Version.v1: return 1 if self.cert.version == x509.Version.v3: return 3 return "unknown" - def _get_key_usage(self): + def _get_key_usage(self) -> tuple[list[str] | None, bool]: try: current_key_ext = self.cert.extensions.get_extension_for_class( x509.KeyUsage @@ -297,7 +323,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, False - def _get_extended_key_usage(self): + def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]: try: ext_keyusage_ext = self.cert.extensions.get_extension_for_class( x509.ExtendedKeyUsage @@ -311,7 +337,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, False - def _get_basic_constraints(self): + def _get_basic_constraints(self) -> tuple[list[str] | None, bool]: try: ext_keyusage_ext = self.cert.extensions.get_extension_for_class( x509.BasicConstraints @@ -324,7 +350,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, False - def _get_ocsp_must_staple(self): + def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]: try: tlsfeature_ext = self.cert.extensions.get_extension_for_class( x509.TLSFeature @@ -336,7 +362,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, False - def _get_subject_alt_name(self): + def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]: try: san_ext = self.cert.extensions.get_extension_for_class( x509.SubjectAlternativeName @@ -349,22 +375,22 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, False - def get_not_before(self): + def get_not_before(self) -> datetime.datetime: return get_not_valid_before(self.cert) - def get_not_after(self): + def get_not_after(self) -> datetime.datetime: return get_not_valid_after(self.cert) - def _get_public_key_pem(self): + def _get_public_key_pem(self) -> bytes: return self.cert.public_key().public_bytes( serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo, ) - def _get_public_key_object(self): + def _get_public_key_object(self) -> PublicKeyTypes: return self.cert.public_key() - def _get_subject_key_identifier(self): + def _get_subject_key_identifier(self) -> bytes | None: try: ext = self.cert.extensions.get_extension_for_class( x509.SubjectKeyIdentifier @@ -373,7 +399,9 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None - def _get_authority_key_identifier(self): + def _get_authority_key_identifier( + self, + ) -> tuple[bytes | None, list[str] | None, int | None]: try: ext = self.cert.extensions.get_extension_for_class( x509.AuthorityKeyIdentifier @@ -392,13 +420,13 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, None, None - def _get_serial_number(self): + def _get_serial_number(self) -> int: return self.cert.serial_number - def _get_all_extensions(self): + def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]: return cryptography_get_extensions_from_cert(self.cert) - def _get_ocsp_uri(self): + def _get_ocsp_uri(self) -> str | None: try: ext = self.cert.extensions.get_extension_for_class( x509.AuthorityInformationAccess @@ -411,7 +439,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): pass return None - def _get_issuer_uri(self): + def _get_issuer_uri(self) -> str | None: try: ext = self.cert.extensions.get_extension_for_class( x509.AuthorityInformationAccess @@ -428,12 +456,16 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): return None -def get_certificate_info(module, content, prefer_one_fingerprint=False): +def get_certificate_info( + module: GeneralAnsibleModule, content: bytes, prefer_one_fingerprint: bool = False +) -> dict[str, t.Any]: info = CertificateInfoRetrievalCryptography(module, content) return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint) -def select_backend(module, content): +def select_backend( + module: GeneralAnsibleModule, content: bytes +) -> CertificateInfoRetrieval: assert_required_cryptography_version( module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) diff --git a/plugins/module_utils/crypto/module_backends/certificate_ownca.py b/plugins/module_utils/crypto/module_backends/certificate_ownca.py index b72fed70..a0dde491 100644 --- a/plugins/module_utils/crypto/module_backends/certificate_ownca.py +++ b/plugins/module_utils/crypto/module_backends/certificate_ownca.py @@ -6,6 +6,7 @@ from __future__ import annotations import os +import typing as t from random import randrange from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -18,6 +19,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp cryptography_verify_certificate_signature, get_not_valid_after, get_not_valid_before, + is_potential_certificate_issuer_public_key, set_not_valid_after, set_not_valid_before, ) @@ -28,7 +30,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac ) from ansible_collections.community.crypto.plugins.module_utils.crypto.support import ( load_certificate, - load_privatekey, + load_certificate_issuer_privatekey, select_message_digest, ) from ansible_collections.community.crypto.plugins.module_utils.time import ( @@ -36,6 +38,17 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( ) +if t.TYPE_CHECKING: + import datetime + + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPrivateKeyTypes, + ) + + from ...argspec import ArgumentSpec + + try: import cryptography from cryptography import x509 @@ -45,13 +58,13 @@ except ImportError: class OwnCACertificateBackendCryptography(CertificateBackend): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(OwnCACertificateBackendCryptography, self).__init__(module) - self.create_subject_key_identifier = module.params[ - "ownca_create_subject_key_identifier" - ] - self.create_authority_key_identifier = module.params[ + self.create_subject_key_identifier: t.Literal[ + "create_if_not_provided", "always_create", "never_create" + ] = module.params["ownca_create_subject_key_identifier"] + self.create_authority_key_identifier: bool = module.params[ "ownca_create_authority_key_identifier" ] self.notBefore = get_relative_time_option( @@ -65,31 +78,40 @@ class OwnCACertificateBackendCryptography(CertificateBackend): with_timezone=CRYPTOGRAPHY_TIMEZONE, ) self.digest = select_message_digest(module.params["ownca_digest"]) - self.version = module.params["ownca_version"] + self.version: int = module.params["ownca_version"] self.serial_number = x509.random_serial_number() - self.ca_cert_path = module.params["ownca_path"] - self.ca_cert_content = module.params["ownca_content"] - if self.ca_cert_content is not None: - self.ca_cert_content = self.ca_cert_content.encode("utf-8") - self.ca_privatekey_path = module.params["ownca_privatekey_path"] - self.ca_privatekey_content = module.params["ownca_privatekey_content"] - if self.ca_privatekey_content is not None: - self.ca_privatekey_content = self.ca_privatekey_content.encode("utf-8") - self.ca_privatekey_passphrase = module.params["ownca_privatekey_passphrase"] + self.ca_cert_path: str | None = module.params["ownca_path"] + ca_cert_content: str | None = module.params["ownca_content"] + if ca_cert_content is not None: + self.ca_cert_content: bytes | None = ca_cert_content.encode("utf-8") + else: + self.ca_cert_content = None + self.ca_privatekey_path: str | None = module.params["ownca_privatekey_path"] + ca_privatekey_content: str | None = module.params["ownca_privatekey_content"] + if ca_privatekey_content is not None: + self.ca_privatekey_content: bytes | None = ca_privatekey_content.encode( + "utf-8" + ) + else: + self.ca_privatekey_content = None + self.ca_privatekey_passphrase: str | None = module.params[ + "ownca_privatekey_passphrase" + ] - if self.csr_content is None and self.csr_path is None: - raise CertificateError( - "csr_path or csr_content is required for ownca provider" - ) - if self.csr_content is None and not os.path.exists(self.csr_path): - raise CertificateError( - f"The certificate signing request file {self.csr_path} does not exist" - ) - if self.ca_cert_content is None and not os.path.exists(self.ca_cert_path): + if self.csr_content is None: + if self.csr_path is None: + raise CertificateError( + "csr_path or csr_content is required for ownca provider" + ) + if not os.path.exists(self.csr_path): + raise CertificateError( + f"The certificate signing request file {self.csr_path} does not exist" + ) + if self.ca_cert_path is not None and not os.path.exists(self.ca_cert_path): raise CertificateError( f"The CA certificate file {self.ca_cert_path} does not exist" ) - if self.ca_privatekey_content is None and not os.path.exists( + if self.ca_privatekey_path is not None and not os.path.exists( self.ca_privatekey_path ): raise CertificateError( @@ -101,8 +123,12 @@ class OwnCACertificateBackendCryptography(CertificateBackend): path=self.ca_cert_path, content=self.ca_cert_content, ) + if not is_potential_certificate_issuer_public_key(self.ca_cert.public_key()): + raise CertificateError( + "CA certificate's public key cannot be used to sign certificates" + ) try: - self.ca_private_key = load_privatekey( + self.ca_private_key = load_certificate_issuer_privatekey( path=self.ca_privatekey_path, content=self.ca_privatekey_content, passphrase=self.ca_privatekey_passphrase, @@ -125,8 +151,10 @@ class OwnCACertificateBackendCryptography(CertificateBackend): else: self.digest = None - def generate_certificate(self): + def generate_certificate(self) -> None: """(Re-)Generate certificate.""" + if self.csr is None: + raise AssertionError("Contract violation: csr has not been populated") cert_builder = x509.CertificateBuilder() cert_builder = cert_builder.subject_name(self.csr.subject) cert_builder = cert_builder.issuer_name(self.ca_cert.subject) @@ -166,10 +194,10 @@ class OwnCACertificateBackendCryptography(CertificateBackend): critical=False, ) except cryptography.x509.ExtensionNotFound: + public_key = self.ca_cert.public_key() + assert is_potential_certificate_issuer_public_key(public_key) cert_builder = cert_builder.add_extension( - x509.AuthorityKeyIdentifier.from_issuer_public_key( - self.ca_cert.public_key() - ), + x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key), critical=False, ) @@ -180,17 +208,24 @@ class OwnCACertificateBackendCryptography(CertificateBackend): self.cert = certificate - def get_certificate_data(self): + def get_certificate_data(self) -> bytes: """Return bytes for self.cert.""" + if self.cert is None: + raise AssertionError("Contract violation: cert has not been populated") return self.cert.public_bytes(Encoding.PEM) - def needs_regeneration(self): + def needs_regeneration( + self, + not_before: datetime.datetime | None = None, + not_after: datetime.datetime | None = None, + ) -> bool: if super(OwnCACertificateBackendCryptography, self).needs_regeneration( not_before=self.notBefore, not_after=self.notAfter ): return True self._ensure_existing_certificate_loaded() + assert self.existing_certificate is not None # Check whether certificate is signed by CA certificate if not cryptography_verify_certificate_signature( @@ -205,31 +240,33 @@ class OwnCACertificateBackendCryptography(CertificateBackend): # Check AuthorityKeyIdentifier if self.create_authority_key_identifier: try: - ext = self.ca_cert.extensions.get_extension_for_class( + ext_ski = self.ca_cert.extensions.get_extension_for_class( x509.SubjectKeyIdentifier ) expected_ext = ( x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( - ext.value + ext_ski.value ) ) except cryptography.x509.ExtensionNotFound: + public_key = self.ca_cert.public_key() + assert is_potential_certificate_issuer_public_key(public_key) expected_ext = x509.AuthorityKeyIdentifier.from_issuer_public_key( - self.ca_cert.public_key() + public_key ) try: - ext = self.existing_certificate.extensions.get_extension_for_class( + ext_aki = self.existing_certificate.extensions.get_extension_for_class( x509.AuthorityKeyIdentifier ) - if ext.value != expected_ext: + if ext_aki.value != expected_ext: return True except cryptography.x509.ExtensionNotFound: return True return False - def dump(self, include_certificate): + def dump(self, include_certificate: bool) -> dict[str, t.Any]: result = super(OwnCACertificateBackendCryptography, self).dump( include_certificate ) @@ -251,6 +288,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend): else: if self.cert is None: self.cert = self.existing_certificate + assert self.cert is not None result.update( { "notBefore": get_not_valid_before(self.cert).strftime( @@ -266,7 +304,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend): return result -def generate_serial_number(): +def generate_serial_number() -> int: """Generate a serial number for a certificate""" while True: result = randrange(0, 1 << 160) @@ -275,7 +313,7 @@ def generate_serial_number(): class OwnCACertificateProvider(CertificateProvider): - def validate_module_args(self, module): + def validate_module_args(self, module: AnsibleModule) -> None: if ( module.params["ownca_path"] is None and module.params["ownca_content"] is None @@ -291,14 +329,16 @@ class OwnCACertificateProvider(CertificateProvider): msg="One of ownca_privatekey_path and ownca_privatekey_content must be specified for the ownca provider." ) - def needs_version_two_certs(self, module): + def needs_version_two_certs(self, module: AnsibleModule) -> bool: return module.params["ownca_version"] == 2 - def create_backend(self, module): + def create_backend( + self, module: AnsibleModule + ) -> OwnCACertificateBackendCryptography: return OwnCACertificateBackendCryptography(module) -def add_ownca_provider_to_argument_spec(argument_spec): +def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None: argument_spec.argument_spec["provider"]["choices"].append("ownca") argument_spec.argument_spec.update( dict( diff --git a/plugins/module_utils/crypto/module_backends/certificate_selfsigned.py b/plugins/module_utils/crypto/module_backends/certificate_selfsigned.py index 7714b58f..65e1bf6c 100644 --- a/plugins/module_utils/crypto/module_backends/certificate_selfsigned.py +++ b/plugins/module_utils/crypto/module_backends/certificate_selfsigned.py @@ -6,6 +6,7 @@ from __future__ import annotations import os +import typing as t from random import randrange from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import ( @@ -14,6 +15,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp cryptography_verify_certificate_signature, get_not_valid_after, get_not_valid_before, + is_potential_certificate_issuer_private_key, set_not_valid_after, set_not_valid_before, ) @@ -30,6 +32,17 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( ) +if t.TYPE_CHECKING: + import datetime + + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPrivateKeyTypes, + ) + + from ...argspec import ArgumentSpec + + try: import cryptography from cryptography import x509 @@ -39,12 +52,14 @@ except ImportError: class SelfSignedCertificateBackendCryptography(CertificateBackend): - def __init__(self, module): + privatekey: CertificateIssuerPrivateKeyTypes + + def __init__(self, module: AnsibleModule) -> None: super(SelfSignedCertificateBackendCryptography, self).__init__(module) - self.create_subject_key_identifier = module.params[ - "selfsigned_create_subject_key_identifier" - ] + self.create_subject_key_identifier: t.Literal[ + "create_if_not_provided", "always_create", "never_create" + ] = module.params["selfsigned_create_subject_key_identifier"] self.notBefore = get_relative_time_option( module.params["selfsigned_not_before"], "selfsigned_not_before", @@ -56,14 +71,16 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend): with_timezone=CRYPTOGRAPHY_TIMEZONE, ) self.digest = select_message_digest(module.params["selfsigned_digest"]) - self.version = module.params["selfsigned_version"] + self.version: int = module.params["selfsigned_version"] self.serial_number = x509.random_serial_number() if self.csr_path is not None and not os.path.exists(self.csr_path): raise CertificateError( f"The certificate signing request file {self.csr_path} does not exist" ) - if self.privatekey_content is None and not os.path.exists(self.privatekey_path): + if self.privatekey_path is not None and not os.path.exists( + self.privatekey_path + ): raise CertificateError( f"The private key file {self.privatekey_path} does not exist" ) @@ -71,20 +88,10 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend): self._module = module self._ensure_private_key_loaded() - - self._ensure_csr_loaded() - if self.csr is None: - # Create empty CSR on the fly - csr = cryptography.x509.CertificateSigningRequestBuilder() - csr = csr.subject_name(cryptography.x509.Name([])) - digest = None - if cryptography_key_needs_digest_for_signing(self.privatekey): - digest = self.digest - if digest is None: - self.module.fail_json( - msg=f'Unsupported digest "{module.params["selfsigned_digest"]}"' - ) - self.csr = csr.sign(self.privatekey, digest) + if self.privatekey is None: + raise CertificateError("Private key has not been provided") + if not is_potential_certificate_issuer_private_key(self.privatekey): + raise CertificateError("Private key cannot be used to sign certificates") if cryptography_key_needs_digest_for_signing(self.privatekey): if self.digest is None: @@ -94,8 +101,21 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend): else: self.digest = None - def generate_certificate(self): + self._ensure_csr_loaded() + if self.csr is None: + # Create empty CSR on the fly + csr = cryptography.x509.CertificateSigningRequestBuilder() + csr = csr.subject_name(cryptography.x509.Name([])) + self.csr = csr.sign(self.privatekey, self.digest) + + def generate_certificate(self) -> None: """(Re-)Generate certificate.""" + if self.csr is None: + raise AssertionError("Contract violation: csr has not been populated") + if self.privatekey is None: + raise AssertionError( + "Contract violation: privatekey has not been populated" + ) try: cert_builder = x509.CertificateBuilder() cert_builder = cert_builder.subject_name(self.csr.subject) @@ -130,17 +150,26 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend): self.cert = certificate - def get_certificate_data(self): + def get_certificate_data(self) -> bytes: """Return bytes for self.cert.""" + if self.cert is None: + raise AssertionError("Contract violation: cert has not been populated") return self.cert.public_bytes(Encoding.PEM) - def needs_regeneration(self): + def needs_regeneration( + self, + not_before: datetime.datetime | None = None, + not_after: datetime.datetime | None = None, + ) -> bool: + assert self.privatekey is not None + if super(SelfSignedCertificateBackendCryptography, self).needs_regeneration( not_before=self.notBefore, not_after=self.notAfter ): return True self._ensure_existing_certificate_loaded() + assert self.existing_certificate is not None # Check whether certificate is signed by private key if not cryptography_verify_certificate_signature( @@ -150,7 +179,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend): return False - def dump(self, include_certificate): + def dump(self, include_certificate: bool) -> dict[str, t.Any]: result = super(SelfSignedCertificateBackendCryptography, self).dump( include_certificate ) @@ -166,6 +195,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend): else: if self.cert is None: self.cert = self.existing_certificate + assert self.cert is not None result.update( { "notBefore": get_not_valid_before(self.cert).strftime( @@ -181,7 +211,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend): return result -def generate_serial_number(): +def generate_serial_number() -> int: """Generate a serial number for a certificate""" while True: result = randrange(0, 1 << 160) @@ -190,7 +220,7 @@ def generate_serial_number(): class SelfSignedCertificateProvider(CertificateProvider): - def validate_module_args(self, module): + def validate_module_args(self, module: AnsibleModule) -> None: if ( module.params["privatekey_path"] is None and module.params["privatekey_content"] is None @@ -199,14 +229,16 @@ class SelfSignedCertificateProvider(CertificateProvider): msg="One of privatekey_path and privatekey_content must be specified for the selfsigned provider." ) - def needs_version_two_certs(self, module): + def needs_version_two_certs(self, module: AnsibleModule) -> bool: return module.params["selfsigned_version"] == 2 - def create_backend(self, module): + def create_backend( + self, module: AnsibleModule + ) -> SelfSignedCertificateBackendCryptography: return SelfSignedCertificateBackendCryptography(module) -def add_selfsigned_provider_to_argument_spec(argument_spec): +def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None: argument_spec.argument_spec["provider"]["choices"].append("selfsigned") argument_spec.argument_spec.update( dict( diff --git a/plugins/module_utils/crypto/module_backends/crl_info.py b/plugins/module_utils/crypto/module_backends/crl_info.py index a52a1ba5..07227cd2 100644 --- a/plugins/module_utils/crypto/module_backends/crl_info.py +++ b/plugins/module_utils/crypto/module_backends/crl_info.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_crl import ( TIMESTAMP_FORMAT, cryptography_decode_revoked_certificate, @@ -22,6 +24,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + PrivateKeyTypes, + ) + + from ....plugin_utils.action_module import AnsibleActionModule + from ....plugin_utils.filter_module import FilterModuleMock + + GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock] + + # crypto_utils MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION @@ -33,14 +47,19 @@ except ImportError: class CRLInfoRetrieval: - def __init__(self, module, content, list_revoked_certificates=True): + def __init__( + self, + module: GeneralAnsibleModule, + content: bytes, + list_revoked_certificates: bool = True, + ) -> None: # content must be a bytes string self.module = module self.content = content self.list_revoked_certificates = list_revoked_certificates self.name_encoding = module.params.get("name_encoding", "ignore") - def get_info(self): + def get_info(self) -> dict[str, t.Any]: self.crl_pem = identify_pem_format(self.content) try: if self.crl_pem: @@ -50,7 +69,7 @@ class CRLInfoRetrieval: except ValueError as e: self.module.fail_json(msg=f"Error while decoding CRL: {e}") - result = { + result: dict[str, t.Any] = { "changed": False, "format": "pem" if self.crl_pem else "der", "last_update": None, @@ -61,7 +80,11 @@ class CRLInfoRetrieval: } result["last_update"] = self.crl.last_update.strftime(TIMESTAMP_FORMAT) - result["next_update"] = self.crl.next_update.strftime(TIMESTAMP_FORMAT) + result["next_update"] = ( + self.crl.next_update.strftime(TIMESTAMP_FORMAT) + if self.crl.next_update + else None + ) result["digest"] = cryptography_oid_to_name( cryptography_get_signature_algorithm_oid_from_crl(self.crl) ) @@ -83,7 +106,9 @@ class CRLInfoRetrieval: return result -def get_crl_info(module, content, list_revoked_certificates=True): +def get_crl_info( + module: GeneralAnsibleModule, content: bytes, list_revoked_certificates: bool = True +) -> dict[str, t.Any]: assert_required_cryptography_version( module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) diff --git a/plugins/module_utils/crypto/module_backends/csr.py b/plugins/module_utils/crypto/module_backends/csr.py index f25d6b8e..90a40653 100644 --- a/plugins/module_utils/crypto/module_backends/csr.py +++ b/plugins/module_utils/crypto/module_backends/csr.py @@ -7,6 +7,7 @@ from __future__ import annotations import abc import binascii +import typing as t from ansible.module_utils.common.text.converters import to_text from ansible_collections.community.crypto.plugins.module_utils.argspec import ( @@ -26,13 +27,14 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp cryptography_name_to_oid, cryptography_parse_key_usage_params, cryptography_parse_relative_distinguished_name, + is_potential_certificate_issuer_public_key, ) from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr_info import ( get_csr_info, ) from ansible_collections.community.crypto.plugins.module_utils.crypto.support import ( + load_certificate_issuer_privatekey, load_certificate_request, - load_privatekey, parse_name_field, parse_ordered_name_field, select_message_digest, @@ -43,6 +45,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPrivateKeyTypes, + PrivateKeyTypes, + ) + + from ..cryptography_support import CertificatePrivateKeyTypes + + _ET = t.TypeVar("_ET", bound="cryptography.x509.ExtensionType") + + MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION try: @@ -69,49 +83,58 @@ class CertificateSigningRequestError(OpenSSLObjectError): class CertificateSigningRequestBackend(metaclass=abc.ABCMeta): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module - self.digest = module.params["digest"] - self.privatekey_path = module.params["privatekey_path"] - self.privatekey_content = module.params["privatekey_content"] - if self.privatekey_content is not None: - self.privatekey_content = self.privatekey_content.encode("utf-8") - self.privatekey_passphrase = module.params["privatekey_passphrase"] - self.version = module.params["version"] - self.subjectAltName = module.params["subject_alt_name"] - self.subjectAltName_critical = module.params["subject_alt_name_critical"] - self.keyUsage = module.params["key_usage"] - self.keyUsage_critical = module.params["key_usage_critical"] - self.extendedKeyUsage = module.params["extended_key_usage"] - self.extendedKeyUsage_critical = module.params["extended_key_usage_critical"] - self.basicConstraints = module.params["basic_constraints"] - self.basicConstraints_critical = module.params["basic_constraints_critical"] - self.ocspMustStaple = module.params["ocsp_must_staple"] - self.ocspMustStaple_critical = module.params["ocsp_must_staple_critical"] - self.name_constraints_permitted = ( + self.digest: str = module.params["digest"] + self.privatekey_path: str | None = module.params["privatekey_path"] + privatekey_content: str | None = module.params["privatekey_content"] + if privatekey_content is not None: + self.privatekey_content: bytes | None = privatekey_content.encode("utf-8") + else: + self.privatekey_content = None + self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"] + self.version: t.Literal[1] = module.params["version"] + self.subjectAltName: list[str] | None = module.params["subject_alt_name"] + self.subjectAltName_critical: bool = module.params["subject_alt_name_critical"] + self.keyUsage: list[str] | None = module.params["key_usage"] + self.keyUsage_critical: bool = module.params["key_usage_critical"] + self.extendedKeyUsage: list[str] | None = module.params["extended_key_usage"] + self.extendedKeyUsage_critical: bool = module.params[ + "extended_key_usage_critical" + ] + self.basicConstraints: list[str] | None = module.params["basic_constraints"] + self.basicConstraints_critical: bool = module.params[ + "basic_constraints_critical" + ] + self.ocspMustStaple: bool = module.params["ocsp_must_staple"] + self.ocspMustStaple_critical: bool = module.params["ocsp_must_staple_critical"] + self.name_constraints_permitted: list[str] = ( module.params["name_constraints_permitted"] or [] ) - self.name_constraints_excluded = ( + self.name_constraints_excluded: list[str] = ( module.params["name_constraints_excluded"] or [] ) - self.name_constraints_critical = module.params["name_constraints_critical"] - self.create_subject_key_identifier = module.params[ + self.name_constraints_critical: bool = module.params[ + "name_constraints_critical" + ] + self.create_subject_key_identifier: bool = module.params[ "create_subject_key_identifier" ] - self.subject_key_identifier = module.params["subject_key_identifier"] - self.authority_key_identifier = module.params["authority_key_identifier"] - self.authority_cert_issuer = module.params["authority_cert_issuer"] - self.authority_cert_serial_number = module.params[ + subject_key_identifier: str | None = module.params["subject_key_identifier"] + authority_key_identifier: str | None = module.params["authority_key_identifier"] + self.authority_cert_issuer: list[str] | None = module.params[ + "authority_cert_issuer" + ] + self.authority_cert_serial_number: int = module.params[ "authority_cert_serial_number" ] - self.crl_distribution_points = module.params["crl_distribution_points"] - self.csr = None - self.privatekey = None + self.crl_distribution_points: ( + list[cryptography.x509.DistributionPoint] | None + ) = None + self.csr: cryptography.x509.CertificateSigningRequest | None = None + self.privatekey: CertificateIssuerPrivateKeyTypes | None = None - if ( - self.create_subject_key_identifier - and self.subject_key_identifier is not None - ): + if self.create_subject_key_identifier and subject_key_identifier is not None: module.fail_json( msg="subject_key_identifier cannot be specified if create_subject_key_identifier is true" ) @@ -153,35 +176,37 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta): self.using_common_name_for_san = True break - if self.subject_key_identifier is not None: + self.subject_key_identifier: bytes | None = None + if subject_key_identifier is not None: try: self.subject_key_identifier = binascii.unhexlify( - self.subject_key_identifier.replace(":", "") + subject_key_identifier.replace(":", "") ) except Exception as e: raise CertificateSigningRequestError( f"Cannot parse subject_key_identifier: {e}" ) - if self.authority_key_identifier is not None: + self.authority_key_identifier: bytes | None = None + if authority_key_identifier is not None: try: self.authority_key_identifier = binascii.unhexlify( - self.authority_key_identifier.replace(":", "") + authority_key_identifier.replace(":", "") ) except Exception as e: raise CertificateSigningRequestError( f"Cannot parse authority_key_identifier: {e}" ) - self.existing_csr = None - self.existing_csr_bytes = None + self.existing_csr: cryptography.x509.CertificateSigningRequest | None = None + self.existing_csr_bytes: bytes | None = None self.diff_before = self._get_info(None) self.diff_after = self._get_info(None) - def _get_info(self, data): + def _get_info(self, data: bytes | None) -> dict[str, t.Any]: if data is None: - return dict() + return {} try: result = get_csr_info( self.module, @@ -195,30 +220,28 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta): return dict(can_parse_csr=False) @abc.abstractmethod - def generate_csr(self): + def generate_csr(self) -> None: """(Re-)Generate CSR.""" - pass @abc.abstractmethod - def get_csr_data(self): + def get_csr_data(self) -> bytes: """Return bytes for self.csr.""" - pass - def set_existing(self, csr_bytes): + def set_existing(self, csr_bytes: bytes | None) -> None: """Set existing CSR bytes. None indicates that the CSR does not exist.""" self.existing_csr_bytes = csr_bytes self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes) - def has_existing(self): + def has_existing(self) -> bool: """Query whether an existing CSR is/has been there.""" return self.existing_csr_bytes is not None - def _ensure_private_key_loaded(self): + def _ensure_private_key_loaded(self) -> None: """Load the provided private key into self.privatekey.""" if self.privatekey is not None: return try: - self.privatekey = load_privatekey( + self.privatekey = load_certificate_issuer_privatekey( path=self.privatekey_path, content=self.privatekey_content, passphrase=self.privatekey_passphrase, @@ -227,11 +250,10 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta): raise CertificateSigningRequestError(exc) @abc.abstractmethod - def _check_csr(self): + def _check_csr(self) -> bool: """Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated.""" - pass - def needs_regeneration(self): + def needs_regeneration(self) -> bool: """Check whether a regeneration is necessary.""" if self.existing_csr_bytes is None: return True @@ -245,9 +267,9 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta): self._ensure_private_key_loaded() return not self._check_csr() - def dump(self, include_csr): + def dump(self, include_csr: bool) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" - result = { + result: dict[str, t.Any] = { "privatekey": self.privatekey_path, "subject": self.subject, "subjectAltName": self.subjectAltName, @@ -274,44 +296,49 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta): return result -def parse_crl_distribution_points(module, crl_distribution_points): +def parse_crl_distribution_points( + module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]] +) -> list[cryptography.x509.DistributionPoint]: result = [] for index, parse_crl_distribution_point in enumerate(crl_distribution_points): try: - params = dict( - full_name=None, - relative_name=None, - crl_issuer=None, - reasons=None, - ) + full_name = None + relative_name = None + crl_issuer = None + reasons = None if parse_crl_distribution_point["full_name"] is not None: if not parse_crl_distribution_point["full_name"]: raise OpenSSLObjectError("full_name must not be empty") - params["full_name"] = [ + full_name = [ cryptography_get_name(name, "full name") for name in parse_crl_distribution_point["full_name"] ] if parse_crl_distribution_point["relative_name"] is not None: if not parse_crl_distribution_point["relative_name"]: raise OpenSSLObjectError("relative_name must not be empty") - params["relative_name"] = ( - cryptography_parse_relative_distinguished_name( - parse_crl_distribution_point["relative_name"] - ) + relative_name = cryptography_parse_relative_distinguished_name( + parse_crl_distribution_point["relative_name"] ) if parse_crl_distribution_point["crl_issuer"] is not None: if not parse_crl_distribution_point["crl_issuer"]: raise OpenSSLObjectError("crl_issuer must not be empty") - params["crl_issuer"] = [ + crl_issuer = [ cryptography_get_name(name, "CRL issuer") for name in parse_crl_distribution_point["crl_issuer"] ] if parse_crl_distribution_point["reasons"] is not None: - reasons = [] + reasons_list = [] for reason in parse_crl_distribution_point["reasons"]: - reasons.append(REVOCATION_REASON_MAP[reason]) - params["reasons"] = frozenset(reasons) - result.append(cryptography.x509.DistributionPoint(**params)) + reasons_list.append(REVOCATION_REASON_MAP[reason]) + reasons = frozenset(reasons_list) + result.append( + cryptography.x509.DistributionPoint( + full_name=full_name, + relative_name=relative_name, + crl_issuer=crl_issuer, + reasons=reasons, + ) + ) except (OpenSSLObjectError, ValueError) as e: raise OpenSSLObjectError( f"Error while parsing CRL distribution point #{index}: {e}" @@ -321,21 +348,25 @@ def parse_crl_distribution_points(module, crl_distribution_points): # Implementation with using cryptography class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(CertificateSigningRequestCryptographyBackend, self).__init__(module) if self.version != 1: module.warn( "The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)" ) - if self.crl_distribution_points: + crl_distribution_points: list[dict[str, t.Any]] | None = module.params[ + "crl_distribution_points" + ] + if crl_distribution_points: self.crl_distribution_points = parse_crl_distribution_points( - module, self.crl_distribution_points + module, crl_distribution_points ) - def generate_csr(self): + def generate_csr(self) -> None: """(Re-)Generate CSR.""" self._ensure_private_key_loaded() + assert self.privatekey is not None csr = cryptography.x509.CertificateSigningRequestBuilder() try: @@ -412,6 +443,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack raise OpenSSLObjectError(f"Error while parsing name constraint: {e}") if self.create_subject_key_identifier: + if not is_potential_certificate_issuer_public_key( + self.privatekey.public_key() + ): + raise OpenSSLObjectError( + "Private key can not be used to create subject key identifier" + ) csr = csr.add_extension( cryptography.x509.SubjectKeyIdentifier.from_public_key( self.privatekey.public_key() @@ -450,7 +487,10 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack critical=False, ) - digest = None + # csr.sign() does not accept some digests we theoretically could have in digest. + # For that reason we use type t.Any here. csr.sign() will complain if + # the digest is not acceptable. + digest: t.Any | None = None if cryptography_key_needs_digest_for_signing(self.privatekey): digest = select_message_digest(self.digest) if digest is None: @@ -482,16 +522,22 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack + "This is probably caused by an invalid Subject Alternative DNS Name." ) - def get_csr_data(self): + def get_csr_data(self) -> bytes: """Return bytes for self.csr.""" + if self.csr is None: + raise AssertionError("Violated contract: csr is not populated") return self.csr.public_bytes( cryptography.hazmat.primitives.serialization.Encoding.PEM ) - def _check_csr(self): + def _check_csr(self) -> bool: """Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated.""" + if self.existing_csr is None: + raise AssertionError("Violated contract: existing_csr is not populated") + if self.privatekey is None: + raise AssertionError("Violated contract: privatekey is not populated") - def _check_subject(csr): + def _check_subject(csr: cryptography.x509.CertificateSigningRequest) -> bool: subject = [ (cryptography_name_to_oid(entry[0]), to_text(entry[1])) for entry in self.subject @@ -502,12 +548,14 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack else: return set(subject) == set(current_subject) - def _find_extension(extensions, exttype): + def _find_extension( + extensions: cryptography.x509.Extensions, exttype: type[_ET] + ) -> cryptography.x509.Extension[_ET] | None: return next( (ext for ext in extensions if isinstance(ext.value, exttype)), None ) - def _check_subjectAltName(extensions): + def _check_subjectAltName(extensions: cryptography.x509.Extensions) -> bool: current_altnames_ext = _find_extension( extensions, cryptography.x509.SubjectAlternativeName ) @@ -526,12 +574,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack ) if set(altnames) != set(current_altnames): return False - if altnames: + if altnames and current_altnames_ext: if current_altnames_ext.critical != self.subjectAltName_critical: return False return True - def _check_keyUsage(extensions): + def _check_keyUsage(extensions: cryptography.x509.Extensions) -> bool: current_keyusage_ext = _find_extension( extensions, cryptography.x509.KeyUsage ) @@ -547,7 +595,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack return False return True - def _check_extenededKeyUsage(extensions): + def _check_extenededKeyUsage(extensions: cryptography.x509.Extensions) -> bool: current_usages_ext = _find_extension( extensions, cryptography.x509.ExtendedKeyUsage ) @@ -566,12 +614,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack ) if set(current_usages) != set(usages): return False - if usages: + if usages and current_usages_ext: if current_usages_ext.critical != self.extendedKeyUsage_critical: return False return True - def _check_basicConstraints(extensions): + def _check_basicConstraints(extensions: cryptography.x509.Extensions) -> bool: bc_ext = _find_extension(extensions, cryptography.x509.BasicConstraints) current_ca = bc_ext.value.ca if bc_ext else False current_path_length = bc_ext.value.path_length if bc_ext else None @@ -591,7 +639,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack else: return bc_ext is None - def _check_ocspMustStaple(extensions): + def _check_ocspMustStaple(extensions: cryptography.x509.Extensions) -> bool: tlsfeature_ext = _find_extension(extensions, cryptography.x509.TLSFeature) if self.ocspMustStaple: if ( @@ -606,7 +654,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack else: return tlsfeature_ext is None - def _check_nameConstraints(extensions): + def _check_nameConstraints(extensions: cryptography.x509.Extensions) -> bool: current_nc_ext = _find_extension( extensions, cryptography.x509.NameConstraints ) @@ -638,12 +686,14 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack current_nc_excl ): return False - if nc_perm or nc_excl: + if (nc_perm or nc_excl) and current_nc_ext: if current_nc_ext.critical != self.name_constraints_critical: return False return True - def _check_subject_key_identifier(extensions): + def _check_subject_key_identifier( + extensions: cryptography.x509.Extensions, + ) -> bool: ext = _find_extension(extensions, cryptography.x509.SubjectKeyIdentifier) if ( self.create_subject_key_identifier @@ -652,6 +702,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack if not ext or ext.critical: return False if self.create_subject_key_identifier: + assert self.privatekey is not None digest = cryptography.x509.SubjectKeyIdentifier.from_public_key( self.privatekey.public_key() ).digest @@ -661,7 +712,9 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack else: return ext is None - def _check_authority_key_identifier(extensions): + def _check_authority_key_identifier( + extensions: cryptography.x509.Extensions, + ) -> bool: ext = _find_extension(extensions, cryptography.x509.AuthorityKeyIdentifier) if ( self.authority_key_identifier is not None @@ -688,7 +741,9 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack else: return ext is None - def _check_crl_distribution_points(extensions): + def _check_crl_distribution_points( + extensions: cryptography.x509.Extensions, + ) -> bool: ext = _find_extension(extensions, cryptography.x509.CRLDistributionPoints) if self.crl_distribution_points is None: return ext is None @@ -696,7 +751,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack return False return list(ext.value) == self.crl_distribution_points - def _check_extensions(csr): + def _check_extensions(csr: cryptography.x509.CertificateSigningRequest) -> bool: extensions = csr.extensions return ( _check_subjectAltName(extensions) @@ -710,7 +765,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack and _check_crl_distribution_points(extensions) ) - def _check_signature(csr): + def _check_signature(csr: cryptography.x509.CertificateSigningRequest) -> bool: if not csr.is_signature_valid: return False # To check whether public key of CSR belongs to private key, @@ -719,6 +774,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack cryptography.hazmat.primitives.serialization.Encoding.PEM, cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo, ) + assert self.privatekey is not None key_b = self.privatekey.public_key().public_bytes( cryptography.hazmat.primitives.serialization.Encoding.PEM, cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo, @@ -732,14 +788,16 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack ) -def select_backend(module): +def select_backend( + module: AnsibleModule, +) -> CertificateSigningRequestCryptographyBackend: assert_required_cryptography_version( module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) return CertificateSigningRequestCryptographyBackend(module) -def get_csr_argument_spec(): +def get_csr_argument_spec() -> ArgumentSpec: return ArgumentSpec( argument_spec=dict( digest=dict(type="str", default="sha256"), diff --git a/plugins/module_utils/crypto/module_backends/csr_info.py b/plugins/module_utils/crypto/module_backends/csr_info.py index 3de6ee68..4de8ac26 100644 --- a/plugins/module_utils/crypto/module_backends/csr_info.py +++ b/plugins/module_utils/crypto/module_backends/csr_info.py @@ -8,6 +8,7 @@ from __future__ import annotations import abc import binascii +import typing as t from ansible.module_utils.common.text.converters import to_native from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import ( @@ -27,6 +28,19 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + CertificatePublicKeyTypes, + PrivateKeyTypes, + ) + + from ....plugin_utils.action_module import AnsibleActionModule + from ....plugin_utils.filter_module import FilterModuleMock + + GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock] + + MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION try: @@ -41,66 +55,69 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ" class CSRInfoRetrieval(metaclass=abc.ABCMeta): - def __init__(self, module, content, validate_signature): - # content must be a bytes string + def __init__( + self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool + ) -> None: self.module = module self.content = content self.validate_signature = validate_signature @abc.abstractmethod - def _get_subject_ordered(self): + def _get_subject_ordered(self) -> list[list[str]]: pass @abc.abstractmethod - def _get_key_usage(self): + def _get_key_usage(self) -> tuple[list[str] | None, bool]: pass @abc.abstractmethod - def _get_extended_key_usage(self): + def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]: pass @abc.abstractmethod - def _get_basic_constraints(self): + def _get_basic_constraints(self) -> tuple[list[str] | None, bool]: pass @abc.abstractmethod - def _get_ocsp_must_staple(self): + def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]: pass @abc.abstractmethod - def _get_subject_alt_name(self): + def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]: pass @abc.abstractmethod - def _get_name_constraints(self): + def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]: pass @abc.abstractmethod - def _get_public_key_pem(self): + def _get_public_key_pem(self) -> bytes: pass @abc.abstractmethod - def _get_public_key_object(self): + def _get_public_key_object(self) -> CertificatePublicKeyTypes: pass @abc.abstractmethod - def _get_subject_key_identifier(self): + def _get_subject_key_identifier(self) -> bytes | None: pass @abc.abstractmethod - def _get_authority_key_identifier(self): + def _get_authority_key_identifier( + self, + ) -> tuple[bytes | None, list[str] | None, int | None]: pass @abc.abstractmethod - def _get_all_extensions(self): + def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]: pass @abc.abstractmethod - def _is_signature_valid(self): + def _is_signature_valid(self) -> bool: pass - def get_info(self, prefer_one_fingerprint=False): - result = dict() + def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} self.csr = load_certificate_request( None, content=self.content, @@ -145,15 +162,17 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta): } ) - ski = self._get_subject_key_identifier() - if ski is not None: - ski = binascii.hexlify(ski).decode("ascii") + ski_bytes = self._get_subject_key_identifier() + ski = None + if ski_bytes is not None: + ski = binascii.hexlify(ski_bytes).decode("ascii") ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)]) result["subject_key_identifier"] = ski - aki, aci, acsn = self._get_authority_key_identifier() - if aki is not None: - aki = binascii.hexlify(aki).decode("ascii") + aki_bytes, aci, acsn = self._get_authority_key_identifier() + aki = None + if aki_bytes is not None: + aki = binascii.hexlify(aki_bytes).decode("ascii") aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)]) result["authority_key_identifier"] = aki result["authority_cert_issuer"] = aci @@ -170,19 +189,25 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta): class CSRInfoRetrievalCryptography(CSRInfoRetrieval): """Validate the supplied CSR, using the cryptography backend""" - def __init__(self, module, content, validate_signature): + def __init__( + self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool + ) -> None: super(CSRInfoRetrievalCryptography, self).__init__( module, content, validate_signature ) - self.name_encoding = module.params.get("name_encoding", "ignore") + self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get( + "name_encoding", "ignore" + ) - def _get_subject_ordered(self): - result = [] + def _get_subject_ordered(self) -> list[list[str]]: + result: list[list[str]] = [] for attribute in self.csr.subject: - result.append([cryptography_oid_to_name(attribute.oid), attribute.value]) + result.append( + [cryptography_oid_to_name(attribute.oid), to_native(attribute.value)] + ) return result - def _get_key_usage(self): + def _get_key_usage(self) -> tuple[list[str] | None, bool]: try: current_key_ext = self.csr.extensions.get_extension_for_class(x509.KeyUsage) current_key_usage = current_key_ext.value @@ -229,7 +254,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, False - def _get_extended_key_usage(self): + def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]: try: ext_keyusage_ext = self.csr.extensions.get_extension_for_class( x509.ExtendedKeyUsage @@ -243,7 +268,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, False - def _get_basic_constraints(self): + def _get_basic_constraints(self) -> tuple[list[str] | None, bool]: try: ext_keyusage_ext = self.csr.extensions.get_extension_for_class( x509.BasicConstraints @@ -255,7 +280,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, False - def _get_ocsp_must_staple(self): + def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]: try: # This only works with cryptography >= 2.1 tlsfeature_ext = self.csr.extensions.get_extension_for_class( @@ -268,7 +293,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, False - def _get_subject_alt_name(self): + def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]: try: san_ext = self.csr.extensions.get_extension_for_class( x509.SubjectAlternativeName @@ -281,7 +306,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, False - def _get_name_constraints(self): + def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]: try: nc_ext = self.csr.extensions.get_extension_for_class(x509.NameConstraints) permitted = [ @@ -296,23 +321,25 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, None, False - def _get_public_key_pem(self): + def _get_public_key_pem(self) -> bytes: return self.csr.public_key().public_bytes( serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo, ) - def _get_public_key_object(self): + def _get_public_key_object(self) -> CertificatePublicKeyTypes: return self.csr.public_key() - def _get_subject_key_identifier(self): + def _get_subject_key_identifier(self) -> bytes | None: try: ext = self.csr.extensions.get_extension_for_class(x509.SubjectKeyIdentifier) return ext.value.digest except cryptography.x509.ExtensionNotFound: return None - def _get_authority_key_identifier(self): + def _get_authority_key_identifier( + self, + ) -> tuple[bytes | None, list[str] | None, int | None]: try: ext = self.csr.extensions.get_extension_for_class( x509.AuthorityKeyIdentifier @@ -331,23 +358,28 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval): except cryptography.x509.ExtensionNotFound: return None, None, None - def _get_all_extensions(self): + def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]: return cryptography_get_extensions_from_csr(self.csr) - def _is_signature_valid(self): + def _is_signature_valid(self) -> bool: return self.csr.is_signature_valid def get_csr_info( - module, content, validate_signature=True, prefer_one_fingerprint=False -): + module: GeneralAnsibleModule, + content: bytes, + validate_signature: bool = True, + prefer_one_fingerprint: bool = False, +) -> dict[str, t.Any]: info = CSRInfoRetrievalCryptography( module, content, validate_signature=validate_signature ) return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint) -def select_backend(module, content, validate_signature=True): +def select_backend( + module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True +) -> CSRInfoRetrieval: assert_required_cryptography_version( module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) diff --git a/plugins/module_utils/crypto/module_backends/privatekey.py b/plugins/module_utils/crypto/module_backends/privatekey.py index def4f3b7..60fab40b 100644 --- a/plugins/module_utils/crypto/module_backends/privatekey.py +++ b/plugins/module_utils/crypto/module_backends/privatekey.py @@ -8,6 +8,7 @@ from __future__ import annotations import abc import base64 import traceback +import typing as t from ansible.module_utils.common.text.converters import to_bytes from ansible_collections.community.crypto.plugins.module_utils.argspec import ( @@ -33,6 +34,17 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + PrivateKeyTypes, + ) + + from ....plugin_utils.action_module import AnsibleActionModule + + GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule] + + MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION try: @@ -64,29 +76,37 @@ class PrivateKeyError(OpenSSLObjectError): class PrivateKeyBackend(metaclass=abc.ABCMeta): - def __init__(self, module): + def __init__(self, module: GeneralAnsibleModule) -> None: self.module = module - self.type = module.params["type"] - self.size = module.params["size"] - self.curve = module.params["curve"] - self.passphrase = module.params["passphrase"] - self.cipher = module.params["cipher"] - self.format = module.params["format"] - self.format_mismatch = module.params.get("format_mismatch", "regenerate") - self.regenerate = module.params.get("regenerate", "full_idempotence") + self.type: t.Literal[ + "DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448" + ] = module.params["type"] + self.size: int = module.params["size"] + self.curve: str | None = module.params["curve"] + self.passphrase: str | None = module.params["passphrase"] + self.cipher: str = module.params["cipher"] + self.format: t.Literal["pkcs1", "pkcs8", "raw", "auto", "auto_ignore"] = ( + module.params["format"] + ) + self.format_mismatch: t.Literal["regenerate", "convert"] = module.params.get( + "format_mismatch", "regenerate" + ) + self.regenerate: t.Literal[ + "never", "fail", "partial_idempotence", "full_idempotence", "always" + ] = module.params.get("regenerate", "full_idempotence") - self.private_key = None + self.private_key: PrivateKeyTypes | None = None - self.existing_private_key = None - self.existing_private_key_bytes = None + self.existing_private_key: PrivateKeyTypes | None = None + self.existing_private_key_bytes: bytes | None = None self.diff_before = self._get_info(None) self.diff_after = self._get_info(None) - def _get_info(self, data): + def _get_info(self, data: bytes | None) -> dict[str, t.Any]: if data is None: - return dict() - result = dict(can_parse_key=False) + return {} + result: dict[str, t.Any] = {"can_parse_key": False} try: result.update( get_privatekey_info( @@ -106,11 +126,11 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta): return result @abc.abstractmethod - def generate_private_key(self): + def generate_private_key(self) -> None: """(Re-)Generate private key.""" pass - def convert_private_key(self): + def convert_private_key(self) -> None: """Convert existing private key (self.existing_private_key) to new private key (self.private_key). This is effectively a copy without active conversion. The conversion is done @@ -121,42 +141,37 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta): self.private_key = self.existing_private_key @abc.abstractmethod - def get_private_key_data(self): + def get_private_key_data(self) -> bytes: """Return bytes for self.private_key.""" - pass - def set_existing(self, privatekey_bytes): + def set_existing(self, privatekey_bytes: bytes | None) -> None: """Set existing private key bytes. None indicates that the key does not exist.""" self.existing_private_key_bytes = privatekey_bytes self.diff_after = self.diff_before = self._get_info( self.existing_private_key_bytes ) - def has_existing(self): + def has_existing(self) -> bool: """Query whether an existing private key is/has been there.""" return self.existing_private_key_bytes is not None @abc.abstractmethod - def _check_passphrase(self): + def _check_passphrase(self) -> bool: """Check whether provided passphrase matches, assuming self.existing_private_key_bytes has been populated.""" - pass @abc.abstractmethod - def _ensure_existing_private_key_loaded(self): + def _ensure_existing_private_key_loaded(self) -> None: """Make sure that self.existing_private_key is populated from self.existing_private_key_bytes.""" - pass @abc.abstractmethod - def _check_size_and_type(self): + def _check_size_and_type(self) -> bool: """Check whether provided size and type matches, assuming self.existing_private_key has been populated.""" - pass @abc.abstractmethod - def _check_format(self): + def _check_format(self) -> bool: """Check whether the key file format, assuming self.existing_private_key and self.existing_private_key_bytes has been populated.""" - pass - def needs_regeneration(self): + def needs_regeneration(self) -> bool: """Check whether a regeneration is necessary.""" if self.regenerate == "always": return True @@ -194,7 +209,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta): ) return False - def needs_conversion(self): + def needs_conversion(self) -> bool: """Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False.""" # During conversion step, convert if format does not match and format_mismatch == 'convert' self._ensure_existing_private_key_loaded() @@ -204,7 +219,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta): and not self._check_format() ) - def _get_fingerprint(self): + def _get_fingerprint(self) -> dict[str, str] | None: if self.private_key: return get_fingerprint_of_privatekey(self.private_key) try: @@ -214,8 +229,9 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta): pass if self.existing_private_key: return get_fingerprint_of_privatekey(self.existing_private_key) + return None - def dump(self, include_key): + def dump(self, include_key: bool) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" if not self.private_key: @@ -224,7 +240,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta): except Exception: # Ignore errors pass - result = { + result: dict[str, t.Any] = { "type": self.type, "size": self.size, "fingerprint": self._get_fingerprint(), @@ -253,38 +269,57 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta): return result -# Implementation with using cryptography -class PrivateKeyCryptographyBackend(PrivateKeyBackend): +class _Curve: + def __init__( + self, + name: str, + ectype: str, + deprecated: bool, + ) -> None: + self.name = name + self.ectype = ectype + self.deprecated = deprecated - def _get_ec_class(self, ectype): - ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(ectype) + def _get_ec_class( + self, module: GeneralAnsibleModule + ) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]: + ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore if ecclass is None: - self.module.fail_json( - msg=f"Your cryptography version does not support {ectype}" + module.fail_json( + msg=f"Your cryptography version does not support {self.ectype}" ) return ecclass - def _add_curve(self, name, ectype, deprecated=False): - def create(size): - ecclass = self._get_ec_class(ectype) - return ecclass() + def create( + self, size: int, module: GeneralAnsibleModule + ) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve: + ecclass = self._get_ec_class(module) + return ecclass() - def verify(privatekey): - ecclass = self._get_ec_class(ectype) - return isinstance( - privatekey.private_numbers().public_numbers.curve, ecclass - ) + def verify( + self, + privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey, + module: GeneralAnsibleModule, + ) -> bool: + ecclass = self._get_ec_class(module) + return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass) - self.curves[name] = { - "create": create, - "verify": verify, - "deprecated": deprecated, - } - def __init__(self, module): +# Implementation with using cryptography +class PrivateKeyCryptographyBackend(PrivateKeyBackend): + + def _add_curve( + self, + name: str, + ectype: str, + deprecated: bool = False, + ) -> None: + self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated) + + def __init__(self, module: GeneralAnsibleModule) -> None: super(PrivateKeyCryptographyBackend, self).__init__(module=module) - self.curves = dict() + self.curves: dict[str, _Curve] = {} self._add_curve("secp224r1", "SECP224R1") self._add_curve("secp256k1", "SECP256K1") self._add_curve("secp256r1", "SECP256R1") @@ -305,15 +340,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend): self._add_curve("brainpoolP384r1", "BrainpoolP384R1", deprecated=True) self._add_curve("brainpoolP512r1", "BrainpoolP512R1", deprecated=True) - def _get_wanted_format(self): + def _get_wanted_format(self) -> t.Literal["pkcs1", "pkcs8", "raw"]: if self.format not in ("auto", "auto_ignore"): - return self.format + return self.format # type: ignore if self.type in ("X25519", "X448", "Ed25519", "Ed448"): return "pkcs8" else: return "pkcs1" - def generate_private_key(self): + def generate_private_key(self) -> None: """(Re-)Generate private key.""" try: if self.type == "RSA": @@ -346,13 +381,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend): cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate() ) if self.type == "ECC" and self.curve in self.curves: - if self.curves[self.curve]["deprecated"]: + if self.curves[self.curve].deprecated: self.module.warn( f"Elliptic curves of type {self.curve} should not be used for new keys!" ) self.private_key = ( cryptography.hazmat.primitives.asymmetric.ec.generate_private_key( - curve=self.curves[self.curve]["create"](self.size), + curve=self.curves[self.curve].create( + size=self.size, module=self.module + ), ) ) except cryptography.exceptions.UnsupportedAlgorithm: @@ -360,22 +397,24 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend): msg=f"Cryptography backend does not support the algorithm required for {self.type}" ) - def get_private_key_data(self): + def get_private_key_data(self) -> bytes: """Return bytes for self.private_key""" + if self.private_key is None: + raise AssertionError("private_key not set") # Select export format and encoding try: - export_format = self._get_wanted_format() + export_format_txt = self._get_wanted_format() export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM - if export_format == "pkcs1": + if export_format_txt == "pkcs1": # "TraditionalOpenSSL" format is PKCS1 export_format = ( cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL ) - elif export_format == "pkcs8": + elif export_format_txt == "pkcs8": export_format = ( cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8 ) - elif export_format == "raw": + elif export_format_txt == "raw": export_format = ( cryptography.hazmat.primitives.serialization.PrivateFormat.Raw ) @@ -388,9 +427,9 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend): ) # Select key encryption - encryption_algorithm = ( - cryptography.hazmat.primitives.serialization.NoEncryption() - ) + encryption_algorithm: ( + cryptography.hazmat.primitives.serialization.KeySerializationEncryption + ) = cryptography.hazmat.primitives.serialization.NoEncryption() if self.cipher and self.passphrase: if self.cipher == "auto": encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption( @@ -418,8 +457,10 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend): exception=traceback.format_exc(), ) - def _load_privatekey(self): + def _load_privatekey(self) -> PrivateKeyTypes: data = self.existing_private_key_bytes + if data is None: + raise AssertionError("existing_private_key_bytes not set") try: # Interpret bytes depending on format. format = identify_private_key_format(data) @@ -460,11 +501,13 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend): except Exception as e: raise PrivateKeyError(e) - def _ensure_existing_private_key_loaded(self): + def _ensure_existing_private_key_loaded(self) -> None: if self.existing_private_key is None and self.has_existing(): self.existing_private_key = self._load_privatekey() - def _check_passphrase(self): + def _check_passphrase(self) -> bool: + if self.existing_private_key_bytes is None: + raise AssertionError("existing_private_key_bytes not set") try: format = identify_private_key_format(self.existing_private_key_bytes) if format == "raw": @@ -475,7 +518,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend): # provided. return self.passphrase is None else: - return ( + return bool( cryptography.hazmat.primitives.serialization.load_pem_private_key( self.existing_private_key_bytes, None if self.passphrase is None else to_bytes(self.passphrase), @@ -484,7 +527,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend): except Exception: return False - def _check_size_and_type(self): + def _check_size_and_type(self) -> bool: if isinstance( self.existing_private_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey, @@ -527,11 +570,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend): return False if self.curve not in self.curves: return False - return self.curves[self.curve]["verify"](self.existing_private_key) + return self.curves[self.curve].verify( + self.existing_private_key, module=self.module + ) return False - def _check_format(self): + def _check_format(self) -> bool: + if self.existing_private_key_bytes is None: + raise AssertionError("existing_private_key_bytes not set") if self.format == "auto_ignore": return True try: @@ -541,14 +588,14 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend): return False -def select_backend(module): +def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend: assert_required_cryptography_version( module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) return PrivateKeyCryptographyBackend(module) -def get_privatekey_argument_spec(): +def get_privatekey_argument_spec() -> ArgumentSpec: return ArgumentSpec( argument_spec=dict( size=dict(type="int", default=4096), @@ -607,6 +654,6 @@ def get_privatekey_argument_spec(): ), ), required_if=[ - ["type", "ECC", ["curve"]], + ("type", "ECC", ["curve"]), ], ) diff --git a/plugins/module_utils/crypto/module_backends/privatekey_convert.py b/plugins/module_utils/crypto/module_backends/privatekey_convert.py index 3cbda2fd..585fb8ae 100644 --- a/plugins/module_utils/crypto/module_backends/privatekey_convert.py +++ b/plugins/module_utils/crypto/module_backends/privatekey_convert.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc import traceback +import typing as t from ansible.module_utils.common.text.converters import to_bytes from ansible_collections.community.crypto.plugins.module_utils.argspec import ( @@ -27,6 +28,13 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep from ansible_collections.community.crypto.plugins.module_utils.io import load_file +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + PrivateKeyTypes, + ) + + MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION try: @@ -58,42 +66,48 @@ class PrivateKeyError(OpenSSLObjectError): class PrivateKeyConvertBackend(metaclass=abc.ABCMeta): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module - self.src_path = module.params["src_path"] - self.src_content = module.params["src_content"] - self.src_passphrase = module.params["src_passphrase"] - self.format = module.params["format"] - self.dest_passphrase = module.params["dest_passphrase"] + self.src_path: str | None = module.params["src_path"] + self.src_content: str | None = module.params["src_content"] + self.src_passphrase: str | None = module.params["src_passphrase"] + self.format: t.Literal["pkcs1", "pkcs8", "raw"] = module.params["format"] + self.dest_passphrase: str | None = module.params["dest_passphrase"] - self.src_private_key = None + self.src_private_key: PrivateKeyTypes | None = None if self.src_path is not None: self.src_private_key_bytes = load_file(self.src_path, module) else: + if self.src_content is None: + raise AssertionError("src_content is None") self.src_private_key_bytes = self.src_content.encode("utf-8") - self.dest_private_key = None - self.dest_private_key_bytes = None + self.dest_private_key: PrivateKeyTypes | None = None + self.dest_private_key_bytes: bytes | None = None @abc.abstractmethod - def get_private_key_data(self): + def get_private_key_data(self) -> bytes: """Return bytes for self.src_private_key in output format.""" pass - def set_existing_destination(self, privatekey_bytes): + def set_existing_destination(self, privatekey_bytes: bytes | None) -> None: """Set existing private key bytes. None indicates that the key does not exist.""" self.dest_private_key_bytes = privatekey_bytes - def has_existing_destination(self): + def has_existing_destination(self) -> bool: """Query whether an existing private key is/has been there.""" return self.dest_private_key_bytes is not None @abc.abstractmethod - def _load_private_key(self, data, passphrase, current_hint=None): + def _load_private_key( + self, + data: bytes, + passphrase: str | None, + current_hint: PrivateKeyTypes | None = None, + ) -> tuple[str, PrivateKeyTypes]: """Check whether data can be loaded as a private key with the provided passphrase. Return tuple (type, private_key).""" - pass - def needs_conversion(self): + def needs_conversion(self) -> bool: """Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False.""" dummy, self.src_private_key = self._load_private_key( self.src_private_key_bytes, self.src_passphrase @@ -101,6 +115,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta): if not self.has_existing_destination(): return True + assert self.dest_private_key_bytes is not None try: format, self.dest_private_key = self._load_private_key( @@ -115,18 +130,20 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta): self.dest_private_key, self.src_private_key ) - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" return {} # Implementation with using cryptography class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module) - def get_private_key_data(self): + def get_private_key_data(self) -> bytes: """Return bytes for self.src_private_key in output format""" + if self.src_private_key is None: + raise AssertionError("src_private_key not set") # Select export format and encoding try: export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM @@ -152,9 +169,9 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend): ) # Select key encryption - encryption_algorithm = ( - cryptography.hazmat.primitives.serialization.NoEncryption() - ) + encryption_algorithm: ( + cryptography.hazmat.primitives.serialization.KeySerializationEncryption + ) = cryptography.hazmat.primitives.serialization.NoEncryption() if self.dest_passphrase: encryption_algorithm = ( cryptography.hazmat.primitives.serialization.BestAvailableEncryption( @@ -179,7 +196,12 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend): exception=traceback.format_exc(), ) - def _load_private_key(self, data, passphrase, current_hint=None): + def _load_private_key( + self, + data: bytes, + passphrase: str | None, + current_hint: PrivateKeyTypes | None = None, + ) -> tuple[str, PrivateKeyTypes]: try: # Interpret bytes depending on format. format = identify_private_key_format(data) @@ -247,14 +269,14 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend): raise PrivateKeyError(e) -def select_backend(module): +def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend: assert_required_cryptography_version( module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) return PrivateKeyConvertCryptographyBackend(module) -def get_privatekey_argument_spec(): +def get_privatekey_argument_spec() -> ArgumentSpec: return ArgumentSpec( argument_spec=dict( src_path=dict(type="path"), diff --git a/plugins/module_utils/crypto/module_backends/privatekey_info.py b/plugins/module_utils/crypto/module_backends/privatekey_info.py index 613a894f..271a2bf1 100644 --- a/plugins/module_utils/crypto/module_backends/privatekey_info.py +++ b/plugins/module_utils/crypto/module_backends/privatekey_info.py @@ -7,6 +7,7 @@ from __future__ import annotations import abc +import typing as t from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -29,6 +30,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + PrivateKeyTypes, + ) + + from ....plugin_utils.action_module import AnsibleActionModule + from ....plugin_utils.filter_module import FilterModuleMock + + GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock] + + MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION try: @@ -40,38 +53,49 @@ except ImportError: SIGNATURE_TEST_DATA = b"1234" -def _get_cryptography_private_key_info(key, need_private_key_data=False): +def _get_cryptography_private_key_info( + key: PrivateKeyTypes, need_private_key_data: bool = False +) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: key_type, key_public_data = _get_cryptography_public_key_info(key.public_key()) - key_private_data = dict() + key_private_data: dict[str, t.Any] = {} if need_private_key_data: if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey): - private_numbers = key.private_numbers() - key_private_data["p"] = private_numbers.p - key_private_data["q"] = private_numbers.q - key_private_data["exponent"] = private_numbers.d + rsa_private_numbers = key.private_numbers() + key_private_data["p"] = rsa_private_numbers.p + key_private_data["q"] = rsa_private_numbers.q + key_private_data["exponent"] = rsa_private_numbers.d elif isinstance( key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey ): - private_numbers = key.private_numbers() - key_private_data["x"] = private_numbers.x + dsa_private_numbers = key.private_numbers() + key_private_data["x"] = dsa_private_numbers.x elif isinstance( key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey ): - private_numbers = key.private_numbers() - key_private_data["multiplier"] = private_numbers.private_value + ecc_private_numbers = key.private_numbers() + key_private_data["multiplier"] = ecc_private_numbers.private_value return key_type, key_public_data, key_private_data -def _check_dsa_consistency(key_public_data, key_private_data): +def _check_dsa_consistency( + key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any] +) -> bool | None: # Get parameters - p = key_public_data.get("p") - q = key_public_data.get("q") - g = key_public_data.get("g") - y = key_public_data.get("y") - x = key_private_data.get("x") - for v in (p, q, g, y, x): - if v is None: - return None + p: int | None = key_public_data.get("p") + if p is None: + return None + q: int | None = key_public_data.get("q") + if q is None: + return None + g: int | None = key_public_data.get("g") + if g is None: + return None + y: int | None = key_public_data.get("y") + if y is None: + return None + x: int | None = key_private_data.get("x") + if x is None: + return None # Make sure that g is not 0, 1 or -1 in Z/pZ if g < 2 or g >= p - 1: return False @@ -94,13 +118,16 @@ def _check_dsa_consistency(key_public_data, key_private_data): def _is_cryptography_key_consistent( - key, key_public_data, key_private_data, warn_func=None -): + key: PrivateKeyTypes, + key_public_data: dict[str, t.Any], + key_private_data: dict[str, t.Any], + warn_func: t.Callable[[str], None] | None = None, +) -> bool | None: if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey): # key._backend was removed in cryptography 42.0.0 backend = getattr(key, "_backend", None) if backend is not None: - return bool(backend._lib.RSA_check_key(key._rsa_cdata)) + return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey): result = _check_dsa_consistency(key_public_data, key_private_data) if result is not None: @@ -145,9 +172,9 @@ def _is_cryptography_key_consistent( if isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey): has_simple_sign_function = True if has_simple_sign_function: - signature = key.sign(SIGNATURE_TEST_DATA) + signature = key.sign(SIGNATURE_TEST_DATA) # type: ignore try: - key.public_key().verify(signature, SIGNATURE_TEST_DATA) + key.public_key().verify(signature, SIGNATURE_TEST_DATA) # type: ignore return True except cryptography.exceptions.InvalidSignature: return False @@ -158,14 +185,14 @@ def _is_cryptography_key_consistent( class PrivateKeyConsistencyError(OpenSSLObjectError): - def __init__(self, msg, result): + def __init__(self, msg: str, result: dict[str, t.Any]) -> None: super(PrivateKeyConsistencyError, self).__init__(msg) self.error_message = msg self.result = result class PrivateKeyParseError(OpenSSLObjectError): - def __init__(self, msg, result): + def __init__(self, msg: str, result: dict[str, t.Any]) -> None: super(PrivateKeyParseError, self).__init__(msg) self.error_message = msg self.result = result @@ -174,13 +201,12 @@ class PrivateKeyParseError(OpenSSLObjectError): class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta): def __init__( self, - module, - content, - passphrase=None, - return_private_key_data=False, - check_consistency=False, + module: GeneralAnsibleModule, + content: bytes, + passphrase: str | None = None, + return_private_key_data: bool = False, + check_consistency: bool = False, ): - # content must be a bytes string self.module = module self.content = content self.passphrase = passphrase @@ -188,22 +214,26 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta): self.check_consistency = check_consistency @abc.abstractmethod - def _get_public_key(self, binary): + def _get_public_key(self, binary: bool) -> bytes: pass @abc.abstractmethod - def _get_key_info(self, need_private_key_data=False): + def _get_key_info( + self, need_private_key_data: bool = False + ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: pass @abc.abstractmethod - def _is_key_consistent(self, key_public_data, key_private_data): + def _is_key_consistent( + self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any] + ) -> bool | None: pass - def get_info(self, prefer_one_fingerprint=False): - result = dict( - can_parse_key=False, - key_is_consistent=None, - ) + def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]: + result: dict[str, t.Any] = { + "can_parse_key": False, + "key_is_consistent": None, + } priv_key_detail = self.content try: self.key = load_privatekey( @@ -252,35 +282,39 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta): class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval): """Validate the supplied private key, using the cryptography backend""" - def __init__(self, module, content, **kwargs): + def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None: super(PrivateKeyInfoRetrievalCryptography, self).__init__( module, content, **kwargs ) - def _get_public_key(self, binary): + def _get_public_key(self, binary: bool) -> bytes: return self.key.public_key().public_bytes( serialization.Encoding.DER if binary else serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo, ) - def _get_key_info(self, need_private_key_data=False): + def _get_key_info( + self, need_private_key_data: bool = False + ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return _get_cryptography_private_key_info( self.key, need_private_key_data=need_private_key_data ) - def _is_key_consistent(self, key_public_data, key_private_data): + def _is_key_consistent( + self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any] + ) -> bool | None: return _is_cryptography_key_consistent( self.key, key_public_data, key_private_data, warn_func=self.module.warn ) def get_privatekey_info( - module, - content, - passphrase=None, - return_private_key_data=False, - prefer_one_fingerprint=False, -): + module: GeneralAnsibleModule, + content: bytes, + passphrase: str | None = None, + return_private_key_data: bool = False, + prefer_one_fingerprint: bool = False, +) -> dict[str, t.Any]: info = PrivateKeyInfoRetrievalCryptography( module, content, @@ -291,12 +325,12 @@ def get_privatekey_info( def select_backend( - module, - content, - passphrase=None, - return_private_key_data=False, - check_consistency=False, -): + module: GeneralAnsibleModule, + content: bytes, + passphrase: str | None = None, + return_private_key_data: bool = False, + check_consistency: bool = False, +) -> PrivateKeyInfoRetrieval: assert_required_cryptography_version( module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) diff --git a/plugins/module_utils/crypto/module_backends/publickey_info.py b/plugins/module_utils/crypto/module_backends/publickey_info.py index ae340bcf..e99a706b 100644 --- a/plugins/module_utils/crypto/module_backends/publickey_info.py +++ b/plugins/module_utils/crypto/module_backends/publickey_info.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +import typing as t from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( OpenSSLObjectError, @@ -19,6 +20,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + PublicKeyTypes, + ) + + from ....plugin_utils.action_module import AnsibleActionModule + from ....plugin_utils.filter_module import FilterModuleMock + + GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock] + + MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION try: @@ -32,23 +45,25 @@ except ImportError: pass -def _get_cryptography_public_key_info(key): - key_public_data = dict() +def _get_cryptography_public_key_info( + key: PublicKeyTypes, +) -> tuple[str, dict[str, t.Any]]: + key_public_data: dict[str, t.Any] = {} if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey): key_type = "RSA" - public_numbers = key.public_numbers() + rsa_public_numbers = key.public_numbers() key_public_data["size"] = key.key_size - key_public_data["modulus"] = public_numbers.n - key_public_data["exponent"] = public_numbers.e + key_public_data["modulus"] = rsa_public_numbers.n + key_public_data["exponent"] = rsa_public_numbers.e elif isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey): key_type = "DSA" - parameter_numbers = key.parameters().parameter_numbers() - public_numbers = key.public_numbers() + dsa_parameter_numbers = key.parameters().parameter_numbers() + dsa_public_numbers = key.public_numbers() key_public_data["size"] = key.key_size - key_public_data["p"] = parameter_numbers.p - key_public_data["q"] = parameter_numbers.q - key_public_data["g"] = parameter_numbers.g - key_public_data["y"] = public_numbers.y + key_public_data["p"] = dsa_parameter_numbers.p + key_public_data["q"] = dsa_parameter_numbers.q + key_public_data["g"] = dsa_parameter_numbers.g + key_public_data["y"] = dsa_public_numbers.y elif isinstance( key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey ): @@ -67,10 +82,10 @@ def _get_cryptography_public_key_info(key): key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey ): key_type = "ECC" - public_numbers = key.public_numbers() + ecc_public_numbers = key.public_numbers() key_public_data["curve"] = key.curve.name - key_public_data["x"] = public_numbers.x - key_public_data["y"] = public_numbers.y + key_public_data["x"] = ecc_public_numbers.x + key_public_data["y"] = ecc_public_numbers.y key_public_data["exponent_size"] = key.curve.key_size else: key_type = f"unknown ({type(key)})" @@ -78,29 +93,34 @@ def _get_cryptography_public_key_info(key): class PublicKeyParseError(OpenSSLObjectError): - def __init__(self, msg, result): + def __init__(self, msg: str, result: dict[str, t.Any]) -> None: super(PublicKeyParseError, self).__init__(msg) self.error_message = msg self.result = result class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta): - def __init__(self, module, content=None, key=None): + def __init__( + self, + module: GeneralAnsibleModule, + content: bytes | None = None, + key: PublicKeyTypes | None = None, + ) -> None: # content must be a bytes string self.module = module self.content = content self.key = key @abc.abstractmethod - def _get_public_key(self, binary): + def _get_public_key(self, binary: bool) -> bytes: pass @abc.abstractmethod - def _get_key_info(self): + def _get_key_info(self) -> tuple[str, dict[str, t.Any]]: pass - def get_info(self, prefer_one_fingerprint=False): - result = dict() + def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} if self.key is None: try: self.key = load_publickey(content=self.content) @@ -123,27 +143,45 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta): class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval): """Validate the supplied public key, using the cryptography backend""" - def __init__(self, module, content=None, key=None): + def __init__( + self, + module: GeneralAnsibleModule, + content: bytes | None = None, + key: PublicKeyTypes | None = None, + ) -> None: super(PublicKeyInfoRetrievalCryptography, self).__init__( module, content=content, key=key ) - def _get_public_key(self, binary): + def _get_public_key(self, binary: bool) -> bytes: + if self.key is None: + raise AssertionError("key must be set") return self.key.public_bytes( serialization.Encoding.DER if binary else serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo, ) - def _get_key_info(self): + def _get_key_info(self) -> tuple[str, dict[str, t.Any]]: + if self.key is None: + raise AssertionError("key must be set") return _get_cryptography_public_key_info(self.key) -def get_publickey_info(module, content=None, key=None, prefer_one_fingerprint=False): +def get_publickey_info( + module: GeneralAnsibleModule, + content: bytes | None = None, + key: PublicKeyTypes | None = None, + prefer_one_fingerprint: bool = False, +) -> dict[str, t.Any]: info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key) return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint) -def select_backend(module, content=None, key=None): +def select_backend( + module: GeneralAnsibleModule, + content: bytes | None = None, + key: PublicKeyTypes | None = None, +) -> PublicKeyInfoRetrieval: assert_required_cryptography_version( module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) diff --git a/plugins/module_utils/crypto/openssh.py b/plugins/module_utils/crypto/openssh.py index de896a4e..080dc897 100644 --- a/plugins/module_utils/crypto/openssh.py +++ b/plugins/module_utils/crypto/openssh.py @@ -8,3 +8,6 @@ from __future__ import annotations from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import ( # noqa: F401, pylint: disable=unused-import parse_openssh_version, ) + + +# TODO: delete! diff --git a/plugins/module_utils/crypto/pem.py b/plugins/module_utils/crypto/pem.py index 603cb562..efa26956 100644 --- a/plugins/module_utils/crypto/pem.py +++ b/plugins/module_utils/crypto/pem.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + PEM_START = "-----BEGIN " PEM_END_START = "-----END " @@ -12,7 +14,7 @@ PKCS8_PRIVATEKEY_NAMES = ("PRIVATE KEY", "ENCRYPTED PRIVATE KEY") PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY" -def identify_pem_format(content, encoding="utf-8"): +def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool: """Given the contents of a binary file, tests whether this could be a PEM file.""" try: first_pem = extract_first_pem(content.decode(encoding)) @@ -30,7 +32,9 @@ def identify_pem_format(content, encoding="utf-8"): return False -def identify_private_key_format(content, encoding="utf-8"): +def identify_private_key_format( + content: bytes, encoding: str = "utf-8" +) -> t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]: """Given the contents of a private key file, identifies its format.""" # See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85 # (PEM_read_bio_PrivateKey) @@ -59,12 +63,12 @@ def identify_private_key_format(content, encoding="utf-8"): return "raw" -def split_pem_list(text, keep_inbetween=False): +def split_pem_list(text: str, keep_inbetween: bool = False) -> list[str]: """ Split concatenated PEM objects into a list of strings, where each is one PEM object. """ result = [] - current = [] if keep_inbetween else None + current: list[str] | None = [] if keep_inbetween else None for line in text.splitlines(True): if line.strip(): if not keep_inbetween and line.startswith("-----BEGIN "): @@ -77,7 +81,7 @@ def split_pem_list(text, keep_inbetween=False): return result -def extract_first_pem(text): +def extract_first_pem(text: str) -> str | None: """ Given one PEM or multiple concatenated PEM objects, return only the first one, or None if there is none. """ @@ -87,7 +91,7 @@ def extract_first_pem(text): return all_pems[0] -def _extract_type(line, start=PEM_START): +def _extract_type(line: str, start: str = PEM_START) -> str | None: if not line.startswith(start): return None if not line.endswith(PEM_END): @@ -95,7 +99,7 @@ def _extract_type(line, start=PEM_START): return line[len(start) : -len(PEM_END)] -def extract_pem(content, strict=False): +def extract_pem(content: str, strict: bool = False) -> tuple[str, str]: lines = content.splitlines() if len(lines) < 3: raise ValueError(f"PEM must have at least 3 lines, have only {len(lines)}") @@ -117,5 +121,4 @@ def extract_pem(content, strict=False): raise ValueError( f"Last line has length {len(lines[-2])}, should be in (0, 64]" ) - content = lines[1:-1] - return header_type, "".join(content) + return header_type, "".join(lines[1:-1]) diff --git a/plugins/module_utils/crypto/support.py b/plugins/module_utils/crypto/support.py index 999eae8d..4bc7f795 100644 --- a/plugins/module_utils/crypto/support.py +++ b/plugins/module_utils/crypto/support.py @@ -8,8 +8,13 @@ import abc import errno import hashlib import os +import typing as t from ansible.module_utils.common.text.converters import to_bytes +from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import ( + is_potential_certificate_issuer_private_key, + is_potential_certificate_private_key, +) from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import ( identify_pem_format, ) @@ -34,6 +39,17 @@ except ImportError: from .basic import OpenSSLBadPassphraseError, OpenSSLObjectError +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPrivateKeyTypes, + PrivateKeyTypes, + PublicKeyTypes, + ) + + from .cryptography_support import CertificatePrivateKeyTypes + + # This list of preferred fingerprints is used when prefer_one=True is supplied to the # fingerprinting methods. PREFERRED_FINGERPRINTS = ( @@ -48,18 +64,12 @@ PREFERRED_FINGERPRINTS = ( ) -def get_fingerprint_of_bytes(source, prefer_one=False): +def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[str, str]: """Generate the fingerprint of the given bytes.""" fingerprint = {} - try: - algorithms = hashlib.algorithms - except AttributeError: - try: - algorithms = hashlib.algorithms_guaranteed - except AttributeError: - return None + algorithms: t.Iterable[str] = hashlib.algorithms_guaranteed if prefer_one: # Sort algorithms to have the ones in PREFERRED_FINGERPRINTS at the beginning @@ -97,7 +107,9 @@ def get_fingerprint_of_bytes(source, prefer_one=False): return fingerprint -def get_fingerprint_of_privatekey(privatekey, prefer_one=False): +def get_fingerprint_of_privatekey( + privatekey: PrivateKeyTypes, prefer_one: bool = False +) -> dict[str, str]: """Generate the fingerprint of the public key.""" publickey = privatekey.public_key().public_bytes( @@ -107,11 +119,16 @@ def get_fingerprint_of_privatekey(privatekey, prefer_one=False): return get_fingerprint_of_bytes(publickey, prefer_one=prefer_one) -def get_fingerprint(path, passphrase=None, content=None, prefer_one=False): +def get_fingerprint( + path: os.PathLike | str | None = None, + passphrase: str | bytes | None = None, + content: bytes | None = None, + prefer_one: bool = False, +) -> dict[str, str]: """Generate the fingerprint of the public key.""" privatekey = load_privatekey( - path, + path=path, passphrase=passphrase, content=content, check_passphrase=False, @@ -121,11 +138,11 @@ def get_fingerprint(path, passphrase=None, content=None, prefer_one=False): def load_privatekey( - path, - passphrase=None, - check_passphrase=True, - content=None, -): + path: os.PathLike | str | None = None, + passphrase: str | bytes | None = None, + check_passphrase: bool = True, + content: bytes | None = None, +) -> PrivateKeyTypes: """Load the specified OpenSSL private key. The content can also be specified via content; in that case, @@ -134,6 +151,8 @@ def load_privatekey( try: if content is None: + if path is None: + raise OpenSSLObjectError("Must provide either path or content") with open(path, "rb") as b_priv_key_fh: priv_key_detail = b_priv_key_fh.read() else: @@ -154,7 +173,55 @@ def load_privatekey( raise OpenSSLBadPassphraseError("Wrong passphrase provided for private key") -def load_publickey(path=None, content=None): +def load_certificate_privatekey( + *, + path: os.PathLike | str | None = None, + content: bytes | None = None, + passphrase: str | bytes | None = None, + check_passphrase: bool = True, +) -> CertificatePrivateKeyTypes: + """ + Load the specified OpenSSL private key that can be used as a private key for certificates. + """ + private_key = load_privatekey( + path=path, + passphrase=passphrase, + check_passphrase=check_passphrase, + content=content, + ) + if not is_potential_certificate_private_key(private_key): + raise OpenSSLObjectError( + f"Key of type {type(private_key)} not supported for certificates" + ) + return private_key + + +def load_certificate_issuer_privatekey( + *, + path: os.PathLike | str | None = None, + content: bytes | None = None, + passphrase: str | bytes | None = None, + check_passphrase: bool = True, +) -> CertificateIssuerPrivateKeyTypes: + """ + Load the specified OpenSSL private key that can be used for issuing certificates. + """ + private_key = load_privatekey( + path=path, + passphrase=passphrase, + check_passphrase=check_passphrase, + content=content, + ) + if not is_potential_certificate_issuer_private_key(private_key): + raise OpenSSLObjectError( + f"Key of type {type(private_key)} not supported for issuing certificates" + ) + return private_key + + +def load_publickey( + path: os.PathLike | str | None = None, content: bytes | None = None +) -> PublicKeyTypes: if content is None: if path is None: raise OpenSSLObjectError("Must provide either path or content") @@ -170,11 +237,17 @@ def load_publickey(path=None, content=None): raise OpenSSLObjectError(f"Error while deserializing key: {e}") -def load_certificate(path, content=None, der_support_enabled=False): +def load_certificate( + path: os.PathLike | str | None = None, + content: bytes | None = None, + der_support_enabled: bool = False, +) -> x509.Certificate: """Load the specified certificate.""" try: if content is None: + if path is None: + raise OpenSSLObjectError("Must provide either path or content") with open(path, "rb") as cert_fh: cert_content = cert_fh.read() else: @@ -193,10 +266,14 @@ def load_certificate(path, content=None, der_support_enabled=False): raise OpenSSLObjectError(f"Cannot parse DER certificate: {exc}") -def load_certificate_request(path, content=None): +def load_certificate_request( + path: os.PathLike | str | None = None, content: bytes | None = None +) -> x509.CertificateSigningRequest: """Load the specified certificate signing request.""" try: if content is None: + if path is None: + raise OpenSSLObjectError("Must provide either path or content") with open(path, "rb") as csr_fh: csr_content = csr_fh.read() else: @@ -209,45 +286,44 @@ def load_certificate_request(path, content=None): raise OpenSSLObjectError(exc) -def parse_name_field(input_dict, name_field_name=None): +def parse_name_field( + input_dict: dict[str, list[str | bytes] | str | bytes], + name_field_name: str | None = None, +) -> list[tuple[str, str | bytes]]: """Take a dict with key: value or key: list_of_values mappings and return a list of tuples""" - error_str = "{key}" if name_field_name is None else "{key} in {name}" + + def error_str(key: str) -> str: + if name_field_name is None: + return f"{key}" + return f"{key} in {name_field_name}" result = [] for key, value in input_dict.items(): if isinstance(value, list): for entry in value: if not isinstance(entry, (str, bytes)): - raise TypeError( - f"Values {error_str} must be strings".format( - key=key, name=name_field_name - ) - ) + raise TypeError(f"Values {error_str(key)} must be strings") if not entry: raise ValueError( - f"Values for {error_str} must not be empty strings".format( - key=key, name=name_field_name - ) + f"Values for {error_str(key)} must not be empty strings" ) result.append((key, entry)) elif isinstance(value, (str, bytes)): if not value: raise ValueError( - f"Value for {error_str} must not be an empty string".format( - key=key, name=name_field_name - ) + f"Value for {error_str(key)} must not be an empty string" ) result.append((key, value)) else: raise TypeError( - ( - f"Value for {error_str} must be either a string or a list of strings" - ).format(key=key, name=name_field_name) + f"Value for {error_str(key)} must be either a string or a list of strings" ) return result -def parse_ordered_name_field(input_list, name_field_name): +def parse_ordered_name_field( + input_list: list[dict[str, list[str | bytes] | str | bytes]], name_field_name: str +) -> list[tuple[str, str | bytes]]: """Take a dict with key: value or key: list_of_values mappings and return a list of tuples""" result = [] @@ -265,24 +341,39 @@ def parse_ordered_name_field(input_list, name_field_name): return result -def select_message_digest(digest_string): - digest = None +@t.overload +def select_message_digest( + digest_string: t.Literal["sha256", "sha384", "sha512", "sha1", "md5"], +) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5: ... + + +@t.overload +def select_message_digest( + digest_string: str, +) -> ( + hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None +): ... + + +def select_message_digest( + digest_string: str, +) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None: if digest_string == "sha256": - digest = hashes.SHA256() - elif digest_string == "sha384": - digest = hashes.SHA384() - elif digest_string == "sha512": - digest = hashes.SHA512() - elif digest_string == "sha1": - digest = hashes.SHA1() - elif digest_string == "md5": - digest = hashes.MD5() - return digest + return hashes.SHA256() + if digest_string == "sha384": + return hashes.SHA384() + if digest_string == "sha512": + return hashes.SHA512() + if digest_string == "sha1": + return hashes.SHA1() + if digest_string == "md5": + return hashes.MD5() + return None class OpenSSLObject(metaclass=abc.ABCMeta): - def __init__(self, path, state, force, check_mode): + def __init__(self, path: str, state: str, force: bool, check_mode: bool) -> None: self.path = path self.state = state self.force = force @@ -290,13 +381,13 @@ class OpenSSLObject(metaclass=abc.ABCMeta): self.changed = False self.check_mode = check_mode - def check(self, module, perms_required=True): + def check(self, module: AnsibleModule, perms_required: bool = True) -> bool: """Ensure the resource is in its desired state.""" - def _check_state(): + def _check_state() -> bool: return os.path.exists(self.path) - def _check_perms(module): + def _check_perms(module: AnsibleModule) -> bool: file_args = module.load_file_common_arguments(module.params) if module.check_file_absent_if_check_mode(file_args["path"]): return False @@ -308,18 +399,14 @@ class OpenSSLObject(metaclass=abc.ABCMeta): return _check_state() and _check_perms(module) @abc.abstractmethod - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" - pass - @abc.abstractmethod - def generate(self): + def generate(self, module: AnsibleModule) -> None: """Generate the resource.""" - pass - - def remove(self, module): + def remove(self, module: AnsibleModule) -> None: """Remove the resource from the filesystem.""" if self.check_mode: if os.path.exists(self.path): diff --git a/plugins/module_utils/cryptography_dep.py b/plugins/module_utils/cryptography_dep.py index a5177f17..12d2583d 100644 --- a/plugins/module_utils/cryptography_dep.py +++ b/plugins/module_utils/cryptography_dep.py @@ -11,6 +11,7 @@ Must be kept in sync with plugins/doc_fragments/cryptography_dep.py. from __future__ import annotations import traceback +import typing as t from ansible.module_utils.basic import missing_required_lib from ansible_collections.community.crypto.plugins.module_utils.version import ( @@ -18,19 +19,29 @@ from ansible_collections.community.crypto.plugins.module_utils.version import ( ) -_CRYPTOGRAPHY_IMP_ERR = None +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + from ..plugin_utils.action_module import AnsibleActionModule + from ..plugin_utils.filter_module import FilterModuleMock + + GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock] + + +_CRYPTOGRAPHY_IMP_ERR: str | None = None +_CRYPTOGRAPHY_FILE: str | None = None try: import cryptography from cryptography import x509 # noqa: F401, pylint: disable=unused-import - _CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__) + CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__) _CRYPTOGRAPHY_FILE = cryptography.__file__ except ImportError: _CRYPTOGRAPHY_IMP_ERR = traceback.format_exc() - _CRYPTOGRAPHY_FOUND = False - _CRYPTOGRAPHY_FILE = None + CRYPTOGRAPHY_FOUND = False + CRYPTOGRAPHY_VERSION = LooseVersion("0.0") else: - _CRYPTOGRAPHY_FOUND = True + CRYPTOGRAPHY_FOUND = True # Corresponds to the community.crypto.cryptography_dep.minimum doc fragment @@ -38,25 +49,27 @@ COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION = "3.3" def assert_required_cryptography_version( - module, + module: GeneralAnsibleModule, *, minimum_cryptography_version: str = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION, ) -> None: - if not _CRYPTOGRAPHY_FOUND: + if not CRYPTOGRAPHY_FOUND: module.fail_json( msg=missing_required_lib(f"cryptography >= {minimum_cryptography_version}"), exception=_CRYPTOGRAPHY_IMP_ERR, ) - if _CRYPTOGRAPHY_VERSION < LooseVersion(minimum_cryptography_version): + if CRYPTOGRAPHY_VERSION < LooseVersion(minimum_cryptography_version): module.fail_json( msg=( f"Cannot detect the required Python library cryptography (>= {minimum_cryptography_version})." - f" Only found a too old version ({_CRYPTOGRAPHY_VERSION}) at {_CRYPTOGRAPHY_FILE}." + f" Only found a too old version ({CRYPTOGRAPHY_VERSION}) at {_CRYPTOGRAPHY_FILE}." ), ) __all__ = ( "COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION", + "CRYPTOGRAPHY_FOUND", + "CRYPTOGRAPHY_VERSION", "assert_required_cryptography_version", ) diff --git a/plugins/module_utils/ecs/api.py b/plugins/module_utils/ecs/api.py index 7523df28..8436999e 100644 --- a/plugins/module_utils/ecs/api.py +++ b/plugins/module_utils/ecs/api.py @@ -14,6 +14,7 @@ import json import os import re import traceback +import typing as t from urllib.error import HTTPError from urllib.parse import urlencode @@ -34,7 +35,7 @@ else: valid_file_format = re.compile(r".*(\.)(yml|yaml|json)$") -def ecs_client_argument_spec(): +def ecs_client_argument_spec() -> dict[str, t.Any]: return dict( entrust_api_user=dict(type="str", required=True), entrust_api_key=dict(type="str", required=True, no_log=True), @@ -50,19 +51,17 @@ def ecs_client_argument_spec(): class SessionConfigurationException(Exception): """Raised if we cannot configure a session with the API""" - pass - class RestOperationException(Exception): """Encapsulate a REST API error""" - def __init__(self, error): + def __init__(self, error: dict[str, t.Any]) -> None: self.status = to_native(error.get("status", None)) self.errors = [to_native(err.get("message")) for err in error.get("errors", {})] self.message = " ".join(self.errors) -def generate_docstring(operation_spec): +def generate_docstring(operation_spec: dict[str, t.Any]) -> str: """Generate a docstring for an operation defined in operation_spec (swagger)""" # Description of the operation docs = operation_spec.get("description", "No Description") diff --git a/plugins/module_utils/gnupg/cli.py b/plugins/module_utils/gnupg/cli.py index e597fe5d..799d7928 100644 --- a/plugins/module_utils/gnupg/cli.py +++ b/plugins/module_utils/gnupg/cli.py @@ -14,7 +14,9 @@ class GPGError(Exception): class GPGRunner(metaclass=abc.ABCMeta): @abc.abstractmethod - def run_command(self, command, check_rc=True, data=None): + def run_command( + self, command: list[str], check_rc: bool = True, data: bytes | None = None + ) -> tuple[int, str, str]: """ Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``. @@ -29,7 +31,7 @@ class GPGRunner(metaclass=abc.ABCMeta): pass -def get_fingerprint_from_stdout(stdout): +def get_fingerprint_from_stdout(stdout: str) -> str: lines = stdout.splitlines(False) for line in lines: if line.startswith("fpr:"): @@ -42,7 +44,7 @@ def get_fingerprint_from_stdout(stdout): raise GPGError(f'Cannot extract fingerprint from stdout "{stdout}"') -def get_fingerprint_from_file(gpg_runner, path): +def get_fingerprint_from_file(gpg_runner: GPGRunner, path: str) -> str: if not os.path.exists(path): raise GPGError(f"{path} does not exist") stdout = gpg_runner.run_command( @@ -59,7 +61,7 @@ def get_fingerprint_from_file(gpg_runner, path): return get_fingerprint_from_stdout(stdout) -def get_fingerprint_from_bytes(gpg_runner, content): +def get_fingerprint_from_bytes(gpg_runner: GPGRunner, content: bytes) -> str: stdout = gpg_runner.run_command( [ "--no-keyring", diff --git a/plugins/module_utils/io.py b/plugins/module_utils/io.py index 3fa166d2..73bb33d5 100644 --- a/plugins/module_utils/io.py +++ b/plugins/module_utils/io.py @@ -7,9 +7,14 @@ from __future__ import annotations import errno import os import tempfile +import typing as t -def load_file(path, module=None): +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + +def load_file(path: str | os.PathLike, module: AnsibleModule | None = None) -> bytes: """ Load the file as a bytes string. """ @@ -22,7 +27,11 @@ def load_file(path, module=None): module.fail_json(f"Error while loading {path} - {exc}") -def load_file_if_exists(path, module=None, ignore_errors=False): +def load_file_if_exists( + path: str | os.PathLike, + module: AnsibleModule | None = None, + ignore_errors: bool = False, +) -> bytes | None: """ Load the file as a bytes string. If the file does not exist, ``None`` is returned. @@ -49,7 +58,12 @@ def load_file_if_exists(path, module=None, ignore_errors=False): module.fail_json(f"Error while loading {path} - {exc}") -def write_file(module, content, default_mode=None, path=None): +def write_file( + module: AnsibleModule, + content: bytes, + default_mode: str | int | None = None, + path: str | os.PathLike | None = None, +) -> None: """ Writes content into destination file as securely as possible. Uses file arguments from module. diff --git a/plugins/module_utils/openssh/backends/common.py b/plugins/module_utils/openssh/backends/common.py index 34190339..a4b4b630 100644 --- a/plugins/module_utils/openssh/backends/common.py +++ b/plugins/module_utils/openssh/backends/common.py @@ -8,14 +8,31 @@ import abc import os import stat import traceback +import typing as t from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import ( parse_openssh_version, ) -def restore_on_failure(f): - def backup_and_restore(module, path, *args, **kwargs): +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPrivateKeyTypes, + PrivateKeyTypes, + ) + + from ..certificate import OpensshCertificateTimeParameters + + Param = t.ParamSpec("Param") + + +def restore_on_failure( + f: t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None], +) -> t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None]: + def backup_and_restore( + module: AnsibleModule, path: str | os.PathLike, *args, **kwargs + ) -> None: backup_file = module.backup_local(path) if os.path.exists(path) else None try: @@ -31,12 +48,31 @@ def restore_on_failure(f): @restore_on_failure -def safe_atomic_move(module, path, destination): +def safe_atomic_move( + module: AnsibleModule, path: str | os.PathLike, destination: str | os.PathLike +) -> None: module.atomic_move(os.path.abspath(path), os.path.abspath(destination)) -def _restore_all_on_failure(f): - def backup_and_restore(self, sources_and_destinations, *args, **kwargs): +def _restore_all_on_failure( + f: t.Callable[ + t.Concatenate[ + OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param + ], + None, + ], +) -> t.Callable[ + t.Concatenate[ + OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param + ], + None, +]: + def backup_and_restore( + self: OpensshModule, + sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]], + *args, + **kwargs, + ) -> None: backups = [ (d, self.module.backup_local(d)) for s, d in sources_and_destinations @@ -59,13 +95,13 @@ def _restore_all_on_failure(f): class OpensshModule(metaclass=abc.ABCMeta): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module - self.changed = False - self.check_mode = self.module.check_mode + self.changed: bool = False + self.check_mode: bool = self.module.check_mode - def execute(self): + def execute(self) -> t.NoReturn: try: self._execute() except Exception as e: @@ -77,11 +113,11 @@ class OpensshModule(metaclass=abc.ABCMeta): self.module.exit_json(**self.result) @abc.abstractmethod - def _execute(self): + def _execute(self) -> None: pass @property - def result(self): + def result(self) -> dict[str, t.Any]: result = self._result result["changed"] = self.changed @@ -93,31 +129,31 @@ class OpensshModule(metaclass=abc.ABCMeta): @property @abc.abstractmethod - def _result(self): + def _result(self) -> dict[str, t.Any]: pass @property @abc.abstractmethod - def diff(self): + def diff(self) -> dict[str, t.Any]: pass @staticmethod - def skip_if_check_mode(f): - def wrapper(self, *args, **kwargs): + def skip_if_check_mode(f: t.Callable[Param, None]) -> t.Callable[Param, None]: + def wrapper(self, *args, **kwargs) -> None: if not self.check_mode: f(self, *args, **kwargs) - return wrapper + return wrapper # type: ignore @staticmethod - def trigger_change(f): - def wrapper(self, *args, **kwargs): + def trigger_change(f: t.Callable[Param, None]) -> t.Callable[Param, None]: + def wrapper(self, *args, **kwargs) -> None: f(self, *args, **kwargs) self.changed = True - return wrapper + return wrapper # type: ignore - def _check_if_base_dir(self, path): + def _check_if_base_dir(self, path: str | os.PathLike) -> None: base_dir = os.path.dirname(path) or "." if not os.path.isdir(base_dir): self.module.fail_json( @@ -125,16 +161,19 @@ class OpensshModule(metaclass=abc.ABCMeta): msg=f"The directory {base_dir} does not exist or the file is not a directory", ) - def _get_ssh_version(self): + def _get_ssh_version(self) -> str | None: ssh_bin = self.module.get_bin_path("ssh") if not ssh_bin: - return "" + return None return parse_openssh_version( self.module.run_command([ssh_bin, "-V", "-q"], check_rc=True)[2].strip() ) @_restore_all_on_failure - def _safe_secure_move(self, sources_and_destinations): + def _safe_secure_move( + self, + sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]], + ) -> None: """Moves a list of files from 'source' to 'destination' and restores 'destination' from backup upon failure. If 'destination' does not already exist, then 'source' permissions are preserved to prevent exposing protected data ('atomic_move' uses the 'destination' base directory mask for @@ -148,7 +187,7 @@ class OpensshModule(metaclass=abc.ABCMeta): else: self.module.preserved_copy(source, destination) - def _update_permissions(self, path): + def _update_permissions(self, path: str | os.PathLike) -> None: file_args = self.module.load_file_common_arguments(self.module.params) file_args["path"] = path @@ -161,25 +200,25 @@ class OpensshModule(metaclass=abc.ABCMeta): class KeygenCommand: - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self._bin_path = module.get_bin_path("ssh-keygen", True) self._run_command = module.run_command def generate_certificate( self, - certificate_path, - identifier, - options, - pkcs11_provider, - principals, - serial_number, - signature_algorithm, - signing_key_path, - type, - time_parameters, - use_agent, + certificate_path: str, + identifier: str, + options: list[str] | None, + pkcs11_provider: str | None, + principals: list[str] | None, + serial_number: int | None, + signature_algorithm: str | None, + signing_key_path: str, + type: t.Literal["host", "user"] | None, + time_parameters: OpensshCertificateTimeParameters, + use_agent: bool, **kwargs, - ): + ) -> tuple[int, str, str]: args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier] if options: @@ -203,7 +242,9 @@ class KeygenCommand: return self._run_command(args, **kwargs) - def generate_keypair(self, private_key_path, size, type, comment, **kwargs): + def generate_keypair( + self, private_key_path: str, size: int, type: str, comment: str | None, **kwargs + ) -> tuple[int, str, str]: args = [ self._bin_path, "-q", @@ -224,32 +265,40 @@ class KeygenCommand: return self._run_command(args, data=data, **kwargs) - def get_certificate_info(self, certificate_path, **kwargs): + def get_certificate_info( + self, certificate_path: str, **kwargs + ) -> tuple[int, str, str]: return self._run_command( [self._bin_path, "-L", "-f", certificate_path], **kwargs ) - def get_matching_public_key(self, private_key_path, **kwargs): + def get_matching_public_key( + self, private_key_path: str, **kwargs + ) -> tuple[int, str, str]: return self._run_command( [self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs ) - def get_private_key(self, private_key_path, **kwargs): + def get_private_key(self, private_key_path: str, **kwargs) -> tuple[int, str, str]: return self._run_command( [self._bin_path, "-l", "-f", private_key_path], **kwargs ) def update_comment( - self, private_key_path, comment, force_new_format=True, **kwargs - ): + self, + private_key_path: str, + comment: str, + force_new_format: bool = True, + **kwargs, + ) -> tuple[int, str, str]: if os.path.exists(private_key_path) and not os.access( private_key_path, os.W_OK ): try: os.chmod(private_key_path, stat.S_IWUSR + stat.S_IRUSR) except (IOError, OSError) as e: - raise e( - f"The private key at {private_key_path} is not writeable preventing a comment update" + raise ValueError( + f"The private key at {private_key_path} is not writeable preventing a comment update ({e})" ) command = [self._bin_path, "-q"] @@ -259,31 +308,36 @@ class KeygenCommand: return self._run_command(command, **kwargs) +_PrivateKey = t.TypeVar("_PrivateKey", bound="PrivateKey") + + class PrivateKey: - def __init__(self, size, key_type, fingerprint, format=""): + def __init__( + self, size: int, key_type: str, fingerprint: str, format: str = "" + ) -> None: self._size = size self._type = key_type self._fingerprint = fingerprint self._format = format @property - def size(self): + def size(self) -> int: return self._size @property - def type(self): + def type(self) -> str: return self._type @property - def fingerprint(self): + def fingerprint(self) -> str: return self._fingerprint @property - def format(self): + def format(self) -> str: return self._format @classmethod - def from_string(cls, string): + def from_string(cls: t.Type[_PrivateKey], string: str) -> _PrivateKey: properties = string.split() return cls( @@ -292,7 +346,7 @@ class PrivateKey: fingerprint=properties[1], ) - def to_dict(self): + def to_dict(self) -> dict[str, t.Any]: return { "size": self._size, "type": self._type, @@ -301,13 +355,16 @@ class PrivateKey: } +_PublicKey = t.TypeVar("_PublicKey", bound="PublicKey") + + class PublicKey: - def __init__(self, type_string, data, comment): + def __init__(self, type_string: str, data: str, comment: str | None) -> None: self._type_string = type_string self._data = data self._comment = comment - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented @@ -323,30 +380,30 @@ class PublicKey: ] ) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other - def __str__(self): + def __str__(self) -> str: return f"{self._type_string} {self._data}" @property - def comment(self): + def comment(self) -> str | None: return self._comment @comment.setter - def comment(self, value): + def comment(self, value: str | None) -> None: self._comment = value @property - def data(self): + def data(self) -> str: return self._data @property - def type_string(self): + def type_string(self) -> str: return self._type_string @classmethod - def from_string(cls, string): + def from_string(cls: t.Type[_PublicKey], string: str) -> _PublicKey: properties = string.strip("\n").split(" ", 2) return cls( @@ -356,7 +413,7 @@ class PublicKey: ) @classmethod - def load(cls, path): + def load(cls: t.Type[_PublicKey], path: str | os.PathLike) -> _PublicKey | None: try: with open(path, "r") as f: properties = f.read().strip(" \n").split(" ", 2) @@ -372,14 +429,16 @@ class PublicKey: comment="" if len(properties) <= 2 else properties[2], ) - def to_dict(self): + def to_dict(self) -> dict[str, t.Any]: return { "comment": self._comment, "public_key": self._data, } -def parse_private_key_format(path): +def parse_private_key_format( + path: str | os.PathLike, +) -> t.Literal["SSH", "PKCS8", "PKCS1", ""]: with open(path, "r") as file: header = file.readline().strip() diff --git a/plugins/module_utils/openssh/backends/keypair_backend.py b/plugins/module_utils/openssh/backends/keypair_backend.py index b1935063..c29e4462 100644 --- a/plugins/module_utils/openssh/backends/keypair_backend.py +++ b/plugins/module_utils/openssh/backends/keypair_backend.py @@ -7,6 +7,7 @@ from __future__ import annotations import abc import os +import typing as t from ansible.module_utils.basic import missing_required_lib from ansible.module_utils.common.text.converters import to_bytes, to_text @@ -39,31 +40,43 @@ from ansible_collections.community.crypto.plugins.module_utils.version import ( ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPrivateKeyTypes, + PrivateKeyTypes, + ) + + class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(KeypairBackend, self).__init__(module) - self.comment = self.module.params["comment"] - self.private_key_path = self.module.params["path"] + self.comment: str | None = self.module.params["comment"] + self.private_key_path: str = self.module.params["path"] self.public_key_path = self.private_key_path + ".pub" - self.regenerate = ( + self.regenerate: t.Literal[ + "never", "fail", "partial_idempotence", "full_idempotence", "always" + ] = ( self.module.params["regenerate"] if not self.module.params["force"] else "always" ) - self.state = self.module.params["state"] - self.type = self.module.params["type"] + self.state: t.Literal["present", "absent"] = self.module.params["state"] + self.type: t.Literal["rsa", "dsa", "rsa1", "ecdsa", "ed25519"] = ( + self.module.params["type"] + ) - self.size = self._get_size(self.module.params["size"]) + self.size: int = self._get_size(self.module.params["size"]) self._validate_path() - self.original_private_key = None - self.original_public_key = None - self.private_key = None - self.public_key = None + self.original_private_key: PrivateKey | None = None + self.original_public_key: PublicKey | None = None + self.private_key: PrivateKey | None = None + self.public_key: PublicKey | None = None - def _get_size(self, size): + def _get_size(self, size: int | None) -> int: if self.type in ("rsa", "rsa1"): result = 4096 if size is None else size if result < 1024: @@ -96,7 +109,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): return result - def _validate_path(self): + def _validate_path(self) -> None: self._check_if_base_dir(self.private_key_path) if os.path.isdir(self.private_key_path): @@ -104,7 +117,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): msg=f"{self.private_key_path} is a directory. Please specify a path to a file." ) - def _execute(self): + def _execute(self) -> None: self.original_private_key = self._load_private_key() self.original_public_key = self._load_public_key() @@ -125,7 +138,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): if self._should_remove(): self._remove() - def _load_private_key(self): + def _load_private_key(self) -> PrivateKey | None: result = None if self._private_key_exists(): try: @@ -135,14 +148,14 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): return result - def _private_key_exists(self): + def _private_key_exists(self) -> bool: return os.path.exists(self.private_key_path) @abc.abstractmethod - def _get_private_key(self): + def _get_private_key(self) -> PrivateKey: pass - def _load_public_key(self): + def _load_public_key(self) -> PublicKey | None: result = None if self._public_key_exists(): try: @@ -151,10 +164,10 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): pass return result - def _public_key_exists(self): + def _public_key_exists(self) -> bool: return os.path.exists(self.public_key_path) - def _validate_key_load(self): + def _validate_key_load(self) -> None: if ( self._private_key_exists() and self.regenerate in ("never", "fail", "partial_idempotence") @@ -167,10 +180,10 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): ) @abc.abstractmethod - def _private_key_readable(self): + def _private_key_readable(self) -> bool: pass - def _should_generate(self): + def _should_generate(self) -> bool: if self.original_private_key is None: return True elif self.regenerate == "never": @@ -188,7 +201,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): else: return True - def _private_key_valid(self): + def _private_key_valid(self) -> bool: if self.original_private_key is None: return False @@ -196,17 +209,17 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): [ self.size == self.original_private_key.size, self.type == self.original_private_key.type, - self._private_key_valid_backend(), + self._private_key_valid_backend(self.original_private_key), ] ) @abc.abstractmethod - def _private_key_valid_backend(self): + def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool: pass @OpensshModule.trigger_change @OpensshModule.skip_if_check_mode - def _generate(self): + def _generate(self) -> None: temp_private_key, temp_public_key = self._generate_temp_keypair() try: @@ -219,7 +232,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): except OSError as e: self.module.fail_json(msg=str(e)) - def _generate_temp_keypair(self): + def _generate_temp_keypair(self) -> tuple[str, str]: temp_private_key = os.path.join( self.module.tmpdir, os.path.basename(self.private_key_path) ) @@ -236,25 +249,26 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): return temp_private_key, temp_public_key @abc.abstractmethod - def _generate_keypair(self, private_key_path): + def _generate_keypair(self, private_key_path: str) -> None: pass - def _public_key_valid(self): + def _public_key_valid(self) -> bool: if self.original_public_key is None: return False valid_public_key = self._get_public_key() - valid_public_key.comment = self.comment + if valid_public_key: + valid_public_key.comment = self.comment return self.original_public_key == valid_public_key @abc.abstractmethod - def _get_public_key(self): + def _get_public_key(self) -> PublicKey | t.Literal[""]: pass @OpensshModule.trigger_change @OpensshModule.skip_if_check_mode - def _restore_public_key(self): + def _restore_public_key(self) -> None: try: temp_public_key = self._create_temp_public_key( str(self._get_public_key()) + "\n" @@ -269,7 +283,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): if self.comment: self._update_comment() - def _create_temp_public_key(self, content): + def _create_temp_public_key(self, content: str | bytes) -> str: temp_public_key = os.path.join( self.module.tmpdir, os.path.basename(self.public_key_path) ) @@ -290,15 +304,15 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): return temp_public_key @abc.abstractmethod - def _update_comment(self): + def _update_comment(self) -> None: pass - def _should_remove(self): + def _should_remove(self) -> bool: return self._private_key_exists() or self._public_key_exists() @OpensshModule.trigger_change @OpensshModule.skip_if_check_mode - def _remove(self): + def _remove(self) -> None: try: if self._private_key_exists(): os.remove(self.private_key_path) @@ -308,7 +322,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): self.module.fail_json(msg=str(e)) @property - def _result(self): + def _result(self) -> dict[str, t.Any]: private_key = self.private_key or self.original_private_key public_key = self.public_key or self.original_public_key @@ -322,7 +336,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): } @property - def diff(self): + def diff(self) -> dict[str, t.Any]: before = ( self.original_private_key.to_dict() if self.original_private_key else {} ) @@ -340,7 +354,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): class KeypairBackendOpensshBin(KeypairBackend): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(KeypairBackendOpensshBin, self).__init__(module) if self.module.params["private_key_format"] != "auto": @@ -350,12 +364,12 @@ class KeypairBackendOpensshBin(KeypairBackend): self.ssh_keygen = KeygenCommand(self.module) - def _generate_keypair(self, private_key_path): + def _generate_keypair(self, private_key_path: str) -> None: self.ssh_keygen.generate_keypair( private_key_path, self.size, self.type, self.comment, check_rc=True ) - def _get_private_key(self): + def _get_private_key(self) -> PrivateKey: rc, private_key_content, err = self.ssh_keygen.get_private_key( self.private_key_path, check_rc=False ) @@ -363,13 +377,13 @@ class KeypairBackendOpensshBin(KeypairBackend): raise ValueError(err) return PrivateKey.from_string(private_key_content) - def _get_public_key(self): + def _get_public_key(self) -> PublicKey | t.Literal[""]: public_key_content = self.ssh_keygen.get_matching_public_key( self.private_key_path, check_rc=True )[1] return PublicKey.from_string(public_key_content) - def _private_key_readable(self): + def _private_key_readable(self) -> bool: rc, stdout, stderr = self.ssh_keygen.get_matching_public_key( self.private_key_path, check_rc=False ) @@ -383,7 +397,7 @@ class KeypairBackendOpensshBin(KeypairBackend): ) ) - def _update_comment(self): + def _update_comment(self) -> None: try: ssh_version = self._get_ssh_version() or "7.8" force_new_format = ( @@ -391,19 +405,19 @@ class KeypairBackendOpensshBin(KeypairBackend): ) self.ssh_keygen.update_comment( self.private_key_path, - self.comment, + self.comment or "", force_new_format=force_new_format, check_rc=True, ) except (IOError, OSError) as e: self.module.fail_json(msg=str(e)) - def _private_key_valid_backend(self): + def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool: return True class KeypairBackendCryptography(KeypairBackend): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(KeypairBackendCryptography, self).__init__(module) if self.type == "rsa1": @@ -416,12 +430,15 @@ class KeypairBackendCryptography(KeypairBackend): if module.params["passphrase"] else None ) - self.private_key_format = self._get_key_format( - module.params["private_key_format"] - ) + key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"] = module.params[ + "private_key_format" + ] + self.private_key_format = self._get_key_format(key_format) - def _get_key_format(self, key_format): - result = "SSH" + def _get_key_format( + self, key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"] + ) -> t.Literal["SSH", "PKCS1", "PKCS8"]: + result: t.Literal["SSH", "PKCS1", "PKCS8"] = "SSH" if key_format == "auto": # Default to OpenSSH 7.8 compatibility when OpenSSH is not installed @@ -435,11 +452,12 @@ class KeypairBackendCryptography(KeypairBackend): # but still defaulted to PKCS1 format with the exception of ed25519 keys result = "PKCS1" else: - result = key_format.upper() + result = key_format.upper() # type: ignore return result - def _generate_keypair(self, private_key_path): + def _generate_keypair(self, private_key_path: str) -> None: + assert self.type != "rsa1" keypair = OpensshKeypair.generate( keytype=self.type, size=self.size, @@ -455,7 +473,7 @@ class KeypairBackendCryptography(KeypairBackend): public_key_path = private_key_path + ".pub" secure_write(public_key_path, 0o644, keypair.public_key) - def _get_private_key(self): + def _get_private_key(self) -> PrivateKey: keypair = OpensshKeypair.load( path=self.private_key_path, passphrase=self.passphrase, no_public_key=True ) @@ -467,7 +485,7 @@ class KeypairBackendCryptography(KeypairBackend): format=parse_private_key_format(self.private_key_path), ) - def _get_public_key(self): + def _get_public_key(self) -> PublicKey | t.Literal[""]: try: keypair = OpensshKeypair.load( path=self.private_key_path, @@ -480,7 +498,7 @@ class KeypairBackendCryptography(KeypairBackend): return PublicKey.from_string(to_text(keypair.public_key)) - def _private_key_readable(self): + def _private_key_readable(self) -> bool: try: OpensshKeypair.load( path=self.private_key_path, @@ -504,7 +522,7 @@ class KeypairBackendCryptography(KeypairBackend): return True - def _update_comment(self): + def _update_comment(self) -> None: keypair = OpensshKeypair.load( path=self.private_key_path, passphrase=self.passphrase, no_public_key=True ) @@ -519,16 +537,18 @@ class KeypairBackendCryptography(KeypairBackend): except (IOError, OSError) as e: self.module.fail_json(msg=str(e)) - def _private_key_valid_backend(self): + def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool: # avoids breaking behavior and prevents # automatic conversions with OpenSSH upgrades if self.module.params["private_key_format"] == "auto": return True - return self.private_key_format == self.original_private_key.format + return self.private_key_format == original_private_key.format -def select_backend(module, backend): +def select_backend( + module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"] +) -> KeypairBackend: can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion( CRYPTOGRAPHY_VERSION ) >= LooseVersion(COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION) diff --git a/plugins/module_utils/openssh/certificate.py b/plugins/module_utils/openssh/certificate.py index a6e2207d..40d89c59 100644 --- a/plugins/module_utils/openssh/certificate.py +++ b/plugins/module_utils/openssh/certificate.py @@ -8,6 +8,7 @@ import abc import binascii import datetime as _datetime import os +import typing as t from base64 import b64encode from datetime import datetime from hashlib import sha256 @@ -26,6 +27,16 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( ) +if t.TYPE_CHECKING: + from .cryptography import KeyType + + DateFormat = t.Literal["human_readable", "openssh", "timestamp"] + DateFormatStr = t.Literal["human_readable", "openssh"] + DateFormatInt = t.Literal["timestamp"] +else: + KeyType = None + + # Protocol References # ------------------- # https://datatracker.ietf.org/doc/html/rfc4251 @@ -44,7 +55,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( _USER_TYPE = 1 _HOST_TYPE = 2 -_SSH_TYPE_STRINGS = { +_SSH_TYPE_STRINGS: dict[KeyType | str, bytes] = { "rsa": b"ssh-rsa", "dsa": b"ssh-dss", "ecdsa-nistp256": b"ecdsa-sha2-nistp256", @@ -94,16 +105,18 @@ _EXTENSIONS = ( class OpensshCertificateTimeParameters: - def __init__(self, valid_from, valid_to): + def __init__( + self, valid_from: str | bytes | int, valid_to: str | bytes | int + ) -> None: self._valid_from = self.to_datetime(valid_from) self._valid_to = self.to_datetime(valid_to) if self._valid_from > self._valid_to: raise ValueError( - f"Valid from: {valid_from} must not be greater than Valid to: {valid_to}" + f"Valid from: {valid_from!r} must not be greater than Valid to: {valid_to!r}" ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented else: @@ -112,55 +125,83 @@ class OpensshCertificateTimeParameters: and self._valid_to == other._valid_to ) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other @property - def validity_string(self): + def validity_string(self) -> str: if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER): return f"{self.valid_from(date_format='openssh')}:{self.valid_to(date_format='openssh')}" return "" - def valid_from(self, date_format): + @t.overload + def valid_from(self, date_format: DateFormatStr) -> str: ... + + @t.overload + def valid_from(self, date_format: DateFormatInt) -> int: ... + + @t.overload + def valid_from(self, date_format: DateFormat) -> str | int: ... + + def valid_from(self, date_format: DateFormat) -> str | int: return self.format_datetime(self._valid_from, date_format) - def valid_to(self, date_format): + @t.overload + def valid_to(self, date_format: DateFormatStr) -> str: ... + + @t.overload + def valid_to(self, date_format: DateFormatInt) -> int: ... + + @t.overload + def valid_to(self, date_format: DateFormat) -> str | int: ... + + def valid_to(self, date_format: DateFormat) -> str | int: return self.format_datetime(self._valid_to, date_format) - def within_range(self, valid_at): + def within_range(self, valid_at: str | bytes | int | None) -> bool: if valid_at is not None: valid_at_datetime = self.to_datetime(valid_at) return self._valid_from <= valid_at_datetime <= self._valid_to return True + @t.overload @staticmethod - def format_datetime(dt, date_format): + def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ... + + @t.overload + @staticmethod + def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ... + + @t.overload + @staticmethod + def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ... + + @staticmethod + def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: if date_format in ("human_readable", "openssh"): if dt == _ALWAYS: - result = "always" - elif dt == _FOREVER: - result = "forever" + return "always" + if dt == _FOREVER: + return "forever" else: - result = ( + return ( dt.isoformat().replace("+00:00", "") if date_format == "human_readable" else dt.strftime("%Y%m%d%H%M%S") ) - elif date_format == "timestamp": + if date_format == "timestamp": td = dt - _ALWAYS - result = int( + return int( (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6 ) - else: - raise ValueError(f"{date_format} is not a valid format") - return result + raise ValueError(f"{date_format} is not a valid format") @staticmethod - def to_datetime(time_string_or_timestamp): + def to_datetime(time_string_or_timestamp: str | bytes | int) -> datetime: try: if isinstance(time_string_or_timestamp, (str, bytes)): result = OpensshCertificateTimeParameters._time_string_to_datetime( - time_string_or_timestamp.strip() + to_text(time_string_or_timestamp.strip()) ) elif isinstance(time_string_or_timestamp, int): result = OpensshCertificateTimeParameters._timestamp_to_datetime( @@ -175,43 +216,53 @@ class OpensshCertificateTimeParameters: return result @staticmethod - def _timestamp_to_datetime(timestamp): + def _timestamp_to_datetime(timestamp: int) -> datetime: if timestamp == 0x0: - result = _ALWAYS - elif timestamp == 0xFFFFFFFFFFFFFFFF: - result = _FOREVER - else: - try: - result = datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc) - except OverflowError: - raise ValueError - return result + return _ALWAYS + if timestamp == 0xFFFFFFFFFFFFFFFF: + return _FOREVER + try: + return datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc) + except OverflowError: + raise ValueError @staticmethod - def _time_string_to_datetime(time_string): - result = None + def _time_string_to_datetime(time_string: str) -> datetime: if time_string == "always": - result = _ALWAYS - elif time_string == "forever": - result = _FOREVER - elif is_relative_time_string(time_string): + return _ALWAYS + if time_string == "forever": + return _FOREVER + if is_relative_time_string(time_string): result = convert_relative_to_datetime(time_string, with_timezone=True) - else: - for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"): - try: - result = _add_or_remove_timezone( - datetime.strptime(time_string, time_format), - with_timezone=True, - ) - except ValueError: - pass if result is None: raise ValueError + return result + result = None + for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"): + try: + result = _add_or_remove_timezone( + datetime.strptime(time_string, time_format), + with_timezone=True, + ) + except ValueError: + pass + if result is None: + raise ValueError return result +_OpensshCertificateOption = t.TypeVar( + "_OpensshCertificateOption", bound="OpensshCertificateOption" +) + + class OpensshCertificateOption: - def __init__(self, option_type, name, data): + def __init__( + self, + option_type: t.Literal["critical", "extension"], + name: str | bytes, + data: str | bytes, + ): if option_type not in ("critical", "extension"): raise ValueError("type must be either 'critical' or 'extension'") @@ -225,7 +276,7 @@ class OpensshCertificateOption: self._name = name.lower() self._data = data - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented @@ -237,32 +288,34 @@ class OpensshCertificateOption: ] ) - def __hash__(self): + def __hash__(self) -> int: return hash((self._option_type, self._name, self._data)) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other - def __str__(self): + def __str__(self) -> str: if self._data: - return f"{self._name}={self._data}" - return self._name + return f"{self._name!r}={self._data!r}" + return f"{self._name!r}" @property - def data(self): + def data(self) -> str | bytes: return self._data @property - def name(self): + def name(self) -> str | bytes: return self._name @property - def type(self): + def type(self) -> t.Literal["critical", "extension"]: return self._option_type @classmethod - def from_string(cls, option_string): - if not isinstance(option_string, (str, bytes)): + def from_string( + cls: t.Type[_OpensshCertificateOption], option_string: str + ) -> _OpensshCertificateOption: + if not isinstance(option_string, str): raise ValueError( f"option_string must be a string not {type(option_string)}" ) @@ -280,7 +333,8 @@ class OpensshCertificateOption: name, data = option_string.strip(), "" return cls( - option_type=option_type or get_option_type(name.lower()), + # We have str, but we're expecting a specific literal: + option_type=option_type or get_option_type(name.lower()), # type: ignore name=name, data=data, ) @@ -291,21 +345,21 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta): def __init__( self, - nonce=None, - serial=None, - cert_type=None, - key_id=None, - principals=None, - valid_after=None, - valid_before=None, - critical_options=None, - extensions=None, - reserved=None, - signing_key=None, + nonce: bytes | None = None, + serial: int | None = None, + cert_type: int | None = None, + key_id: bytes | None = None, + principals: list[bytes] | None = None, + valid_after: int | None = None, + valid_before: int | None = None, + critical_options: list[tuple[bytes, bytes]] | None = None, + extensions: list[tuple[bytes, bytes]] | None = None, + reserved: bytes | None = None, + signing_key: bytes | None = None, ): self.nonce = nonce self.serial = serial - self._cert_type = cert_type + self._cert_type: int | None = cert_type self.key_id = key_id self.principals = principals self.valid_after = valid_after @@ -315,10 +369,10 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta): self.reserved = reserved self.signing_key = signing_key - self.type_string = None + self.type_string: bytes | None = None @property - def cert_type(self): + def cert_type(self) -> t.Literal["user", "host", ""]: if self._cert_type == _USER_TYPE: return "user" elif self._cert_type == _HOST_TYPE: @@ -327,7 +381,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta): return "" @cert_type.setter - def cert_type(self, cert_type): + def cert_type(self, cert_type: t.Literal["user", "host"] | int) -> None: if cert_type == "user" or cert_type == _USER_TYPE: self._cert_type = _USER_TYPE elif cert_type == "host" or cert_type == _HOST_TYPE: @@ -335,28 +389,30 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta): else: raise ValueError(f"{cert_type} is not a valid certificate type") - def signing_key_fingerprint(self): + def signing_key_fingerprint(self) -> bytes: + if self.signing_key is None: + raise ValueError("signing_key not present") return fingerprint(self.signing_key) @abc.abstractmethod - def public_key_fingerprint(self): + def public_key_fingerprint(self) -> bytes: pass @abc.abstractmethod - def parse_public_numbers(self, parser): + def parse_public_numbers(self, parser: OpensshParser) -> None: pass class OpensshRSACertificateInfo(OpensshCertificateInfo): - def __init__(self, e=None, n=None, **kwargs): + def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None: super(OpensshRSACertificateInfo, self).__init__(**kwargs) self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01 self.e = e self.n = n # See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6 - def public_key_fingerprint(self): - if any([self.e is None, self.n is None]): + def public_key_fingerprint(self) -> bytes: + if self.e is None or self.n is None: return b"" writer = _OpensshWriter() @@ -366,13 +422,20 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo): return fingerprint(writer.bytes()) - def parse_public_numbers(self, parser): + def parse_public_numbers(self, parser: OpensshParser) -> None: self.e = parser.mpint() self.n = parser.mpint() class OpensshDSACertificateInfo(OpensshCertificateInfo): - def __init__(self, p=None, q=None, g=None, y=None, **kwargs): + def __init__( + self, + p: int | None = None, + q: int | None = None, + g: int | None = None, + y: int | None = None, + **kwargs, + ) -> None: super(OpensshDSACertificateInfo, self).__init__(**kwargs) self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01 self.p = p @@ -381,8 +444,8 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo): self.y = y # See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6 - def public_key_fingerprint(self): - if any([self.p is None, self.q is None, self.g is None, self.y is None]): + def public_key_fingerprint(self) -> bytes: + if self.p is None or self.q is None or self.g is None or self.y is None: return b"" writer = _OpensshWriter() @@ -394,7 +457,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo): return fingerprint(writer.bytes()) - def parse_public_numbers(self, parser): + def parse_public_numbers(self, parser: OpensshParser) -> None: self.p = parser.mpint() self.q = parser.mpint() self.g = parser.mpint() @@ -402,7 +465,9 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo): class OpensshECDSACertificateInfo(OpensshCertificateInfo): - def __init__(self, curve=None, public_key=None, **kwargs): + def __init__( + self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs + ): super(OpensshECDSACertificateInfo, self).__init__(**kwargs) self._curve = None if curve is not None: @@ -411,11 +476,11 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo): self.public_key = public_key @property - def curve(self): + def curve(self) -> bytes | None: return self._curve @curve.setter - def curve(self, curve): + def curve(self, curve: bytes) -> None: if curve in _ECDSA_CURVE_IDENTIFIERS.values(): self._curve = curve self.type_string = ( @@ -428,8 +493,8 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo): ) # See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6 - def public_key_fingerprint(self): - if any([self.curve is None, self.public_key is None]): + def public_key_fingerprint(self) -> bytes: + if self.curve is None or self.public_key is None: return b"" writer = _OpensshWriter() @@ -439,18 +504,18 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo): return fingerprint(writer.bytes()) - def parse_public_numbers(self, parser): + def parse_public_numbers(self, parser: OpensshParser) -> None: self.curve = parser.string() self.public_key = parser.string() class OpensshED25519CertificateInfo(OpensshCertificateInfo): - def __init__(self, pk=None, **kwargs): + def __init__(self, pk: bytes | None = None, **kwargs) -> None: super(OpensshED25519CertificateInfo, self).__init__(**kwargs) self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01 self.pk = pk - def public_key_fingerprint(self): + def public_key_fingerprint(self) -> bytes: if self.pk is None: return b"" @@ -460,21 +525,26 @@ class OpensshED25519CertificateInfo(OpensshCertificateInfo): return fingerprint(writer.bytes()) - def parse_public_numbers(self, parser): + def parse_public_numbers(self, parser: OpensshParser) -> None: self.pk = parser.string() +_OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate") + + # See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD class OpensshCertificate: """Encapsulates a formatted OpenSSH certificate including signature and signing key""" - def __init__(self, cert_info, signature): + def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes): self._cert_info = cert_info self.signature = signature @classmethod - def load(cls, path): + def load( + cls: t.Type[_OpensshCertificate], path: str | os.PathLike + ) -> _OpensshCertificate: if not os.path.exists(path): raise ValueError(f"{path} is not a valid path.") @@ -492,11 +562,11 @@ class OpensshCertificate: for key_type, string in _SSH_TYPE_STRINGS.items(): if format_identifier == string + _CERT_SUFFIX_V01: - pub_key_type = key_type + pub_key_type = t.cast(KeyType, key_type) break else: raise ValueError( - f"Invalid certificate format identifier: {format_identifier}" + f"Invalid certificate format identifier: {format_identifier!r}" ) parser = OpensshParser(cert) @@ -521,75 +591,97 @@ class OpensshCertificate: ) @property - def type_string(self): + def type_string(self) -> str: return to_text(self._cert_info.type_string) @property - def nonce(self): + def nonce(self) -> bytes: + if self._cert_info.nonce is None: + raise ValueError return self._cert_info.nonce @property - def public_key(self): + def public_key(self) -> str: return to_text(self._cert_info.public_key_fingerprint()) @property - def serial(self): + def serial(self) -> int: + if self._cert_info.serial is None: + raise ValueError return self._cert_info.serial @property - def type(self): - return self._cert_info.cert_type + def type(self) -> t.Literal["user", "host"]: + result = self._cert_info.cert_type + if result == "": + raise ValueError + return result @property - def key_id(self): + def key_id(self) -> str: return to_text(self._cert_info.key_id) @property - def principals(self): + def principals(self) -> list[str]: + if self._cert_info.principals is None: + raise ValueError return [to_text(p) for p in self._cert_info.principals] @property - def valid_after(self): + def valid_after(self) -> int: + if self._cert_info.valid_after is None: + raise ValueError return self._cert_info.valid_after @property - def valid_before(self): + def valid_before(self) -> int: + if self._cert_info.valid_before is None: + raise ValueError return self._cert_info.valid_before @property - def critical_options(self): + def critical_options(self) -> list[OpensshCertificateOption]: + if self._cert_info.critical_options is None: + raise ValueError return [ OpensshCertificateOption("critical", to_text(n), to_text(d)) for n, d in self._cert_info.critical_options ] @property - def extensions(self): + def extensions(self) -> list[OpensshCertificateOption]: + if self._cert_info.extensions is None: + raise ValueError return [ OpensshCertificateOption("extension", to_text(n), to_text(d)) for n, d in self._cert_info.extensions ] @property - def reserved(self): + def reserved(self) -> bytes: + if self._cert_info.reserved is None: + raise ValueError return self._cert_info.reserved @property - def signing_key(self): + def signing_key(self) -> str: return to_text(self._cert_info.signing_key_fingerprint()) @property - def signature_type(self): + def signature_type(self) -> str: signature_data = OpensshParser.signature_data(self.signature) return to_text(signature_data["signature_type"]) @staticmethod - def _parse_cert_info(pub_key_type, parser): + def _parse_cert_info( + pub_key_type: KeyType, parser: OpensshParser + ) -> OpensshCertificateInfo: cert_info = get_cert_info_object(pub_key_type) cert_info.nonce = parser.string() cert_info.parse_public_numbers(parser) cert_info.serial = parser.uint64() - cert_info.cert_type = parser.uint32() + # mypy doesn't understand that the setter accepts other types than the getter: + cert_info.cert_type = parser.uint32() # type: ignore cert_info.key_id = parser.string() cert_info.principals = parser.string_list() cert_info.valid_after = parser.uint64() @@ -601,7 +693,7 @@ class OpensshCertificate: return cert_info - def to_dict(self): + def to_dict(self) -> dict[str, t.Any]: time_parameters = OpensshCertificateTimeParameters( valid_from=self.valid_after, valid_to=self.valid_before ) @@ -624,7 +716,7 @@ class OpensshCertificate: } -def apply_directives(directives): +def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOption]: if any(d not in _DIRECTIVES for d in directives): raise ValueError(f"directives must be one of {', '.join(_DIRECTIVES)}") @@ -650,50 +742,47 @@ def apply_directives(directives): ) -def default_options(): +def default_options() -> list[OpensshCertificateOption]: return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS] -def fingerprint(public_key): +def fingerprint(public_key: bytes) -> bytes: """Generates a SHA256 hash and formats output to resemble ``ssh-keygen``""" h = sha256() h.update(public_key) return b"SHA256:" + b64encode(h.digest()).rstrip(b"=") -def get_cert_info_object(key_type): +def get_cert_info_object(key_type: KeyType) -> OpensshCertificateInfo: if key_type == "rsa": - cert_info = OpensshRSACertificateInfo() - elif key_type == "dsa": - cert_info = OpensshDSACertificateInfo() - elif key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"): - cert_info = OpensshECDSACertificateInfo() - elif key_type == "ed25519": - cert_info = OpensshED25519CertificateInfo() - else: - raise ValueError(f"{key_type} is not a valid key type") - - return cert_info + return OpensshRSACertificateInfo() + if key_type == "dsa": + return OpensshDSACertificateInfo() + if key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"): + return OpensshECDSACertificateInfo() + if key_type == "ed25519": + return OpensshED25519CertificateInfo() + raise ValueError(f"{key_type} is not a valid key type") -def get_option_type(name): +def get_option_type(name: str) -> t.Literal["critical", "extension"]: if name in _CRITICAL_OPTIONS: - result = "critical" - elif name in _EXTENSIONS: - result = "extension" - else: - raise ValueError( - f"{name} is not a valid option. " - "Custom options must start with 'critical:' or 'extension:' to indicate type" - ) - return result + return "critical" + if name in _EXTENSIONS: + return "extension" + raise ValueError( + f"{name} is not a valid option. " + "Custom options must start with 'critical:' or 'extension:' to indicate type" + ) -def is_relative_time_string(time_string): +def is_relative_time_string(time_string: str) -> bool: return time_string.startswith("+") or time_string.startswith("-") -def parse_option_list(option_list): +def parse_option_list( + option_list: t.Iterable[str], +) -> tuple[list[OpensshCertificateOption], list[OpensshCertificateOption]]: critical_options = [] directives = [] extensions = [] diff --git a/plugins/module_utils/openssh/cryptography.py b/plugins/module_utils/openssh/cryptography.py index 24653d68..d0c3808f 100644 --- a/plugins/module_utils/openssh/cryptography.py +++ b/plugins/module_utils/openssh/cryptography.py @@ -5,6 +5,7 @@ from __future__ import annotations import os +import typing as t from base64 import b64decode, b64encode from getpass import getuser from socket import gethostname @@ -64,6 +65,27 @@ except ImportError: CRYPTOGRAPHY_VERSION = "0.0" _ALGORITHM_PARAMETERS = {} + +if t.TYPE_CHECKING: + KeyFormat = t.Literal["SSH", "PKCS8", "PKCS1"] + KeySerializationFormat = t.Literal["PEM", "DER", "SSH"] + KeyType = t.Literal["rsa", "dsa", "ed25519", "ecdsa"] + + PrivateKeyTypes = t.Union[ + rsa.RSAPrivateKey, + dsa.DSAPrivateKey, + ec.EllipticCurvePrivateKey, + Ed25519PrivateKey, + ] + PublicKeyTypes = t.Union[ + rsa.RSAPublicKey, dsa.DSAPublicKey, ec.EllipticCurvePublicKey, Ed25519PublicKey + ] + + from cryptography.hazmat.primitives.asymmetric.types import ( + PublicKeyTypes as AllPublicKeyTypes, + ) + + _TEXT_ENCODING = "UTF-8" @@ -111,11 +133,19 @@ class InvalidSignatureError(OpenSSHError): pass +_AsymmetricKeypair = t.TypeVar("_AsymmetricKeypair", bound="AsymmetricKeypair") + + class AsymmetricKeypair: """Container for newly generated asymmetric key pairs or those loaded from existing files""" @classmethod - def generate(cls, keytype="rsa", size=None, passphrase=None): + def generate( + cls: t.Type[_AsymmetricKeypair], + keytype: KeyType = "rsa", + size: int | None = None, + passphrase: bytes | None = None, + ) -> _AsymmetricKeypair: """Returns an Asymmetric_Keypair object generated with the supplied parameters or defaults to an unencrypted RSA-2048 key @@ -124,19 +154,21 @@ class AsymmetricKeypair: :passphrase: Secret of type Bytes used to encrypt the private key being generated """ - if keytype not in _ALGORITHM_PARAMETERS.keys(): + if keytype not in _ALGORITHM_PARAMETERS: raise InvalidKeyTypeError( f"{keytype} is not a valid keytype. Valid keytypes are {', '.join(_ALGORITHM_PARAMETERS)}" ) if not size: - size = _ALGORITHM_PARAMETERS[keytype]["default_size"] + size = _ALGORITHM_PARAMETERS[keytype]["default_size"] # type: ignore else: - if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]: + if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]: # type: ignore raise InvalidKeySizeError( f"{size} is not a valid key size for {keytype} keys" ) + size = t.cast(int, size) + privatekey: PrivateKeyTypes if passphrase: encryption_algorithm = get_encryption_algorithm(passphrase) else: @@ -157,7 +189,7 @@ class AsymmetricKeypair: privatekey = Ed25519PrivateKey.generate() elif keytype == "ecdsa": privatekey = ec.generate_private_key( - _ALGORITHM_PARAMETERS["ecdsa"]["curves"][size], + _ALGORITHM_PARAMETERS["ecdsa"]["curves"][size], # type: ignore ) publickey = privatekey.public_key() @@ -172,13 +204,13 @@ class AsymmetricKeypair: @classmethod def load( - cls, - path, - passphrase=None, - private_key_format="PEM", - public_key_format="PEM", - no_public_key=False, - ): + cls: t.Type[_AsymmetricKeypair], + path: str | os.PathLike, + passphrase: bytes | None = None, + private_key_format: KeySerializationFormat = "PEM", + public_key_format: KeySerializationFormat = "PEM", + no_public_key: bool = False, + ) -> _AsymmetricKeypair: """Returns an Asymmetric_Keypair object loaded from the supplied file path :path: A path to an existing private key to be loaded @@ -197,14 +229,17 @@ class AsymmetricKeypair: if no_public_key: publickey = privatekey.public_key() else: - publickey = load_publickey(path + ".pub", public_key_format) + # TODO: BUG: load_publickey() can return unsupported key types + # (Also we should check whether the public key fits the private key...) + publickey = load_publickey(path + ".pub", public_key_format) # type: ignore # Ed25519 keys are always of size 256 and do not have a key_size attribute if isinstance(privatekey, Ed25519PrivateKey): - size = _ALGORITHM_PARAMETERS["ed25519"]["default_size"] + size: int = _ALGORITHM_PARAMETERS["ed25519"]["default_size"] # type: ignore else: size = privatekey.key_size + keytype: KeyType if isinstance(privatekey, rsa.RSAPrivateKey): keytype = "rsa" elif isinstance(privatekey, dsa.DSAPrivateKey): @@ -224,7 +259,14 @@ class AsymmetricKeypair: encryption_algorithm=encryption_algorithm, ) - def __init__(self, keytype, size, privatekey, publickey, encryption_algorithm): + def __init__( + self, + keytype: KeyType, + size: int, + privatekey: PrivateKeyTypes, + publickey: PublicKeyTypes, + encryption_algorithm: serialization.KeySerializationEncryption, + ) -> None: """ :keytype: One of rsa, dsa, ecdsa, ed25519 :size: The key length for the private key of this key pair @@ -246,7 +288,7 @@ class AsymmetricKeypair: "The private key and public key of this keypair do not match" ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, AsymmetricKeypair): return NotImplemented @@ -256,55 +298,53 @@ class AsymmetricKeypair: self.encryption_algorithm, other.encryption_algorithm ) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other @property - def private_key(self): + def private_key(self) -> PrivateKeyTypes: """Returns the private key of this key pair""" return self.__privatekey @property - def public_key(self): + def public_key(self) -> PublicKeyTypes: """Returns the public key of this key pair""" return self.__publickey @property - def size(self): + def size(self) -> int: """Returns the size of the private key of this key pair""" return self.__size @property - def key_type(self): + def key_type(self) -> KeyType: """Returns the key type of this key pair""" return self.__keytype @property - def encryption_algorithm(self): + def encryption_algorithm(self) -> serialization.KeySerializationEncryption: """Returns the key encryption algorithm of this key pair""" return self.__encryption_algorithm - def sign(self, data): + def sign(self, data: bytes) -> bytes: """Returns signature of data signed with the private key of this key pair :data: byteslike data to sign """ try: - signature = self.__privatekey.sign( - data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"] + return self.__privatekey.sign( + data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"] # type: ignore ) except TypeError as e: raise InvalidDataError(e) - return signature - - def verify(self, signature, data): + def verify(self, signature: bytes, data: bytes) -> None: """Verifies that the signature associated with the provided data was signed by the private key of this key pair. @@ -312,15 +352,15 @@ class AsymmetricKeypair: :data: byteslike data signed by the provided signature """ try: - return self.__publickey.verify( + self.__publickey.verify( signature, data, - **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"], + **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"], # type: ignore ) except InvalidSignature: raise InvalidSignatureError - def update_passphrase(self, passphrase=None): + def update_passphrase(self, passphrase: bytes | None = None) -> None: """Updates the encryption algorithm of this key pair :passphrase: Byte secret used to encrypt this key pair @@ -332,11 +372,20 @@ class AsymmetricKeypair: self.__encryption_algorithm = serialization.NoEncryption() +_OpensshKeypair = t.TypeVar("_OpensshKeypair", bound="OpensshKeypair") + + class OpensshKeypair: """Container for OpenSSH encoded asymmetric key pairs""" @classmethod - def generate(cls, keytype="rsa", size=None, passphrase=None, comment=None): + def generate( + cls: t.Type[_OpensshKeypair], + keytype: KeyType = "rsa", + size: int | None = None, + passphrase: bytes | None = None, + comment: str | None = None, + ) -> _OpensshKeypair: """Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key :keytype: One of rsa, dsa, ecdsa, ed25519 @@ -362,7 +411,12 @@ class OpensshKeypair: ) @classmethod - def load(cls, path, passphrase=None, no_public_key=False): + def load( + cls: t.Type[_OpensshKeypair], + path: str | os.PathLike, + passphrase: bytes | None = None, + no_public_key: bool = False, + ) -> _OpensshKeypair: """Returns an Openssh_Keypair object loaded from the supplied file path :path: A path to an existing private key to be loaded @@ -373,7 +427,7 @@ class OpensshKeypair: if no_public_key: comment = "" else: - comment = extract_comment(path + ".pub") + comment = extract_comment(str(path) + ".pub") asym_keypair = AsymmetricKeypair.load( path, passphrase, "SSH", "SSH", no_public_key @@ -391,7 +445,9 @@ class OpensshKeypair: ) @staticmethod - def encode_openssh_privatekey(asym_keypair, key_format): + def encode_openssh_privatekey( + asym_keypair: AsymmetricKeypair, key_format: KeyFormat + ) -> bytes: """Returns an OpenSSH encoded private key for a given keypair :asym_keypair: Asymmetric_Keypair from the private key is extracted @@ -422,7 +478,9 @@ class OpensshKeypair: return encoded_privatekey @staticmethod - def encode_openssh_publickey(asym_keypair, comment): + def encode_openssh_publickey( + asym_keypair: AsymmetricKeypair, comment: str + ) -> bytes: """Returns an OpenSSH encoded public key for a given keypair :asym_keypair: Asymmetric_Keypair from the public key is extracted @@ -436,14 +494,19 @@ class OpensshKeypair: validate_comment(comment) encoded_publickey += ( - f" {comment}".encode(encoding=_TEXT_ENCODING) if comment else b"" + (b" " + comment.encode(encoding=_TEXT_ENCODING)) if comment else b"" ) return encoded_publickey def __init__( - self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment - ): + self, + asym_keypair: AsymmetricKeypair, + openssh_privatekey: bytes, + openssh_publickey: bytes, + fingerprint: str, + comment: str | None, + ) -> None: """ :asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived :openssh_privatekey: An OpenSSH encoded private key @@ -458,7 +521,7 @@ class OpensshKeypair: self.__fingerprint = fingerprint self.__comment = comment - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, OpensshKeypair): return NotImplemented @@ -468,49 +531,49 @@ class OpensshKeypair: ) @property - def asymmetric_keypair(self): + def asymmetric_keypair(self) -> AsymmetricKeypair: """Returns the underlying asymmetric key pair of this OpenSSH encoded key pair""" return self.__asym_keypair @property - def private_key(self): + def private_key(self) -> bytes: """Returns the OpenSSH formatted private key of this key pair""" return self.__openssh_privatekey @property - def public_key(self): + def public_key(self) -> bytes: """Returns the OpenSSH formatted public key of this key pair""" return self.__openssh_publickey @property - def size(self): + def size(self) -> int: """Returns the size of the private key of this key pair""" return self.__asym_keypair.size @property - def key_type(self): + def key_type(self) -> KeyType: """Returns the key type of this key pair""" return self.__asym_keypair.key_type @property - def fingerprint(self): + def fingerprint(self) -> str: """Returns the fingerprint (SHA256 Hash) of the public key of this key pair""" return self.__fingerprint @property - def comment(self): + def comment(self) -> str | None: """Returns the comment applied to the OpenSSH formatted public key of this key pair""" return self.__comment @comment.setter - def comment(self, comment): + def comment(self, comment: str) -> bytes: """Updates the comment applied to the OpenSSH formatted public key of this key pair :comment: Text to update the OpenSSH public key comment @@ -529,7 +592,7 @@ class OpensshKeypair: ) return self.__openssh_publickey - def update_passphrase(self, passphrase): + def update_passphrase(self, passphrase: bytes | None) -> None: """Updates the passphrase used to encrypt the private key of this keypair :passphrase: Text secret used for encryption @@ -541,18 +604,17 @@ class OpensshKeypair: ) -def load_privatekey(path, passphrase, key_format): +def load_privatekey( + path: str | os.PathLike, + passphrase: bytes | None, + key_format: KeySerializationFormat, +) -> PrivateKeyTypes: privatekey_loaders = { "PEM": serialization.load_pem_private_key, "DER": serialization.load_der_private_key, + "SSH": serialization.load_ssh_private_key, } - # OpenSSH formatted private keys are not available in Cryptography <3.0 - if hasattr(serialization, "load_ssh_private_key"): - privatekey_loaders["SSH"] = serialization.load_ssh_private_key - else: - privatekey_loaders["SSH"] = serialization.load_pem_private_key - try: privatekey_loader = privatekey_loaders[key_format] except KeyError: @@ -567,16 +629,16 @@ def load_privatekey(path, passphrase, key_format): with open(path, "rb") as f: content = f.read() - privatekey = privatekey_loader( + privatekey = privatekey_loader( # type: ignore data=content, password=passphrase, ) - except ValueError as e: + except ValueError as exc: # Revert to PEM if key could not be loaded in SSH format if key_format == "SSH": try: - privatekey = privatekey_loaders["PEM"]( + privatekey = privatekey_loaders["PEM"]( # type: ignore data=content, password=passphrase, ) @@ -587,7 +649,7 @@ def load_privatekey(path, passphrase, key_format): except UnsupportedAlgorithm as e: raise InvalidAlgorithmError(e) else: - raise InvalidPrivateKeyFileError(e) + raise InvalidPrivateKeyFileError(exc) except TypeError as e: raise InvalidPassphraseError(e) except UnsupportedAlgorithm as e: @@ -596,7 +658,9 @@ def load_privatekey(path, passphrase, key_format): return privatekey -def load_publickey(path, key_format): +def load_publickey( + path: str | os.PathLike, key_format: KeySerializationFormat +) -> AllPublicKeyTypes: publickey_loaders = { "PEM": serialization.load_pem_public_key, "DER": serialization.load_der_public_key, @@ -628,20 +692,27 @@ def load_publickey(path, key_format): return publickey -def compare_publickeys(pk1, pk2): +def compare_publickeys(pk1: PublicKeyTypes, pk2: PublicKeyTypes) -> bool: a = isinstance(pk1, Ed25519PublicKey) b = isinstance(pk2, Ed25519PublicKey) if a or b: if not a or not b: return False - a = pk1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) - b = pk2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) - return a == b + a_bytes = pk1.public_bytes( + serialization.Encoding.Raw, serialization.PublicFormat.Raw + ) + b_bytes = pk2.public_bytes( + serialization.Encoding.Raw, serialization.PublicFormat.Raw + ) + return a_bytes == b_bytes else: - return pk1.public_numbers() == pk2.public_numbers() + return pk1.public_numbers() == pk2.public_numbers() # type: ignore -def compare_encryption_algorithms(ea1, ea2): +def compare_encryption_algorithms( + ea1: serialization.KeySerializationEncryption, + ea2: serialization.KeySerializationEncryption, +) -> bool: if isinstance(ea1, serialization.NoEncryption) and isinstance( ea2, serialization.NoEncryption ): @@ -654,19 +725,21 @@ def compare_encryption_algorithms(ea1, ea2): return False -def get_encryption_algorithm(passphrase): +def get_encryption_algorithm( + passphrase: bytes, +) -> serialization.KeySerializationEncryption: try: return serialization.BestAvailableEncryption(passphrase) except ValueError as e: raise InvalidPassphraseError(e) -def validate_comment(comment): +def validate_comment(comment: str) -> None: if not hasattr(comment, "encode"): raise InvalidCommentError(f"{comment} cannot be encoded to text") -def extract_comment(path): +def extract_comment(path: str | os.PathLike) -> str: if not os.path.exists(path): raise InvalidPublicKeyFileError(f"No file was found at {path}") @@ -684,7 +757,7 @@ def extract_comment(path): return comment -def calculate_fingerprint(openssh_publickey): +def calculate_fingerprint(openssh_publickey: bytes) -> str: digest = hashes.Hash(hashes.SHA256()) decoded_pubkey = b64decode(openssh_publickey.split(b" ")[1]) digest.update(decoded_pubkey) diff --git a/plugins/module_utils/openssh/utils.py b/plugins/module_utils/openssh/utils.py index d58c917d..39520a1e 100644 --- a/plugins/module_utils/openssh/utils.py +++ b/plugins/module_utils/openssh/utils.py @@ -7,6 +7,7 @@ from __future__ import annotations import os import re +import typing as t from contextlib import contextmanager from struct import Struct @@ -38,17 +39,20 @@ _UINT64 = Struct(b"!Q") _UINT64_MAX = 0xFFFFFFFFFFFFFFFF -def any_in(sequence, *elements): +_T = t.TypeVar("_T") + + +def any_in(sequence: t.Iterable[_T], *elements: _T) -> bool: return any(e in sequence for e in elements) -def file_mode(path): +def file_mode(path: str | os.PathLike) -> int: if not os.path.exists(path): return 0o000 return os.stat(path).st_mode & 0o777 -def parse_openssh_version(version_string): +def parse_openssh_version(version_string: str) -> str | None: """Parse the version output of ssh -V and return version numbers that can be compared""" parsed_result = re.match( @@ -63,7 +67,7 @@ def parse_openssh_version(version_string): @contextmanager -def secure_open(path, mode): +def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]: fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode) try: yield fd @@ -71,7 +75,7 @@ def secure_open(path, mode): os.close(fd) -def secure_write(path, mode, content): +def secure_write(path: str | os.PathLike, mode: int, content: bytes) -> None: with secure_open(path, mode) as fd: os.write(fd, content) @@ -84,35 +88,35 @@ class OpensshParser: UINT32_OFFSET = 4 UINT64_OFFSET = 8 - def __init__(self, data): + def __init__(self, data: bytes | bytearray) -> None: if not isinstance(data, (bytes, bytearray)): raise TypeError(f"Data must be bytes-like not {type(data)}") self._data = memoryview(data) self._pos = 0 - def boolean(self): + def boolean(self) -> bool: next_pos = self._check_position(self.BOOLEAN_OFFSET) value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0] self._pos = next_pos return value - def uint32(self): + def uint32(self) -> int: next_pos = self._check_position(self.UINT32_OFFSET) value = _UINT32.unpack(self._data[self._pos : next_pos])[0] self._pos = next_pos return value - def uint64(self): + def uint64(self) -> int: next_pos = self._check_position(self.UINT64_OFFSET) value = _UINT64.unpack(self._data[self._pos : next_pos])[0] self._pos = next_pos return value - def string(self): + def string(self) -> bytes: length = self.uint32() next_pos = self._check_position(length) @@ -122,15 +126,15 @@ class OpensshParser: # Cast to bytes is required as a memoryview slice is itself a memoryview return bytes(value) - def mpint(self): + def mpint(self) -> int: return self._big_int(self.string(), "big", signed=True) - def name_list(self): + def name_list(self) -> list[str]: raw_string = self.string() return raw_string.decode("ASCII").split(",") # Convenience function, but not an official data type from SSH - def string_list(self): + def string_list(self) -> list[bytes]: result = [] raw_string = self.string() @@ -142,7 +146,7 @@ class OpensshParser: return result # Convenience function, but not an official data type from SSH - def option_list(self): + def option_list(self) -> list[tuple[bytes, bytes]]: result = [] raw_string = self.string() @@ -159,15 +163,15 @@ class OpensshParser: return result - def seek(self, offset): + def seek(self, offset: int) -> int: self._pos = self._check_position(offset) return self._pos - def remaining_bytes(self): + def remaining_bytes(self) -> int: return len(self._data) - self._pos - def _check_position(self, offset): + def _check_position(self, offset: int) -> int: if self._pos + offset > len(self._data): raise ValueError(f"Insufficient data remaining at position: {self._pos}") elif self._pos + offset < 0: @@ -176,8 +180,8 @@ class OpensshParser: return self._pos + offset @classmethod - def signature_data(cls, signature_string): - signature_data = {} + def signature_data(cls, signature_string: bytes) -> dict[str, bytes | int]: + signature_data: dict[str, bytes | int] = {} parser = cls(signature_string) signature_type = parser.string() @@ -205,14 +209,19 @@ class OpensshParser: signature_data["R"] = cls._big_int(signature_blob[:32], "little") signature_data["S"] = cls._big_int(signature_blob[32:], "little") else: - raise ValueError(f"{signature_type} is not a valid signature type") + raise ValueError(f"{signature_type!r} is not a valid signature type") signature_data["signature_type"] = signature_type return signature_data @classmethod - def _big_int(cls, raw_string, byte_order, signed=False): + def _big_int( + cls, + raw_string: bytes, + byte_order: t.Literal["big", "little"], + signed: bool = False, + ) -> int: if byte_order not in ("big", "little"): raise ValueError( f"Byte_order must be one of (big, little) not {byte_order}" @@ -230,18 +239,16 @@ class _OpensshWriter: in validating parsed material. """ - def __init__(self, buffer=None): + def __init__(self, buffer: bytearray | None = None): if buffer is not None: - if not isinstance(buffer, (bytes, bytearray)): - raise TypeError( - f"Buffer must be a bytes-like object not {type(buffer)}" - ) + if not isinstance(buffer, bytearray): + raise TypeError(f"Buffer must be a bytearray, not {type(buffer)}") else: buffer = bytearray() - self._buff = buffer + self._buff: bytearray = buffer - def boolean(self, value): + def boolean(self, value: bool) -> t.Self: if not isinstance(value, bool): raise TypeError(f"Value must be of type bool not {type(value)}") @@ -249,7 +256,7 @@ class _OpensshWriter: return self - def uint32(self, value): + def uint32(self, value: int) -> t.Self: if not isinstance(value, int): raise TypeError(f"Value must be of type int not {type(value)}") if value < 0 or value > _UINT32_MAX: @@ -261,7 +268,7 @@ class _OpensshWriter: return self - def uint64(self, value): + def uint64(self, value: int) -> t.Self: if not isinstance(value, int): raise TypeError(f"Value must be of type int not {type(value)}") if value < 0 or value > _UINT64_MAX: @@ -273,7 +280,7 @@ class _OpensshWriter: return self - def string(self, value): + def string(self, value: bytes | bytearray) -> t.Self: if not isinstance(value, (bytes, bytearray)): raise TypeError(f"Value must be bytes-like not {type(value)}") self.uint32(len(value)) @@ -281,7 +288,7 @@ class _OpensshWriter: return self - def mpint(self, value): + def mpint(self, value: int) -> t.Self: if not isinstance(value, int): raise TypeError(f"Value must be of type int not {type(value)}") @@ -289,7 +296,7 @@ class _OpensshWriter: return self - def name_list(self, value): + def name_list(self, value: list[str]) -> t.Self: if not isinstance(value, list): raise TypeError(f"Value must be a list of byte strings not {type(value)}") @@ -300,7 +307,7 @@ class _OpensshWriter: return self - def string_list(self, value): + def string_list(self, value: list[bytes]) -> t.Self: if not isinstance(value, list): raise TypeError(f"Value must be a list of byte string not {type(value)}") @@ -312,7 +319,7 @@ class _OpensshWriter: return self - def option_list(self, value): + def option_list(self, value: list[tuple[bytes, bytes]]) -> t.Self: if not isinstance(value, list) or (value and not isinstance(value[0], tuple)): raise TypeError("Value must be a list of tuples") @@ -327,7 +334,7 @@ class _OpensshWriter: return self @staticmethod - def _int_to_mpint(num): + def _int_to_mpint(num: int) -> bytes: byte_length = (num.bit_length() + 7) // 8 try: return num.to_bytes(byte_length, "big", signed=True) @@ -335,5 +342,5 @@ class _OpensshWriter: except OverflowError: return num.to_bytes(byte_length + 1, "big", signed=True) - def bytes(self): + def bytes(self) -> bytes: return bytes(self._buff) diff --git a/plugins/module_utils/serial.py b/plugins/module_utils/serial.py index 875b6544..8445cfc7 100644 --- a/plugins/module_utils/serial.py +++ b/plugins/module_utils/serial.py @@ -10,7 +10,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.math impor ) -def th(number): +def th(number: int) -> str: abs_number = abs(number) mod_10 = abs_number % 10 mod_100 = abs_number % 100 @@ -24,13 +24,13 @@ def th(number): return "th" -def parse_serial(value): +def parse_serial(value: str | bytes) -> int: """ Given a colon-separated string of hexadecimal byte values, converts it to an integer. """ - value = to_native(value) + value_str = to_native(value) result = 0 - for i, part in enumerate(value.split(":")): + for i, part in enumerate(value_str.split(":")): try: part_value = int(part, 16) if part_value < 0 or part_value > 255: @@ -43,11 +43,11 @@ def parse_serial(value): return result -def to_serial(value): +def to_serial(value: int) -> str: """ Given an integer, converts its absolute value to a colon-separated string of hexadecimal byte values. """ - value = convert_int_to_hex(value).upper() - if len(value) % 2 != 0: - value = "0" + value - return ":".join(value[i : i + 2] for i in range(0, len(value), 2)) + value_str = convert_int_to_hex(value).upper() + if len(value_str) % 2 != 0: + value_str = f"0{value_str}" + return ":".join(value_str[i : i + 2] for i in range(0, len(value_str), 2)) diff --git a/plugins/module_utils/time.py b/plugins/module_utils/time.py index 0ccf859d..29e157c4 100644 --- a/plugins/module_utils/time.py +++ b/plugins/module_utils/time.py @@ -13,37 +13,16 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.basic impo ) -try: - UTC = datetime.timezone.utc -except AttributeError: - _DURATION_ZERO = datetime.timedelta(0) - - class _UTCClass(datetime.tzinfo): - def utcoffset(self, dt): - return _DURATION_ZERO - - def dst(self, dt): - return _DURATION_ZERO - - def tzname(self, dt): - return "UTC" - - def fromutc(self, dt): - return dt - - def __repr__(self): - return "UTC" - - UTC = _UTCClass() +UTC = datetime.timezone.utc -def get_now_datetime(with_timezone): +def get_now_datetime(with_timezone: bool) -> datetime.datetime: if with_timezone: return datetime.datetime.now(tz=UTC) return datetime.datetime.utcnow() -def ensure_utc_timezone(timestamp): +def ensure_utc_timezone(timestamp: datetime.datetime) -> datetime.datetime: if timestamp.tzinfo is UTC: return timestamp if timestamp.tzinfo is None: @@ -52,7 +31,7 @@ def ensure_utc_timezone(timestamp): return timestamp.astimezone(UTC) -def remove_timezone(timestamp): +def remove_timezone(timestamp: datetime.datetime) -> datetime.datetime: # Convert to native datetime object if timestamp.tzinfo is None: return timestamp @@ -61,26 +40,34 @@ def remove_timezone(timestamp): return timestamp.replace(tzinfo=None) -def add_or_remove_timezone(timestamp, with_timezone): +def add_or_remove_timezone( + timestamp: datetime.datetime, with_timezone: bool +) -> datetime.datetime: return ( ensure_utc_timezone(timestamp) if with_timezone else remove_timezone(timestamp) ) -def get_epoch_seconds(timestamp): +def get_epoch_seconds(timestamp: datetime.datetime) -> float: if timestamp.tzinfo is None: # timestamp.timestamp() is offset by the local timezone if timestamp has no timezone timestamp = ensure_utc_timezone(timestamp) return timestamp.timestamp() -def from_epoch_seconds(timestamp, with_timezone): +def from_epoch_seconds( + timestamp: int | float, with_timezone: bool +) -> datetime.datetime: if with_timezone: return datetime.datetime.fromtimestamp(timestamp, UTC) return datetime.datetime.utcfromtimestamp(timestamp) -def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=None): +def convert_relative_to_datetime( + relative_time_string: str, + with_timezone: bool = False, + now: datetime.datetime | None = None, +) -> datetime.datetime | None: """Get a datetime.datetime or None from a string in the time format described in sshd_config(5)""" parsed_result = re.match( @@ -115,7 +102,12 @@ def convert_relative_to_datetime(relative_time_string, with_timezone=False, now= return now - offset -def get_relative_time_option(input_string, input_name, with_timezone=False, now=None): +def get_relative_time_option( + input_string: str, + input_name: str, + with_timezone: bool = False, + now: datetime.datetime | None = None, +) -> datetime.datetime: """ Return an absolute timespec if a relative timespec or an ASN1 formatted string is provided. @@ -129,9 +121,12 @@ def get_relative_time_option(input_string, input_name, with_timezone=False, now= ) # Relative time if result.startswith("+") or result.startswith("-"): - return convert_relative_to_datetime( - result, with_timezone=with_timezone, now=now - ) + res = convert_relative_to_datetime(result, with_timezone=with_timezone, now=now) + if res is None: + raise OpenSSLObjectError( + f'The timespec "{input_string}" for {input_name} is invalid' + ) + return res # Absolute time for date_fmt, length in [ ( diff --git a/plugins/modules/acme_account.py b/plugins/modules/acme_account.py index a846c812..05ed2562 100644 --- a/plugins/modules/acme_account.py +++ b/plugins/modules/acme_account.py @@ -165,6 +165,7 @@ account_uri: """ import base64 +import typing as t from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( ACMEAccount, @@ -180,7 +181,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor ) -def main(): +def main() -> t.NoReturn: argument_spec = create_default_argspec() argument_spec.update_argspec( terms_agreed=dict(type="bool", default=False), @@ -204,24 +205,24 @@ def main(): ), ) argument_spec.update( - mutually_exclusive=(["new_account_key_src", "new_account_key_content"],), - required_if=( + mutually_exclusive=[("new_account_key_src", "new_account_key_content")], + required_if=[ # Make sure that for state == changed_key, one of # new_account_key_src and new_account_key_content are specified - [ + ( "state", "changed_key", ["new_account_key_src", "new_account_key_content"], True, - ], - ), + ), + ], ) module = argument_spec.create_ansible_module(supports_check_mode=True) backend = create_backend(module, True) if module.params["external_account_binding"]: # Make sure padding is there - key = module.params["external_account_binding"]["key"] + key: str = module.params["external_account_binding"]["key"] if len(key) % 4 != 0: key = key + ("=" * (4 - (len(key) % 4))) # Make sure key is Base64 encoded @@ -237,24 +238,25 @@ def main(): client = ACMEClient(module, backend) account = ACMEAccount(client) changed = False - state = module.params.get("state") - diff_before = {} - diff_after = {} + state: t.Literal["present", "absent", "changed_key"] = module.params["state"] + diff_before: dict[str, t.Any] = {} + diff_after: dict[str, t.Any] = {} if state == "absent": created, account_data = account.setup_account(allow_creation=False) if account_data: diff_before = dict(account_data) - diff_before["public_account_key"] = client.account_key_data["jwk"] + if client.account_key_data: + diff_before["public_account_key"] = client.account_key_data["jwk"] if created: raise AssertionError("Unwanted account creation") if account_data is not None: # Account is not yet deactivated if not module.check_mode: # Deactivate it - payload = {"status": "deactivated"} + deactivate_payload = {"status": "deactivated"} result, info = client.send_signed_request( - client.account_uri, - payload, + t.cast(str, client.account_uri), + deactivate_payload, error_msg="Failed to deactivate account", expected_status_codes=[200], ) @@ -278,13 +280,15 @@ def main(): diff_before = {} else: diff_before = dict(account_data) - diff_before["public_account_key"] = client.account_key_data["jwk"] + if client.account_key_data: + diff_before["public_account_key"] = client.account_key_data["jwk"] updated = False if not created: updated, account_data = account.update_account(account_data, contact) changed = created or updated diff_after = dict(account_data) - diff_after["public_account_key"] = client.account_key_data["jwk"] + if client.account_key_data: + diff_after["public_account_key"] = client.account_key_data["jwk"] elif state == "changed_key": # Parse new account key try: @@ -306,7 +310,8 @@ def main(): msg="Account does not exist or is deactivated." ) diff_before = dict(account_data) - diff_before["public_account_key"] = client.account_key_data["jwk"] + if client.account_key_data: + diff_before["public_account_key"] = client.account_key_data["jwk"] # Now we can start the account key rollover if not module.check_mode: # Compose inner signed message @@ -317,12 +322,12 @@ def main(): "jwk": new_key_data["jwk"], "url": url, } - payload = { + change_key_payload = { "account": client.account_uri, "newKey": new_key_data["jwk"], # specified in draft 12 and older "oldKey": client.account_jwk, # specified in draft 13 and newer } - data = client.sign_request(protected, payload, new_key_data) + data = client.sign_request(protected, change_key_payload, new_key_data) # Send request and verify result result, info = client.send_signed_request( url, @@ -332,8 +337,9 @@ def main(): ) if module._diff: client.account_key_data = new_key_data - client.account_jws_header["alg"] = new_key_data["alg"] - diff_after = account.get_account_data() + if client.account_jws_header: + client.account_jws_header["alg"] = new_key_data["alg"] + diff_after = account.get_account_data() or {} elif module._diff: # Kind of fake diff_after diff_after = dict(diff_before) diff --git a/plugins/modules/acme_account_info.py b/plugins/modules/acme_account_info.py index bf47cf0e..fd215cff 100644 --- a/plugins/modules/acme_account_info.py +++ b/plugins/modules/acme_account_info.py @@ -204,6 +204,8 @@ order_uris: version_added: 1.5.0 """ +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( ACMEAccount, ) @@ -220,48 +222,55 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import ) -def get_orders_list(module, client, orders_url): +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + +def get_orders_list( + module: AnsibleModule, client: ACMEClient, orders_url: str +) -> list[str]: """ - Retrieves orders list (handles pagination). + Retrieves order URL list (handles pagination). """ - orders = [] - while orders_url: + orders: list[str] = [] + next_orders_url: str | None = orders_url + while next_orders_url: # Get part of orders list res, info = client.get_request( - orders_url, parse_json_result=True, fail_on_error=True + next_orders_url, parse_json_result=True, fail_on_error=True ) if not res.get("orders"): if orders: module.warn( - f"When retrieving orders list part {orders_url}, got empty result list" + f"When retrieving orders list part {next_orders_url}, got empty result list" ) break # Add order URLs to result list orders.extend(res["orders"]) # Extract URL of next part of results list - new_orders_url = [] + new_orders_url: list[str | None] = [] - def f(link, relation): + def f(link: str, relation: str) -> None: if relation == "next": new_orders_url.append(link) process_links(info, f) new_orders_url.append(None) - previous_orders_url, orders_url = orders_url, new_orders_url.pop(0) - if orders_url == previous_orders_url: + previous_orders_url, next_orders_url = next_orders_url, new_orders_url.pop(0) + if next_orders_url == previous_orders_url: # Prevent infinite loop - orders_url = None + next_orders_url = None return orders -def get_order(client, order_url): +def get_order(client: ACMEClient, order_url: str) -> dict[str, t.Any]: """ Retrieve order data. """ return client.get_request(order_url, parse_json_result=True, fail_on_error=True)[0] -def main(): +def main() -> t.NoReturn: argument_spec = create_default_argspec() argument_spec.update_argspec( retrieve_orders=dict( @@ -282,16 +291,19 @@ def main(): ) if created: raise AssertionError("Unwanted account creation") - result = { + result: dict[str, t.Any] = { "changed": False, - "exists": client.account_uri is not None, - "account_uri": client.account_uri, + "exists": False, + "account_uri": None, } - if client.account_uri is not None: + if client.account_uri is not None and account_data: + result["account_uri"] = client.account_uri + result["exists"] = True # Make sure promised data is there if "contact" not in account_data: account_data["contact"] = [] - account_data["public_account_key"] = client.account_key_data["jwk"] + if client.account_key_data: + account_data["public_account_key"] = client.account_key_data["jwk"] result["account"] = account_data # Retrieve orders list if ( diff --git a/plugins/modules/acme_ari_info.py b/plugins/modules/acme_ari_info.py index 1860b8a5..eec21555 100644 --- a/plugins/modules/acme_ari_info.py +++ b/plugins/modules/acme_ari_info.py @@ -94,6 +94,8 @@ renewal_info: sample: '2024-04-29T01:17:10.236921+00:00' """ +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( ACMEClient, create_backend, @@ -104,15 +106,15 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor ) -def main(): +def main() -> t.NoReturn: argument_spec = create_default_argspec(with_account=False) argument_spec.update_argspec( certificate_path=dict(type="path"), certificate_content=dict(type="str"), ) argument_spec.update( - required_one_of=(["certificate_path", "certificate_content"],), - mutually_exclusive=(["certificate_path", "certificate_content"],), + required_one_of=[("certificate_path", "certificate_content")], + mutually_exclusive=[("certificate_path", "certificate_content")], ) module = argument_spec.create_ansible_module(supports_check_mode=True) backend = create_backend(module, True) diff --git a/plugins/modules/acme_certificate.py b/plugins/modules/acme_certificate.py index 68d3503b..0df02a14 100644 --- a/plugins/modules/acme_certificate.py +++ b/plugins/modules/acme_certificate.py @@ -562,6 +562,7 @@ all_chains: """ import os +import typing as t from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( ACMEAccount, @@ -592,6 +593,17 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from ansible_collections.community.crypto.plugins.module_utils.acme.backends import ( + CertificateInformation, + CryptoBackend, + ) + from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import ( + Authorization, + ) + + NO_CHALLENGE = "no challenge" @@ -602,7 +614,7 @@ class ACMECertificateClient: certificates. """ - def __init__(self, module, backend): + def __init__(self, module: AnsibleModule, backend: CryptoBackend): self.module = module self.version = module.params["acme_version"] self.challenge = module.params["challenge"] @@ -618,9 +630,9 @@ class ACMECertificateClient: self.account = ACMEAccount(self.client) self.directory = self.client.directory self.data = module.params["data"] - self.authorizations = None + self.authorizations: dict[str, Authorization] | None = None self.cert_days = -1 - self.order = None + self.order: Order | None = None self.order_uri = self.data.get("order_uri") if self.data else None self.all_chains = None self.select_chain_matcher = [] @@ -662,7 +674,6 @@ class ACMECertificateClient: contact.append("mailto:" + module.params["account_email"]) created, account_data = self.account.setup_account( contact, - agreement=module.params.get("agreement"), terms_agreed=module.params.get("terms_agreed"), allow_creation=modify_account, ) @@ -681,7 +692,7 @@ class ACMECertificateClient: csr_filename=self.csr, csr_content=self.csr_content ) - def is_first_step(self): + def is_first_step(self) -> bool: """ Return True if this is the first execution of this module, i.e. if a sufficient data object from a first run has not been provided. @@ -692,7 +703,7 @@ class ACMECertificateClient: # stored in self.order_uri by the constructor). return self.order_uri is None - def _get_cert_info_or_none(self): + def _get_cert_info_or_none(self) -> CertificateInformation | None: if self.module.params.get("dest"): filename = self.module.params["dest"] else: @@ -701,7 +712,7 @@ class ACMECertificateClient: return None return self.client.backend.get_cert_information(cert_filename=filename) - def start_challenges(self): + def start_challenges(self) -> None: """ Create new authorizations for all identifiers of the CSR, respectively start a new order for ACME v2. @@ -733,13 +744,16 @@ class ACMECertificateClient: self.authorizations.update(self.order.authorizations) self.changed = True - def get_challenges_data(self, first_step): + def get_challenges_data( + self, first_step: bool + ) -> tuple[dict[str, t.Any], dict[str, list[str]]]: """ Get challenge details for the chosen challenge type. Return a tuple of generic challenge details, and specialized DNS challenge details. """ - # Get general challenge data - data = {} + assert self.authorizations is not None + data: dict[str, t.Any] = {} + data_dns: dict[str, list[str]] = {} for type_identifier, authz in self.authorizations.items(): identifier_type, identifier = split_identifier(type_identifier) # Skip valid authentications: their challenges are already valid @@ -747,7 +761,9 @@ class ACMECertificateClient: if authz.status == "valid": continue # We drop the type from the key to preserve backwards compatibility - data[authz.identifier] = authz.get_challenge_data(self.client) + challenges = authz.get_challenge_data(self.client) + assert authz.identifier is not None + data[authz.identifier] = challenges if ( first_step and self.challenge is not None @@ -756,10 +772,7 @@ class ACMECertificateClient: raise ModuleFailException( f"Found no challenge of type '{self.challenge}' for identifier {type_identifier}!" ) - # Get DNS challenge data - data_dns = {} - if self.challenge == "dns-01": - for identifier, challenges in data.items(): + if self.challenge == "dns-01": if self.challenge in challenges: values = data_dns.get(challenges[self.challenge]["record"]) if values is None: @@ -768,7 +781,7 @@ class ACMECertificateClient: values.append(challenges[self.challenge]["resource_value"]) return data, data_dns - def finish_challenges(self): + def finish_challenges(self) -> None: """ Verify challenges for all identifiers of the CSR. """ @@ -777,6 +790,7 @@ class ACMECertificateClient: # Step 1: obtain challenge information # For ACME v2, we obtain the order object by fetching the # order URI, and extract the information from there. + assert self.order_uri is not None self.order = Order.from_url(self.client, self.order_uri) self.order.load_authorizations(self.client) self.authorizations.update(self.order.authorizations) @@ -799,7 +813,9 @@ class ACMECertificateClient: # Step 3: wait for authzs to validate wait_for_validation(authzs_to_wait_for, self.client) - def download_alternate_chains(self, cert): + def download_alternate_chains( + self, cert: CertificateChain + ) -> list[CertificateChain]: alternate_chains = [] for alternate in cert.alternates: try: @@ -812,7 +828,9 @@ class ACMECertificateClient: alternate_chains.append(alt_cert) return alternate_chains - def find_matching_chain(self, chains): + def find_matching_chain( + self, chains: t.Iterable[CertificateChain] + ) -> CertificateChain | None: for criterium_idx, matcher in enumerate(self.select_chain_matcher): for chain in chains: if matcher.match(chain): @@ -822,12 +840,13 @@ class ACMECertificateClient: return chain return None - def get_certificate(self): + def get_certificate(self) -> None: """ Request a new certificate and write it to the destination file. First verifies whether all authorizations are valid; if not, aborts with an error. """ + assert self.authorizations is not None for identifier_type, identifier in self.identifiers: authz = self.authorizations.get( normalize_combined_identifier( @@ -844,7 +863,9 @@ class ACMECertificateClient: module=self.module, ) + assert self.order is not None self.order.finalize(self.client, pem_to_der(self.csr, self.csr_content)) + assert self.order.certificate_uri is not None cert = CertificateChain.download(self.client, self.order.certificate_uri) if self.module.params["retrieve_all_alternates"] or self.select_chain_matcher: # Retrieve alternate chains @@ -887,12 +908,13 @@ class ACMECertificateClient: ): self.changed = True - def deactivate_authzs(self): + def deactivate_authzs(self) -> None: """ Deactivates all valid authz's. Does not raise exceptions. https://community.letsencrypt.org/t/authorization-deactivation/19860/2 https://tools.ietf.org/html/rfc8555#section-7.5.2 """ + assert self.authorizations is not None for authz in self.authorizations.values(): try: authz.deactivate(self.client) @@ -905,7 +927,7 @@ class ACMECertificateClient: ) -def main(): +def main() -> t.NoReturn: argument_spec = create_default_argspec(with_certificate=True) argument_spec.argument_spec["csr"]["aliases"] = ["src"] argument_spec.update_argspec( @@ -981,7 +1003,7 @@ def main(): else: client = ACMECertificateClient(module, backend) client.cert_days = cert_days - other = dict() + other: dict[str, t.Any] = {} is_first_step = client.is_first_step() if is_first_step: # First run: start challenges / start new order @@ -998,6 +1020,7 @@ def main(): client.deactivate_authzs() data, data_dns = client.get_challenges_data(first_step=is_first_step) auths = dict() + assert client.authorizations is not None for k, v in client.authorizations.items(): # Remove "type:" from key auths[v.identifier] = v.to_json() diff --git a/plugins/modules/acme_certificate_deactivate_authz.py b/plugins/modules/acme_certificate_deactivate_authz.py index 77fc26f3..842e8a45 100644 --- a/plugins/modules/acme_certificate_deactivate_authz.py +++ b/plugins/modules/acme_certificate_deactivate_authz.py @@ -51,6 +51,8 @@ EXAMPLES = r""" RETURN = """#""" +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( ACMEAccount, ) @@ -65,7 +67,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor from ansible_collections.community.crypto.plugins.module_utils.acme.orders import Order -def main(): +def main() -> t.NoReturn: argument_spec = create_default_argspec() argument_spec.update_argspec( order_uri=dict(type="str", required=True), diff --git a/plugins/modules/acme_certificate_order_create.py b/plugins/modules/acme_certificate_order_create.py index 3193c80f..02c248af 100644 --- a/plugins/modules/acme_certificate_order_create.py +++ b/plugins/modules/acme_certificate_order_create.py @@ -371,6 +371,8 @@ account_uri: type: str """ +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( create_backend, create_default_argspec, @@ -383,7 +385,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor ) -def main(): +def main() -> t.NoReturn: argument_spec = create_default_argspec(with_certificate=True) argument_spec.update_argspec( deactivate_authzs=dict(type="bool", default=True), diff --git a/plugins/modules/acme_certificate_order_finalize.py b/plugins/modules/acme_certificate_order_finalize.py index 0cad2fcb..c2cc8d78 100644 --- a/plugins/modules/acme_certificate_order_finalize.py +++ b/plugins/modules/acme_certificate_order_finalize.py @@ -317,6 +317,8 @@ selected_chain: returned: always """ +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( create_backend, create_default_argspec, @@ -329,7 +331,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor ) -def main(): +if t.TYPE_CHECKING: + from ansible_collections.community.crypto.plugins.module_utils.acme.certificates import ( + CertificateChain, + ) + + +def main() -> t.NoReturn: argument_spec = create_default_argspec(with_certificate=True) argument_spec.update_argspec( order_uri=dict(type="str", required=True), @@ -375,6 +383,7 @@ def main(): or module.params["retrieve_all_alternates"] ) changed = False + alternate_chains: list[CertificateChain] | None if order.status == "valid": # Step 2 and 3: download certificate(s) and chain(s) cert, alternate_chains = client.download_certificate( diff --git a/plugins/modules/acme_certificate_order_info.py b/plugins/modules/acme_certificate_order_info.py index bde4bd88..bd6f291a 100644 --- a/plugins/modules/acme_certificate_order_info.py +++ b/plugins/modules/acme_certificate_order_info.py @@ -357,6 +357,8 @@ authorizations_by_status: returned: always """ +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( create_backend, create_default_argspec, @@ -369,7 +371,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor ) -def main(): +def main() -> t.NoReturn: argument_spec = create_default_argspec(with_certificate=False) argument_spec.update_argspec( order_uri=dict(type="str", required=True), @@ -381,8 +383,8 @@ def main(): try: client = ACMECertificateClient(module, backend) order = client.load_order() - authorizations_by_identifier = dict() - authorizations_by_status = { + authorizations_by_identifier: dict[str, dict[str, t.Any]] = {} + authorizations_by_status: dict[str, list[str]] = { "pending": [], "invalid": [], "valid": [], @@ -392,7 +394,8 @@ def main(): } for identifier, authz in order.authorizations.items(): authorizations_by_identifier[identifier] = authz.to_json() - authorizations_by_status[authz.status].append(identifier) + if authz.status is not None: + authorizations_by_status[authz.status].append(identifier) module.exit_json( changed=False, account_uri=client.client.account_uri, diff --git a/plugins/modules/acme_certificate_order_validate.py b/plugins/modules/acme_certificate_order_validate.py index 22b07704..981a3dec 100644 --- a/plugins/modules/acme_certificate_order_validate.py +++ b/plugins/modules/acme_certificate_order_validate.py @@ -229,6 +229,8 @@ validating_challenges: returned: always """ +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( create_backend, create_default_argspec, @@ -241,7 +243,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor ) -def main(): +if t.TYPE_CHECKING: + from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import ( + Authorization, + ) + + +def main() -> t.NoReturn: argument_spec = create_default_argspec(with_certificate=False) argument_spec.update_argspec( order_uri=dict(type="str", required=True), @@ -271,10 +279,12 @@ def main(): missing_challenge_authzs = [k for k, v in challenges.items() if v is None] if missing_challenge_authzs: - missing_challenge_authzs = ", ".join(sorted(missing_challenge_authzs)) + missing_challenge_authzs_str = ", ".join( + sorted(missing_challenge_authzs) + ) raise ModuleFailException( "The challenge parameter must be supplied if there are pending authorizations." - f" The following authorizations are pending: {missing_challenge_authzs}" + f" The following authorizations are pending: {missing_challenge_authzs_str}" ) bad_challenge_authzs = [ @@ -293,11 +303,13 @@ def main(): f"The following authorizations do not support the selected challenges: {authz_challenges_pairs}" ) + def is_pending(authz: Authorization) -> bool: + challenge_name = challenges[authz.combined_identifier] + challenge_obj = authz.find_challenge(challenge_name) + return challenge_obj is not None and challenge_obj.status == "pending" + really_pending_authzs = [ - authz - for authz in pending_authzs - if authz.find_challenge(challenges[authz.combined_identifier]).status - == "pending" + authz for authz in pending_authzs if is_pending(authz) ] # Step 4: validate pending authorizations @@ -320,7 +332,7 @@ def main(): identifier_type=authz.identifier_type, authz_url=authz.url, challenge_type=challenge_type, - challenge_url=challenge.url, + challenge_url=challenge.url if challenge else None, ) for authz, challenge_type, challenge in authzs_with_challenges_to_wait_for ], diff --git a/plugins/modules/acme_certificate_renewal_info.py b/plugins/modules/acme_certificate_renewal_info.py index fc46f2f5..cdf838e3 100644 --- a/plugins/modules/acme_certificate_renewal_info.py +++ b/plugins/modules/acme_certificate_renewal_info.py @@ -160,6 +160,7 @@ cert_id: import os import random +import typing as t from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( ACMEClient, @@ -175,7 +176,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import ) -def main(): +def main() -> t.NoReturn: argument_spec = create_default_argspec(with_account=False) argument_spec.update_argspec( certificate_path=dict(type="path"), @@ -190,7 +191,7 @@ def main(): treat_parsing_error_as_non_existing=dict(type="bool", default=False), ) argument_spec.update( - mutually_exclusive=(["certificate_path", "certificate_content"],), + mutually_exclusive=[("certificate_path", "certificate_content")], ) module = argument_spec.create_ansible_module(supports_check_mode=True) backend = create_backend(module, True) @@ -203,7 +204,7 @@ def main(): supports_ari=False, ) - def complete(should_renew, **kwargs): + def complete(should_renew: bool, **kwargs) -> t.NoReturn: result["should_renew"] = should_renew result.update(kwargs) module.exit_json(**result) diff --git a/plugins/modules/acme_certificate_revoke.py b/plugins/modules/acme_certificate_revoke.py index 5f9d729e..95428952 100644 --- a/plugins/modules/acme_certificate_revoke.py +++ b/plugins/modules/acme_certificate_revoke.py @@ -110,6 +110,8 @@ EXAMPLES = r""" RETURN = """#""" +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( ACMEAccount, ) @@ -129,7 +131,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import ) -def main(): +def main() -> t.NoReturn: argument_spec = create_default_argspec(require_account_key=False) argument_spec.update_argspec( private_key_src=dict(type="path"), @@ -139,22 +141,22 @@ def main(): revoke_reason=dict(type="int"), ) argument_spec.update( - required_one_of=( - [ + required_one_of=[ + ( "account_key_src", "account_key_content", "private_key_src", "private_key_content", - ], - ), - mutually_exclusive=( - [ + ), + ], + mutually_exclusive=[ + ( "account_key_src", "account_key_content", "private_key_src", "private_key_content", - ], - ), + ), + ], ) module = argument_spec.create_ansible_module() backend = create_backend(module, False) @@ -164,9 +166,9 @@ def main(): account = ACMEAccount(client) # Load certificate certificate = pem_to_der(module.params.get("certificate")) - certificate = nopad_b64(certificate) + certificate_b64 = nopad_b64(certificate) # Construct payload - payload = {"certificate": certificate} + payload = {"certificate": certificate_b64} if module.params.get("revoke_reason") is not None: payload["reason"] = module.params.get("revoke_reason") endpoint = client.directory["revokeCert"] diff --git a/plugins/modules/acme_challenge_cert_helper.py b/plugins/modules/acme_challenge_cert_helper.py index 5c303bb9..8ad70c2f 100644 --- a/plugins/modules/acme_challenge_cert_helper.py +++ b/plugins/modules/acme_challenge_cert_helper.py @@ -149,6 +149,7 @@ regular_certificate: import base64 import datetime import ipaddress +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_bytes, to_text @@ -173,10 +174,13 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( try: import cryptography import cryptography.hazmat.backends + import cryptography.hazmat.primitives.asymmetric.dh import cryptography.hazmat.primitives.asymmetric.ec import cryptography.hazmat.primitives.asymmetric.padding import cryptography.hazmat.primitives.asymmetric.rsa import cryptography.hazmat.primitives.asymmetric.utils + import cryptography.hazmat.primitives.asymmetric.x448 + import cryptography.hazmat.primitives.asymmetric.x25519 import cryptography.hazmat.primitives.hashes import cryptography.hazmat.primitives.serialization import cryptography.x509 @@ -186,7 +190,7 @@ except ImportError: # Convert byte string to ASN1 encoded octet string -def encode_octet_string(octet_string): +def encode_octet_string(octet_string: bytes) -> bytes: if len(octet_string) >= 128: raise ModuleFailException( "Cannot handle octet strings with more than 128 bytes" @@ -194,7 +198,7 @@ def encode_octet_string(octet_string): return bytes([0x4, len(octet_string)]) + octet_string -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( challenge=dict(type="str", required=True, choices=["tls-alpn-01"]), @@ -213,16 +217,16 @@ def main(): try: # Get parameters - challenge = module.params["challenge"] - challenge_data = module.params["challenge_data"] + challenge: t.Literal["tls-alpn-01"] = module.params["challenge"] + challenge_data: dict[str, t.Any] = module.params["challenge_data"] # Get hold of private key - private_key_content = module.params.get("private_key_content") - private_key_passphrase = module.params.get("private_key_passphrase") - if private_key_content is None: + private_key_content_str: str | None = module.params["private_key_content"] + private_key_passphrase: str | None = module.params["private_key_passphrase"] + if private_key_content_str is None: private_key_content = read_file(module.params["private_key_src"]) else: - private_key_content = to_bytes(private_key_content) + private_key_content = to_bytes(private_key_content_str) try: private_key = ( cryptography.hazmat.primitives.serialization.load_pem_private_key( @@ -236,6 +240,17 @@ def main(): ) except Exception as e: raise ModuleFailException(f"Error while loading private key: {e}") + if isinstance( + private_key, + ( + cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey, + cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey, + cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey, + ), + ): + raise ModuleFailException( + f"Cannot use private key type {type(private_key)}" + ) # Some common attributes domain = to_text(challenge_data["resource"]) @@ -246,6 +261,7 @@ def main(): now = get_now_datetime(with_timezone=CRYPTOGRAPHY_TIMEZONE) not_valid_before = now not_valid_after = now + datetime.timedelta(days=10) + san: cryptography.x509.GeneralName if identifier_type == "dns": san = cryptography.x509.DNSName(identifier) elif identifier_type == "ip": diff --git a/plugins/modules/acme_inspect.py b/plugins/modules/acme_inspect.py index 41cb8cd9..ead2c5ab 100644 --- a/plugins/modules/acme_inspect.py +++ b/plugins/modules/acme_inspect.py @@ -223,6 +223,8 @@ output_json: - '...' """ +import typing as t + from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( ACMEClient, @@ -235,7 +237,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor ) -def main(): +def main() -> t.NoReturn: argument_spec = create_default_argspec(require_account_key=False) argument_spec.update_argspec( url=dict(type="str"), @@ -246,17 +248,17 @@ def main(): fail_on_acme_error=dict(type="bool", default=True), ) argument_spec.update( - required_if=( - ["method", "get", ["url"]], - ["method", "post", ["url", "content"]], - ["method", "get", ["account_key_src", "account_key_content"], True], - ["method", "post", ["account_key_src", "account_key_content"], True], - ), + required_if=[ + ("method", "get", ["url"]), + ("method", "post", ["url", "content"]), + ("method", "get", ["account_key_src", "account_key_content"], True), + ("method", "post", ["account_key_src", "account_key_content"], True), + ], ) module = argument_spec.create_ansible_module() backend = create_backend(module, False) - result = dict() + result: dict[str, t.Any] = {} changed = False try: # Get hold of ACMEClient and ACMEAccount objects (includes directory) diff --git a/plugins/modules/certificate_complete_chain.py b/plugins/modules/certificate_complete_chain.py index ee981e7e..3372ec17 100644 --- a/plugins/modules/certificate_complete_chain.py +++ b/plugins/modules/certificate_complete_chain.py @@ -121,6 +121,7 @@ complete_chain: """ import os +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_bytes @@ -153,14 +154,18 @@ class Certificate: Stores PEM with parsed certificate. """ - def __init__(self, pem, cert): + def __init__(self, pem: str, cert: cryptography.x509.Certificate) -> None: if not (pem.endswith("\n") or pem.endswith("\r")): pem = pem + "\n" self.pem = pem self.cert = cert -def is_parent(module, cert, potential_parent): +def is_parent( + module: AnsibleModule, + cert: Certificate, + potential_parent: Certificate, +) -> bool: """ Tests whether the given certificate has been issued by the potential parent certificate. """ @@ -173,6 +178,10 @@ def is_parent(module, cert, potential_parent): if isinstance( public_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey ): + if cert.cert.signature_hash_algorithm is None: + raise AssertionError( + "signature_hash_algorithm should be present for RSA certificates" + ) public_key.verify( cert.cert.signature, cert.cert.tbs_certificate_bytes, @@ -183,6 +192,10 @@ def is_parent(module, cert, potential_parent): public_key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey, ): + if cert.cert.signature_hash_algorithm is None: + raise AssertionError( + "signature_hash_algorithm should be present for EC certificates" + ) public_key.verify( cert.cert.signature, cert.cert.tbs_certificate_bytes, @@ -213,11 +226,16 @@ def is_parent(module, cert, potential_parent): module.fail_json(msg=f"Unknown error on signature validation: {e}") -def parse_PEM_list(module, text, source, fail_on_error=True): +def parse_PEM_list( + module: AnsibleModule, + text: str, + source: str | os.PathLike, + fail_on_error: bool = True, +) -> list[Certificate]: """ Parse concatenated PEM certificates. Return list of ``Certificate`` objects. """ - result = [] + result: list[Certificate] = [] for cert_pem in split_pem_list(text): # Try to load PEM certificate try: @@ -232,7 +250,9 @@ def parse_PEM_list(module, text, source, fail_on_error=True): return result -def load_PEM_list(module, path, fail_on_error=True): +def load_PEM_list( + module: AnsibleModule, path: str | os.PathLike, fail_on_error: bool = True +) -> list[Certificate]: """ Load concatenated PEM certificates from file. Return list of ``Certificate`` objects. """ @@ -258,13 +278,15 @@ class CertificateSet: Stores a set of certificates. Allows to search for parent (issuer of a certificate). """ - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module - self.certificates = set() - self.certificates_by_issuer = dict() - self.certificate_by_cert = dict() + self.certificates: set[Certificate] = set() + self.certificates_by_issuer: dict[cryptography.x509.Name, list[Certificate]] = ( + {} + ) + self.certificate_by_cert: dict[cryptography.x509.Certificate, Certificate] = {} - def _load_file(self, path): + def _load_file(self, path: str | os.PathLike) -> None: certs = load_PEM_list(self.module, path, fail_on_error=False) for cert in certs: self.certificates.add(cert) @@ -273,7 +295,7 @@ class CertificateSet: self.certificates_by_issuer[cert.cert.subject].append(cert) self.certificate_by_cert[cert.cert] = cert - def load(self, path): + def load(self, path: str | os.PathLike) -> None: """ Load lists of PEM certificates from a file or a directory. """ @@ -285,7 +307,7 @@ class CertificateSet: else: self._load_file(b_path) - def find_parent(self, cert): + def find_parent(self, cert: Certificate) -> Certificate | None: """ Search for the parent (issuer) of a certificate. Return ``None`` if none was found. """ @@ -296,14 +318,18 @@ class CertificateSet: return None -def format_cert(cert): +def format_cert(cert: Certificate) -> str: """ Return human readable representation of certificate for error messages. """ return str(cert.cert) -def check_cycle(module, occured_certificates, next): +def check_cycle( + module: AnsibleModule, + occured_certificates: set[cryptography.x509.Certificate], + next: Certificate, +) -> None: """ Make sure that next is not in occured_certificates so far, and add it. """ @@ -313,7 +339,7 @@ def check_cycle(module, occured_certificates, next): occured_certificates.add(next_cert) -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( input_chain=dict(type="str", required=True), @@ -354,10 +380,10 @@ def main(): roots.load(path) # Try to complete chain - current = chain[-1] + current: Certificate | None = chain[-1] completed = [] occured_certificates = set([cert.cert for cert in chain]) - if current.cert in roots.certificate_by_cert: + if current and current.cert in roots.certificate_by_cert: # Do not try to complete the chain when it is already ending with a root certificate current = None while current: diff --git a/plugins/modules/crypto_info.py b/plugins/modules/crypto_info.py index dc6f10d6..eaedcd27 100644 --- a/plugins/modules/crypto_info.py +++ b/plugins/modules/crypto_info.py @@ -152,10 +152,13 @@ openssl: """ import traceback +import typing as t from ansible.module_utils.basic import AnsibleModule +CRYPTOGRAPHY_VERSION: str | None +CRYPTOGRAPHY_IMP_ERR: str | None try: import cryptography from cryptography.exceptions import UnsupportedAlgorithm @@ -165,10 +168,10 @@ try: # only got added in 0.2, so let's guard the import from cryptography.exceptions import InternalError as CryptographyInternalError except ImportError: - CryptographyInternalError = Exception + CryptographyInternalError = Exception # type: ignore except ImportError: - UnsupportedAlgorithm = Exception - CryptographyInternalError = Exception + UnsupportedAlgorithm = Exception # type: ignore + CryptographyInternalError = Exception # type: ignore HAS_CRYPTOGRAPHY = False CRYPTOGRAPHY_VERSION = None CRYPTOGRAPHY_IMP_ERR = traceback.format_exc() @@ -201,8 +204,8 @@ CURVES = ( ) -def add_crypto_information(module): - result = {} +def add_crypto_information(module: AnsibleModule) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} result["python_cryptography_installed"] = HAS_CRYPTOGRAPHY if not HAS_CRYPTOGRAPHY: result["python_cryptography_import_error"] = CRYPTOGRAPHY_IMP_ERR @@ -397,9 +400,9 @@ def add_crypto_information(module): return result -def add_openssl_information(module): +def add_openssl_information(module: AnsibleModule) -> dict[str, t.Any]: openssl_binary = module.get_bin_path("openssl") - result = { + result: dict[str, t.Any] = { "openssl_present": openssl_binary is not None, } if openssl_binary is None: @@ -426,9 +429,9 @@ INFO_FUNCTIONS = ( ) -def main(): +def main() -> t.NoReturn: module = AnsibleModule(argument_spec={}, supports_check_mode=True) - result = {} + result: dict[str, t.Any] = {} for fn in INFO_FUNCTIONS: result.update(fn(module)) module.exit_json(**result) diff --git a/plugins/modules/ecs_certificate.py b/plugins/modules/ecs_certificate.py index 009fb44b..dfa7bc24 100644 --- a/plugins/modules/ecs_certificate.py +++ b/plugins/modules/ecs_certificate.py @@ -550,6 +550,7 @@ import datetime import os import re import time +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_bytes @@ -572,7 +573,7 @@ from ansible_collections.community.crypto.plugins.module_utils.io import write_f MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION -def validate_cert_expiry(cert_expiry): +def validate_cert_expiry(cert_expiry: str) -> bool: search_string_partial = re.compile( r"^([0-9]+)-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01])\Z" ) @@ -587,7 +588,7 @@ def validate_cert_expiry(cert_expiry): return False -def calculate_cert_days(expires_after): +def calculate_cert_days(expires_after: str | None) -> int: cert_days = 0 if expires_after: expires_after_datetime = datetime.datetime.strptime( @@ -600,7 +601,9 @@ def calculate_cert_days(expires_after): # Populate the value of body[dict_param_name] with the JSON equivalent of # module parameter of param_name if that parameter is present, otherwise leave field # out of resulting dict -def convert_module_param_to_json_bool(module, dict_param_name, param_name): +def convert_module_param_to_json_bool( + module: AnsibleModule, dict_param_name: str, param_name: str +) -> dict[str, str]: body = {} if module.params[param_name] is not None: if module.params[param_name]: @@ -886,7 +889,7 @@ class EcsCertificate: return result -def custom_fields_spec(): +def custom_fields_spec() -> dict[str, dict[str, str]]: return dict( text1=dict(type="str"), text2=dict(type="str"), @@ -926,7 +929,7 @@ def custom_fields_spec(): ) -def ecs_certificate_argument_spec(): +def ecs_certificate_argument_spec() -> dict[str, dict[str, t.Any]]: return dict( backup=dict(type="bool", default=False), force=dict(type="bool", default=False), @@ -979,7 +982,7 @@ def ecs_certificate_argument_spec(): ) -def main(): +def main() -> t.NoReturn: ecs_argument_spec = ecs_client_argument_spec() ecs_argument_spec.update(ecs_certificate_argument_spec()) module = AnsibleModule( diff --git a/plugins/modules/ecs_domain.py b/plugins/modules/ecs_domain.py index ff014d2f..c5761051 100644 --- a/plugins/modules/ecs_domain.py +++ b/plugins/modules/ecs_domain.py @@ -218,6 +218,7 @@ ev_days_remaining: import datetime import time +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible_collections.community.crypto.plugins.module_utils.ecs.api import ( @@ -228,7 +229,7 @@ from ansible_collections.community.crypto.plugins.module_utils.ecs.api import ( ) -def calculate_days_remaining(expiry_date): +def calculate_days_remaining(expiry_date: str | None) -> int | None: days_remaining = None if expiry_date: expiry_datetime = datetime.datetime.strptime(expiry_date, "%Y-%m-%dT%H:%M:%SZ") @@ -403,8 +404,8 @@ class EcsDomain: msg=f"Failed to request domain validation from Entrust (ECS) {e.message}" ) - def dump(self): - result = { + def dump(self) -> dict[str, t.Any]: + result: dict[str, t.Any] = { "changed": self.changed, "client_id": self.client_id, "domain_status": self.domain_status, @@ -436,7 +437,7 @@ class EcsDomain: return result -def ecs_domain_argument_spec(): +def ecs_domain_argument_spec() -> dict[str, dict[str, t.Any]]: return dict( client_id=dict(type="int", default=1), domain_name=dict(type="str", required=True), @@ -447,7 +448,7 @@ def ecs_domain_argument_spec(): ) -def main(): +def main() -> t.NoReturn: ecs_argument_spec = ecs_client_argument_spec() ecs_argument_spec.update(ecs_domain_argument_spec()) module = AnsibleModule( diff --git a/plugins/modules/get_certificate.py b/plugins/modules/get_certificate.py index 98f827d1..6048b5a7 100644 --- a/plugins/modules/get_certificate.py +++ b/plugins/modules/get_certificate.py @@ -268,6 +268,7 @@ import atexit import base64 import ssl import sys +import typing as t from os.path import isfile from socket import create_connection, setdefaulttimeout, socket from ssl import ( @@ -305,7 +306,7 @@ except ImportError: pass -def send_starttls_packet(sock, server_type): +def send_starttls_packet(sock: socket, server_type: t.Literal["mysql"]) -> None: if server_type == "mysql": ssl_request_packet = ( b"\x20\x00\x00\x01\x85\xae\x7f\x00" @@ -321,7 +322,7 @@ def send_starttls_packet(sock, server_type): sock.send(ssl_request_packet) -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( ca_cert=dict(type="path"), @@ -342,18 +343,18 @@ def main(): ), ) - ca_cert = module.params.get("ca_cert") - host = module.params.get("host") - port = module.params.get("port") - proxy_host = module.params.get("proxy_host") - proxy_port = module.params.get("proxy_port") - timeout = module.params.get("timeout") - server_name = module.params.get("server_name") - start_tls_server_type = module.params.get("starttls") - ciphers = module.params.get("ciphers") - asn1_base64 = module.params["asn1_base64"] - tls_ctx_options = module.params["tls_ctx_options"] - get_certificate_chain = module.params["get_certificate_chain"] + ca_cert: str | None = module.params.get("ca_cert") + host: str = module.params.get("host") + port: int = module.params.get("port") + proxy_host: str | None = module.params.get("proxy_host") + proxy_port: int | None = module.params.get("proxy_port") + timeout: int = module.params.get("timeout") + server_name: str | None = module.params.get("server_name") + start_tls_server_type: t.Literal["mysql"] | None = module.params.get("starttls") + ciphers: list[str] | None = module.params.get("ciphers") + asn1_base64: bool = module.params["asn1_base64"] + tls_ctx_options: list[str | bytes | int] | None = module.params["tls_ctx_options"] + get_certificate_chain: bool = module.params["get_certificate_chain"] if get_certificate_chain and sys.version_info < (3, 10): module.fail_json( @@ -365,9 +366,9 @@ def main(): module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) - result = dict( - changed=False, - ) + result: dict[str, t.Any] = { + "changed": False, + } if timeout: setdefaulttimeout(timeout) @@ -409,7 +410,7 @@ def main(): if tls_ctx_options is not None: # Clear default ctx options - ctx.options = 0 + ctx.options = 0 # type: ignore # For each item in the tls_ctx_options list for tls_ctx_option in tls_ctx_options: @@ -450,8 +451,10 @@ def main(): ) tls_sock = ctx.wrap_socket(sock, server_hostname=server_name or host) - cert = tls_sock.getpeercert(True) - cert = DER_cert_to_PEM_cert(cert) + cert_der = tls_sock.getpeercert(True) + if cert_der is None: + raise Exception("Unexpected error: no peer certificate has been returned") + cert: str = DER_cert_to_PEM_cert(cert_der) if get_certificate_chain: if sys.version_info < (3, 13): @@ -474,7 +477,7 @@ def main(): # Python 3.13 do not return lists of byte strings, but lists of _ssl.Certificate objects. This is going to # be fixed by https://github.com/python/cpython/pull/118669. For now we convert the certificates ourselves # if they are not byte strings to work around this. - def _convert_chain(chain): + def _convert_chain(chain: list[bytes]) -> list[bytes]: return [ ( c @@ -514,13 +517,13 @@ def main(): result["extensions"] = [] for dotted_number, entry in cryptography_get_extensions_from_cert(x509).items(): oid = cryptography.x509.oid.ObjectIdentifier(dotted_number) - ext = { + ext: dict[str, t.Any] = { "critical": entry["critical"], "asn1_data": entry["value"], "name": cryptography_oid_to_name(oid, short=True), } if not asn1_base64: - ext["asn1_data"] = base64.b64decode(ext["asn1_data"]) + ext["asn1_data"] = base64.b64decode(entry["value"]) # type: ignore result["extensions"].append(ext) result["issuer"] = {} diff --git a/plugins/modules/luks_device.py b/plugins/modules/luks_device.py index e70ba490..f8dcb7f7 100644 --- a/plugins/modules/luks_device.py +++ b/plugins/modules/luks_device.py @@ -420,16 +420,13 @@ name: import os import re import stat +import typing as t from base64 import b64decode from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_bytes, to_native -RETURN_CODE = 0 -STDOUT = 1 -STDERR = 2 - # used to get out of lsblk output in format 'crypt ' # regex takes care of any possible blank characters LUKS_NAME_REGEX = re.compile(r"^crypt\s+([^\s]*)\s*$") @@ -456,7 +453,7 @@ LUKS2_HEADER_OFFSETS = [ LUKS2_HEADER2 = b"SKUL\xba\xbe" -def wipe_luks_headers(device): +def wipe_luks_headers(device: str) -> None: wipe_offsets = [] with open(device, "rb") as f: # f.seek(0) @@ -478,12 +475,12 @@ def wipe_luks_headers(device): class Handler: - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self._module = module self._lsblk_bin = self._module.get_bin_path("lsblk", True) self._passphrase_encoding = module.params["passphrase_encoding"] - def get_passphrase_from_module_params(self, parameter_name): + def get_passphrase_from_module_params(self, parameter_name: str) -> bytes | None: passphrase = self._module.params[parameter_name] if passphrase is None: return None @@ -496,88 +493,91 @@ class Handler: f"Error while base64-decoding '{parameter_name}': {exc}" ) - def _run_command(self, command, data=None): + def _run_command( + self, command: list[str], data: bytes | None = None + ) -> tuple[int, str, str]: return self._module.run_command(command, data=data, binary_data=True) - def get_device_by_uuid(self, uuid): + def get_device_by_uuid(self, uuid: str | None) -> str | None: """Returns the device that holds UUID passed by user""" self._blkid_bin = self._module.get_bin_path("blkid", True) - uuid = self._module.params["uuid"] if uuid is None: return None - result = self._run_command([self._blkid_bin, "--uuid", uuid]) - if result[RETURN_CODE] != 0: + rc, stdout, dummy = self._run_command([self._blkid_bin, "--uuid", uuid]) + if rc != 0: return None - return result[STDOUT].strip() + return stdout.strip() - def get_device_by_label(self, label): + def get_device_by_label(self, label: str) -> str | None: """Returns the device that holds label passed by user""" self._blkid_bin = self._module.get_bin_path("blkid", True) label = self._module.params["label"] if label is None: return None - result = self._run_command([self._blkid_bin, "--label", label]) - if result[RETURN_CODE] != 0: + rc, stdout, dummy = self._run_command([self._blkid_bin, "--label", label]) + if rc != 0: return None - return result[STDOUT].strip() + return stdout.strip() - def generate_luks_name(self, device): + def generate_luks_name(self, device: str) -> str: """Generate name for luks based on device UUID ('luks-'). Raises ValueError when obtaining of UUID fails. """ - result = self._run_command([self._lsblk_bin, "-n", device, "-o", "UUID"]) + rc, stdout, stderr = self._run_command( + [self._lsblk_bin, "-n", device, "-o", "UUID"] + ) - if result[RETURN_CODE] != 0: - raise ValueError( - f"Error while generating LUKS name for {device}: {result[STDERR]}" - ) - dev_uuid = result[STDOUT].strip() + if rc != 0: + raise ValueError(f"Error while generating LUKS name for {device}: {stderr}") + dev_uuid = stdout.strip() return f"luks-{dev_uuid}" class CryptHandler(Handler): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(CryptHandler, self).__init__(module) self._cryptsetup_bin = self._module.get_bin_path("cryptsetup", True) - def get_container_name_by_device(self, device): + def get_container_name_by_device(self, device: str) -> str | None: """obtain LUKS container name based on the device where it is located return None if not found raise ValueError if lsblk command fails """ - result = self._run_command([self._lsblk_bin, device, "-nlo", "type,name"]) - if result[RETURN_CODE] != 0: - raise ValueError( - f"Error while obtaining LUKS name for {device}: {result[STDERR]}" - ) + rc, stdout, stderr = self._run_command( + [self._lsblk_bin, device, "-nlo", "type,name"] + ) + if rc != 0: + raise ValueError(f"Error while obtaining LUKS name for {device}: {stderr}") - for line in result[STDOUT].splitlines(False): + for line in stdout.splitlines(False): m = LUKS_NAME_REGEX.match(line) if m: return m.group(1) return None - def get_container_device_by_name(self, name): + def get_container_device_by_name(self, name: str) -> str | None: """obtain device name based on the LUKS container name return None if not found raise ValueError if lsblk command fails """ # apparently each device can have only one LUKS container on it - result = self._run_command([self._cryptsetup_bin, "status", name]) - if result[RETURN_CODE] != 0: + rc, stdout, dummy = self._run_command([self._cryptsetup_bin, "status", name]) + if rc != 0: return None - m = LUKS_DEVICE_REGEX.search(result[STDOUT]) + m = LUKS_DEVICE_REGEX.search(stdout) + if not m: + return None device = m.group(1) return device - def is_luks(self, device): + def is_luks(self, device: str) -> bool: """check if the LUKS container does exist""" - result = self._run_command([self._cryptsetup_bin, "isLuks", device]) - return result[RETURN_CODE] == 0 + rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "isLuks", device]) + return rc == 0 - def get_luks_type(self, device): + def get_luks_type(self, device: str) -> t.Literal["luks1", "luks2"] | None: """get the luks type of a device""" if self.is_luks(device): with open(device, "rb") as f: @@ -589,16 +589,18 @@ class CryptHandler(Handler): return "luks1" return None - def is_luks_slot_set(self, device, keyslot): + def is_luks_slot_set(self, device: str, keyslot: int) -> bool: """check if a keyslot is set""" - result = self._run_command([self._cryptsetup_bin, "luksDump", device]) - if result[RETURN_CODE] != 0: + rc, stdout, dummy = self._run_command( + [self._cryptsetup_bin, "luksDump", device] + ) + if rc != 0: raise ValueError(f"Error while dumping LUKS header from {device}") - result_luks1 = f"Key Slot {keyslot}: ENABLED" in result[STDOUT] - result_luks2 = f" {keyslot}: luks2" in result[STDOUT] + result_luks1 = f"Key Slot {keyslot}: ENABLED" in stdout + result_luks2 = f" {keyslot}: luks2" in stdout return result_luks1 or result_luks2 - def _add_pbkdf_options(self, options, pbkdf): + def _add_pbkdf_options(self, options: list[str], pbkdf: dict[str, t.Any]) -> None: if pbkdf["iteration_time"] is not None: options.extend(["--iter-time", str(int(pbkdf["iteration_time"] * 1000))]) if pbkdf["iteration_count"] is not None: @@ -612,16 +614,16 @@ class CryptHandler(Handler): def run_luks_create( self, - device, - keyfile, - passphrase, - keyslot, - keysize, - cipher, - hash_, - sector_size, - pbkdf, - ): + device: str, + keyfile: str | None, + passphrase: bytes | None, + keyslot: int | None, + keysize: int | None, + cipher: str | None, + hash_: str | None, + sector_size: str | None, + pbkdf: dict[str, t.Any] | None, + ) -> None: # create a new luks container; use batch mode to auto confirm luks_type = self._module.params["type"] label = self._module.params["label"] @@ -653,23 +655,23 @@ class CryptHandler(Handler): else: args.append("-") - result = self._run_command(args, data=passphrase) - if result[RETURN_CODE] != 0: - raise ValueError(f"Error while creating LUKS on {device}: {result[STDERR]}") + rc, dummy, stderr = self._run_command(args, data=passphrase) + if rc != 0: + raise ValueError(f"Error while creating LUKS on {device}: {stderr}") def run_luks_open( self, - device, - keyfile, - passphrase, - perf_same_cpu_crypt, - perf_submit_from_crypt_cpus, - perf_no_read_workqueue, - perf_no_write_workqueue, - persistent, - allow_discards, - name, - ): + device: str, + keyfile: str | None, + passphrase: bytes | None, + perf_same_cpu_crypt: bool, + perf_submit_from_crypt_cpus: bool, + perf_no_read_workqueue: bool, + perf_no_write_workqueue: bool, + persistent: bool, + allow_discards: bool, + name: str, + ) -> None: args = [self._cryptsetup_bin] if keyfile: args.extend(["--key-file", keyfile]) @@ -689,27 +691,27 @@ class CryptHandler(Handler): args.extend(["--allow-discards"]) args.extend(["open", "--type", "luks", device, name]) - result = self._run_command(args, data=passphrase) - if result[RETURN_CODE] != 0: + rc, dummy, stderr = self._run_command(args, data=passphrase) + if rc != 0: raise ValueError( - f"Error while opening LUKS container on {device}: {result[STDERR]}" + f"Error while opening LUKS container on {device}: {stderr}" ) - def run_luks_close(self, name): - result = self._run_command([self._cryptsetup_bin, "close", name]) - if result[RETURN_CODE] != 0: + def run_luks_close(self, name: str) -> None: + rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "close", name]) + if rc != 0: raise ValueError(f"Error while closing LUKS container {name}") - def run_luks_remove(self, device): + def run_luks_remove(self, device: str) -> None: wipefs_bin = self._module.get_bin_path("wipefs", True) name = self.get_container_name_by_device(device) if name is not None: self.run_luks_close(name) - result = self._run_command([wipefs_bin, "--all", device]) - if result[RETURN_CODE] != 0: + rc, dummy, stderr = self._run_command([wipefs_bin, "--all", device]) + if rc != 0: raise ValueError( - f"Error while wiping LUKS container signatures for {device}: {result[STDERR]}" + f"Error while wiping LUKS container signatures for {device}: {stderr}" ) # For LUKS2, sometimes both `cryptsetup erase` and `wipefs` do **not** @@ -724,14 +726,14 @@ class CryptHandler(Handler): def run_luks_add_key( self, - device, - keyfile, - passphrase, - new_keyfile, - new_passphrase, - new_keyslot, - pbkdf, - ): + device: str, + keyfile: str | None, + passphrase: bytes | None, + new_keyfile: str | None, + new_passphrase: bytes | None, + new_keyslot: int | None, + pbkdf: dict[str, t.Any] | None, + ) -> None: """Add new key from a keyfile or passphrase to given 'device'; authentication done using 'keyfile' or 'passphrase'. Raises ValueError when command fails. @@ -746,36 +748,47 @@ class CryptHandler(Handler): if keyfile: args.extend(["--key-file", keyfile]) - else: + elif passphrase is not None: args.extend(["--key-file", "-", "--keyfile-size", str(len(passphrase))]) data.append(passphrase) + else: + raise ValueError("Need passphrase or keyfile") if new_keyfile: args.append(new_keyfile) - else: + elif new_passphrase is not None: args.append("-") data.append(new_passphrase) + else: + raise ValueError("Need new passphrase or new keyfile") - result = self._run_command(args, data=b"".join(data) or None) - if result[RETURN_CODE] != 0: + rc, dummy, stderr = self._run_command(args, data=b"".join(data) or None) + if rc != 0: raise ValueError( - f"Error while adding new LUKS keyslot to {device}: {result[STDERR]}" + f"Error while adding new LUKS keyslot to {device}: {stderr}" ) def run_luks_remove_key( - self, device, keyfile, passphrase, keyslot, force_remove_last_key=False - ): + self, + device: str, + keyfile: str | None, + passphrase: bytes | None, + keyslot: int | None, + force_remove_last_key: bool = False, + ) -> None: """Remove key from given device Raises ValueError when command fails """ if not force_remove_last_key: - result = self._run_command([self._cryptsetup_bin, "luksDump", device]) - if result[RETURN_CODE] != 0: + rc, stdout, dummy = self._run_command( + [self._cryptsetup_bin, "luksDump", device] + ) + if rc != 0: raise ValueError(f"Error while dumping LUKS header from {device}") keyslot_count = 0 keyslot_area = False keyslot_re = re.compile(r"^Key Slot [0-9]+: ENABLED") - for line in result[STDOUT].splitlines(): + for line in stdout.splitlines(): if line.startswith("Keyslots:"): keyslot_area = True elif line.startswith(" "): @@ -808,13 +821,17 @@ class CryptHandler(Handler): # Since we supply -q no passphrase is needed args = [self._cryptsetup_bin, "luksKillSlot", device, "-q", str(keyslot)] passphrase = None - result = self._run_command(args, data=passphrase) - if result[RETURN_CODE] != 0: - raise ValueError( - f"Error while removing LUKS key from {device}: {result[STDERR]}" - ) + rc, dummy, stderr = self._run_command(args, data=passphrase) + if rc != 0: + raise ValueError(f"Error while removing LUKS key from {device}: {stderr}") - def luks_test_key(self, device, keyfile, passphrase, keyslot=None): + def luks_test_key( + self, + device: str, + keyfile: str | None, + passphrase: bytes | None, + keyslot: int | None = None, + ) -> bool: """Check whether the keyfile or passphrase works. Raises ValueError when command fails. """ @@ -830,42 +847,37 @@ class CryptHandler(Handler): if keyslot is not None: args.extend(["--key-slot", str(keyslot)]) - result = self._run_command(args, data=data) - if result[RETURN_CODE] == 0: + rc, stdout, stderr = self._run_command(args, data=data) + if rc == 0: return True - for output in (STDOUT, STDERR): - if "No key available with this passphrase" in result[output]: + for output in (stdout, stderr): + if "No key available with this passphrase" in output: return False - if "No usable keyslot is available." in result[output]: + if "No usable keyslot is available." in output: return False # This check is necessary due to cryptsetup in version 2.0.3 not printing 'No usable keyslot is available' # when using the --key-slot parameter in combination with --test-passphrase - if ( - result[RETURN_CODE] == 1 - and keyslot is not None - and result[STDOUT] == "" - and result[STDERR] == "" - ): + if rc == 1 and keyslot is not None and stdout == "" and stderr == "": return False raise ValueError( - f"Error while testing whether keyslot exists on {device}: {result[STDERR]}" + f"Error while testing whether keyslot exists on {device}: {stderr}" ) class ConditionsHandler(Handler): - def __init__(self, module, crypthandler): + def __init__(self, module: AnsibleModule, crypthandler: CryptHandler) -> None: super(ConditionsHandler, self).__init__(module) self._crypthandler = crypthandler self.device = self.get_device_name() - def get_device_name(self): - device = self._module.params.get("device") - label = self._module.params.get("label") - uuid = self._module.params.get("uuid") - name = self._module.params.get("name") + def get_device_name(self) -> str | None: + device: str | None = self._module.params.get("device") + label: str | None = self._module.params.get("label") + uuid: str | None = self._module.params.get("uuid") + name: str | None = self._module.params.get("name") if device is None and label is not None: device = self.get_device_by_label(label) @@ -876,7 +888,7 @@ class ConditionsHandler(Handler): return device - def luks_create(self): + def luks_create(self) -> bool: return ( self.device is not None and ( @@ -887,7 +899,7 @@ class ConditionsHandler(Handler): and not self._crypthandler.is_luks(self.device) ) - def opened_luks_name(self): + def opened_luks_name(self, device: str) -> str | None: """If luks is already opened, return its name. If 'name' parameter is specified and differs from obtained value, fail. @@ -897,7 +909,7 @@ class ConditionsHandler(Handler): return None # try to obtain luks name - it may be already opened - name = self._crypthandler.get_container_name_by_device(self.device) + name = self._crypthandler.get_container_name_by_device(device) if name is None: # container is not open @@ -917,7 +929,7 @@ class ConditionsHandler(Handler): # container is opened and the names match return name - def luks_open(self): + def luks_open(self) -> bool: if ( ( self._module.params["keyfile"] is None @@ -929,13 +941,13 @@ class ConditionsHandler(Handler): # conditions for open not fulfilled return False - name = self.opened_luks_name() + name = self.opened_luks_name(self.device) if name is None: return True return False - def luks_close(self): + def luks_close(self) -> bool: if ( self._module.params["name"] is None and self.device is None ) or self._module.params["state"] != "closed": @@ -948,15 +960,17 @@ class ConditionsHandler(Handler): luks_is_open = name is not None if self._module.params["name"] is not None: - self.device = self._crypthandler.get_container_device_by_name( + device = self._crypthandler.get_container_device_by_name( self._module.params["name"] ) # successfully getting device based on name means that luks is open - luks_is_open = self.device is not None + luks_is_open = device is not None + if device is not None: + self.device = device return luks_is_open - def luks_add_key(self): + def luks_add_key(self) -> bool: if ( self.device is None or ( @@ -995,7 +1009,7 @@ class ConditionsHandler(Handler): return not key_present - def luks_remove_key(self): + def luks_remove_key(self) -> bool: if self.device is None or ( self._module.params["remove_keyfile"] is None and self._module.params["remove_passphrase"] is None @@ -1037,14 +1051,16 @@ class ConditionsHandler(Handler): self.get_passphrase_from_module_params("remove_passphrase"), ) - def luks_remove(self): + def luks_remove(self) -> bool: return ( self.device is not None and self._module.params["state"] == "absent" and self._crypthandler.is_luks(self.device) ) - def validate_keyslot(self, param, luks_type): + def validate_keyslot( + self, param: str, luks_type: t.Literal["luks1", "luks2"] | None + ) -> None: if self._module.params[param] is not None: if luks_type is None and param == "keyslot": if 8 <= self._module.params[param] <= 31: @@ -1066,7 +1082,7 @@ class ConditionsHandler(Handler): ) -def run_module(): +def run_module() -> t.NoReturn: # available arguments/parameters that a user can pass module_args = dict( state=dict( @@ -1122,7 +1138,7 @@ def run_module(): ] # seed the result dict in the object - result = dict(changed=False, name=None) + result: dict[str, t.Any] = {"changed": False, "name": None} module = AnsibleModule( argument_spec=module_args, @@ -1142,19 +1158,26 @@ def run_module(): except Exception as e: module.fail_json(msg=str(e)) - crypt = CryptHandler(module) - conditions = ConditionsHandler(module, crypt) - # conditions not allowed to run if module.params["label"] is not None and module.params["type"] == "luks1": module.fail_json(msg="You cannot combine type luks1 with the label option.") + crypt = CryptHandler(module) + try: + conditions = ConditionsHandler(module, crypt) + except ValueError as exc: + module.fail_json(msg=str(exc)) + if ( module.params["keyslot"] is not None or module.params["new_keyslot"] is not None or module.params["remove_keyslot"] is not None ): - luks_type = crypt.get_luks_type(conditions.get_device_name()) + luks_type = ( + crypt.get_luks_type(conditions.device) + if conditions.device is not None + else None + ) if luks_type is None and module.params["type"] is not None: luks_type = module.params["type"] for param in ["keyslot", "new_keyslot", "remove_keyslot"]: @@ -1175,6 +1198,7 @@ def run_module(): # luks create if conditions.luks_create(): + assert conditions.device # ensured in conditions.luks_create() if not module.check_mode: try: crypt.run_luks_create( @@ -1196,11 +1220,13 @@ def run_module(): # luks open - name = conditions.opened_luks_name() - if name is not None: - result["name"] = name + if conditions.device is not None: + name = conditions.opened_luks_name(conditions.device) + if name is not None: + result["name"] = name if conditions.luks_open(): + assert conditions.device # ensured in conditions.luks_open() name = module.params["name"] if name is None: try: @@ -1237,6 +1263,8 @@ def run_module(): module.fail_json(msg=f"luks_device error: {e}") else: name = module.params["name"] + if name is None: + module.fail_json(msg="Cannot determine name to close device") if not module.check_mode: try: crypt.run_luks_close(name) @@ -1249,6 +1277,7 @@ def run_module(): # luks add key if conditions.luks_add_key(): + assert conditions.device # ensured in conditions.luks_add_key() if not module.check_mode: try: crypt.run_luks_add_key( @@ -1268,6 +1297,7 @@ def run_module(): # luks remove key if conditions.luks_remove_key(): + assert conditions.device # ensured in conditions.luks_remove_key() if not module.check_mode: try: last_key = module.params["force_remove_last_key"] @@ -1286,6 +1316,7 @@ def run_module(): # luks remove if conditions.luks_remove(): + assert conditions.device # ensured in conditions.luks_remove() if not module.check_mode: try: crypt.run_luks_remove(conditions.device) @@ -1299,7 +1330,7 @@ def run_module(): module.exit_json(**result) -def main(): +def main() -> t.NoReturn: run_module() diff --git a/plugins/modules/openssh_cert.py b/plugins/modules/openssh_cert.py index 6cd1b91d..24fd2205 100644 --- a/plugins/modules/openssh_cert.py +++ b/plugins/modules/openssh_cert.py @@ -284,6 +284,7 @@ info: """ import os +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.common import ( @@ -302,47 +303,58 @@ from ansible_collections.community.crypto.plugins.module_utils.version import ( class Certificate(OpensshModule): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(Certificate, self).__init__(module) self.ssh_keygen = KeygenCommand(self.module) - self.identifier = self.module.params["identifier"] or "" - self.options = self.module.params["options"] or [] - self.path = self.module.params["path"] - self.pkcs11_provider = self.module.params["pkcs11_provider"] - self.principals = self.module.params["principals"] or [] - self.public_key = self.module.params["public_key"] - self.regenerate = ( + self.identifier: str = self.module.params["identifier"] or "" + self.options: list[str] = self.module.params["options"] or [] + self.path: str = self.module.params["path"] + self.pkcs11_provider: str | None = self.module.params["pkcs11_provider"] + self.principals: list[str] = self.module.params["principals"] or [] + self.public_key: str | None = self.module.params["public_key"] + self.regenerate: t.Literal[ + "never", + "fail", + "partial_idempotence", + "full_idempotence", + "always", + ] = ( self.module.params["regenerate"] if not self.module.params["force"] else "always" ) - self.serial_number = self.module.params["serial_number"] - self.signature_algorithm = self.module.params["signature_algorithm"] - self.signing_key = self.module.params["signing_key"] - self.state = self.module.params["state"] - self.type = self.module.params["type"] - self.use_agent = self.module.params["use_agent"] - self.valid_at = self.module.params["valid_at"] - self.ignore_timestamps = self.module.params["ignore_timestamps"] + self.serial_number: int | None = self.module.params["serial_number"] + self.signature_algorithm: ( + t.Literal["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"] | None + ) = self.module.params["signature_algorithm"] + self.signing_key: str | None = self.module.params["signing_key"] + self.state: t.Literal["absent", "present"] = self.module.params["state"] + self.type: t.Literal["host", "user"] | None = self.module.params["type"] + self.use_agent: bool = self.module.params["use_agent"] + self.valid_at: str | None = self.module.params["valid_at"] + self.ignore_timestamps: bool = self.module.params["ignore_timestamps"] self._check_if_base_dir(self.path) if self.state == "present": self._validate_parameters() - self.data = None - self.original_data = None + self.data: OpensshCertificate | None = None + self.original_data: OpensshCertificate | None = None if self._exists(): self._load_certificate() - self.time_parameters = None + self.time_parameters: OpensshCertificateTimeParameters | None = None if self.state == "present": self._set_time_parameters() - def _validate_parameters(self): + def _validate_parameters(self) -> None: for path in (self.public_key, self.signing_key): - self._check_if_base_dir(path) + if ( + path is not None + ): # should never be None, but the type checker doesn't know + self._check_if_base_dir(path) if self.options and self.type == "host": self.module.fail_json( @@ -352,7 +364,7 @@ class Certificate(OpensshModule): if self.use_agent: self._use_agent_available() - def _use_agent_available(self): + def _use_agent_available(self) -> None: ssh_version = self._get_ssh_version() if not ssh_version: self.module.fail_json(msg="Failed to determine ssh version") @@ -362,10 +374,10 @@ class Certificate(OpensshModule): + f" Your version is: {ssh_version}" ) - def _exists(self): + def _exists(self) -> bool: return os.path.exists(self.path) - def _load_certificate(self): + def _load_certificate(self) -> None: try: self.original_data = OpensshCertificate.load(self.path) except (TypeError, ValueError) as e: @@ -373,7 +385,7 @@ class Certificate(OpensshModule): self.module.fail_json(msg=f"Unable to read existing certificate: {e}") self.module.warn(f"Unable to read existing certificate: {e}") - def _set_time_parameters(self): + def _set_time_parameters(self) -> None: try: self.time_parameters = OpensshCertificateTimeParameters( valid_from=self.module.params["valid_from"], @@ -382,7 +394,7 @@ class Certificate(OpensshModule): except ValueError as e: self.module.fail_json(msg=str(e)) - def _execute(self): + def _execute(self) -> None: if self.state == "present": if self._should_generate(): self._generate() @@ -391,7 +403,7 @@ class Certificate(OpensshModule): if self._exists(): self._remove() - def _should_generate(self): + def _should_generate(self) -> bool: if self.regenerate == "never": return self.original_data is None elif self.regenerate == "fail": @@ -408,7 +420,13 @@ class Certificate(OpensshModule): else: return True - def _is_fully_valid(self): + def _is_fully_valid(self) -> bool: + if self.original_data is None: + raise AssertionError("Contract violation original_data not provided") + if self.public_key is None: + raise AssertionError("Contract violation public_key not provided") + if self.signing_key is None: + raise AssertionError("Contract violation signing_key not provided") return self._is_partially_valid() and all( [ self._compare_options() if self.original_data.type == "user" else True, @@ -420,7 +438,9 @@ class Certificate(OpensshModule): ] ) - def _is_partially_valid(self): + def _is_partially_valid(self) -> bool: + if self.original_data is None: + raise AssertionError("Contract violation original_data not provided") return all( [ set(self.original_data.principals) == set(self.principals), @@ -439,7 +459,9 @@ class Certificate(OpensshModule): ] ) - def _compare_time_parameters(self): + def _compare_time_parameters(self) -> bool: + if self.original_data is None: + raise AssertionError("Contract violation original_data not provided") try: original_time_parameters = OpensshCertificateTimeParameters( valid_from=self.original_data.valid_after, @@ -458,7 +480,9 @@ class Certificate(OpensshModule): ] ) - def _compare_options(self): + def _compare_options(self) -> bool: + if self.original_data is None: + raise AssertionError("Contract violation original_data not provided") try: critical_options, extensions = parse_option_list(self.options) except ValueError as e: @@ -471,13 +495,13 @@ class Certificate(OpensshModule): ] ) - def _get_key_fingerprint(self, path): + def _get_key_fingerprint(self, path: str) -> str: private_key_content = self.ssh_keygen.get_private_key(path, check_rc=True)[1] return PrivateKey.from_string(private_key_content).fingerprint @OpensshModule.trigger_change @OpensshModule.skip_if_check_mode - def _generate(self): + def _generate(self) -> None: try: temp_certificate = self._generate_temp_certificate() self._safe_secure_move([(temp_certificate, self.path)]) @@ -491,7 +515,14 @@ class Certificate(OpensshModule): except (TypeError, ValueError) as e: self.module.fail_json(msg=f"Unable to read new certificate: {e}") - def _generate_temp_certificate(self): + def _generate_temp_certificate(self) -> str: + if self.public_key is None: + raise AssertionError("Contract violation public_key not provided") + if self.signing_key is None: + raise AssertionError("Contract violation signing_key not provided") + if self.time_parameters is None: + raise AssertionError("Contract violation time_parameters not provided") + key_copy = os.path.join(self.module.tmpdir, os.path.basename(self.public_key)) try: @@ -523,14 +554,14 @@ class Certificate(OpensshModule): @OpensshModule.trigger_change @OpensshModule.skip_if_check_mode - def _remove(self): + def _remove(self) -> None: try: os.remove(self.path) except OSError as e: self.module.fail_json(msg=f"Unable to remove existing certificate: {e}") @property - def _result(self): + def _result(self) -> dict[str, t.Any]: if self.state != "present": return {} @@ -546,14 +577,14 @@ class Certificate(OpensshModule): } @property - def diff(self): + def diff(self) -> dict[str, t.Any]: return { "before": get_cert_dict(self.original_data), "after": get_cert_dict(self.data), } -def format_cert_info(cert_info): +def format_cert_info(cert_info: str) -> list[str]: result = [] string = "" @@ -579,7 +610,7 @@ def format_cert_info(cert_info): return result -def get_cert_dict(data): +def get_cert_dict(data: OpensshCertificate | None) -> dict[str, t.Any]: if data is None: return {} @@ -590,7 +621,7 @@ def get_cert_dict(data): return result -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( force=dict(type="bool", default=False), diff --git a/plugins/modules/openssh_keypair.py b/plugins/modules/openssh_keypair.py index 28e30fdc..6161be44 100644 --- a/plugins/modules/openssh_keypair.py +++ b/plugins/modules/openssh_keypair.py @@ -198,13 +198,15 @@ comment: sample: test@comment """ +import typing as t + from ansible.module_utils.basic import AnsibleModule from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.keypair_backend import ( select_backend, ) -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( diff --git a/plugins/modules/openssl_csr.py b/plugins/modules/openssl_csr.py index b40583ab..5d0dbee5 100644 --- a/plugins/modules/openssl_csr.py +++ b/plugins/modules/openssl_csr.py @@ -239,6 +239,7 @@ csr: """ import os +import typing as t from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( OpenSSLObjectError, @@ -256,9 +257,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import ( ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import ( + CertificateSigningRequestBackend, + ) + + class CertificateSigningRequestModule(OpenSSLObject): - def __init__(self, module, module_backend): + def __init__( + self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend + ) -> None: super(CertificateSigningRequestModule, self).__init__( module.params["path"], module.params["state"], @@ -269,11 +279,11 @@ class CertificateSigningRequestModule(OpenSSLObject): self.return_content = module.params["return_content"] self.backup = module.params["backup"] - self.backup_file = None + self.backup_file: str | None = None self.module_backend.set_existing(load_file_if_exists(self.path, module)) - def generate(self, module): + def generate(self, module: AnsibleModule) -> None: """Generate the certificate signing request.""" if self.force or self.module_backend.needs_regeneration(): if not self.check_mode: @@ -292,13 +302,13 @@ class CertificateSigningRequestModule(OpenSSLObject): file_args, self.changed ) - def remove(self, module): + def remove(self, module: AnsibleModule) -> None: self.module_backend.set_existing(None) if self.backup and not self.check_mode: self.backup_file = module.backup_local(self.path) super(CertificateSigningRequestModule, self).remove(module) - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" result = self.module_backend.dump(include_csr=self.return_content) result.update( @@ -312,7 +322,7 @@ class CertificateSigningRequestModule(OpenSSLObject): return result -def main(): +def main() -> t.NoReturn: argument_spec = get_csr_argument_spec() argument_spec.argument_spec.update( dict( diff --git a/plugins/modules/openssl_csr_info.py b/plugins/modules/openssl_csr_info.py index 52e33c61..a7abe449 100644 --- a/plugins/modules/openssl_csr_info.py +++ b/plugins/modules/openssl_csr_info.py @@ -308,6 +308,7 @@ authority_cert_serial_number: sample: 12345 """ +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -318,7 +319,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac ) -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( path=dict(type="path"), @@ -335,11 +336,15 @@ def main(): supports_check_mode=True, ) - if module.params["content"] is not None: - data = module.params["content"].encode("utf-8") + content: str | None = module.params["content"] + path: str | None = module.params["path"] + if content is not None: + data = content.encode("utf-8") else: + if path is None: + module.fail_json(msg="One of content and path must be provided") try: - with open(module.params["path"], "rb") as f: + with open(path, "rb") as f: data = f.read() except (IOError, OSError) as e: module.fail_json(msg=f"Error while reading CSR file from disk: {e}") diff --git a/plugins/modules/openssl_csr_pipe.py b/plugins/modules/openssl_csr_pipe.py index f3afc05d..35d8eb0a 100644 --- a/plugins/modules/openssl_csr_pipe.py +++ b/plugins/modules/openssl_csr_pipe.py @@ -127,6 +127,8 @@ csr: type: str """ +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( OpenSSLObjectError, ) @@ -136,8 +138,17 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import ( + CertificateSigningRequestBackend, + ) + + class CertificateSigningRequestModule: - def __init__(self, module, module_backend): + def __init__( + self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend + ) -> None: self.check_mode = module.check_mode self.module = module self.module_backend = module_backend @@ -145,13 +156,13 @@ class CertificateSigningRequestModule: if module.params["content"] is not None: self.module_backend.set_existing(module.params["content"].encode("utf-8")) - def generate(self, module): + def generate(self, module: AnsibleModule) -> None: """Generate the certificate signing request.""" if self.module_backend.needs_regeneration(): self.module_backend.generate_csr() self.changed = True - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" result = self.module_backend.dump(include_csr=True) result.update( @@ -162,7 +173,7 @@ class CertificateSigningRequestModule: return result -def main(): +def main() -> t.NoReturn: argument_spec = get_csr_argument_spec() argument_spec.argument_spec.update( dict( diff --git a/plugins/modules/openssl_dhparam.py b/plugins/modules/openssl_dhparam.py index 8b3fb028..d2ab17f1 100644 --- a/plugins/modules/openssl_dhparam.py +++ b/plugins/modules/openssl_dhparam.py @@ -132,6 +132,7 @@ import abc import os import re import tempfile +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_native @@ -173,23 +174,23 @@ class DHParameterError(Exception): class DHParameterBase: - def __init__(self, module): - self.state = module.params["state"] - self.path = module.params["path"] - self.size = module.params["size"] - self.force = module.params["force"] + def __init__(self, module: AnsibleModule) -> None: + self.state: t.Literal["absent", "present"] = module.params["state"] + self.path: str = module.params["path"] + self.size: int = module.params["size"] + self.force: bool = module.params["force"] self.changed = False - self.return_content = module.params["return_content"] + self.return_content: bool = module.params["return_content"] - self.backup = module.params["backup"] - self.backup_file = None + self.backup: bool = module.params["backup"] + self.backup_file: str | None = None @abc.abstractmethod - def _do_generate(self, module): + def _do_generate(self, module: AnsibleModule) -> None: """Actually generate the DH params.""" pass - def generate(self, module): + def generate(self, module: AnsibleModule) -> None: """Generate DH params.""" changed = False @@ -206,7 +207,7 @@ class DHParameterBase: self.changed = changed - def remove(self, module): + def remove(self, module: AnsibleModule) -> None: if self.backup: self.backup_file = module.backup_local(self.path) try: @@ -215,28 +216,27 @@ class DHParameterBase: except OSError as exc: module.fail_json(msg=str(exc)) - def check(self, module): + def check(self, module: AnsibleModule) -> bool: """Ensure the resource is in its desired state.""" if self.force: return False return self._check_params_valid(module) and self._check_fs_attributes(module) @abc.abstractmethod - def _check_params_valid(self, module): + def _check_params_valid(self, module: AnsibleModule) -> bool: """Check if the params are in the correct state""" - pass - def _check_fs_attributes(self, module): + def _check_fs_attributes(self, module: AnsibleModule) -> bool: """Checks (and changes if not in check mode!) fs attributes""" file_args = module.load_file_common_arguments(module.params) if module.check_file_absent_if_check_mode(file_args["path"]): return False return not module.set_fs_attributes_if_different(file_args, False) - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" - result = { + result: dict[str, t.Any] = { "size": self.size, "filename": self.path, "changed": self.changed, @@ -252,25 +252,24 @@ class DHParameterBase: class DHParameterAbsent(DHParameterBase): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(DHParameterAbsent, self).__init__(module) - def _do_generate(self, module): + def _do_generate(self, module: AnsibleModule) -> None: """Actually generate the DH params.""" - pass - def _check_params_valid(self, module): + def _check_params_valid(self, module: AnsibleModule) -> bool: """Check if the params are in the correct state""" - pass + return False class DHParameterOpenSSL(DHParameterBase): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(DHParameterOpenSSL, self).__init__(module) self.openssl_bin = module.get_bin_path("openssl", True) - def _do_generate(self, module): + def _do_generate(self, module: AnsibleModule) -> None: """Actually generate the DH params.""" # create a tempfile fd, tmpsrc = tempfile.mkstemp() @@ -288,7 +287,7 @@ class DHParameterOpenSSL(DHParameterBase): except Exception as e: module.fail_json(msg=f"Failed to write to file {self.path}: {str(e)}") - def _check_params_valid(self, module): + def _check_params_valid(self, module: AnsibleModule) -> bool: """Check if the params are in the correct state""" command = [ self.openssl_bin, @@ -321,10 +320,10 @@ class DHParameterOpenSSL(DHParameterBase): class DHParameterCryptography(DHParameterBase): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(DHParameterCryptography, self).__init__(module) - def _do_generate(self, module): + def _do_generate(self, module: AnsibleModule) -> None: """Actually generate the DH params.""" # Generate parameters params = cryptography.hazmat.primitives.asymmetric.dh.generate_parameters( @@ -341,7 +340,7 @@ class DHParameterCryptography(DHParameterBase): self.backup_file = module.backup_local(self.path) write_file(module, result) - def _check_params_valid(self, module): + def _check_params_valid(self, module: AnsibleModule) -> bool: """Check if the params are in the correct state""" # Load parameters try: @@ -357,7 +356,7 @@ class DHParameterCryptography(DHParameterBase): return bits == self.size -def main(): +def main() -> t.NoReturn: """Main function""" module = AnsibleModule( @@ -383,6 +382,7 @@ def main(): msg=f"The directory '{base_dir}' does not exist or the file is not a directory", ) + dhparam: DHParameterOpenSSL | DHParameterCryptography | DHParameterAbsent if module.params["state"] == "present": backend = module.params["select_crypto_backend"] if backend == "auto": diff --git a/plugins/modules/openssl_pkcs12.py b/plugins/modules/openssl_pkcs12.py index 3640636c..12671a9e 100644 --- a/plugins/modules/openssl_pkcs12.py +++ b/plugins/modules/openssl_pkcs12.py @@ -280,6 +280,7 @@ import itertools import os import stat import traceback +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_bytes, to_native @@ -296,7 +297,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import from ansible_collections.community.crypto.plugins.module_utils.crypto.support import ( OpenSSLObject, load_certificate, - load_privatekey, + load_certificate_issuer_privatekey, ) from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import ( COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION, @@ -320,6 +321,7 @@ except ImportError: CRYPTOGRAPHY_COMPATIBILITY2022_ERR = None try: + import cryptography.x509 from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.serialization.pkcs12 import PBES @@ -333,8 +335,22 @@ except Exception: else: CRYPTOGRAPHY_HAS_COMPATIBILITY2022 = True +if t.TYPE_CHECKING: + from ..module_utils.crypto.cryptography_support import ( + CertificateIssuerPrivateKeyTypes, + ) -def load_certificate_set(filename): + PKCS12 = tuple[ + t.Union[CertificateIssuerPrivateKeyTypes, None], + t.Union[cryptography.x509.Certificate, None], + list[cryptography.x509.Certificate], + t.Union[bytes, None], + ] + + +def load_certificate_set( + filename: str | os.PathLike, +) -> list[cryptography.x509.Certificate]: """ Load list of concatenated PEM files, and return a list of parsed certificates. """ @@ -351,70 +367,80 @@ class PkcsError(OpenSSLObjectError): class Pkcs(OpenSSLObject): - def __init__(self, module, iter_size_default=2048): + path: str + + def __init__(self, module: AnsibleModule, iter_size_default: int = 2048) -> None: super(Pkcs, self).__init__( module.params["path"], module.params["state"], module.params["force"], module.check_mode, ) - self.action = module.params["action"] - self.other_certificates = module.params["other_certificates"] - self.other_certificates_parse_all = module.params[ + self.action: t.Literal["export", "parse"] = module.params["action"] + self.other_certificates: list[cryptography.x509.Certificate] = [] + self.other_certificates_str: list[str] | None = module.params[ + "other_certificates" + ] + self.other_certificates_parse_all: bool = module.params[ "other_certificates_parse_all" ] - self.other_certificates_content = module.params["other_certificates_content"] - self.certificate_path = module.params["certificate_path"] - self.certificate_content = module.params["certificate_content"] - self.friendly_name = module.params["friendly_name"] - self.iter_size = module.params["iter_size"] or iter_size_default - self.maciter_size = module.params["maciter_size"] or 1 - self.encryption_level = module.params["encryption_level"] + self.other_certificates_content: list[str] | None = module.params[ + "other_certificates_content" + ] + self.certificate_path: str | None = module.params["certificate_path"] + certificate_content: str | None = module.params["certificate_content"] + self.friendly_name: str | None = module.params["friendly_name"] + self.iter_size: int = module.params["iter_size"] or iter_size_default + self.maciter_size: int = module.params["maciter_size"] or 1 + self.encryption_level: t.Literal["auto", "compatibility2022"] = module.params[ + "encryption_level" + ] self.passphrase = module.params["passphrase"] - self.pkcs12 = None - self.privatekey_passphrase = module.params["privatekey_passphrase"] - self.privatekey_path = module.params["privatekey_path"] - self.privatekey_content = module.params["privatekey_content"] - self.pkcs12_bytes = None - self.return_content = module.params["return_content"] - self.src = module.params["src"] + self.pkcs12: PKCS12 | None = None + self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"] + self.privatekey_path: str | None = module.params["privatekey_path"] + privatekey_content: str | None = module.params["privatekey_content"] + self.pkcs12_bytes: bytes | None = None + self.return_content: bool = module.params["return_content"] + self.src: str | None = module.params["src"] if module.params["mode"] is None: module.params["mode"] = "0400" - self.backup = module.params["backup"] - self.backup_file = None + self.backup: bool = module.params["backup"] + self.backup_file: str | None = None + self.certificate_content: bytes | None = None if self.certificate_path is not None: try: with open(self.certificate_path, "rb") as fh: self.certificate_content = fh.read() except (IOError, OSError) as exc: raise PkcsError(exc) - elif self.certificate_content is not None: - self.certificate_content = to_bytes(self.certificate_content) + elif certificate_content is not None: + self.certificate_content = to_bytes(certificate_content) + self.privatekey_content: bytes | None = None if self.privatekey_path is not None: try: with open(self.privatekey_path, "rb") as fh: self.privatekey_content = fh.read() except (IOError, OSError) as exc: raise PkcsError(exc) - elif self.privatekey_content is not None: - self.privatekey_content = to_bytes(self.privatekey_content) + elif privatekey_content is not None: + self.privatekey_content = to_bytes(privatekey_content) - if self.other_certificates: + if self.other_certificates_str: if self.other_certificates_parse_all: - filenames = list(self.other_certificates) self.other_certificates = [] - for other_cert_bundle in filenames: + for other_cert_bundle in self.other_certificates_str: self.other_certificates.extend( load_certificate_set(other_cert_bundle) ) else: self.other_certificates = [ load_certificate(other_cert) - for other_cert in self.other_certificates + for other_cert in self.other_certificates_str ] elif self.other_certificates_content: certs = self.other_certificates_content @@ -430,40 +456,42 @@ class Pkcs(OpenSSLObject): ] @abc.abstractmethod - def generate_bytes(self, module): + def generate_bytes(self, module: AnsibleModule) -> bytes: """Generate PKCS#12 file archive.""" + + @abc.abstractmethod + def parse_bytes(self, pkcs12_content: bytes) -> tuple[ + bytes | None, + bytes | None, + list[bytes], + bytes | None, + ]: pass @abc.abstractmethod - def parse_bytes(self, pkcs12_content): + def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None: pass @abc.abstractmethod - def _dump_privatekey(self, pkcs12): + def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None: pass @abc.abstractmethod - def _dump_certificate(self, pkcs12): + def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]: pass @abc.abstractmethod - def _dump_other_certificates(self, pkcs12): + def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None: pass - @abc.abstractmethod - def _get_friendly_name(self, pkcs12): - pass - - def check(self, module, perms_required=True): + def check(self, module: AnsibleModule, perms_required: bool = True) -> bool: """Ensure the resource is in its desired state.""" - state_and_perms = super(Pkcs, self).check(module, perms_required) - def _check_pkey_passphrase(): + def _check_pkey_passphrase() -> bool: if self.privatekey_passphrase: try: - load_privatekey( - None, + load_certificate_issuer_privatekey( content=self.privatekey_content, passphrase=self.privatekey_passphrase, ) @@ -476,6 +504,7 @@ class Pkcs(OpenSSLObject): if os.path.exists(self.path) and module.params["action"] == "export": self.generate_bytes(module) # ignore result + assert self.pkcs12 is not None self.src = self.path try: ( @@ -524,7 +553,7 @@ class Pkcs(OpenSSLObject): return False elif ( module.params["action"] == "parse" - and os.path.exists(self.src) + and os.path.exists(self.src or "") and os.path.exists(self.path) ): try: @@ -548,10 +577,10 @@ class Pkcs(OpenSSLObject): return _check_pkey_passphrase() - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" - result = { + result: dict[str, t.Any] = { "filename": self.path, } if self.privatekey_path: @@ -567,13 +596,20 @@ class Pkcs(OpenSSLObject): return result - def remove(self, module): + def remove(self, module: AnsibleModule) -> None: if self.backup: self.backup_file = module.backup_local(self.path) super(Pkcs, self).remove(module) - def parse(self): + def parse(self) -> tuple[ + bytes | None, + bytes | None, + list[bytes], + bytes | None, + ]: """Read PKCS#12 file.""" + if self.src is None: + raise AssertionError("Contract violation: src is None") try: with open(self.src, "rb") as pkcs12_fh: @@ -582,10 +618,13 @@ class Pkcs(OpenSSLObject): except IOError as exc: raise PkcsError(exc) - def generate(self): + def generate(self, module: AnsibleModule) -> None: + # Empty method because OpenSSLObject wants this pass - def write(self, module, content, mode=None): + def write( + self, module: AnsibleModule, content: bytes, mode: int | str | None = None + ) -> None: """Write the PKCS#12 file.""" if self.backup: self.backup_file = module.backup_local(self.path) @@ -595,7 +634,7 @@ class Pkcs(OpenSSLObject): class PkcsCryptography(Pkcs): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(PkcsCryptography, self).__init__(module, iter_size_default=50000) if ( self.encryption_level == "compatibility2022" @@ -607,13 +646,12 @@ class PkcsCryptography(Pkcs): exception=CRYPTOGRAPHY_COMPATIBILITY2022_ERR, ) - def generate_bytes(self, module): + def generate_bytes(self, module: AnsibleModule) -> bytes: """Generate PKCS#12 file archive.""" pkey = None if self.privatekey_content: try: - pkey = load_privatekey( - None, + pkey = load_certificate_issuer_privatekey( content=self.privatekey_content, passphrase=self.privatekey_passphrase, ) @@ -631,6 +669,7 @@ class PkcsCryptography(Pkcs): # Store fake object which can be used to retrieve the components back self.pkcs12 = (pkey, cert, self.other_certificates, friendly_name) + encryption: serialization.KeySerializationEncryption if not self.passphrase: encryption = serialization.NoEncryption() elif self.encryption_level == "compatibility2022": @@ -654,7 +693,12 @@ class PkcsCryptography(Pkcs): encryption, ) - def parse_bytes(self, pkcs12_content): + def parse_bytes(self, pkcs12_content: bytes) -> tuple[ + bytes | None, + bytes | None, + list[bytes], + bytes | None, + ]: try: private_key, certificate, additional_certificates, friendly_name = ( parse_pkcs12(pkcs12_content, self.passphrase) @@ -683,11 +727,7 @@ class PkcsCryptography(Pkcs): except ValueError as exc: raise PkcsError(exc) - # The following methods will get self.pkcs12 passed, which is computed as: - # - # self.pkcs12 = (pkey, cert, self.other_certificates, self.friendly_name) - - def _dump_privatekey(self, pkcs12): + def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None: return ( pkcs12[0].private_bytes( encoding=serialization.Encoding.PEM, @@ -698,27 +738,27 @@ class PkcsCryptography(Pkcs): else None ) - def _dump_certificate(self, pkcs12): + def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None: return pkcs12[1].public_bytes(serialization.Encoding.PEM) if pkcs12[1] else None - def _dump_other_certificates(self, pkcs12): + def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]: return [ other_cert.public_bytes(serialization.Encoding.PEM) for other_cert in pkcs12[2] ] - def _get_friendly_name(self, pkcs12): + def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None: return pkcs12[3] -def select_backend(module): +def select_backend(module: AnsibleModule) -> Pkcs: assert_required_cryptography_version( module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) return PkcsCryptography(module) -def main(): +def main() -> t.NoReturn: argument_spec = dict( action=dict(type="str", default="export", choices=["export", "parse"]), other_certificates=dict( diff --git a/plugins/modules/openssl_privatekey.py b/plugins/modules/openssl_privatekey.py index 73d87f69..6b997479 100644 --- a/plugins/modules/openssl_privatekey.py +++ b/plugins/modules/openssl_privatekey.py @@ -155,6 +155,7 @@ privatekey: """ import os +import typing as t from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( OpenSSLObjectError, @@ -172,9 +173,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import ( ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey import ( + PrivateKeyBackend, + ) + + class PrivateKeyModule(OpenSSLObject): - def __init__(self, module, module_backend): + def __init__( + self, module: AnsibleModule, module_backend: PrivateKeyBackend + ) -> None: super(PrivateKeyModule, self).__init__( module.params["path"], module.params["state"], @@ -182,19 +192,19 @@ class PrivateKeyModule(OpenSSLObject): module.check_mode, ) self.module_backend = module_backend - self.return_content = module.params["return_content"] + self.return_content: bool = module.params["return_content"] if self.force: module_backend.regenerate = "always" - self.backup = module.params["backup"] - self.backup_file = None + self.backup: str | None = module.params["backup"] + self.backup_file: str | None = None if module.params["mode"] is None: module.params["mode"] = "0600" module_backend.set_existing(load_file_if_exists(self.path, module)) - def generate(self, module): + def generate(self, module: AnsibleModule) -> None: """Generate a keypair.""" if self.module_backend.needs_regeneration(): @@ -228,13 +238,13 @@ class PrivateKeyModule(OpenSSLObject): file_args, self.changed ) - def remove(self, module): + def remove(self, module: AnsibleModule) -> None: self.module_backend.set_existing(None) if self.backup and not self.check_mode: self.backup_file = module.backup_local(self.path) super(PrivateKeyModule, self).remove(module) - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" result = self.module_backend.dump(include_key=self.return_content) @@ -246,7 +256,7 @@ class PrivateKeyModule(OpenSSLObject): return result -def main(): +def main() -> t.NoReturn: argument_spec = get_privatekey_argument_spec() argument_spec.argument_spec.update( diff --git a/plugins/modules/openssl_privatekey_convert.py b/plugins/modules/openssl_privatekey_convert.py index 8bf43747..92179d30 100644 --- a/plugins/modules/openssl_privatekey_convert.py +++ b/plugins/modules/openssl_privatekey_convert.py @@ -60,6 +60,7 @@ backup_file: """ import os +import typing as t from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( OpenSSLObjectError, @@ -77,8 +78,17 @@ from ansible_collections.community.crypto.plugins.module_utils.io import ( ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey_convert import ( + PrivateKeyConvertBackend, + ) + + class PrivateKeyConvertModule(OpenSSLObject): - def __init__(self, module, module_backend): + def __init__( + self, module: AnsibleModule, module_backend: PrivateKeyConvertBackend + ) -> None: super(PrivateKeyConvertModule, self).__init__( module.params["dest_path"], "present", @@ -87,8 +97,8 @@ class PrivateKeyConvertModule(OpenSSLObject): ) self.module_backend = module_backend - self.backup = module.params["backup"] - self.backup_file = None + self.backup: bool = module.params["backup"] + self.backup_file: str | None = None module.params["path"] = module.params["dest_path"] if module.params["mode"] is None: @@ -96,12 +106,14 @@ class PrivateKeyConvertModule(OpenSSLObject): module_backend.set_existing_destination(load_file_if_exists(self.path, module)) - def generate(self, module): + def generate(self, module: AnsibleModule) -> None: """Do conversion.""" if self.module_backend.needs_conversion(): # Convert privatekey_data = self.module_backend.get_private_key_data() + if privatekey_data is None: + raise AssertionError("Contract violation: privatekey_data is None") if not self.check_mode: if self.backup: self.backup_file = module.backup_local(self.path) @@ -116,7 +128,7 @@ class PrivateKeyConvertModule(OpenSSLObject): file_args, self.changed ) - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" result = self.module_backend.dump() @@ -127,7 +139,7 @@ class PrivateKeyConvertModule(OpenSSLObject): return result -def main(): +def main() -> t.NoReturn: argument_spec = get_privatekey_argument_spec() argument_spec.argument_spec.update( diff --git a/plugins/modules/openssl_privatekey_info.py b/plugins/modules/openssl_privatekey_info.py index ec75177a..fd74178b 100644 --- a/plugins/modules/openssl_privatekey_info.py +++ b/plugins/modules/openssl_privatekey_info.py @@ -200,6 +200,7 @@ private_data: type: dict """ +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -212,7 +213,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac ) -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( path=dict(type="path"), @@ -243,7 +244,7 @@ def main(): data = f.read() except (IOError, OSError) as e: module.fail_json( - msg=f"Error while reading private key file from disk: {e}", **result + msg=f"Error while reading private key file from disk: {e}", **result # type: ignore ) result["can_load_key"] = True @@ -261,10 +262,10 @@ def main(): module.exit_json(**result) except PrivateKeyParseError as exc: result.update(exc.result) - module.fail_json(msg=exc.error_message, **result) + module.fail_json(msg=exc.error_message, **result) # type: ignore except PrivateKeyConsistencyError as exc: result.update(exc.result) - module.fail_json(msg=exc.error_message, **result) + module.fail_json(msg=exc.error_message, **result) # type: ignore except OpenSSLObjectError as exc: module.fail_json(msg=str(exc)) diff --git a/plugins/modules/openssl_publickey.py b/plugins/modules/openssl_publickey.py index 37375b24..096f1692 100644 --- a/plugins/modules/openssl_publickey.py +++ b/plugins/modules/openssl_publickey.py @@ -186,6 +186,7 @@ publickey: """ import os +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -218,6 +219,12 @@ try: except ImportError: pass +if t.TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.types import ( + PrivateKeyTypes, + PublicKeyTypes, + ) + class PublicKeyError(OpenSSLObjectError): pass @@ -225,7 +232,7 @@ class PublicKeyError(OpenSSLObjectError): class PublicKey(OpenSSLObject): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(PublicKey, self).__init__( module.params["path"], module.params["state"], @@ -233,27 +240,29 @@ class PublicKey(OpenSSLObject): module.check_mode, ) self.module = module - self.format = module.params["format"] - self.privatekey_path = module.params["privatekey_path"] - self.privatekey_content = module.params["privatekey_content"] - if self.privatekey_content is not None: - self.privatekey_content = self.privatekey_content.encode("utf-8") - self.privatekey_passphrase = module.params["privatekey_passphrase"] - self.privatekey = None - self.publickey_bytes = None - self.return_content = module.params["return_content"] - self.fingerprint = {} + self.format: t.Literal["OpenSSH", "PEM"] = module.params["format"] + self.privatekey_path: str | None = module.params["privatekey_path"] + privatekey_content: str | None = module.params["privatekey_content"] + if privatekey_content is not None: + self.privatekey_content: bytes | None = privatekey_content.encode("utf-8") + else: + self.privatekey_content = None + self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"] + self.privatekey: PrivateKeyTypes | None = None + self.publickey_bytes: bytes | None = None + self.return_content: bool = module.params["return_content"] + self.fingerprint: dict[str, str] = {} - self.backup = module.params["backup"] - self.backup_file = None + self.backup: bool = module.params["backup"] + self.backup_file: str | None = None self.diff_before = self._get_info(None) self.diff_after = self._get_info(None) - def _get_info(self, data): + def _get_info(self, data: bytes | None) -> dict[str, t.Any]: if data is None: - return dict() - result = dict(can_parse_key=False) + return {} + result = {"can_parse_key": False} try: result.update( get_publickey_info( @@ -267,7 +276,7 @@ class PublicKey(OpenSSLObject): pass return result - def _create_publickey(self, module): + def _create_publickey(self, module: AnsibleModule) -> bytes: self.privatekey = load_privatekey( path=self.privatekey_path, content=self.privatekey_content, @@ -284,10 +293,12 @@ class PublicKey(OpenSSLObject): crypto_serialization.PublicFormat.SubjectPublicKeyInfo, ) - def generate(self, module): + def generate(self, module: AnsibleModule) -> None: """Generate the public key.""" - if self.privatekey_content is None and not os.path.exists(self.privatekey_path): + if self.privatekey_path is not None and not os.path.exists( + self.privatekey_path + ): raise PublicKeyError( f"The private key {self.privatekey_path} does not exist" ) @@ -320,17 +331,18 @@ class PublicKey(OpenSSLObject): elif module.set_fs_attributes_if_different(file_args, False): self.changed = True - def check(self, module, perms_required=True): + def check(self, module: AnsibleModule, perms_required: bool = True) -> bool: """Ensure the resource is in its desired state.""" state_and_perms = super(PublicKey, self).check(module, perms_required) - def _check_privatekey(): - if self.privatekey_content is None and not os.path.exists( + def _check_privatekey() -> bool: + if self.privatekey_path is not None and not os.path.exists( self.privatekey_path ): return False + current_publickey: PublicKeyTypes try: with open(self.path, "rb") as public_key_fh: publickey_content = public_key_fh.read() @@ -369,15 +381,15 @@ class PublicKey(OpenSSLObject): return _check_privatekey() - def remove(self, module): + def remove(self, module: AnsibleModule) -> None: if self.backup: self.backup_file = module.backup_local(self.path) super(PublicKey, self).remove(module) - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" - result = { + result: dict[str, t.Any] = { "privatekey": self.privatekey_path, "filename": self.path, "format": self.format, @@ -403,7 +415,7 @@ class PublicKey(OpenSSLObject): return result -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( diff --git a/plugins/modules/openssl_publickey_info.py b/plugins/modules/openssl_publickey_info.py index 9642e0e8..e472b980 100644 --- a/plugins/modules/openssl_publickey_info.py +++ b/plugins/modules/openssl_publickey_info.py @@ -152,6 +152,7 @@ public_data: returned: When RV(type=DSA) or RV(type=ECC) """ +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -163,7 +164,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac ) -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( path=dict(type="path"), @@ -191,7 +192,7 @@ def main(): data = f.read() except (IOError, OSError) as e: module.fail_json( - msg=f"Error while reading public key file from disk: {e}", **result + msg=f"Error while reading public key file from disk: {e}", **result # type: ignore ) module_backend = select_backend(module, data) @@ -201,7 +202,7 @@ def main(): module.exit_json(**result) except PublicKeyParseError as exc: result.update(exc.result) - module.fail_json(msg=exc.error_message, **result) + module.fail_json(msg=exc.error_message, **result) # type: ignore except OpenSSLObjectError as exc: module.fail_json(msg=str(exc)) diff --git a/plugins/modules/openssl_signature.py b/plugins/modules/openssl_signature.py index 6f0b0974..df182ef2 100644 --- a/plugins/modules/openssl_signature.py +++ b/plugins/modules/openssl_signature.py @@ -99,6 +99,7 @@ signature: import base64 import os +import typing as t from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import ( COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION, @@ -132,7 +133,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im class SignatureBase(OpenSSLObject): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(SignatureBase, self).__init__( path=module.params["path"], state="present", @@ -140,32 +141,35 @@ class SignatureBase(OpenSSLObject): check_mode=module.check_mode, ) - self.privatekey_path = module.params["privatekey_path"] - self.privatekey_content = module.params["privatekey_content"] - if self.privatekey_content is not None: - self.privatekey_content = self.privatekey_content.encode("utf-8") - self.privatekey_passphrase = module.params["privatekey_passphrase"] + self.module = module + self.privatekey_path: str | None = module.params["privatekey_path"] + privatekey_content: str | None = module.params["privatekey_content"] + if privatekey_content is not None: + self.privatekey_content: bytes | None = privatekey_content.encode("utf-8") + else: + self.privatekey_content = None + self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"] - def generate(self): + def generate(self, module: AnsibleModule) -> None: # Empty method because OpenSSLObject wants this pass - def dump(self): + def dump(self) -> dict[str, t.Any]: # Empty method because OpenSSLObject wants this - pass + return {} # Implementation with using cryptography class SignatureCryptography(SignatureBase): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(SignatureCryptography, self).__init__(module) - def run(self): + def run(self) -> dict[str, t.Any]: _padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15() _hash = cryptography.hazmat.primitives.hashes.SHA256() - result = dict() + result: dict[str, t.Any] = {} try: with open(self.path, "rb") as f: @@ -223,7 +227,7 @@ class SignatureCryptography(SignatureBase): raise OpenSSLObjectError(e) -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( privatekey_path=dict(type="path"), diff --git a/plugins/modules/openssl_signature_info.py b/plugins/modules/openssl_signature_info.py index 1ac66899..1a0ccf4d 100644 --- a/plugins/modules/openssl_signature_info.py +++ b/plugins/modules/openssl_signature_info.py @@ -88,6 +88,7 @@ valid: import base64 import os +import typing as t from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import ( COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION, @@ -121,7 +122,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im class SignatureInfoBase(OpenSSLObject): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(SignatureInfoBase, self).__init__( path=module.params["path"], state="present", @@ -129,32 +130,35 @@ class SignatureInfoBase(OpenSSLObject): check_mode=module.check_mode, ) - self.signature = module.params["signature"] - self.certificate_path = module.params["certificate_path"] - self.certificate_content = module.params["certificate_content"] - if self.certificate_content is not None: - self.certificate_content = self.certificate_content.encode("utf-8") + self.module = module + self.signature: str = module.params["signature"] + self.certificate_path: str | None = module.params["certificate_path"] + certificate_content: str | None = module.params["certificate_content"] + if certificate_content is not None: + self.certificate_content: bytes | None = certificate_content.encode("utf-8") + else: + self.certificate_content = None - def generate(self): + def generate(self, module: AnsibleModule) -> None: # Empty method because OpenSSLObject wants this pass - def dump(self): + def dump(self) -> dict[str, t.Any]: # Empty method because OpenSSLObject wants this - pass + return {} # Implementation with using cryptography class SignatureInfoCryptography(SignatureInfoBase): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(SignatureInfoCryptography, self).__init__(module) - def run(self): + def run(self) -> dict[str, t.Any]: _padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15() _hash = cryptography.hazmat.primitives.hashes.SHA256() - result = dict() + result: dict[str, t.Any] = {} try: with open(self.path, "rb") as f: @@ -228,7 +232,7 @@ class SignatureInfoCryptography(SignatureInfoBase): raise OpenSSLObjectError(e) -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( certificate_path=dict(type="path"), diff --git a/plugins/modules/x509_certificate.py b/plugins/modules/x509_certificate.py index 7c228c09..019a5010 100644 --- a/plugins/modules/x509_certificate.py +++ b/plugins/modules/x509_certificate.py @@ -224,6 +224,7 @@ certificate: import os +import typing as t from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( OpenSSLObjectError, @@ -257,8 +258,15 @@ from ansible_collections.community.crypto.plugins.module_utils.io import ( ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import ( + CertificateBackend, + ) + + class CertificateAbsent(OpenSSLObject): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(CertificateAbsent, self).__init__( module.params["path"], module.params["state"], @@ -266,19 +274,19 @@ class CertificateAbsent(OpenSSLObject): module.check_mode, ) self.module = module - self.return_content = module.params["return_content"] - self.backup = module.params["backup"] - self.backup_file = None + self.return_content: bool = module.params["return_content"] + self.backup: bool = module.params["backup"] + self.backup_file: str | None = None - def generate(self, module): + def generate(self, module: AnsibleModule) -> None: pass - def remove(self, module): + def remove(self, module: AnsibleModule) -> None: if self.backup: self.backup_file = module.backup_local(self.path) super(CertificateAbsent, self).remove(module) - def dump(self, check_mode=False): + def dump(self, check_mode: bool = False) -> dict[str, t.Any]: result = { "changed": self.changed, "filename": self.path, @@ -296,7 +304,7 @@ class CertificateAbsent(OpenSSLObject): class GenericCertificate(OpenSSLObject): """Retrieve a certificate using the given module backend.""" - def __init__(self, module, module_backend): + def __init__(self, module: AnsibleModule, module_backend: CertificateBackend): super(GenericCertificate, self).__init__( module.params["path"], module.params["state"], @@ -311,7 +319,7 @@ class GenericCertificate(OpenSSLObject): self.module_backend = module_backend self.module_backend.set_existing(load_file_if_exists(self.path, module)) - def generate(self, module): + def generate(self, module: AnsibleModule) -> None: if self.module_backend.needs_regeneration(): if not self.check_mode: self.module_backend.generate_certificate() @@ -329,14 +337,14 @@ class GenericCertificate(OpenSSLObject): file_args, self.changed ) - def check(self, module, perms_required=True): + def check(self, module: AnsibleModule, perms_required: bool = True) -> bool: """Ensure the resource is in its desired state.""" return ( super(GenericCertificate, self).check(module, perms_required) and not self.module_backend.needs_regeneration() ) - def dump(self, check_mode=False): + def dump(self, check_mode: bool = False) -> dict[str, t.Any]: result = self.module_backend.dump(include_certificate=self.return_content) result.update( { @@ -349,7 +357,7 @@ class GenericCertificate(OpenSSLObject): return result -def main(): +def main() -> t.NoReturn: argument_spec = get_certificate_argument_spec() add_acme_provider_to_argument_spec(argument_spec) add_entrust_provider_to_argument_spec(argument_spec) @@ -363,13 +371,14 @@ def main(): return_content=dict(type="bool", default=False), ) ) - argument_spec.required_if.append(["state", "present", ["provider"]]) + argument_spec.required_if.append(("state", "present", ["provider"])) module = argument_spec.create_ansible_module( add_file_common_args=True, supports_check_mode=True, ) try: + certificate: GenericCertificate | CertificateAbsent if module.params["state"] == "absent": certificate = CertificateAbsent(module) @@ -389,7 +398,13 @@ def main(): ) provider = module.params["provider"] - provider_map = { + provider_map: dict[ + str, + type[AcmeCertificateProvider] + | type[EntrustCertificateProvider] + | type[OwnCACertificateProvider] + | type[SelfSignedCertificateProvider], + ] = { "acme": AcmeCertificateProvider, "entrust": EntrustCertificateProvider, "ownca": OwnCACertificateProvider, diff --git a/plugins/modules/x509_certificate_convert.py b/plugins/modules/x509_certificate_convert.py index e030e4cc..9c8d137e 100644 --- a/plugins/modules/x509_certificate_convert.py +++ b/plugins/modules/x509_certificate_convert.py @@ -106,6 +106,7 @@ backup_file: import base64 import os +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_bytes, to_text @@ -142,8 +143,12 @@ except ImportError: pass -def parse_certificate(input, strict=False): - input_format = "pem" if identify_pem_format(input) else "der" +def parse_certificate( + input: bytes, strict: bool = False +) -> tuple[bytes, t.Literal["pem", "der"], str | None]: + input_format: t.Literal["pem", "der"] = ( + "pem" if identify_pem_format(input) else "der" + ) if input_format == "pem": pems = split_pem_list(to_text(input)) if len(pems) > 1 and strict: @@ -162,7 +167,7 @@ def parse_certificate(input, strict=False): class X509CertificateConvertModule(OpenSSLObject): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(X509CertificateConvertModule, self).__init__( module.params["dest_path"], "present", @@ -170,9 +175,9 @@ class X509CertificateConvertModule(OpenSSLObject): module.check_mode, ) - self.src_path = module.params["src_path"] - self.src_content = module.params["src_content"] - self.src_content_base64 = module.params["src_content_base64"] + self.src_path: str | None = module.params["src_path"] + self.src_content: str | None = module.params["src_content"] + self.src_content_base64: bool = module.params["src_content_base64"] if self.src_content is not None: self.input = to_bytes(self.src_content) if self.src_content_base64: @@ -181,6 +186,8 @@ class X509CertificateConvertModule(OpenSSLObject): except Exception as exc: module.fail_json(msg=f"Cannot Base64 decode src_content: {exc}") else: + if self.src_path is None: + module.fail_json(msg="One of src_path and src_content must be provided") try: with open(self.src_path, "rb") as f: self.input = f.read() @@ -189,8 +196,8 @@ class X509CertificateConvertModule(OpenSSLObject): msg=f"Failure while reading file {self.src_path}: {exc}" ) - self.format = module.params["format"] - self.strict = module.params["strict"] + self.format: t.Literal["pem", "der"] = module.params["format"] + self.strict: bool = module.params["strict"] self.wanted_pem_type = "CERTIFICATE" try: @@ -203,8 +210,8 @@ class X509CertificateConvertModule(OpenSSLObject): if module.params["verify_cert_parsable"]: self.verify_cert_parsable(module) - self.backup = module.params["backup"] - self.backup_file = None + self.backup: bool = module.params["backup"] + self.backup_file: str | None = None module.params["path"] = self.path @@ -221,7 +228,7 @@ class X509CertificateConvertModule(OpenSSLObject): except Exception: pass - def verify_cert_parsable(self, module): + def verify_cert_parsable(self, module: AnsibleModule) -> None: assert_required_cryptography_version( module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION ) @@ -230,7 +237,7 @@ class X509CertificateConvertModule(OpenSSLObject): except Exception as exc: module.fail_json(msg=f"Error while parsing certificate: {exc}") - def needs_conversion(self): + def needs_conversion(self) -> bool: if self.dest_content is None or self.dest_content_format is None: return True if self.dest_content_format != self.format: @@ -241,7 +248,7 @@ class X509CertificateConvertModule(OpenSSLObject): return True return False - def get_dest_certificate(self): + def get_dest_certificate(self) -> bytes: if self.format == "der": return self.input data = to_bytes(base64.b64encode(self.input)) @@ -250,7 +257,7 @@ class X509CertificateConvertModule(OpenSSLObject): lines.append(to_bytes(f"{PEM_END_START}{self.wanted_pem_type}{PEM_END}\n")) return b"\n".join(lines) - def generate(self, module): + def generate(self, module: AnsibleModule) -> None: """Do conversion.""" if self.needs_conversion(): # Convert @@ -269,18 +276,18 @@ class X509CertificateConvertModule(OpenSSLObject): file_args, self.changed ) - def dump(self): + def dump(self) -> dict[str, t.Any]: """Serialize the object into a dictionary.""" - result = dict( - changed=self.changed, - ) + result: dict[str, t.Any] = { + "changed": self.changed, + } if self.backup_file: result["backup_file"] = self.backup_file return result -def main(): +def main() -> t.NoReturn: argument_spec = dict( src_path=dict(type="path"), src_content=dict(type="str"), diff --git a/plugins/modules/x509_certificate_info.py b/plugins/modules/x509_certificate_info.py index 6b417e8c..0ed1b442 100644 --- a/plugins/modules/x509_certificate_info.py +++ b/plugins/modules/x509_certificate_info.py @@ -390,8 +390,10 @@ issuer_uri: version_added: 2.9.0 """ +import typing as t from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.common.text.converters import to_text from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( OpenSSLObjectError, ) @@ -406,7 +408,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( ) -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( path=dict(type="path"), @@ -424,18 +426,22 @@ def main(): supports_check_mode=True, ) - if module.params["content"] is not None: - data = module.params["content"].encode("utf-8") + content: str | None = module.params["content"] + path: str | None = module.params["path"] + if content is not None: + data = content.encode("utf-8") else: + if path is None: + module.fail_json(msg="One of path and content must be provided") try: - with open(module.params["path"], "rb") as f: + with open(path, "rb") as f: data = f.read() except (IOError, OSError) as e: module.fail_json(msg=f"Error while reading certificate file from disk: {e}") module_backend = select_backend(module, data) - valid_at = module.params["valid_at"] + valid_at: dict[str, t.Any] = module.params["valid_at"] if valid_at: for k, v in valid_at.items(): if not isinstance(v, (str, bytes)): @@ -443,13 +449,11 @@ def main(): msg=f"The value for valid_at.{k} must be of type string (got {type(v)})" ) valid_at[k] = get_relative_time_option( - v, f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE + to_text(v), f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE ) try: - result = module_backend.get_info( - der_support_enabled=module.params["content"] is None - ) + result = module_backend.get_info(der_support_enabled=content is None) not_before = module_backend.get_not_before() not_after = module_backend.get_not_after() diff --git a/plugins/modules/x509_certificate_pipe.py b/plugins/modules/x509_certificate_pipe.py index 07bad567..f47951c3 100644 --- a/plugins/modules/x509_certificate_pipe.py +++ b/plugins/modules/x509_certificate_pipe.py @@ -118,6 +118,8 @@ certificate: type: str """ +import typing as t + from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( OpenSSLObjectError, ) @@ -139,23 +141,31 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac ) +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import ( + CertificateBackend, + ) + + class GenericCertificate: """Retrieve a certificate using the given module backend.""" - def __init__(self, module, module_backend): + def __init__(self, module: AnsibleModule, module_backend: CertificateBackend): self.check_mode = module.check_mode self.module = module self.module_backend = module_backend self.changed = False - if module.params["content"] is not None: - self.module_backend.set_existing(module.params["content"].encode("utf-8")) + content: str | None = module.params["content"] + if content is not None: + self.module_backend.set_existing(content.encode("utf-8")) - def generate(self, module): + def generate(self, module: AnsibleModule) -> None: if self.module_backend.needs_regeneration(): self.module_backend.generate_certificate() self.changed = True - def dump(self, check_mode=False): + def dump(self, check_mode: bool = False) -> dict[str, t.Any]: result = self.module_backend.dump(include_certificate=True) result.update( { @@ -165,7 +175,7 @@ class GenericCertificate: return result -def main(): +def main() -> t.NoReturn: argument_spec = get_certificate_argument_spec() argument_spec.argument_spec["provider"]["required"] = True add_entrust_provider_to_argument_spec(argument_spec) @@ -182,7 +192,12 @@ def main(): try: provider = module.params["provider"] - provider_map = { + provider_map: dict[ + str, + type[EntrustCertificateProvider] + | type[OwnCACertificateProvider] + | type[SelfSignedCertificateProvider], + ] = { "entrust": EntrustCertificateProvider, "ownca": OwnCACertificateProvider, "selfsigned": SelfSignedCertificateProvider, diff --git a/plugins/modules/x509_crl.py b/plugins/modules/x509_crl.py index a2b542d0..8b64c961 100644 --- a/plugins/modules/x509_crl.py +++ b/plugins/modules/x509_crl.py @@ -425,6 +425,7 @@ crl: import base64 import os +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_text @@ -463,7 +464,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import from ansible_collections.community.crypto.plugins.module_utils.crypto.support import ( OpenSSLObject, load_certificate, - load_privatekey, + load_certificate_issuer_privatekey, parse_name_field, parse_ordered_name_field, select_message_digest, @@ -495,6 +496,9 @@ try: except ImportError: pass +if t.TYPE_CHECKING: + import datetime + class CRLError(OpenSSLObjectError): pass @@ -502,7 +506,7 @@ class CRLError(OpenSSLObjectError): class CRL(OpenSSLObject): - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: super(CRL, self).__init__( module.params["path"], module.params["state"], @@ -510,53 +514,69 @@ class CRL(OpenSSLObject): module.check_mode, ) - self.format = module.params["format"] + self.format: t.Literal["pem", "der"] = module.params["format"] - self.update = module.params["crl_mode"] == "update" - self.ignore_timestamps = module.params["ignore_timestamps"] - self.return_content = module.params["return_content"] - self.name_encoding = module.params["name_encoding"] - self.serial_numbers_format = module.params["serial_numbers"] - self.crl_content = None + self.update: bool = module.params["crl_mode"] == "update" + self.ignore_timestamps: bool = module.params["ignore_timestamps"] + self.return_content: bool = module.params["return_content"] + self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params[ + "name_encoding" + ] + self.serial_numbers_format: t.Literal["integer", "hex-octets"] = module.params[ + "serial_numbers" + ] + self.crl_content: bytes | None = None - self.privatekey_path = module.params["privatekey_path"] - self.privatekey_content = module.params["privatekey_content"] - if self.privatekey_content is not None: - self.privatekey_content = self.privatekey_content.encode("utf-8") - self.privatekey_passphrase = module.params["privatekey_passphrase"] + self.privatekey_path: str | None = module.params["privatekey_path"] + privatekey_content: str | None = module.params["privatekey_content"] + if privatekey_content is not None: + self.privatekey_content: bytes | None = privatekey_content.encode("utf-8") + else: + self.privatekey_content = None + self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"] try: - if module.params["issuer_ordered"]: + issuer_ordered: list[dict[str, t.Any]] | None = module.params[ + "issuer_ordered" + ] + issuer: dict[str, list[str | bytes] | bytes | str] | None = module.params[ + "issuer" + ] + if issuer_ordered: self.issuer_ordered = True - self.issuer = parse_ordered_name_field( - module.params["issuer_ordered"], "issuer_ordered" - ) + self.issuer = parse_ordered_name_field(issuer_ordered, "issuer_ordered") else: self.issuer_ordered = False - self.issuer = parse_name_field(module.params["issuer"], "issuer") + self.issuer = ( + parse_name_field(issuer, "issuer") if issuer is not None else [] + ) except (TypeError, ValueError) as exc: module.fail_json(msg=str(exc)) - self.last_update = get_relative_time_option( + self.last_update: datetime.datetime = get_relative_time_option( module.params["last_update"], "last_update", with_timezone=CRYPTOGRAPHY_TIMEZONE, ) - self.next_update = get_relative_time_option( + self.next_update: datetime.datetime | None = get_relative_time_option( module.params["next_update"], "next_update", with_timezone=CRYPTOGRAPHY_TIMEZONE, ) - self.digest = select_message_digest(module.params["digest"]) - if self.digest is None: + digest = select_message_digest(module.params["digest"]) + if digest is None: raise CRLError(f'The digest "{module.params["digest"]}" is not supported') + self.digest = digest self.module = module self.revoked_certificates = [] - for i, rc in enumerate(module.params["revoked_certificates"]): - result = { + revoked_certificates: list[dict[str, t.Any]] = module.params[ + "revoked_certificates" + ] + for i, rc in enumerate(revoked_certificates): + result: dict[str, t.Any] = { "serial_number": None, "revocation_date": None, "issuer": None, @@ -567,21 +587,25 @@ class CRL(OpenSSLObject): "invalidity_date_critical": False, } path_prefix = f"revoked_certificates[{i}]." - if rc["path"] is not None or rc["content"] is not None: + path: str | None = rc["path"] + content_str: str | None = rc["content"] + if path is not None or content_str is not None: # Load certificate from file or content try: - if rc["content"] is not None: - rc["content"] = rc["content"].encode("utf-8") - cert = load_certificate(rc["path"], content=rc["content"]) + content: bytes | None = None + if content_str is not None: + content = content_str.encode("utf-8") + rc["content"] = content + cert = load_certificate(path, content=content) result["serial_number"] = cert.serial_number except OpenSSLObjectError as e: - if rc["content"] is not None: + if content_str is not None: module.fail_json( msg=f"Cannot parse certificate from {path_prefix}content: {e}" ) else: module.fail_json( - msg=f'Cannot read certificate "{rc["path"]}" from {path_prefix}path: {e}' + msg=f'Cannot read certificate "{path}" from {path_prefix}path: {e}' ) else: # Specify serial_number (and potentially issuer) directly @@ -611,11 +635,11 @@ class CRL(OpenSSLObject): result["invalidity_date_critical"] = rc["invalidity_date_critical"] self.revoked_certificates.append(result) - self.backup = module.params["backup"] - self.backup_file = None + self.backup: bool = module.params["backup"] + self.backup_file: str | None = None try: - self.privatekey = load_privatekey( + self.privatekey = load_certificate_issuer_privatekey( path=self.privatekey_path, content=self.privatekey_content, passphrase=self.privatekey_passphrase, @@ -643,7 +667,7 @@ class CRL(OpenSSLObject): self.diff_after = self.diff_before = self._get_info(data) - def _parse_serial_number(self, value, index): + def _parse_serial_number(self, value: t.Any, index: int) -> int: if self.serial_numbers_format == "integer": try: return check_type_int(value) @@ -662,22 +686,42 @@ class CRL(OpenSSLObject): f"Unexpected value {self.serial_numbers_format} of serial_numbers" ) - def _get_info(self, data): + def _get_info(self, data: bytes | None) -> dict[str, t.Any]: if data is None: - return dict() + return {} try: result = get_crl_info(self.module, data) result["can_parse_crl"] = True return result except Exception: - return dict(can_parse_crl=False) + return {"can_parse_crl": False} - def remove(self): + def remove(self, module: AnsibleModule) -> None: if self.backup: self.backup_file = self.module.backup_local(self.path) super(CRL, self).remove(self.module) - def _compress_entry(self, entry): + def _compress_entry(self, entry: dict[str, t.Any]) -> ( + tuple[ + int | None, + tuple[str, ...] | None, + bool, + int | None, + bool, + datetime.datetime | None, + bool, + ] + | tuple[ + int | None, + datetime.datetime | None, + tuple[str, ...] | None, + bool, + int | None, + bool, + datetime.datetime | None, + bool, + ] + ): issuer = None if entry["issuer"] is not None: # Normalize to IDNA. If this is used-provided, it was already converted to @@ -713,7 +757,12 @@ class CRL(OpenSSLObject): entry["invalidity_date_critical"], ) - def check(self, module, perms_required=True, ignore_conversion=True): + def check( + self, + module: AnsibleModule, + perms_required: bool = True, + ignore_conversion: bool = True, + ) -> bool: """Ensure the resource is in its desired state.""" state_and_perms = super(CRL, self).check(self.module, perms_required) @@ -743,10 +792,13 @@ class CRL(OpenSSLObject): ] is_issuer = [(sub.oid, sub.value) for sub in self.crl.issuer] if not self.issuer_ordered: - want_issuer = set(want_issuer) - is_issuer = set(is_issuer) - if want_issuer != is_issuer: - return False + want_issuer_set = set(want_issuer) + is_issuer_set = set(is_issuer) + if want_issuer_set != is_issuer_set: + return False + else: + if want_issuer != is_issuer: + return False old_entries = [ self._compress_entry(cryptography_decode_revoked_certificate(cert)) @@ -769,7 +821,7 @@ class CRL(OpenSSLObject): return True - def _generate_crl(self): + def _generate_crl(self) -> bytes: crl = CertificateRevocationListBuilder() try: @@ -787,7 +839,8 @@ class CRL(OpenSSLObject): raise CRLError(e) crl = set_last_update(crl, self.last_update) - crl = set_next_update(crl, self.next_update) + if self.next_update is not None: + crl = set_next_update(crl, self.next_update) if self.update and self.crl: new_entries = set( @@ -799,22 +852,26 @@ class CRL(OpenSSLObject): ) if decoded_entry not in new_entries: crl = crl.add_revoked_certificate(entry) - for entry in self.revoked_certificates: + for revoked_entry in self.revoked_certificates: revoked_cert = RevokedCertificateBuilder() - revoked_cert = revoked_cert.serial_number(entry["serial_number"]) - revoked_cert = set_revocation_date(revoked_cert, entry["revocation_date"]) - if entry["issuer"] is not None: + revoked_cert = revoked_cert.serial_number(revoked_entry["serial_number"]) + revoked_cert = set_revocation_date( + revoked_cert, revoked_entry["revocation_date"] + ) + if revoked_entry["issuer"] is not None: revoked_cert = revoked_cert.add_extension( - x509.CertificateIssuer(entry["issuer"]), entry["issuer_critical"] + x509.CertificateIssuer(revoked_entry["issuer"]), + revoked_entry["issuer_critical"], ) - if entry["reason"] is not None: + if revoked_entry["reason"] is not None: revoked_cert = revoked_cert.add_extension( - x509.CRLReason(entry["reason"]), entry["reason_critical"] + x509.CRLReason(revoked_entry["reason"]), + revoked_entry["reason_critical"], ) - if entry["invalidity_date"] is not None: + if revoked_entry["invalidity_date"] is not None: revoked_cert = revoked_cert.add_extension( - x509.InvalidityDate(entry["invalidity_date"]), - entry["invalidity_date_critical"], + x509.InvalidityDate(revoked_entry["invalidity_date"]), + revoked_entry["invalidity_date_critical"], ) crl = crl.add_revoked_certificate(revoked_cert.build()) @@ -827,7 +884,7 @@ class CRL(OpenSSLObject): else: return self.crl.public_bytes(Encoding.DER) - def generate(self): + def generate(self, module: AnsibleModule) -> None: result = None if ( not self.check(self.module, perms_required=False, ignore_conversion=True) @@ -861,7 +918,7 @@ class CRL(OpenSSLObject): elif self.module.set_fs_attributes_if_different(file_args, False): self.changed = True - def dump(self, check_mode=False): + def dump(self, check_mode: bool = False) -> dict[str, t.Any]: result = { "changed": self.changed, "filename": self.path, @@ -879,37 +936,53 @@ class CRL(OpenSSLObject): if check_mode: result["last_update"] = self.last_update.strftime(TIMESTAMP_FORMAT) - result["next_update"] = self.next_update.strftime(TIMESTAMP_FORMAT) + result["next_update"] = ( + self.next_update.strftime(TIMESTAMP_FORMAT) + if self.next_update is not None + else None + ) # result['digest'] = cryptography_oid_to_name(self.crl.signature_algorithm_oid) result["digest"] = self.module.params["digest"] result["issuer_ordered"] = self.issuer - result["issuer"] = {} + issuer: dict[str, str | bytes] = {} + result["issuer"] = issuer for k, v in self.issuer: - result["issuer"][k] = v - result["revoked_certificates"] = [] + issuer[k] = v + revoked_certificates: list[dict[str, t.Any]] = [] + result["revoked_certificates"] = revoked_certificates for entry in self.revoked_certificates: - result["revoked_certificates"].append( + revoked_certificates.append( cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding) ) elif self.crl: result["last_update"] = get_last_update(self.crl).strftime(TIMESTAMP_FORMAT) - result["next_update"] = get_next_update(self.crl).strftime(TIMESTAMP_FORMAT) + next_update = get_next_update(self.crl) + result["next_update"] = ( + next_update.strftime(TIMESTAMP_FORMAT) + if next_update is not None + else None + ) result["digest"] = cryptography_oid_to_name( cryptography_get_signature_algorithm_oid_from_crl(self.crl) ) - issuer = [] + issuer_list: list[list[str]] = [] for attribute in self.crl.issuer: - issuer.append( - [cryptography_oid_to_name(attribute.oid), attribute.value] + issuer_list.append( + [ + cryptography_oid_to_name(attribute.oid), + to_text(attribute.value), + ] ) - result["issuer_ordered"] = issuer - result["issuer"] = {} - for k, v in issuer: - result["issuer"][k] = v - result["revoked_certificates"] = [] + result["issuer_ordered"] = issuer_list + issuer = {} + result["issuer"] = issuer + for k, v in issuer_list: + issuer[k] = v + revoked_certificates = [] + result["revoked_certificates"] = revoked_certificates for cert in self.crl: entry = cryptography_decode_revoked_certificate(cert) - result["revoked_certificates"].append( + revoked_certificates.append( cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding) ) @@ -923,7 +996,7 @@ class CRL(OpenSSLObject): return result -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( state=dict(type="str", default="present", choices=["present", "absent"]), @@ -1015,14 +1088,14 @@ def main(): ) module.exit_json(**result) - crl.generate() + crl.generate(module) else: if module.check_mode: result = crl.dump(check_mode=True) result["changed"] = os.path.exists(module.params["path"]) module.exit_json(**result) - crl.remove() + crl.remove(module) result = crl.dump() module.exit_json(**result) diff --git a/plugins/modules/x509_crl_info.py b/plugins/modules/x509_crl_info.py index a9a43fde..f58918c1 100644 --- a/plugins/modules/x509_crl_info.py +++ b/plugins/modules/x509_crl_info.py @@ -100,7 +100,9 @@ last_update: type: str sample: '20190413202428Z' next_update: - description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME. + description: + - The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME. + - Will be C(none) if no such timestamp is present. returned: success type: str sample: '20190413202428Z' @@ -172,6 +174,7 @@ revoked_certificates: import base64 import binascii +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( @@ -185,7 +188,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import ) -def main(): +def main() -> t.NoReturn: module = AnsibleModule( argument_spec=dict( path=dict(type="path"), @@ -200,25 +203,30 @@ def main(): supports_check_mode=True, ) - if module.params["content"] is None: + content: str | None = module.params["content"] + path: str | None = module.params["path"] + if content is None: + if path is None: + module.fail_json(msg="One of content and path must be provided") try: - with open(module.params["path"], "rb") as f: + with open(path, "rb") as f: data = f.read() except (IOError, OSError) as e: module.fail_json(msg=f"Error while reading CRL file from disk: {e}") else: - data = module.params["content"].encode("utf-8") + data = content.encode("utf-8") if not identify_pem_format(data): try: - data = base64.b64decode(module.params["content"]) + data = base64.b64decode(content) except (binascii.Error, TypeError) as e: module.fail_json(msg=f"Error while Base64 decoding content: {e}") + list_revoked_certificates: bool = module.params["list_revoked_certificates"] try: result = get_crl_info( module, data, - list_revoked_certificates=module.params["list_revoked_certificates"], + list_revoked_certificates=list_revoked_certificates, ) module.exit_json(**result) except OpenSSLObjectError as e: diff --git a/plugins/plugin_utils/action_module.py b/plugins/plugin_utils/action_module.py index cc246fe4..9e2b3dd5 100644 --- a/plugins/plugin_utils/action_module.py +++ b/plugins/plugin_utils/action_module.py @@ -15,20 +15,24 @@ from __future__ import annotations import abc import copy import traceback +import typing as t from ansible.errors import AnsibleError from ansible.module_utils.basic import SEQUENCETYPE, remove_values from ansible.module_utils.common._collections_compat import Mapping from ansible.module_utils.common.arg_spec import ArgumentSpecValidator -from ansible.module_utils.common.validation import ( - safe_eval, -) from ansible.module_utils.errors import UnsupportedError from ansible.plugins.action import ActionBase +if t.TYPE_CHECKING: + from ansible_collections.community.crypto.plugins.module_utils.argspec import ( + ArgumentSpec, + ) + + class _ModuleExitException(Exception): - def __init__(self, result): + def __init__(self, result: dict[str, t.Any]) -> None: super(_ModuleExitException, self).__init__() self.result = result @@ -36,20 +40,21 @@ class _ModuleExitException(Exception): class AnsibleActionModule: def __init__( self, - action_plugin, - argument_spec, - bypass_checks=False, - mutually_exclusive=None, - required_together=None, - required_one_of=None, - supports_check_mode=False, - required_if=None, - required_by=None, - ): + action_plugin: ActionModuleBase, + argument_spec: dict[str, t.Any], + *, + bypass_checks: bool = False, + supports_check_mode: bool = False, + mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None, + required_together: list[list[str] | tuple[str, ...]] | None = None, + required_one_of: list[list[str] | tuple[str, ...]] | None = None, + required_if: list[tuple[str, t.Any, list[str] | tuple[str, ...]]] | None = None, + required_by: dict[str, tuple[str, ...] | list[str]] | None = None, + ) -> None: # Internal data self.__action_plugin = action_plugin - self.__warnings = [] - self.__deprecations = [] + self.__warnings: list[str] = [] + self.__deprecations: list[dict[str, str | None]] = [] # AnsibleModule data self._name = self.__action_plugin._task.action @@ -67,10 +72,6 @@ class AnsibleActionModule: self._diff = self.__action_plugin._play_context.diff self._verbosity = self.__action_plugin._display.verbosity - self.aliases = {} - self._legal_inputs = [] - self._options_context = list() - self.params = copy.deepcopy(self.__action_plugin._task.args) self.no_log_values = set() self._validator = ArgumentSpecValidator( @@ -122,38 +123,41 @@ class AnsibleActionModule: self.fail_json(msg=msg) - def safe_eval(self, value, locals=None, include_exceptions=False): - return safe_eval(value, locals, include_exceptions) - - def warn(self, warning): + def warn(self, warning: str) -> None: # Copied from ansible.module_utils.common.warnings: - if isinstance(warning, (str, bytes)): + if isinstance(warning, str): self.__warnings.append(warning) else: raise TypeError(f"warn requires a string not a {type(warning)}") - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: if version is not None and date is not None: raise AssertionError( "implementation error -- version and date must not both be set" ) # Copied from ansible.module_utils.common.warnings: - if isinstance(msg, (str, bytes)): - # For compatibility, we accept that neither version nor date is set, - # and treat that the same as if version would haven been set - if date is not None: - self.__deprecations.append( - {"msg": msg, "date": date, "collection_name": collection_name} - ) - else: - self.__deprecations.append( - {"msg": msg, "version": version, "collection_name": collection_name} - ) - else: + if not isinstance(msg, str): raise TypeError(f"deprecate requires a string not a {type(msg)}") - def _return_formatted(self, kwargs): + # For compatibility, we accept that neither version nor date is set, + # and treat that the same as if version would haven been set + if date is not None: + self.__deprecations.append( + {"msg": msg, "date": date, "collection_name": collection_name} + ) + else: + self.__deprecations.append( + {"msg": msg, "version": version, "collection_name": collection_name} + ) + + def _return_formatted(self, kwargs: dict[str, t.Any]) -> t.NoReturn: if "invocation" not in kwargs: kwargs["invocation"] = {"module_args": self.params} @@ -194,13 +198,13 @@ class AnsibleActionModule: kwargs = remove_values(kwargs, self.no_log_values) raise _ModuleExitException(kwargs) - def exit_json(self, **kwargs): + def exit_json(self, **kwargs) -> t.NoReturn: result = dict(kwargs) if "failed" not in result: result["failed"] = False self._return_formatted(result) - def fail_json(self, msg, **kwargs): + def fail_json(self, msg: str, **kwargs) -> t.NoReturn: result = dict(kwargs) result["failed"] = True result["msg"] = msg @@ -209,16 +213,15 @@ class AnsibleActionModule: class ActionModuleBase(ActionBase, metaclass=abc.ABCMeta): @abc.abstractmethod - def setup_module(self): + def setup_module(self) -> tuple[ArgumentSpec, dict[str, t.Any]]: """Return pair (ArgumentSpec, kwargs).""" - pass @abc.abstractmethod - def run_module(self, module): + def run_module(self, module: AnsibleActionModule) -> None: """Run module code""" module.fail_json(msg="Not implemented.") - def run(self, tmp=None, task_vars=None): + def run(self, tmp=None, task_vars=None) -> dict[str, t.Any]: if task_vars is None: task_vars = dict() diff --git a/plugins/plugin_utils/filter_module.py b/plugins/plugin_utils/filter_module.py index 9501590b..2b040e91 100644 --- a/plugins/plugin_utils/filter_module.py +++ b/plugins/plugin_utils/filter_module.py @@ -6,14 +6,23 @@ from __future__ import annotations +import typing as t + from ansible.errors import AnsibleFilterError +from ansible.utils.display import Display + + +_display = Display() class FilterModuleMock: - def __init__(self, params): + def __init__(self, params: dict[str, t.Any]) -> None: self.check_mode = True self.params = params self._diff = False - def fail_json(self, msg, **kwargs): + def fail_json(self, msg: str, **kwargs) -> t.NoReturn: raise AnsibleFilterError(msg) + + def warn(self, warning: str) -> None: + _display.warning(warning) diff --git a/plugins/plugin_utils/gnupg.py b/plugins/plugin_utils/gnupg.py index 1ac4cd86..a7d76f60 100644 --- a/plugins/plugin_utils/gnupg.py +++ b/plugins/plugin_utils/gnupg.py @@ -4,6 +4,7 @@ from __future__ import annotations +import typing as t from subprocess import PIPE, Popen from ansible.module_utils.common.process import get_bin_path @@ -15,7 +16,7 @@ from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import class PluginGPGRunner(GPGRunner): - def __init__(self, executable=None, cwd=None): + def __init__(self, executable: str | None = None, cwd: str | None = None) -> None: if executable is None: try: executable = get_bin_path("gpg") @@ -24,7 +25,9 @@ class PluginGPGRunner(GPGRunner): self.executable = executable self.cwd = cwd - def run_command(self, command, check_rc=True, data=None): + def run_command( + self, command: list[str], check_rc: bool = True, data: bytes | None = None + ) -> tuple[int, str, str]: """ Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``. @@ -41,12 +44,10 @@ class PluginGPGRunner(GPGRunner): command, shell=False, cwd=self.cwd, stdin=PIPE, stdout=PIPE, stderr=PIPE ) stdout, stderr = p.communicate(input=data) - stdout = to_native(stdout, errors="surrogate_or_replace") - stderr = to_native(stderr, errors="surrogate_or_replace") + stdout_n = to_native(stdout, errors="surrogate_or_replace") + stderr_n = to_native(stderr, errors="surrogate_or_replace") if check_rc and p.returncode != 0: - stdout_n = (to_native(stdout, errors="surrogate_or_replace"),) - stderr_n = (to_native(stderr, errors="surrogate_or_replace"),) raise GPGError( f'Running {" ".join(command)} yielded return code {p.returncode} with stdout: "{stdout_n}" and stderr: "{stderr_n}")' ) - return p.returncode, stdout, stderr + return t.cast(int, p.returncode), stdout_n, stderr_n diff --git a/tests/nox-config-flake8.ini b/tests/nox-config-flake8.ini index fa6ddf30..f483c0e6 100644 --- a/tests/nox-config-flake8.ini +++ b/tests/nox-config-flake8.ini @@ -6,7 +6,7 @@ extend-ignore = E203, E402, F401 count = true # TODO: decrease this to ~10 -max-complexity = 48 +max-complexity = 60 # black's max-line-length is 89, but it doesn't touch long string literals. # Since ansible-test's limit is 160, let's use that here. max-line-length = 160 diff --git a/tests/nox-config-mypy.ini b/tests/nox-config-mypy.ini new file mode 100644 index 00000000..7bcf5f72 --- /dev/null +++ b/tests/nox-config-mypy.ini @@ -0,0 +1,19 @@ +# Copyright (c) Ansible Project +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later + +[mypy] +# check_untyped_defs = True +# disallow_untyped_defs = True -- not yet feasible + +# strict = True -- only try to enable once everything is typed +strict_equality = True + +[mypy-ansible.*] +# ansible-core has no typing information +# ignore_missing_imports = True +follow_untyped_imports = True + +[mypy-ansible_collections.community.internal_test_tools.*] +# community.internal_test_tools has no typing information +ignore_missing_imports = True diff --git a/tests/sanity/ignore-2.17.txt b/tests/sanity/ignore-2.17.txt index 9ffe1e99..6a8af3b7 100644 --- a/tests/sanity/ignore-2.17.txt +++ b/tests/sanity/ignore-2.17.txt @@ -1,2 +1,27 @@ +plugins/module_utils/acme/account.py pep8:E704 +plugins/module_utils/acme/acme.py pep8:E704 +plugins/module_utils/acme/acme.py pylint:unpacking-non-sequence +plugins/module_utils/acme/backend_openssl_cli.py pep8:E704 +plugins/module_utils/acme/certificate.py pep8:E704 +plugins/module_utils/crypto/cryptography_support.py pep8:E704 +plugins/module_utils/crypto/module_backends/certificate.py no-assert +plugins/module_utils/crypto/module_backends/certificate_entrust.py no-assert +plugins/module_utils/crypto/module_backends/certificate_ownca.py no-assert +plugins/module_utils/crypto/module_backends/certificate_selfsigned.py no-assert +plugins/module_utils/crypto/module_backends/csr.py no-assert +plugins/module_utils/crypto/module_backends/privatekey_convert.py no-assert +plugins/module_utils/crypto/support.py pep8:E704 +plugins/module_utils/openssh/backends/keypair_backend.py no-assert +plugins/module_utils/openssh/certificate.py pep8:E704 +plugins/modules/acme_account.py pylint:unpacking-non-sequence +plugins/modules/acme_account_info.py pylint:unpacking-non-sequence +plugins/modules/acme_certificate.py pylint:unpacking-non-sequence +plugins/modules/acme_certificate.py no-assert +plugins/modules/acme_certificate_deactivate_authz.py pylint:unpacking-non-sequence +plugins/modules/acme_certificate_order_finalize.py pylint:unpacking-non-sequence +plugins/modules/acme_certificate_revoke.py pylint:unpacking-non-sequence +plugins/modules/acme_inspect.py pylint:unpacking-non-sequence +plugins/modules/luks_device.py no-assert +plugins/modules/openssl_pkcs12.py no-assert tests/ee/roles/smoke/library/smoke_ipaddress.py shebang tests/ee/roles/smoke/library/smoke_pyyaml.py shebang diff --git a/tests/sanity/ignore-2.18.txt b/tests/sanity/ignore-2.18.txt index 9ffe1e99..0cc09ff5 100644 --- a/tests/sanity/ignore-2.18.txt +++ b/tests/sanity/ignore-2.18.txt @@ -1,2 +1,19 @@ +plugins/module_utils/acme/account.py pep8:E704 +plugins/module_utils/acme/acme.py pep8:E704 +plugins/module_utils/acme/backend_openssl_cli.py pep8:E704 +plugins/module_utils/acme/certificate.py pep8:E704 +plugins/module_utils/crypto/cryptography_support.py pep8:E704 +plugins/module_utils/crypto/module_backends/certificate.py no-assert +plugins/module_utils/crypto/module_backends/certificate_entrust.py no-assert +plugins/module_utils/crypto/module_backends/certificate_ownca.py no-assert +plugins/module_utils/crypto/module_backends/certificate_selfsigned.py no-assert +plugins/module_utils/crypto/module_backends/csr.py no-assert +plugins/module_utils/crypto/module_backends/privatekey_convert.py no-assert +plugins/module_utils/crypto/support.py pep8:E704 +plugins/module_utils/openssh/backends/keypair_backend.py no-assert +plugins/module_utils/openssh/certificate.py pep8:E704 +plugins/modules/acme_certificate.py no-assert +plugins/modules/luks_device.py no-assert +plugins/modules/openssl_pkcs12.py no-assert tests/ee/roles/smoke/library/smoke_ipaddress.py shebang tests/ee/roles/smoke/library/smoke_pyyaml.py shebang diff --git a/tests/sanity/ignore-2.19.txt b/tests/sanity/ignore-2.19.txt index 9ffe1e99..c3ea832d 100644 --- a/tests/sanity/ignore-2.19.txt +++ b/tests/sanity/ignore-2.19.txt @@ -1,2 +1,12 @@ +plugins/module_utils/crypto/module_backends/certificate.py no-assert +plugins/module_utils/crypto/module_backends/certificate_entrust.py no-assert +plugins/module_utils/crypto/module_backends/certificate_ownca.py no-assert +plugins/module_utils/crypto/module_backends/certificate_selfsigned.py no-assert +plugins/module_utils/crypto/module_backends/csr.py no-assert +plugins/module_utils/crypto/module_backends/privatekey_convert.py no-assert +plugins/module_utils/openssh/backends/keypair_backend.py no-assert +plugins/modules/acme_certificate.py no-assert +plugins/modules/luks_device.py no-assert +plugins/modules/openssl_pkcs12.py no-assert tests/ee/roles/smoke/library/smoke_ipaddress.py shebang tests/ee/roles/smoke/library/smoke_pyyaml.py shebang diff --git a/tests/unit/plugins/module_utils/acme/backend_data.py b/tests/unit/plugins/module_utils/acme/backend_data.py index 4acad56b..de58022a 100644 --- a/tests/unit/plugins/module_utils/acme/backend_data.py +++ b/tests/unit/plugins/module_utils/acme/backend_data.py @@ -7,6 +7,7 @@ from __future__ import annotations import base64 import datetime import os +import typing as t from ansible_collections.community.crypto.plugins.module_utils.acme.backends import ( CertificateInformation, @@ -19,12 +20,20 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor from ..test_time import TIMEZONES, cartesian_product -def load_fixture(name): - with open(os.path.join(os.path.dirname(__file__), "fixtures", name)) as f: +if t.TYPE_CHECKING: + from ansible_collections.community.crypto.plugins.module_utils.acme.backends import ( + Criterium, + ) + + +def load_fixture(name: str) -> str: + with open( + os.path.join(os.path.dirname(__file__), "fixtures", name), encoding="utf-8" + ) as f: return f.read() -TEST_PEM_DERS = [ +TEST_PEM_DERS: list[tuple[str, bytes]] = [ ( load_fixture("privatekey_1.pem"), base64.b64decode( @@ -36,7 +45,7 @@ TEST_PEM_DERS = [ ] -TEST_KEYS = [ +TEST_KEYS: list[tuple[str, dict[str, t.Any], str]] = [ ( load_fixture("privatekey_1.pem"), { @@ -56,7 +65,7 @@ TEST_KEYS = [ ] -TEST_CSRS = [ +TEST_CSRS: list[tuple[str, set[tuple[str, str]], str]] = [ ( load_fixture("csr_1.pem"), set([("dns", "ansible.com"), ("dns", "example.com"), ("dns", "example.org")]), @@ -87,17 +96,19 @@ TEST_CERT_OPENSSL_OUTPUT_2 = load_fixture("cert_2.txt") # OpenSSL 3.3.0 output TEST_CERT_OPENSSL_OUTPUT_2B = load_fixture("cert_2-b.txt") # OpenSSL 1.1.1f output -TEST_CERT_DAYS = cartesian_product( - TIMEZONES, - [ - (datetime.datetime(2018, 11, 15, 1, 2, 3), 11), - (datetime.datetime(2018, 11, 25, 15, 20, 0), 1), - (datetime.datetime(2018, 11, 25, 15, 30, 0), 0), - ], +TEST_CERT_DAYS: list[tuple[datetime.timedelta, datetime.datetime, int]] = ( + cartesian_product( + TIMEZONES, + [ + (datetime.datetime(2018, 11, 15, 1, 2, 3), 11), + (datetime.datetime(2018, 11, 25, 15, 20, 0), 1), + (datetime.datetime(2018, 11, 25, 15, 30, 0), 0), + ], + ) ) -TEST_CERT_INFO = CertificateInformation( +TEST_CERT_INFO_1 = CertificateInformation( not_valid_after=datetime.datetime(2018, 11, 26, 15, 28, 24), not_valid_before=datetime.datetime(2018, 11, 25, 15, 28, 23), serial_number=1, @@ -115,65 +126,69 @@ TEST_CERT_INFO_2 = CertificateInformation( ) -TEST_CERT_INFO = [ - (TEST_CERT, TEST_CERT_INFO, TEST_CERT_OPENSSL_OUTPUT), +TEST_CERT_INFO: list[tuple[str, CertificateInformation, str]] = [ + (TEST_CERT, TEST_CERT_INFO_1, TEST_CERT_OPENSSL_OUTPUT), (TEST_CERT_2, TEST_CERT_INFO_2, TEST_CERT_OPENSSL_OUTPUT_2), (TEST_CERT_2, TEST_CERT_INFO_2, TEST_CERT_OPENSSL_OUTPUT_2B), ] -TEST_PARSE_ACME_TIMESTAMP = cartesian_product( - TIMEZONES, - [ - ( - "2024-01-01T00:11:22Z", - dict(year=2024, month=1, day=1, hour=0, minute=11, second=22), - ), - ( - "2024-01-01T00:11:22.123Z", - dict( - year=2024, - month=1, - day=1, - hour=0, - minute=11, - second=22, - microsecond=123000, +TEST_PARSE_ACME_TIMESTAMP: list[tuple[datetime.timedelta, str, dict[str, int]]] = ( + cartesian_product( + TIMEZONES, + [ + ( + "2024-01-01T00:11:22Z", + dict(year=2024, month=1, day=1, hour=0, minute=11, second=22), ), - ), - ( - "2024-04-17T06:54:13.333333334Z", - dict( - year=2024, - month=4, - day=17, - hour=6, - minute=54, - second=13, - microsecond=333333, + ( + "2024-01-01T00:11:22.123Z", + dict( + year=2024, + month=1, + day=1, + hour=0, + minute=11, + second=22, + microsecond=123000, + ), ), - ), - ( - "2024-01-01T00:11:22+0100", - dict(year=2023, month=12, day=31, hour=23, minute=11, second=22), - ), - ( - "2024-01-01T00:11:22.123+0100", - dict( - year=2023, - month=12, - day=31, - hour=23, - minute=11, - second=22, - microsecond=123000, + ( + "2024-04-17T06:54:13.333333334Z", + dict( + year=2024, + month=4, + day=17, + hour=6, + minute=54, + second=13, + microsecond=333333, + ), ), - ), - ], + ( + "2024-01-01T00:11:22+0100", + dict(year=2023, month=12, day=31, hour=23, minute=11, second=22), + ), + ( + "2024-01-01T00:11:22.123+0100", + dict( + year=2023, + month=12, + day=31, + hour=23, + minute=11, + second=22, + microsecond=123000, + ), + ), + ], + ) ) -TEST_INTERPOLATE_TIMESTAMP = cartesian_product( +TEST_INTERPOLATE_TIMESTAMP: list[ + tuple[datetime.timedelta, dict[str, int], dict[str, int], float, dict[str, int]] +] = cartesian_product( TIMEZONES, [ ( @@ -199,26 +214,50 @@ TEST_INTERPOLATE_TIMESTAMP = cartesian_product( class FakeBackend(CryptoBackend): - def parse_key(self, key_file=None, key_content=None, passphrase=None): + def parse_key( + self, + key_file: str | os.PathLike | None = None, + key_content: str | None = None, + passphrase=None, + ) -> t.NoReturn: raise BackendException("Not implemented in fake backend") - def sign(self, payload64, protected64, key_data): + def sign( + self, payload64: str, protected64: str, key_data: dict[str, t.Any] | None + ) -> t.NoReturn: raise BackendException("Not implemented in fake backend") - def create_mac_key(self, alg, key): + def create_mac_key(self, alg: str, key: str) -> t.NoReturn: raise BackendException("Not implemented in fake backend") - def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None): + def get_ordered_csr_identifiers( + self, + csr_filename: str | os.PathLike | None = None, + csr_content: str | bytes | None = None, + ) -> t.NoReturn: raise BackendException("Not implemented in fake backend") - def get_csr_identifiers(self, csr_filename=None, csr_content=None): + def get_csr_identifiers( + self, + csr_filename: str | os.PathLike | None = None, + csr_content: str | bytes | None = None, + ) -> t.NoReturn: raise BackendException("Not implemented in fake backend") - def get_cert_days(self, cert_filename=None, cert_content=None, now=None): + def get_cert_days( + self, + cert_filename: str | os.PathLike | None = None, + cert_content: str | bytes | None = None, + now: datetime.datetime | None = None, + ) -> t.NoReturn: raise BackendException("Not implemented in fake backend") - def create_chain_matcher(self, criterium): + def create_chain_matcher(self, criterium: Criterium) -> t.NoReturn: raise BackendException("Not implemented in fake backend") - def get_cert_information(self, cert_filename=None, cert_content=None): + def get_cert_information( + self, + cert_filename: str | os.PathLike | None = None, + cert_content: str | bytes | None = None, + ) -> t.NoReturn: raise BackendException("Not implemented in fake backend") diff --git a/tests/unit/plugins/module_utils/acme/test_backend_cryptography.py b/tests/unit/plugins/module_utils/acme/test_backend_cryptography.py index 37c273fe..14bc9cb1 100644 --- a/tests/unit/plugins/module_utils/acme/test_backend_cryptography.py +++ b/tests/unit/plugins/module_utils/acme/test_backend_cryptography.py @@ -5,6 +5,7 @@ from __future__ import annotations import datetime +import typing as t from unittest.mock import ( MagicMock, ) @@ -35,12 +36,20 @@ from .backend_data import ( ) +if t.TYPE_CHECKING: + from ansible_collections.community.crypto.plugins.module_utils.acme.backends import ( + CertificateInformation, + ) + + if not HAS_CURRENT_CRYPTOGRAPHY: pytest.skip("cryptography not found") @pytest.mark.parametrize("pem, result, dummy", TEST_KEYS) -def test_eckeyparse_cryptography(pem, result, dummy, tmpdir): +def test_eckeyparse_cryptography( + pem: str, result: dict[str, t.Any], dummy: str, tmpdir +) -> None: fn = tmpdir / "test.pem" fn.write(pem) module = MagicMock() @@ -54,7 +63,9 @@ def test_eckeyparse_cryptography(pem, result, dummy, tmpdir): @pytest.mark.parametrize("csr, result, openssl_output", TEST_CSRS) -def test_csridentifiers_cryptography(csr, result, openssl_output, tmpdir): +def test_csridentifiers_cryptography( + csr: str, result: set[tuple[str, str]], openssl_output: str, tmpdir +) -> None: fn = tmpdir / "test.csr" fn.write(csr) module = MagicMock() @@ -66,7 +77,9 @@ def test_csridentifiers_cryptography(csr, result, openssl_output, tmpdir): @pytest.mark.parametrize("timezone, now, expected_days", TEST_CERT_DAYS) -def test_certdays_cryptography(timezone, now, expected_days, tmpdir): +def test_certdays_cryptography( + timezone: datetime.timedelta, now: datetime.datetime, expected_days: int, tmpdir +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): fn = tmpdir / "test-cert.pem" fn.write(TEST_CERT) @@ -81,7 +94,12 @@ def test_certdays_cryptography(timezone, now, expected_days, tmpdir): @pytest.mark.parametrize( "cert_content, expected_cert_info, openssl_output", TEST_CERT_INFO ) -def test_get_cert_information(cert_content, expected_cert_info, openssl_output, tmpdir): +def test_get_cert_information( + cert_content: str, + expected_cert_info: CertificateInformation, + openssl_output: str, + tmpdir, +) -> None: fn = tmpdir / "test-cert.pem" fn.write(cert_content) module = MagicMock() @@ -105,7 +123,7 @@ def test_get_cert_information(cert_content, expected_cert_info, openssl_output, @pytest.mark.parametrize( "timezone", [datetime.timedelta(hours=0)] if CRYPTOGRAPHY_TIMEZONE else TIMEZONES ) -def test_now(timezone): +def test_now(timezone: datetime.timedelta) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): module = MagicMock() backend = CryptographyBackend(module) @@ -119,7 +137,9 @@ def test_now(timezone): @pytest.mark.parametrize("timezone, input, expected", TEST_PARSE_ACME_TIMESTAMP) -def test_parse_acme_timestamp(timezone, input, expected): +def test_parse_acme_timestamp( + timezone: datetime.timedelta, input: str, expected: dict[str, int] +) -> None: with freeze_time("2024-02-03 04:05:06 +00:00", tz_offset=timezone): module = MagicMock() backend = CryptographyBackend(module) @@ -131,7 +151,13 @@ def test_parse_acme_timestamp(timezone, input, expected): @pytest.mark.parametrize( "timezone, start, end, percentage, expected", TEST_INTERPOLATE_TIMESTAMP ) -def test_interpolate_timestamp(timezone, start, end, percentage, expected): +def test_interpolate_timestamp( + timezone: datetime.timedelta, + start: dict[str, int], + end: dict[str, int], + percentage: float, + expected: dict[str, int], +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): module = MagicMock() backend = CryptographyBackend(module) diff --git a/tests/unit/plugins/module_utils/acme/test_backend_openssl_cli.py b/tests/unit/plugins/module_utils/acme/test_backend_openssl_cli.py index 690a0a8e..9de4bb09 100644 --- a/tests/unit/plugins/module_utils/acme/test_backend_openssl_cli.py +++ b/tests/unit/plugins/module_utils/acme/test_backend_openssl_cli.py @@ -5,6 +5,7 @@ from __future__ import annotations import datetime +import typing as t from unittest.mock import ( MagicMock, ) @@ -31,6 +32,12 @@ from .backend_data import ( ) +if t.TYPE_CHECKING: + from ansible_collections.community.crypto.plugins.module_utils.acme.backends import ( + CertificateInformation, + ) + + # from ..test_time import TIMEZONES @@ -47,7 +54,9 @@ TEST_IPS = [ @pytest.mark.parametrize("pem, result, openssl_output", TEST_KEYS) -def test_eckeyparse_openssl(pem, result, openssl_output, tmpdir): +def test_eckeyparse_openssl( + pem: str, result: dict[str, t.Any], openssl_output: str, tmpdir +) -> None: fn = tmpdir / "test.key" fn.write(pem) module = MagicMock() @@ -59,7 +68,9 @@ def test_eckeyparse_openssl(pem, result, openssl_output, tmpdir): @pytest.mark.parametrize("csr, result, openssl_output", TEST_CSRS) -def test_csridentifiers_openssl(csr, result, openssl_output, tmpdir): +def test_csridentifiers_openssl( + csr: str, result: set[tuple[str, str]], openssl_output: str, tmpdir +) -> None: fn = tmpdir / "test.csr" fn.write(csr) module = MagicMock() @@ -70,14 +81,16 @@ def test_csridentifiers_openssl(csr, result, openssl_output, tmpdir): @pytest.mark.parametrize("ip, result", TEST_IPS) -def test_normalize_ip(ip, result): +def test_normalize_ip(ip: str, result: str) -> None: module = MagicMock() backend = OpenSSLCLIBackend(module, openssl_binary="openssl") assert backend._normalize_ip(ip) == result @pytest.mark.parametrize("timezone, now, expected_days", TEST_CERT_DAYS) -def test_certdays_cryptography(timezone, now, expected_days, tmpdir): +def test_certdays_cryptography( + timezone: datetime.timedelta, now: datetime.datetime, expected_days: int, tmpdir +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): fn = tmpdir / "test-cert.pem" fn.write(TEST_CERT) @@ -93,7 +106,12 @@ def test_certdays_cryptography(timezone, now, expected_days, tmpdir): @pytest.mark.parametrize( "cert_content, expected_cert_info, openssl_output", TEST_CERT_INFO ) -def test_get_cert_information(cert_content, expected_cert_info, openssl_output, tmpdir): +def test_get_cert_information( + cert_content: str, + expected_cert_info: CertificateInformation, + openssl_output: str, + tmpdir, +) -> None: fn = tmpdir / "test-cert.pem" fn.write(cert_content) module = MagicMock() @@ -115,7 +133,7 @@ def test_get_cert_information(cert_content, expected_cert_info, openssl_output, # Due to a bug in freezegun (https://github.com/spulec/freezegun/issues/348, https://github.com/spulec/freezegun/issues/553) # this only works with timezone = UTC if CRYPTOGRAPHY_TIMEZONE is truish @pytest.mark.parametrize("timezone", [datetime.timedelta(hours=0)]) -def test_now(timezone): +def test_now(timezone: datetime.timedelta) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): module = MagicMock() backend = OpenSSLCLIBackend(module, openssl_binary="openssl") @@ -125,7 +143,9 @@ def test_now(timezone): @pytest.mark.parametrize("timezone, input, expected", TEST_PARSE_ACME_TIMESTAMP) -def test_parse_acme_timestamp(timezone, input, expected): +def test_parse_acme_timestamp( + timezone: datetime.timedelta, input: str, expected: dict[str, int] +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): module = MagicMock() backend = OpenSSLCLIBackend(module, openssl_binary="openssl") @@ -137,7 +157,13 @@ def test_parse_acme_timestamp(timezone, input, expected): @pytest.mark.parametrize( "timezone, start, end, percentage, expected", TEST_INTERPOLATE_TIMESTAMP ) -def test_interpolate_timestamp(timezone, start, end, percentage, expected): +def test_interpolate_timestamp( + timezone: datetime.timedelta, + start: dict[str, int], + end: dict[str, int], + percentage: float, + expected: dict[str, int], +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): module = MagicMock() backend = OpenSSLCLIBackend(module, openssl_binary="openssl") diff --git a/tests/unit/plugins/module_utils/acme/test_challenges.py b/tests/unit/plugins/module_utils/acme/test_challenges.py index 6d1acd6e..42ed706e 100644 --- a/tests/unit/plugins/module_utils/acme/test_challenges.py +++ b/tests/unit/plugins/module_utils/acme/test_challenges.py @@ -4,6 +4,7 @@ from __future__ import annotations +import typing as t from unittest.mock import ( MagicMock, ) @@ -21,21 +22,21 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor ) -def test_combine_identifier(): +def test_combine_identifier() -> None: assert combine_identifier("", "") == ":" assert combine_identifier("a", "b") == "a:b" -def test_split_identifier(): - assert split_identifier(":") == ["", ""] - assert split_identifier("a:b") == ["a", "b"] - assert split_identifier("a:b:c") == ["a", "b:c"] +def test_split_identifier() -> None: + assert split_identifier(":") == ("", "") + assert split_identifier("a:b") == ("a", "b") + assert split_identifier("a:b:c") == ("a", "b:c") with pytest.raises(ModuleFailException) as exc: split_identifier("a") assert exc.value.msg == 'Identifier "a" is not of the form :' -def test_challenge_from_to_json(): +def test_challenge_from_to_json() -> None: client = MagicMock() data = { @@ -57,7 +58,7 @@ def test_challenge_from_to_json(): "status": "valid", "token": "foo", } - challenge = Challenge.from_json(None, data, url="xxx") + challenge = Challenge.from_json(None, data, url="xxx") # type: ignore assert challenge.data == data assert challenge.type == "type" assert challenge.url == "xxx" @@ -66,10 +67,12 @@ def test_challenge_from_to_json(): assert challenge.to_json() == data -def test_authorization_from_to_json(): +def test_authorization_from_to_json() -> None: client = MagicMock() client.version = 2 + data: dict[str, t.Any] + data = { "challenges": [], "status": "valid", @@ -138,7 +141,7 @@ def test_authorization_from_to_json(): } -def test_authorization_create_error(): +def test_authorization_create_error() -> None: client = MagicMock() client.version = 2 client.directory.directory = {} @@ -148,7 +151,7 @@ def test_authorization_create_error(): assert exc.value.msg == "ACME endpoint does not support pre-authorization." -def test_wait_for_validation_error(): +def test_wait_for_validation_error() -> None: client = MagicMock() client.version = 2 data = { diff --git a/tests/unit/plugins/module_utils/acme/test_errors.py b/tests/unit/plugins/module_utils/acme/test_errors.py index be496073..bcc84e7e 100644 --- a/tests/unit/plugins/module_utils/acme/test_errors.py +++ b/tests/unit/plugins/module_utils/acme/test_errors.py @@ -4,6 +4,7 @@ from __future__ import annotations +import typing as t from unittest.mock import ( MagicMock, ) @@ -15,7 +16,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor ) -TEST_FORMAT_ERROR_PROBLEM = [ +TEST_FORMAT_ERROR_PROBLEM: list[tuple[dict[str, t.Any], str, str]] = [ ( { "type": "foo", @@ -90,33 +91,37 @@ TEST_FORMAT_ERROR_PROBLEM = [ @pytest.mark.parametrize( "problem, subproblem_prefix, result", TEST_FORMAT_ERROR_PROBLEM ) -def test_format_error_problem(problem, subproblem_prefix, result): +def test_format_error_problem( + problem: dict[str, t.Any], subproblem_prefix: str, result: str +) -> None: res = format_error_problem(problem, subproblem_prefix) assert res == result -def create_regular_response(response_text): +def create_regular_response(response_text: str) -> MagicMock: response = MagicMock() response.read = MagicMock(return_value=response_text.encode("utf-8")) response.closed = False return response -def create_error_response(): +def create_error_response() -> MagicMock: response = MagicMock() response.read = MagicMock(side_effect=AttributeError("read")) response.closed = True return response -def create_decode_error(msg): - def f(content): +def create_decode_error(msg: str) -> t.Callable[[t.Any], t.Any]: + def f(content: t.Any) -> t.NoReturn: raise Exception(msg) return f -TEST_ACME_PROTOCOL_EXCEPTION = [ +TEST_ACME_PROTOCOL_EXCEPTION: list[ + tuple[dict[str, t.Any], t.Callable[[t.Any], t.Any] | None, str, dict[str, t.Any]] +] = [ ( {}, None, @@ -341,14 +346,19 @@ TEST_ACME_PROTOCOL_EXCEPTION = [ @pytest.mark.parametrize("input, from_json, msg, args", TEST_ACME_PROTOCOL_EXCEPTION) -def test_acme_protocol_exception(input, from_json, msg, args): +def test_acme_protocol_exception( + input: dict[str, t.Any], + from_json: t.Callable[[t.Any], t.NoReturn] | None, + msg: str, + args: dict[str, t.Any], +) -> None: if from_json is None: module = None else: module = MagicMock() module.from_json = from_json with pytest.raises(ACMEProtocolException) as exc: - raise ACMEProtocolException(module, **input) + raise ACMEProtocolException(module, **input) # type: ignore print(exc.value.msg) print(exc.value.module_fail_args) diff --git a/tests/unit/plugins/module_utils/acme/test_io.py b/tests/unit/plugins/module_utils/acme/test_io.py index 7cace31f..522814de 100644 --- a/tests/unit/plugins/module_utils/acme/test_io.py +++ b/tests/unit/plugins/module_utils/acme/test_io.py @@ -18,14 +18,13 @@ TEST_TEXT = r"""1234 5678""" -def test_read_file(tmpdir): +def test_read_file(tmpdir) -> None: fn = tmpdir / "test.txt" fn.write(TEST_TEXT) - assert read_file(str(fn), "t") == TEST_TEXT - assert read_file(str(fn), "b") == TEST_TEXT.encode("utf-8") + assert read_file(str(fn)) == TEST_TEXT.encode("utf-8") -def test_write_file(tmpdir): +def test_write_file(tmpdir) -> None: fn = tmpdir / "test.txt" module = MagicMock() write_file(module, str(fn), TEST_TEXT.encode("utf-8")) diff --git a/tests/unit/plugins/module_utils/acme/test_orders.py b/tests/unit/plugins/module_utils/acme/test_orders.py index 8f6a72a7..eb9a3978 100644 --- a/tests/unit/plugins/module_utils/acme/test_orders.py +++ b/tests/unit/plugins/module_utils/acme/test_orders.py @@ -15,7 +15,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor from ansible_collections.community.crypto.plugins.module_utils.acme.orders import Order -def test_order_from_json(): +def test_order_from_json() -> None: client = MagicMock() data = { @@ -35,7 +35,7 @@ def test_order_from_json(): assert order.authorizations == {} -def test_wait_for_finalization_error(): +def test_wait_for_finalization_error() -> None: client = MagicMock() client.version = 2 diff --git a/tests/unit/plugins/module_utils/acme/test_utils.py b/tests/unit/plugins/module_utils/acme/test_utils.py index 58dc1a9c..361bb6b7 100644 --- a/tests/unit/plugins/module_utils/acme/test_utils.py +++ b/tests/unit/plugins/module_utils/acme/test_utils.py @@ -5,10 +5,12 @@ from __future__ import annotations import datetime +import typing as t import pytest from ansible_collections.community.crypto.plugins.module_utils.acme.backends import ( CertificateInformation, + CryptoBackend, ) from ansible_collections.community.crypto.plugins.module_utils.acme.utils import ( compute_cert_id, @@ -21,7 +23,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import from .backend_data import TEST_PEM_DERS -NOPAD_B64 = [ +NOPAD_B64: list[tuple[str, str]] = [ ("", ""), ("\n", "Cg"), ("123", "MTIz"), @@ -29,7 +31,7 @@ NOPAD_B64 = [ ] -TEST_LINKS_HEADER = [ +TEST_LINKS_HEADER: list[tuple[dict[str, t.Any], list[tuple[str, str]]]] = [ ( {}, [], @@ -60,13 +62,13 @@ TEST_LINKS_HEADER = [ ] -TEST_RETRY_AFTER_HEADER = [ +TEST_RETRY_AFTER_HEADER: list[tuple[str, datetime.datetime]] = [ ("120", datetime.datetime(2024, 4, 29, 0, 2, 0)), ("Wed, 21 Oct 2015 07:28:00 GMT", datetime.datetime(2015, 10, 21, 7, 28, 0)), ] -TEST_COMPUTE_CERT_ID = [ +TEST_COMPUTE_CERT_ID: list[tuple[CertificateInformation, str]] = [ ( CertificateInformation( not_valid_after=datetime.datetime(2018, 11, 26, 15, 28, 24), @@ -93,19 +95,21 @@ TEST_COMPUTE_CERT_ID = [ @pytest.mark.parametrize("value, result", NOPAD_B64) -def test_nopad_b64(value, result): +def test_nopad_b64(value: str, result: str) -> None: assert nopad_b64(value.encode("utf-8")) == result @pytest.mark.parametrize("pem, der", TEST_PEM_DERS) -def test_pem_to_der(pem, der, tmpdir): +def test_pem_to_der(pem: str, der: bytes, tmpdir): fn = tmpdir / "test.pem" fn.write(pem) assert pem_to_der(str(fn)) == der @pytest.mark.parametrize("value, expected_result", TEST_LINKS_HEADER) -def test_process_links(value, expected_result): +def test_process_links( + value: dict[str, t.Any], expected_result: list[tuple[str, str]] +) -> None: data = [] def callback(url, rel): @@ -117,12 +121,15 @@ def test_process_links(value, expected_result): @pytest.mark.parametrize("value, expected_result", TEST_RETRY_AFTER_HEADER) -def test_parse_retry_after(value, expected_result): +def test_parse_retry_after(value: str, expected_result: datetime.datetime) -> None: assert expected_result == parse_retry_after( value, now=datetime.datetime(2024, 4, 29, 0, 0, 0) ) @pytest.mark.parametrize("cert_info, expected_result", TEST_COMPUTE_CERT_ID) -def test_compute_cert_id(cert_info, expected_result): - assert expected_result == compute_cert_id(backend=None, cert_info=cert_info) +def test_compute_cert_id( + cert_info: CertificateInformation, expected_result: str +) -> None: + backend: CryptoBackend = None # type: ignore + assert expected_result == compute_cert_id(backend=backend, cert_info=cert_info) diff --git a/tests/unit/plugins/module_utils/crypto/test_asn1.py b/tests/unit/plugins/module_utils/crypto/test_asn1.py index 74e4fc04..2093cfaf 100644 --- a/tests/unit/plugins/module_utils/crypto/test_asn1.py +++ b/tests/unit/plugins/module_utils/crypto/test_asn1.py @@ -10,12 +10,11 @@ import subprocess import pytest from ansible_collections.community.crypto.plugins.module_utils.crypto._asn1 import ( - pack_asn1, serialize_asn1_string_as_der, ) -TEST_CASES = [ +TEST_CASES: list[tuple[str, bytes]] = [ ("UTF8:Hello World", b"\x0c\x0b\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64"), ( "EXPLICIT:10,UTF8:Hello World", @@ -76,7 +75,7 @@ TEST_CASES = [ @pytest.mark.parametrize("value, expected", TEST_CASES) -def test_serialize_asn1_string_as_der(value, expected): +def test_serialize_asn1_string_as_der(value: str, expected: bytes) -> None: actual = serialize_asn1_string_as_der(value) print(f"{value} | {base64.b16encode(actual).decode()}") assert actual == expected @@ -89,7 +88,7 @@ def test_serialize_asn1_string_as_der(value, expected): "EXPLICIT,UTF:value", ], ) -def test_serialize_asn1_string_as_der_invalid_format(value): +def test_serialize_asn1_string_as_der_invalid_format(value: str) -> None: expected = ( "The ASN.1 serialized string must be in the format [modifier,]type[:value]" ) @@ -97,20 +96,15 @@ def test_serialize_asn1_string_as_der_invalid_format(value): serialize_asn1_string_as_der(value) -def test_serialize_asn1_string_as_der_invalid_type(): +def test_serialize_asn1_string_as_der_invalid_type() -> None: expected = 'The ASN.1 serialized string is not a known type "OID", only UTF8 types are supported' with pytest.raises(ValueError, match=re.escape(expected)): serialize_asn1_string_as_der("OID:1.2.3.4") -def test_pack_asn_invalid_class(): - with pytest.raises(ValueError, match="tag_class must be between 0 and 3 not 4"): - pack_asn1(4, True, 0, b"") - - @pytest.mark.skip() # This is to just to build the test case assertions and shouldn't run normally. @pytest.mark.parametrize("value, expected", TEST_CASES) -def test_test_cases(value, expected, tmp_path): +def test_test_cases(value: str, expected: bytes, tmp_path) -> None: test_file = tmp_path / "test.der" subprocess.run( ["openssl", "asn1parse", "-genstr", value, "-noout", "-out", test_file], diff --git a/tests/unit/plugins/module_utils/crypto/test_cryptography_support.py b/tests/unit/plugins/module_utils/crypto/test_cryptography_support.py index 6e31e0cc..ef9df834 100644 --- a/tests/unit/plugins/module_utils/crypto/test_cryptography_support.py +++ b/tests/unit/plugins/module_utils/crypto/test_cryptography_support.py @@ -5,6 +5,7 @@ from __future__ import annotations import re +import typing as t import cryptography import pytest @@ -20,7 +21,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp from ansible_collections.community.crypto.plugins.module_utils.version import ( LooseVersion, ) -from cryptography.x509 import NameAttribute, oid +from cryptography.x509 import NameAttribute, OtherName, oid @pytest.mark.parametrize( @@ -35,7 +36,7 @@ from cryptography.x509 import NameAttribute, oid ("*.☺.", "*.xn--74h.", None), ], ) -def test_adjust_idn(unicode, idna, cycled_unicode): +def test_adjust_idn(unicode: str, idna: str, cycled_unicode: str | None) -> None: if cycled_unicode is None: cycled_unicode = unicode @@ -70,9 +71,10 @@ def test_adjust_idn(unicode, idna, cycled_unicode): ("bar", "foo", re.escape('Invalid value for idn_rewrite: "foo"')), ], ) -def test_adjust_idn_fail_valueerror(value, idn_rewrite, message): +def test_adjust_idn_fail_valueerror(value: str, idn_rewrite: str, message: str) -> None: with pytest.raises(ValueError, match=message): - _adjust_idn(value, idn_rewrite) + idn_rewrite_: t.Literal["ignore", "idna", "unicode"] = idn_rewrite # type: ignore + _adjust_idn(value, idn_rewrite_) @pytest.mark.parametrize( @@ -88,27 +90,29 @@ def test_adjust_idn_fail_valueerror(value, idn_rewrite, message): ), ], ) -def test_adjust_idn_fail_user_error(value, idn_rewrite, message): +def test_adjust_idn_fail_user_error(value: str, idn_rewrite: str, message: str) -> None: with pytest.raises(OpenSSLObjectError, match=message): - _adjust_idn(value, idn_rewrite) + idn_rewrite_: t.Literal["ignore", "idna", "unicode"] = idn_rewrite # type: ignore + _adjust_idn(value, idn_rewrite_) -def test_cryptography_get_name_invalid_prefix(): +def test_cryptography_get_name_invalid_prefix() -> None: with pytest.raises( OpenSSLObjectError, match="^Cannot parse Subject Alternative Name" ): cryptography_get_name("fake:value") -def test_cryptography_get_name_other_name_no_oid(): +def test_cryptography_get_name_other_name_no_oid() -> None: with pytest.raises( OpenSSLObjectError, match="Cannot parse Subject Alternative Name otherName" ): cryptography_get_name("otherName:value") -def test_cryptography_get_name_other_name_utfstring(): +def test_cryptography_get_name_other_name_utfstring() -> None: actual = cryptography_get_name("otherName:1.3.6.1.4.1.311.20.2.3;UTF8:Hello World") + assert isinstance(actual, OtherName) assert actual.type_id.dotted_string == "1.3.6.1.4.1.311.20.2.3" assert actual.value == b"\x0c\x0bHello World" @@ -164,7 +168,9 @@ def test_cryptography_get_name_other_name_utfstring(): ), ], ) -def test_parse_dn_component(name, options, expected): +def test_parse_dn_component( + name: bytes, options: dict[str, t.Any], expected: tuple[NameAttribute, bytes] +) -> None: result = _parse_dn_component(name, **options) print(result, expected) assert result == expected @@ -186,7 +192,9 @@ if ( (b"CN= ", {}, (NameAttribute(oid.NameOID.COMMON_NAME, ""), b"")), ], ) - def test_parse_dn_component_not_py26(name, options, expected): + def test_parse_dn_component_not_py26( + name: bytes, options: dict[str, t.Any], expected: tuple[NameAttribute, bytes] + ) -> None: result = _parse_dn_component(name, **options) print(result, expected) assert result == expected @@ -200,7 +208,9 @@ if ( (b"CN=#0,", {}, 'Invalid hex sequence entry "0,"'), ], ) -def test_parse_dn_component_failure(name, options, message): +def test_parse_dn_component_failure( + name: bytes, options: dict[str, t.Any], message: str +) -> None: with pytest.raises(OpenSSLObjectError, match=f"^{re.escape(message)}$"): _parse_dn_component(name, **options) @@ -225,7 +235,7 @@ def test_parse_dn_component_failure(name, options, message): ), ], ) -def test_parse_dn(name, expected): +def test_parse_dn(name: bytes, expected: list[NameAttribute]) -> None: result = _parse_dn(name) print(result, expected) assert result == expected @@ -236,14 +246,14 @@ def test_parse_dn(name, expected): [ ( b"CN=\\0", - 'Error while parsing distinguished name "CN=\\0": Hex escape sequence "\\0" incomplete at end of string', + "Error while parsing distinguished name 'CN=\\\\0': Hex escape sequence \"\\0\" incomplete at end of string", ), ( b"CN=x,", - 'Error while parsing distinguished name "CN=x,": unexpected end of string', + "Error while parsing distinguished name 'CN=x,': unexpected end of string", ), ], ) -def test_parse_dn_failure(name, message): +def test_parse_dn_failure(name: bytes, message: str): with pytest.raises(OpenSSLObjectError, match=f"^{re.escape(message)}$"): _parse_dn(name) diff --git a/tests/unit/plugins/module_utils/crypto/test_math.py b/tests/unit/plugins/module_utils/crypto/test_math.py index 35c86a28..6ee34513 100644 --- a/tests/unit/plugins/module_utils/crypto/test_math.py +++ b/tests/unit/plugins/module_utils/crypto/test_math.py @@ -26,7 +26,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.math impor (2, 10, 5, 4), ], ) -def test_binary_exp_mod(f, e, m, result): +def test_binary_exp_mod(f: int, e: int, m: int, result: int) -> None: value = binary_exp_mod(f, e, m) print(value) assert value == result @@ -46,7 +46,7 @@ def test_binary_exp_mod(f, e, m, result): (1024, 10, 2), ], ) -def test_simple_gcd(a, b, result): +def test_simple_gcd(a: int, b: int, result: int) -> None: value = simple_gcd(a, b) print(value) assert value == result @@ -70,7 +70,7 @@ def test_simple_gcd(a, b, result): (211, False), # the smallest prime number >= 200 ], ) -def test_quick_is_not_prime(n, result): +def test_quick_is_not_prime(n: int, result: bool) -> None: value = quick_is_not_prime(n) print(value) assert value == result @@ -88,7 +88,7 @@ def test_quick_is_not_prime(n, result): (256, None, b"\x01\x00"), ], ) -def test_convert_int_to_bytes(no, count, result): +def test_convert_int_to_bytes(no: int, count: int | None, result: bytes) -> None: value = convert_int_to_bytes(no, count=count) print(value) assert value == result @@ -108,7 +108,7 @@ def test_convert_int_to_bytes(no, count, result): (256, 4, "0100"), ], ) -def test_convert_int_to_hex(no, digits, result): +def test_convert_int_to_hex(no: int, digits: int | None, result: str) -> None: value = convert_int_to_hex(no, digits=digits) print(value) assert value == result @@ -125,7 +125,7 @@ def test_convert_int_to_hex(no, digits, result): (b"\x01\x00", 256), ], ) -def test_convert_bytes_to_int(data, result): +def test_convert_bytes_to_int(data: bytes, result: int) -> None: value = convert_bytes_to_int(data) print(value) assert value == result diff --git a/tests/unit/plugins/module_utils/crypto/test_pem.py b/tests/unit/plugins/module_utils/crypto/test_pem.py index eb3ca6cc..15ebe6c1 100644 --- a/tests/unit/plugins/module_utils/crypto/test_pem.py +++ b/tests/unit/plugins/module_utils/crypto/test_pem.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + import pytest from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import ( extract_first_pem, @@ -13,7 +15,9 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import ) -PEM_TEST_CASES = [ +PEM_TEST_CASES: list[ + tuple[bytes, list[str], bool, t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]] +] = [ (b"", [], False, "raw"), (b"random stuff\nblabla", [], False, "raw"), (b"-----BEGIN PRIVATE KEY-----", [], False, "raw"), @@ -51,7 +55,12 @@ PEM_TEST_CASES = [ @pytest.mark.parametrize("data, pems, is_pem, private_key_type", PEM_TEST_CASES) -def test_pem_handling(data, pems, is_pem, private_key_type): +def test_pem_handling( + data: bytes, + pems: list[str], + is_pem: bool, + private_key_type: t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"], +): assert identify_pem_format(data) == is_pem assert identify_private_key_format(data) == private_key_type try: diff --git a/tests/unit/plugins/module_utils/openssh/test_certificate.py b/tests/unit/plugins/module_utils/openssh/test_certificate.py index cf6b1255..3887e358 100644 --- a/tests/unit/plugins/module_utils/openssh/test_certificate.py +++ b/tests/unit/plugins/module_utils/openssh/test_certificate.py @@ -132,7 +132,9 @@ VALID_EXTENSIONS = [ ] INVALID_EXTENSIONS = [OpensshCertificateOption("extension", "test", "")] -VALID_TIME_PARAMETERS = [ +VALID_TIME_PARAMETERS: list[ + tuple[int | str, int | str, str, int, int | str, str, str, int, str] +] = [ ( 0, "always", @@ -223,28 +225,28 @@ VALID_TIME_PARAMETERS = [ ), ] -INVALID_TIME_PARAMETERS = [ +INVALID_TIME_PARAMETERS: list[tuple[int | str, int | str]] = [ (-1, 0xFFFFFFFFFFFFFFFFFF), ("never", "ever"), ("01-01-1980", "01-01-1990"), (1, 0), ] -VALID_VALIDITY_TEST = [ +VALID_VALIDITY_TEST: list[tuple[str, str, str]] = [ ("always", "forever", "2000-01-01"), ("1999-12-31", "2000-01-02", "2000-01-01"), ("1999-12-31 23:59:00", "2000-01-01 00:01:00", "2000-01-01 00:00:00"), ("1999-12-31 23:59:59", "2000-01-01 00:00:01", "2000-01-01 00:00:00"), ] -INVALID_VALIDITY_TEST = [ +INVALID_VALIDITY_TEST: list[tuple[str, str, str]] = [ ("always", "forever", "1969-12-31"), ("always", "2000-01-01", "2000-01-02"), ("2000-01-01", "forever", "1999-12-31"), ("2000-01-01 00:00:00", "2000-01-01 00:00:01", "2000-01-01 00:00:02"), ] -VALID_OPTIONS = [ +VALID_OPTIONS: list[tuple[str, OpensshCertificateOption]] = [ ( "force-command=/usr/bin/csh", OpensshCertificateOption("critical", "force-command", "/usr/bin/csh"), @@ -265,7 +267,7 @@ VALID_OPTIONS = [ ("extension:foo", OpensshCertificateOption("extension", "foo", "")), ] -INVALID_OPTIONS = [ +INVALID_OPTIONS: list[str | list] = [ "foobar", "foo=bar", "foo:bar=baz", @@ -273,7 +275,7 @@ INVALID_OPTIONS = [ ] -def test_rsa_certificate(tmpdir): +def test_rsa_certificate(tmpdir) -> None: cert_file = tmpdir / "id_rsa-cert.pub" cert_file.write(RSA_CERT_SIGNED_BY_DSA, mode="wb") @@ -285,7 +287,7 @@ def test_rsa_certificate(tmpdir): assert cert.signing_key == DSA_FINGERPRINT -def test_dsa_certificate(tmpdir): +def test_dsa_certificate(tmpdir) -> None: cert_file = tmpdir / "id_dsa-cert.pub" cert_file.write(DSA_CERT_SIGNED_BY_ECDSA_NO_OPTS) @@ -298,7 +300,7 @@ def test_dsa_certificate(tmpdir): assert cert.extensions == [] -def test_ecdsa_certificate(tmpdir): +def test_ecdsa_certificate(tmpdir) -> None: cert_file = tmpdir / "id_ecdsa-cert.pub" cert_file.write(ECDSA_CERT_SIGNED_BY_ED25519_VALID_OPTS) @@ -310,7 +312,7 @@ def test_ecdsa_certificate(tmpdir): assert cert.extensions == VALID_EXTENSIONS -def test_ed25519_certificate(tmpdir): +def test_ed25519_certificate(tmpdir) -> None: cert_file = tmpdir / "id_ed25519-cert.pub" cert_file.write(ED25519_CERT_SIGNED_BY_RSA_INVALID_OPTS) @@ -322,7 +324,7 @@ def test_ed25519_certificate(tmpdir): assert cert.extensions == INVALID_EXTENSIONS -def test_invalid_data(tmpdir): +def test_invalid_data(tmpdir) -> None: result = False cert_file = tmpdir / "invalid-cert.pub" cert_file.write(INVALID_DATA) @@ -341,16 +343,16 @@ def test_invalid_data(tmpdir): VALID_TIME_PARAMETERS, ) def test_valid_time_parameters( - valid_from, - valid_from_hr, - valid_from_openssh, - valid_from_timestamp, - valid_to, - valid_to_hr, - valid_to_openssh, - valid_to_timestamp, - validity_string, -): + valid_from: int | str, + valid_from_hr: int | str, + valid_from_openssh: str, + valid_from_timestamp: int, + valid_to: int | str, + valid_to_hr: str, + valid_to_openssh: str, + valid_to_timestamp: int, + validity_string: str, +) -> None: time_parameters = OpensshCertificateTimeParameters( valid_from=valid_from, valid_to=valid_to ) @@ -364,35 +366,37 @@ def test_valid_time_parameters( @pytest.mark.parametrize("valid_from,valid_to", INVALID_TIME_PARAMETERS) -def test_invalid_time_parameters(valid_from, valid_to): +def test_invalid_time_parameters(valid_from: int | str, valid_to: int | str) -> None: with pytest.raises(ValueError): OpensshCertificateTimeParameters(valid_from, valid_to) @pytest.mark.parametrize("valid_from,valid_to,valid_at", VALID_VALIDITY_TEST) -def test_valid_validity_test(valid_from, valid_to, valid_at): +def test_valid_validity_test(valid_from: str, valid_to: str, valid_at: str) -> None: assert OpensshCertificateTimeParameters(valid_from, valid_to).within_range(valid_at) @pytest.mark.parametrize("valid_from,valid_to,valid_at", INVALID_VALIDITY_TEST) -def test_invalid_validity_test(valid_from, valid_to, valid_at): +def test_invalid_validity_test(valid_from: str, valid_to: str, valid_at: str) -> None: assert not OpensshCertificateTimeParameters(valid_from, valid_to).within_range( valid_at ) @pytest.mark.parametrize("option_string,option_object", VALID_OPTIONS) -def test_valid_options(option_string, option_object): +def test_valid_options( + option_string: str, option_object: OpensshCertificateOption +) -> None: assert OpensshCertificateOption.from_string(option_string) == option_object @pytest.mark.parametrize("option_string", INVALID_OPTIONS) -def test_invalid_options(option_string): +def test_invalid_options(option_string: str) -> None: with pytest.raises(ValueError): OpensshCertificateOption.from_string(option_string) -def test_parse_option_list(): +def test_parse_option_list() -> None: critical_options, extensions = parse_option_list(["force-command=/usr/bin/csh"]) critical_option_objects = [ @@ -411,7 +415,7 @@ def test_parse_option_list(): assert set(extensions) == set(extension_objects) -def test_parse_option_list_with_directives(): +def test_parse_option_list_with_directives() -> None: critical_options, extensions = parse_option_list( ["clear", "no-pty", "permit-pty", "permit-user-rc"] ) @@ -425,7 +429,7 @@ def test_parse_option_list_with_directives(): assert set(extensions) == set(extension_objects) -def test_parse_option_list_case_sensitivity(): +def test_parse_option_list_case_sensitivity() -> None: critical_options, extensions = parse_option_list( ["CLEAR", "no-X11-forwarding", "permit-X11-forwarding"] ) diff --git a/tests/unit/plugins/module_utils/openssh/test_cryptography.py b/tests/unit/plugins/module_utils/openssh/test_cryptography.py index ecf358ae..2f3fef1c 100644 --- a/tests/unit/plugins/module_utils/openssh/test_cryptography.py +++ b/tests/unit/plugins/module_utils/openssh/test_cryptography.py @@ -5,6 +5,7 @@ from __future__ import annotations import os.path +import typing as t from getpass import getuser from os import remove, rmdir from socket import gethostname @@ -23,7 +24,13 @@ from ansible_collections.community.crypto.plugins.module_utils.openssh.cryptogra ) -DEFAULT_KEY_PARAMS = [ +if t.TYPE_CHECKING: + from ansible_collections.community.crypto.plugins.module_utils.openssh.cryptography import ( + KeyType, + ) + + +DEFAULT_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [ ( "rsa", None, @@ -50,7 +57,7 @@ DEFAULT_KEY_PARAMS = [ ), ] -VALID_USER_KEY_PARAMS = [ +VALID_USER_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [ ( "rsa", 8192, @@ -77,9 +84,9 @@ VALID_USER_KEY_PARAMS = [ ), ] -INVALID_USER_KEY_PARAMS = [ +INVALID_USER_KEY_PARAMS: list[tuple[KeyType, int | None, bytes | None, str | None]] = [ ( - "dne", + "dne", # type: ignore None, None, None, @@ -87,18 +94,18 @@ INVALID_USER_KEY_PARAMS = [ ( "rsa", None, - [1, 2, 3], + [1, 2, 3], # type: ignore "comment", ), ( "ecdsa", None, None, - [1, 2, 3], + [1, 2, 3], # type: ignore ), ] -INVALID_KEY_SIZES = [ +INVALID_KEY_SIZES: list[tuple[KeyType, int | None, bytes | None, str | None]] = [ ( "rsa", 1023, @@ -134,7 +141,9 @@ INVALID_KEY_SIZES = [ @pytest.mark.parametrize("keytype,size,passphrase,comment", DEFAULT_KEY_PARAMS) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_default_key_params(keytype, size, passphrase, comment): +def test_default_key_params( + keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None +) -> None: result = True default_sizes = { @@ -163,7 +172,9 @@ def test_default_key_params(keytype, size, passphrase, comment): @pytest.mark.parametrize("keytype,size,passphrase,comment", VALID_USER_KEY_PARAMS) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_valid_user_key_params(keytype, size, passphrase, comment): +def test_valid_user_key_params( + keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None +) -> None: result = True try: @@ -181,7 +192,9 @@ def test_valid_user_key_params(keytype, size, passphrase, comment): @pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_USER_KEY_PARAMS) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_invalid_user_key_params(keytype, size, passphrase, comment): +def test_invalid_user_key_params( + keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None +) -> None: result = False try: @@ -199,7 +212,9 @@ def test_invalid_user_key_params(keytype, size, passphrase, comment): @pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_KEY_SIZES) @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_invalid_key_sizes(keytype, size, passphrase, comment): +def test_invalid_key_sizes( + keytype: KeyType, size: int | None, passphrase: bytes | None, comment: str | None +) -> None: result = False try: @@ -216,7 +231,7 @@ def test_invalid_key_sizes(keytype, size, passphrase, comment): @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_valid_comment_update(): +def test_valid_comment_update() -> None: pair = OpensshKeypair.generate() new_comment = "comment" @@ -233,13 +248,13 @@ def test_valid_comment_update(): @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_invalid_comment_update(): +def test_invalid_comment_update() -> None: result = False pair = OpensshKeypair.generate() new_comment = [1, 2, 3] try: - pair.comment = new_comment + pair.comment = new_comment # type: ignore except InvalidCommentError: result = True @@ -247,7 +262,7 @@ def test_invalid_comment_update(): @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_valid_passphrase_update(): +def test_valid_passphrase_update() -> None: result = False passphrase = "change_me".encode("UTF-8") @@ -281,13 +296,13 @@ def test_valid_passphrase_update(): @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_invalid_passphrase_update(): +def test_invalid_passphrase_update() -> None: result = False passphrase = [1, 2, 3] pair = OpensshKeypair.generate() try: - pair.update_passphrase(passphrase) + pair.update_passphrase(passphrase) # type: ignore except InvalidPassphraseError: result = True @@ -295,7 +310,7 @@ def test_invalid_passphrase_update(): @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_invalid_privatekey(): +def test_invalid_privatekey() -> None: result = False try: @@ -325,7 +340,7 @@ def test_invalid_privatekey(): @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_mismatched_keypair(): +def test_mismatched_keypair() -> None: result = False try: @@ -356,7 +371,7 @@ def test_mismatched_keypair(): @pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") -def test_keypair_comparison(): +def test_keypair_comparison() -> None: assert OpensshKeypair.generate() != OpensshKeypair.generate() assert OpensshKeypair.generate() != OpensshKeypair.generate(keytype="dsa") assert OpensshKeypair.generate() != OpensshKeypair.generate(keytype="ed25519") @@ -366,7 +381,7 @@ def test_keypair_comparison(): try: tmpdir = mkdtemp() - keys = { + keys: dict[str, dict[str, t.Any]] = { "rsa": { "pair": OpensshKeypair.generate(), "filename": os.path.join(tmpdir, "id_rsa"), diff --git a/tests/unit/plugins/module_utils/openssh/test_utils.py b/tests/unit/plugins/module_utils/openssh/test_utils.py index 118de444..c17df7f0 100644 --- a/tests/unit/plugins/module_utils/openssh/test_utils.py +++ b/tests/unit/plugins/module_utils/openssh/test_utils.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + import pytest from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import ( OpensshParser, @@ -15,36 +17,36 @@ from ansible_collections.community.crypto.plugins.module_utils.openssh.utils imp SSH_VERSION_STRING = "OpenSSH_7.9p1, OpenSSL 1.1.0i-fips 14 Aug 2018" SSH_VERSION_NUMBER = "7.9" -VALID_BOOLEAN = [True, False] -INVALID_BOOLEAN = [0x02] -VALID_UINT32 = [ +VALID_BOOLEAN: list[bool] = [True, False] +INVALID_BOOLEAN: list[t.Any] = [0x02] +VALID_UINT32: list[int] = [ 0x00, 0x01, 0x01234567, 0xFFFFFFFF, ] -INVALID_UINT32 = [ +INVALID_UINT32: list[int] = [ 0xFFFFFFFFF, -1, ] -VALID_UINT64 = [ +VALID_UINT64: list[int] = [ 0x00, 0x01, 0x0123456789ABCDEF, 0xFFFFFFFFFFFFFFFF, ] -INVALID_UINT64 = [ +INVALID_UINT64: list[int] = [ 0xFFFFFFFFFFFFFFFFF, -1, ] -VALID_STRING = [ +VALID_STRING: list[bytes] = [ b"test string", ] -INVALID_STRING = [ +INVALID_STRING: list[t.Any] = [ [], ] # See https://datatracker.ietf.org/doc/html/rfc4251#section-5 for examples source -VALID_MPINT = [ +VALID_MPINT: list[int] = [ 0x00, 0x9A378F9B2E332A7, 0x80, @@ -53,50 +55,50 @@ VALID_MPINT = [ # Additional large int test 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, ] -INVALID_MPINT = [ +INVALID_MPINT: list[t.Any] = [ [], ] -def test_parse_openssh_version(): +def test_parse_openssh_version() -> None: assert parse_openssh_version(SSH_VERSION_STRING) == SSH_VERSION_NUMBER @pytest.mark.parametrize("boolean", VALID_BOOLEAN) -def test_valid_boolean(boolean): +def test_valid_boolean(boolean: bool) -> None: assert OpensshParser(_OpensshWriter().boolean(boolean).bytes()).boolean() == boolean @pytest.mark.parametrize("boolean", INVALID_BOOLEAN) -def test_invalid_boolean(boolean): +def test_invalid_boolean(boolean: t.Any) -> None: with pytest.raises(TypeError): _OpensshWriter().boolean(boolean) @pytest.mark.parametrize("uint32", VALID_UINT32) -def test_valid_uint32(uint32): +def test_valid_uint32(uint32: int) -> None: assert OpensshParser(_OpensshWriter().uint32(uint32).bytes()).uint32() == uint32 @pytest.mark.parametrize("uint32", INVALID_UINT32) -def test_invalid_uint32(uint32): +def test_invalid_uint32(uint32: int) -> None: with pytest.raises(ValueError): _OpensshWriter().uint32(uint32) @pytest.mark.parametrize("uint64", VALID_UINT64) -def test_valid_uint64(uint64): +def test_valid_uint64(uint64: int) -> None: assert OpensshParser(_OpensshWriter().uint64(uint64).bytes()).uint64() == uint64 @pytest.mark.parametrize("uint64", INVALID_UINT64) -def test_invalid_uint64(uint64): +def test_invalid_uint64(uint64: int) -> None: with pytest.raises(ValueError): _OpensshWriter().uint64(uint64) @pytest.mark.parametrize("ssh_string", VALID_STRING) -def test_valid_string(ssh_string): +def test_valid_string(ssh_string: bytes) -> None: assert ( OpensshParser(_OpensshWriter().string(ssh_string).bytes()).string() == ssh_string @@ -104,23 +106,23 @@ def test_valid_string(ssh_string): @pytest.mark.parametrize("ssh_string", INVALID_STRING) -def test_invalid_string(ssh_string): +def test_invalid_string(ssh_string: t.Any) -> None: with pytest.raises(TypeError): _OpensshWriter().string(ssh_string) @pytest.mark.parametrize("mpint", VALID_MPINT) -def test_valid_mpint(mpint): +def test_valid_mpint(mpint: int) -> None: assert OpensshParser(_OpensshWriter().mpint(mpint).bytes()).mpint() == mpint @pytest.mark.parametrize("mpint", INVALID_MPINT) -def test_invalid_mpint(mpint): +def test_invalid_mpint(mpint: t.Any) -> None: with pytest.raises(TypeError): _OpensshWriter().mpint(mpint) -def test_valid_seek(): +def test_valid_seek() -> None: buffer = bytearray(b"buffer") parser = OpensshParser(buffer) parser.seek(len(buffer)) @@ -129,7 +131,7 @@ def test_valid_seek(): assert parser.remaining_bytes() == len(buffer) -def test_invalid_seek(): +def test_invalid_seek() -> None: buffer = b"buffer" parser = OpensshParser(buffer) @@ -140,6 +142,6 @@ def test_invalid_seek(): parser.seek(-1) -def test_writer_bytes(): +def test_writer_bytes() -> None: buffer = bytearray(b"buffer") assert _OpensshWriter(buffer).bytes() == buffer diff --git a/tests/unit/plugins/module_utils/test_time.py b/tests/unit/plugins/module_utils/test_time.py index 97f239f9..2d7fb9ff 100644 --- a/tests/unit/plugins/module_utils/test_time.py +++ b/tests/unit/plugins/module_utils/test_time.py @@ -5,9 +5,9 @@ from __future__ import annotations import datetime +import typing as t import pytest -from ansible.module_utils.common.collections import is_sequence from ansible_collections.community.crypto.plugins.module_utils.time import ( UTC, add_or_remove_timezone, @@ -30,25 +30,27 @@ TIMEZONES = [ ] -def cartesian_product(list1, list2): - result = [] +if t.TYPE_CHECKING: + _S = t.TypeVar("_S") + _Ts = t.TypeVarTuple("_Ts") + + +def cartesian_product( + list1: list[_S], list2: "list[tuple[*_Ts]]" +) -> "list[tuple[_S, *_Ts]]": + result: "list[tuple[_S, *_Ts]]" = [] for item1 in list1: - if not is_sequence(item1): - item1 = (item1,) - elif not isinstance(item1, tuple): - item1 = tuple(item1) + item1_tuple = (item1,) for item2 in list2: - if not is_sequence(item2): - item2 = (item2,) - elif not isinstance(item2, tuple): - item2 = tuple(item2) - result.append(item1 + item2) + result.append(item1_tuple + item2) return result ONE_HOUR_PLUS = datetime.timezone(datetime.timedelta(hours=1)) -TEST_REMOVE_TIMEZONE = cartesian_product( +TEST_REMOVE_TIMEZONE: list[ + tuple[datetime.timedelta, datetime.datetime, datetime.datetime] +] = cartesian_product( TIMEZONES, [ ( @@ -66,7 +68,9 @@ TEST_REMOVE_TIMEZONE = cartesian_product( ], ) -TEST_UTC_TIMEZONE = cartesian_product( +TEST_UTC_TIMEZONE: list[ + tuple[datetime.timedelta, datetime.datetime, datetime.datetime] +] = cartesian_product( TIMEZONES, [ ( @@ -84,48 +88,67 @@ TEST_UTC_TIMEZONE = cartesian_product( ], ) -TEST_EPOCH_SECONDS = cartesian_product( - TIMEZONES, - [ - (0, dict(year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=0)), - ( - 1e-6, - dict(year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=1), - ), - ( - 1e-3, - dict( - year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=1000 +TEST_EPOCH_SECONDS: list[tuple[datetime.timedelta, float, dict[str, int]]] = ( + cartesian_product( + TIMEZONES, + [ + ( + 0, + dict( + year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=0 + ), ), - ), - ( - 3691.2, - dict( - year=1970, - day=1, - month=1, - hour=1, - minute=1, - second=31, - microsecond=200000, + ( + 1e-6, + dict( + year=1970, day=1, month=1, hour=0, minute=0, second=0, microsecond=1 + ), ), - ), - ], + ( + 1e-3, + dict( + year=1970, + day=1, + month=1, + hour=0, + minute=0, + second=0, + microsecond=1000, + ), + ), + ( + 3691.2, + dict( + year=1970, + day=1, + month=1, + hour=1, + minute=1, + second=31, + microsecond=200000, + ), + ), + ], + ) ) -TEST_EPOCH_TO_SECONDS = cartesian_product( - TIMEZONES, - [ - (datetime.datetime(1970, 1, 1, 0, 1, 2, 0), 62), - (datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=UTC), 62), - ( - datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=ONE_HOUR_PLUS), - 62 - 3600, - ), - ], +TEST_EPOCH_TO_SECONDS: list[tuple[datetime.timedelta, datetime.datetime, int]] = ( + cartesian_product( + TIMEZONES, + [ + (datetime.datetime(1970, 1, 1, 0, 1, 2, 0), 62), + (datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=UTC), 62), + ( + datetime.datetime(1970, 1, 1, 0, 1, 2, 0, tzinfo=ONE_HOUR_PLUS), + 62 - 3600, + ), + ], + ) ) -TEST_CONVERT_RELATIVE_TO_DATETIME = cartesian_product( +TEST_CONVERT_RELATIVE_TO_DATETIME: list[ + tuple[datetime.timedelta, str, bool, datetime.datetime, datetime.datetime] +] = cartesian_product( TIMEZONES, [ ( @@ -167,7 +190,9 @@ TEST_CONVERT_RELATIVE_TO_DATETIME = cartesian_product( ], ) -TEST_GET_RELATIVE_TIME_OPTION = cartesian_product( +TEST_GET_RELATIVE_TIME_OPTION: list[ + tuple[datetime.timedelta, str, str, bool, datetime.datetime, datetime.datetime] +] = cartesian_product( TIMEZONES, [ ( @@ -259,7 +284,9 @@ TEST_GET_RELATIVE_TIME_OPTION = cartesian_product( @pytest.mark.parametrize("timezone, input, expected", TEST_REMOVE_TIMEZONE) -def test_remove_timezone(timezone, input, expected): +def test_remove_timezone( + timezone: datetime.timedelta, input: datetime.datetime, expected: datetime.datetime +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): output_1 = remove_timezone(input) assert expected == output_1 @@ -268,7 +295,9 @@ def test_remove_timezone(timezone, input, expected): @pytest.mark.parametrize("timezone, input, expected", TEST_UTC_TIMEZONE) -def test_utc_timezone(timezone, input, expected): +def test_utc_timezone( + timezone: datetime.timedelta, input: datetime.datetime, expected: datetime.datetime +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): output_1 = ensure_utc_timezone(input) assert expected == output_1 @@ -280,7 +309,7 @@ def test_utc_timezone(timezone, input, expected): # Due to a bug in freezegun (https://github.com/spulec/freezegun/issues/348, https://github.com/spulec/freezegun/issues/553) # this only works with timezone = UTC @pytest.mark.parametrize("timezone", [datetime.timedelta(hours=0)]) -def test_get_now_datetime_w_timezone(timezone): +def test_get_now_datetime_w_timezone(timezone: datetime.timedelta) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): output_2 = get_now_datetime(with_timezone=True) assert output_2.tzinfo is not None @@ -289,7 +318,7 @@ def test_get_now_datetime_w_timezone(timezone): @pytest.mark.parametrize("timezone", TIMEZONES) -def test_get_now_datetime_wo_timezone(timezone): +def test_get_now_datetime_wo_timezone(timezone: datetime.timedelta) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): output_1 = get_now_datetime(with_timezone=False) assert output_1.tzinfo is None @@ -297,13 +326,15 @@ def test_get_now_datetime_wo_timezone(timezone): @pytest.mark.parametrize("timezone, seconds, timestamp", TEST_EPOCH_SECONDS) -def test_epoch_seconds(timezone, seconds, timestamp): +def test_epoch_seconds( + timezone: datetime.timedelta, seconds: float, timestamp: dict[str, int] +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): - ts_wo_tz = datetime.datetime(**timestamp) + ts_wo_tz: datetime.datetime = datetime.datetime(**timestamp) # type: ignore assert seconds == get_epoch_seconds(ts_wo_tz) - timestamp_w_tz = dict(timestamp) + timestamp_w_tz: dict[str, t.Any] = dict(timestamp) timestamp_w_tz["tzinfo"] = UTC - ts_w_tz = datetime.datetime(**timestamp_w_tz) + ts_w_tz: datetime.datetime = datetime.datetime(**timestamp_w_tz) # type: ignore assert seconds == get_epoch_seconds(ts_w_tz) output_1 = from_epoch_seconds(seconds, with_timezone=False) assert ts_wo_tz == output_1 @@ -312,7 +343,9 @@ def test_epoch_seconds(timezone, seconds, timestamp): @pytest.mark.parametrize("timezone, timestamp, expected_seconds", TEST_EPOCH_TO_SECONDS) -def test_epoch_to_seconds(timezone, timestamp, expected_seconds): +def test_epoch_to_seconds( + timezone: datetime.timedelta, timestamp: datetime.datetime, expected_seconds: int +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): assert expected_seconds == get_epoch_seconds(timestamp) @@ -322,8 +355,12 @@ def test_epoch_to_seconds(timezone, timestamp, expected_seconds): TEST_CONVERT_RELATIVE_TO_DATETIME, ) def test_convert_relative_to_datetime( - timezone, relative_time_string, with_timezone, now, expected -): + timezone: datetime.timedelta, + relative_time_string: str, + with_timezone: bool, + now: datetime.datetime, + expected: datetime.datetime, +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): output = convert_relative_to_datetime( relative_time_string, with_timezone=with_timezone, now=now @@ -336,8 +373,13 @@ def test_convert_relative_to_datetime( TEST_GET_RELATIVE_TIME_OPTION, ) def test_get_relative_time_option( - timezone, input_string, input_name, with_timezone, now, expected -): + timezone: datetime.timedelta, + input_string: str, + input_name: str, + with_timezone: bool, + now: datetime.datetime, + expected: datetime.datetime, +) -> None: with freeze_time("2024-02-03 04:05:06", tz_offset=timezone): output = get_relative_time_option( input_string, diff --git a/tests/unit/plugins/modules/test_luks_device.py b/tests/unit/plugins/modules/test_luks_device.py index 790ce2d4..7c92d15e 100644 --- a/tests/unit/plugins/modules/test_luks_device.py +++ b/tests/unit/plugins/modules/test_luks_device.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + import pytest from ansible_collections.community.crypto.plugins.modules import luks_device @@ -23,17 +25,17 @@ class DummyModule: # ===== Handler & CryptHandler methods tests ===== -def test_generate_luks_name(monkeypatch): +def test_generate_luks_name(monkeypatch) -> None: module = DummyModule() module.params["passphrase_encoding"] = "text" monkeypatch.setattr( luks_device.Handler, "_run_command", lambda x, y: [0, "UUID", ""] ) - crypt = luks_device.CryptHandler(module) + crypt = luks_device.CryptHandler(module) # type: ignore assert crypt.generate_luks_name("/dev/dummy") == "luks-UUID" -def test_get_container_name_by_device(monkeypatch): +def test_get_container_name_by_device(monkeypatch) -> None: module = DummyModule() module.params["passphrase_encoding"] = "text" monkeypatch.setattr( @@ -41,11 +43,11 @@ def test_get_container_name_by_device(monkeypatch): "_run_command", lambda x, y: [0, "crypt container_name", ""], ) - crypt = luks_device.CryptHandler(module) + crypt = luks_device.CryptHandler(module) # type: ignore assert crypt.get_container_name_by_device("/dev/dummy") == "container_name" -def test_get_container_device_by_name(monkeypatch): +def test_get_container_device_by_name(monkeypatch) -> None: module = DummyModule() module.params["passphrase_encoding"] = "text" monkeypatch.setattr( @@ -53,15 +55,15 @@ def test_get_container_device_by_name(monkeypatch): "_run_command", lambda x, y: [0, "device: /dev/luksdevice", ""], ) - crypt = luks_device.CryptHandler(module) + crypt = luks_device.CryptHandler(module) # type: ignore assert crypt.get_container_device_by_name("dummy") == "/dev/luksdevice" -def test_run_luks_remove(monkeypatch): - def run_command_check(self, command): +def test_run_luks_remove(monkeypatch) -> None: + def run_command_check(self, command: list[str]) -> tuple[int, str, str]: # check that wipefs command is actually called assert command[0] == "wipefs" - return [0, "", ""] + return 0, "", "" module = DummyModule() module.params["passphrase_encoding"] = "text" @@ -70,14 +72,26 @@ def test_run_luks_remove(monkeypatch): ) monkeypatch.setattr(luks_device.Handler, "_run_command", run_command_check) monkeypatch.setattr(luks_device, "wipe_luks_headers", lambda device: True) - crypt = luks_device.CryptHandler(module) + crypt = luks_device.CryptHandler(module) # type: ignore crypt.run_luks_remove("dummy") # ===== ConditionsHandler methods data and tests ===== # device, key, passphrase, state, is_luks, label, cipher, hash, expected -LUKS_CREATE_DATA = ( +LUKS_CREATE_DATA: list[ + tuple[ + str | None, + str | None, + str | None, + t.Literal["present", "absent", "opened", "closed"], + bool, + str | None, + str | None, + str | None, + bool | t.Literal["exception"], + ] +] = [ ("dummy", "key", None, "present", False, None, "dummy", "dummy", True), (None, "key", None, "present", False, None, "dummy", "dummy", False), (None, "key", None, "present", False, "labelName", "dummy", "dummy", True), @@ -97,18 +111,35 @@ LUKS_CREATE_DATA = ( ("dummy", "key", None, "present", False, None, None, None, True), ("dummy", "key", None, "present", False, None, None, "dummy", True), ("dummy", "key", None, "present", False, None, "dummy", None, True), -) +] # device, state, is_luks, expected -LUKS_REMOVE_DATA = ( +LUKS_REMOVE_DATA: list[ + tuple[ + str | None, + t.Literal["present", "absent", "opened", "closed"], + bool, + bool | t.Literal["exception"], + ] +] = [ ("dummy", "absent", True, True), (None, "absent", True, False), ("dummy", "present", True, False), ("dummy", "absent", False, False), -) +] # device, key, passphrase, state, name, name_by_dev, expected -LUKS_OPEN_DATA = ( +LUKS_OPEN_DATA: list[ + tuple[ + str | None, + str | None, + str | None, + t.Literal["present", "absent", "opened", "closed"], + str | None, + str | None, + bool | t.Literal["exception"], + ] +] = [ ("dummy", "key", None, "present", "name", None, False), ("dummy", "key", None, "absent", "name", None, False), ("dummy", "key", None, "closed", "name", None, False), @@ -125,10 +156,20 @@ LUKS_OPEN_DATA = ( ("dummy", None, None, "opened", "name", None, False), ("dummy", None, "quuz", "opened", "name", "name", False), ("dummy", None, "corge", "opened", "beer", "name", "exception"), -) +] # device, dev_by_name, name, name_by_dev, state, label, expected -LUKS_CLOSE_DATA = ( +LUKS_CLOSE_DATA: list[ + tuple[ + str | None, + str | None, + str | None, + str | None, + t.Literal["present", "absent", "opened", "closed"], + str | None, + bool | t.Literal["exception"], + ] +] = [ ("dummy", "dummy", "name", "name", "present", None, False), ("dummy", "dummy", "name", "name", "absent", None, False), ("dummy", "dummy", "name", "name", "opened", None, False), @@ -136,10 +177,21 @@ LUKS_CLOSE_DATA = ( (None, "dummy", "name", "name", "closed", None, True), ("dummy", "dummy", None, "name", "closed", None, True), (None, "dummy", None, "name", "closed", None, False), -) +] # device, key, passphrase, new_key, new_passphrase, state, label, expected -LUKS_ADD_KEY_DATA = ( +LUKS_ADD_KEY_DATA: list[ + tuple[ + str | None, + str | None, + str | None, + str | None, + str | None, + t.Literal["present", "absent", "opened", "closed"], + str | None, + bool | t.Literal["exception"], + ] +] = [ ("dummy", "key", None, "new_key", None, "present", None, True), (None, "key", None, "new_key", None, "present", "labelName", True), (None, "key", None, "new_key", None, "present", None, False), @@ -156,10 +208,20 @@ LUKS_ADD_KEY_DATA = ( ("dummy", "key", None, None, "new_pass", "absent", None, "exception"), ("dummy", None, "pass", None, "new_pass", "present", None, True), (None, None, "pass", None, "new_pass", "present", "labelName", True), -) +] -# device, remove_key, remove_passphrase, state, label, expected -LUKS_REMOVE_KEY_DATA = ( +# device, remove_key, remove_passphrase, remove_keyslot, state, label, expected +LUKS_REMOVE_KEY_DATA: list[ + tuple[ + str | None, + str | None, + str | None, + str | None, + t.Literal["present", "absent", "opened", "closed"], + str | None, + bool | t.Literal["exception"], + ] +] = [ ("dummy", "key", None, None, "present", None, True), (None, "key", None, None, "present", None, False), (None, "key", None, None, "present", "labelName", True), @@ -170,7 +232,7 @@ LUKS_REMOVE_KEY_DATA = ( (None, None, "foo", None, "present", "labelName", True), ("dummy", None, None, None, "present", None, False), ("dummy", None, "foo", None, "absent", None, "exception"), -) +] @pytest.mark.parametrize( @@ -178,17 +240,17 @@ LUKS_REMOVE_KEY_DATA = ( ((d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in LUKS_CREATE_DATA), ) def test_luks_create( - device, - keyfile, - passphrase, - state, - is_luks, - label, - cipher, - hash_, - expected, + device: str | None, + keyfile: str | None, + passphrase: str | None, + state: t.Literal["present", "absent", "opened", "closed"], + is_luks: bool, + label: str | None, + cipher: str | None, + hash_: str | None, + expected: bool | t.Literal["exception"], monkeypatch, -): +) -> None: module = DummyModule() module.params["device"] = device @@ -201,7 +263,7 @@ def test_luks_create( module.params["hash"] = hash_ monkeypatch.setattr(luks_device.CryptHandler, "is_luks", lambda x, y: is_luks) - crypt = luks_device.CryptHandler(module) + crypt = luks_device.CryptHandler(module) # type: ignore if device is None: monkeypatch.setattr( luks_device.Handler, @@ -209,7 +271,7 @@ def test_luks_create( lambda x, y: [0, "/dev/dummy", ""], ) try: - conditions = luks_device.ConditionsHandler(module, crypt) + conditions = luks_device.ConditionsHandler(module, crypt) # type: ignore assert conditions.luks_create() == expected except ValueError: assert expected == "exception" @@ -219,7 +281,13 @@ def test_luks_create( "device, state, is_luks, expected", ((d[0], d[1], d[2], d[3]) for d in LUKS_REMOVE_DATA), ) -def test_luks_remove(device, state, is_luks, expected, monkeypatch): +def test_luks_remove( + device: str | None, + state: t.Literal["present", "absent", "opened", "closed"], + is_luks: bool, + expected: bool | t.Literal["exception"], + monkeypatch, +) -> None: module = DummyModule() module.params["device"] = device @@ -227,9 +295,9 @@ def test_luks_remove(device, state, is_luks, expected, monkeypatch): module.params["state"] = state monkeypatch.setattr(luks_device.CryptHandler, "is_luks", lambda x, y: is_luks) - crypt = luks_device.CryptHandler(module) + crypt = luks_device.CryptHandler(module) # type: ignore try: - conditions = luks_device.ConditionsHandler(module, crypt) + conditions = luks_device.ConditionsHandler(module, crypt) # type: ignore assert conditions.luks_remove() == expected except ValueError: assert expected == "exception" @@ -240,8 +308,15 @@ def test_luks_remove(device, state, is_luks, expected, monkeypatch): ((d[0], d[1], d[2], d[3], d[4], d[5], d[6]) for d in LUKS_OPEN_DATA), ) def test_luks_open( - device, keyfile, passphrase, state, name, name_by_dev, expected, monkeypatch -): + device: str | None, + keyfile: str | None, + passphrase: str | None, + state: t.Literal["present", "absent", "opened", "closed"], + name: str | None, + name_by_dev: str | None, + expected: bool | t.Literal["exception"], + monkeypatch, +) -> None: module = DummyModule() module.params["device"] = device module.params["keyfile"] = keyfile @@ -261,9 +336,9 @@ def test_luks_open( monkeypatch.setattr( luks_device.Handler, "_run_command", lambda x, y: [0, device, ""] ) - crypt = luks_device.CryptHandler(module) + crypt = luks_device.CryptHandler(module) # type: ignore try: - conditions = luks_device.ConditionsHandler(module, crypt) + conditions = luks_device.ConditionsHandler(module, crypt) # type: ignore assert conditions.luks_open() == expected except ValueError: assert expected == "exception" @@ -274,8 +349,15 @@ def test_luks_open( ((d[0], d[1], d[2], d[3], d[4], d[5], d[6]) for d in LUKS_CLOSE_DATA), ) def test_luks_close( - device, dev_by_name, name, name_by_dev, state, label, expected, monkeypatch -): + device: str | None, + dev_by_name: str | None, + name: str | None, + name_by_dev: str | None, + state: t.Literal["present", "absent", "opened", "closed"], + label: str | None, + expected: bool | t.Literal["exception"], + monkeypatch, +) -> None: module = DummyModule() module.params["device"] = device module.params["name"] = name @@ -293,9 +375,9 @@ def test_luks_close( "get_container_device_by_name", lambda x, y: dev_by_name, ) - crypt = luks_device.CryptHandler(module) + crypt = luks_device.CryptHandler(module) # type: ignore try: - conditions = luks_device.ConditionsHandler(module, crypt) + conditions = luks_device.ConditionsHandler(module, crypt) # type: ignore assert conditions.luks_close() == expected except ValueError: assert expected == "exception" @@ -307,16 +389,16 @@ def test_luks_close( ((d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7]) for d in LUKS_ADD_KEY_DATA), ) def test_luks_add_key( - device, - keyfile, - passphrase, - new_keyfile, - new_passphrase, - state, - label, - expected, + device: str | None, + keyfile: str | None, + passphrase: str | None, + new_keyfile: str | None, + new_passphrase: str | None, + state: t.Literal["present", "absent", "opened", "closed"], + label: str | None, + expected: bool | t.Literal["exception"], monkeypatch, -): +) -> None: module = DummyModule() module.params["device"] = device module.params["keyfile"] = keyfile @@ -335,9 +417,9 @@ def test_luks_add_key( luks_device.CryptHandler, "luks_test_key", lambda x, y, z, w: False ) - crypt = luks_device.CryptHandler(module) + crypt = luks_device.CryptHandler(module) # type: ignore try: - conditions = luks_device.ConditionsHandler(module, crypt) + conditions = luks_device.ConditionsHandler(module, crypt) # type: ignore assert conditions.luks_add_key() == expected except ValueError: assert expected == "exception" @@ -349,15 +431,15 @@ def test_luks_add_key( ((d[0], d[1], d[2], d[3], d[4], d[5], d[6]) for d in LUKS_REMOVE_KEY_DATA), ) def test_luks_remove_key( - device, - remove_keyfile, - remove_passphrase, - remove_keyslot, - state, - label, - expected, + device: str | None, + remove_keyfile: str | None, + remove_passphrase: str | None, + remove_keyslot: str | None, + state: t.Literal["present", "absent", "opened", "closed"], + label: str | None, + expected: bool | t.Literal["exception"], monkeypatch, -): +) -> None: module = DummyModule() module.params["device"] = device @@ -378,9 +460,9 @@ def test_luks_remove_key( luks_device.CryptHandler, "luks_test_key", lambda x, y, z, w: True ) - crypt = luks_device.CryptHandler(module) + crypt = luks_device.CryptHandler(module) # type: ignore try: - conditions = luks_device.ConditionsHandler(module, crypt) + conditions = luks_device.ConditionsHandler(module, crypt) # type: ignore assert conditions.luks_remove_key() == expected except ValueError: assert expected == "exception"