From 52b21b51774ccbd5fdd640eb55d2eae3047465ca Mon Sep 17 00:00:00 2001 From: Felix Fontein Date: Thu, 29 May 2025 23:10:35 +0200 Subject: [PATCH] Fix/improve typing. (#905) --- .../_crypto/module_backends/csr.py | 6 ++-- plugins/module_utils/_crypto/support.py | 36 +++++++++++++++++-- .../_openssh/backends/keypair_backend.py | 2 ++ plugins/module_utils/_openssh/certificate.py | 2 +- 4 files changed, 41 insertions(+), 5 deletions(-) diff --git a/plugins/module_utils/_crypto/module_backends/csr.py b/plugins/module_utils/_crypto/module_backends/csr.py index 57d70b3e..24be4fe0 100644 --- a/plugins/module_utils/_crypto/module_backends/csr.py +++ b/plugins/module_utils/_crypto/module_backends/csr.py @@ -144,7 +144,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta): ) self.ordered_subject = False - self.subject = [ + subject = [ ("C", module.params["country_name"]), ("ST", module.params["state_or_province_name"]), ("L", module.params["locality_name"]), @@ -153,7 +153,9 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta): ("CN", module.params["common_name"]), ("emailAddress", module.params["email_address"]), ] - self.subject = [(entry[0], entry[1]) for entry in self.subject if entry[1]] + self.subject: list[tuple[str, str]] = [ + (entry[0], entry[1]) for entry in subject if entry[1] + ] try: if module.params["subject"]: diff --git a/plugins/module_utils/_crypto/support.py b/plugins/module_utils/_crypto/support.py index d6525297..ae6397b5 100644 --- a/plugins/module_utils/_crypto/support.py +++ b/plugins/module_utils/_crypto/support.py @@ -292,11 +292,27 @@ def load_certificate_request( raise OpenSSLObjectError(exc) from exc +@t.overload +def parse_name_field( + input_dict: dict[str, list[str] | str], + *, + name_field_name: str | None = None, +) -> list[tuple[str, str]]: ... + + +@t.overload def parse_name_field( input_dict: dict[str, list[str | bytes] | str | bytes], *, name_field_name: str | None = None, -) -> list[tuple[str, str | bytes]]: +) -> list[tuple[str, str | bytes]]: ... + + +def parse_name_field( + input_dict: dict[str, t.Any], + *, + name_field_name: str | None = None, +) -> list: """Take a dict with key: value or key: list_of_values mappings and return a list of tuples""" def error_str(key: str) -> str: @@ -328,11 +344,27 @@ def parse_name_field( return result +@t.overload +def parse_ordered_name_field( + input_list: list[dict[str, list[str] | str]], + *, + name_field_name: str, +) -> list[tuple[str, str]]: ... + + +@t.overload def parse_ordered_name_field( input_list: list[dict[str, list[str | bytes] | str | bytes]], *, name_field_name: str, -) -> list[tuple[str, str | bytes]]: +) -> list[tuple[str, str | bytes]]: ... + + +def parse_ordered_name_field( + input_list: list[dict[str, t.Any]], + *, + name_field_name: str, +) -> list: """Take a dict with key: value or key: list_of_values mappings and return a list of tuples""" result = [] diff --git a/plugins/module_utils/_openssh/backends/keypair_backend.py b/plugins/module_utils/_openssh/backends/keypair_backend.py index 7c15de68..652e9b3e 100644 --- a/plugins/module_utils/_openssh/backends/keypair_backend.py +++ b/plugins/module_utils/_openssh/backends/keypair_backend.py @@ -403,6 +403,7 @@ class KeypairBackendOpensshBin(KeypairBackend): ) def _update_comment(self) -> None: + assert self.comment is not None try: ssh_version = self._get_ssh_version() or "7.8" force_new_format = ( @@ -527,6 +528,7 @@ class KeypairBackendCryptography(KeypairBackend): return True def _update_comment(self) -> None: + assert self.comment is not None keypair = OpensshKeypair.load( path=self.private_key_path, passphrase=self.passphrase, no_public_key=True ) diff --git a/plugins/module_utils/_openssh/certificate.py b/plugins/module_utils/_openssh/certificate.py index 35c08980..429e265b 100644 --- a/plugins/module_utils/_openssh/certificate.py +++ b/plugins/module_utils/_openssh/certificate.py @@ -468,7 +468,7 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo): self, *, curve: bytes | None = None, public_key: bytes | None = None, **kwargs ): super().__init__(**kwargs) - self._curve = None + self._curve: bytes | None = None if curve is not None: self.curve = curve