Add type hints and type checking (#885)

* Enable basic type checking.

* Fix first errors.

* Add changelog fragment.

* Add types to module_utils and plugin_utils (without module backends).

* Add typing hints for acme_* modules.

* Add typing to X.509 certificate modules, and add more helpers.

* Add typing to remaining module backends.

* Add typing for action, filter, and lookup plugins.

* Bump ansible-core 2.19 beta requirement for typing.

* Add more typing definitions.

* Add typing to some unit tests.
This commit is contained in:
Felix Fontein
2025-05-11 18:00:11 +02:00
committed by GitHub
parent 82f0176773
commit f758d94fba
124 changed files with 4986 additions and 2662 deletions

View File

@@ -284,6 +284,7 @@ info:
"""
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.common import (
@@ -302,47 +303,58 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
class Certificate(OpensshModule):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(Certificate, self).__init__(module)
self.ssh_keygen = KeygenCommand(self.module)
self.identifier = self.module.params["identifier"] or ""
self.options = self.module.params["options"] or []
self.path = self.module.params["path"]
self.pkcs11_provider = self.module.params["pkcs11_provider"]
self.principals = self.module.params["principals"] or []
self.public_key = self.module.params["public_key"]
self.regenerate = (
self.identifier: str = self.module.params["identifier"] or ""
self.options: list[str] = self.module.params["options"] or []
self.path: str = self.module.params["path"]
self.pkcs11_provider: str | None = self.module.params["pkcs11_provider"]
self.principals: list[str] = self.module.params["principals"] or []
self.public_key: str | None = self.module.params["public_key"]
self.regenerate: t.Literal[
"never",
"fail",
"partial_idempotence",
"full_idempotence",
"always",
] = (
self.module.params["regenerate"]
if not self.module.params["force"]
else "always"
)
self.serial_number = self.module.params["serial_number"]
self.signature_algorithm = self.module.params["signature_algorithm"]
self.signing_key = self.module.params["signing_key"]
self.state = self.module.params["state"]
self.type = self.module.params["type"]
self.use_agent = self.module.params["use_agent"]
self.valid_at = self.module.params["valid_at"]
self.ignore_timestamps = self.module.params["ignore_timestamps"]
self.serial_number: int | None = self.module.params["serial_number"]
self.signature_algorithm: (
t.Literal["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"] | None
) = self.module.params["signature_algorithm"]
self.signing_key: str | None = self.module.params["signing_key"]
self.state: t.Literal["absent", "present"] = self.module.params["state"]
self.type: t.Literal["host", "user"] | None = self.module.params["type"]
self.use_agent: bool = self.module.params["use_agent"]
self.valid_at: str | None = self.module.params["valid_at"]
self.ignore_timestamps: bool = self.module.params["ignore_timestamps"]
self._check_if_base_dir(self.path)
if self.state == "present":
self._validate_parameters()
self.data = None
self.original_data = None
self.data: OpensshCertificate | None = None
self.original_data: OpensshCertificate | None = None
if self._exists():
self._load_certificate()
self.time_parameters = None
self.time_parameters: OpensshCertificateTimeParameters | None = None
if self.state == "present":
self._set_time_parameters()
def _validate_parameters(self):
def _validate_parameters(self) -> None:
for path in (self.public_key, self.signing_key):
self._check_if_base_dir(path)
if (
path is not None
): # should never be None, but the type checker doesn't know
self._check_if_base_dir(path)
if self.options and self.type == "host":
self.module.fail_json(
@@ -352,7 +364,7 @@ class Certificate(OpensshModule):
if self.use_agent:
self._use_agent_available()
def _use_agent_available(self):
def _use_agent_available(self) -> None:
ssh_version = self._get_ssh_version()
if not ssh_version:
self.module.fail_json(msg="Failed to determine ssh version")
@@ -362,10 +374,10 @@ class Certificate(OpensshModule):
+ f" Your version is: {ssh_version}"
)
def _exists(self):
def _exists(self) -> bool:
return os.path.exists(self.path)
def _load_certificate(self):
def _load_certificate(self) -> None:
try:
self.original_data = OpensshCertificate.load(self.path)
except (TypeError, ValueError) as e:
@@ -373,7 +385,7 @@ class Certificate(OpensshModule):
self.module.fail_json(msg=f"Unable to read existing certificate: {e}")
self.module.warn(f"Unable to read existing certificate: {e}")
def _set_time_parameters(self):
def _set_time_parameters(self) -> None:
try:
self.time_parameters = OpensshCertificateTimeParameters(
valid_from=self.module.params["valid_from"],
@@ -382,7 +394,7 @@ class Certificate(OpensshModule):
except ValueError as e:
self.module.fail_json(msg=str(e))
def _execute(self):
def _execute(self) -> None:
if self.state == "present":
if self._should_generate():
self._generate()
@@ -391,7 +403,7 @@ class Certificate(OpensshModule):
if self._exists():
self._remove()
def _should_generate(self):
def _should_generate(self) -> bool:
if self.regenerate == "never":
return self.original_data is None
elif self.regenerate == "fail":
@@ -408,7 +420,13 @@ class Certificate(OpensshModule):
else:
return True
def _is_fully_valid(self):
def _is_fully_valid(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
if self.public_key is None:
raise AssertionError("Contract violation public_key not provided")
if self.signing_key is None:
raise AssertionError("Contract violation signing_key not provided")
return self._is_partially_valid() and all(
[
self._compare_options() if self.original_data.type == "user" else True,
@@ -420,7 +438,9 @@ class Certificate(OpensshModule):
]
)
def _is_partially_valid(self):
def _is_partially_valid(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
return all(
[
set(self.original_data.principals) == set(self.principals),
@@ -439,7 +459,9 @@ class Certificate(OpensshModule):
]
)
def _compare_time_parameters(self):
def _compare_time_parameters(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
try:
original_time_parameters = OpensshCertificateTimeParameters(
valid_from=self.original_data.valid_after,
@@ -458,7 +480,9 @@ class Certificate(OpensshModule):
]
)
def _compare_options(self):
def _compare_options(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
try:
critical_options, extensions = parse_option_list(self.options)
except ValueError as e:
@@ -471,13 +495,13 @@ class Certificate(OpensshModule):
]
)
def _get_key_fingerprint(self, path):
def _get_key_fingerprint(self, path: str) -> str:
private_key_content = self.ssh_keygen.get_private_key(path, check_rc=True)[1]
return PrivateKey.from_string(private_key_content).fingerprint
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _generate(self):
def _generate(self) -> None:
try:
temp_certificate = self._generate_temp_certificate()
self._safe_secure_move([(temp_certificate, self.path)])
@@ -491,7 +515,14 @@ class Certificate(OpensshModule):
except (TypeError, ValueError) as e:
self.module.fail_json(msg=f"Unable to read new certificate: {e}")
def _generate_temp_certificate(self):
def _generate_temp_certificate(self) -> str:
if self.public_key is None:
raise AssertionError("Contract violation public_key not provided")
if self.signing_key is None:
raise AssertionError("Contract violation signing_key not provided")
if self.time_parameters is None:
raise AssertionError("Contract violation time_parameters not provided")
key_copy = os.path.join(self.module.tmpdir, os.path.basename(self.public_key))
try:
@@ -523,14 +554,14 @@ class Certificate(OpensshModule):
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _remove(self):
def _remove(self) -> None:
try:
os.remove(self.path)
except OSError as e:
self.module.fail_json(msg=f"Unable to remove existing certificate: {e}")
@property
def _result(self):
def _result(self) -> dict[str, t.Any]:
if self.state != "present":
return {}
@@ -546,14 +577,14 @@ class Certificate(OpensshModule):
}
@property
def diff(self):
def diff(self) -> dict[str, t.Any]:
return {
"before": get_cert_dict(self.original_data),
"after": get_cert_dict(self.data),
}
def format_cert_info(cert_info):
def format_cert_info(cert_info: str) -> list[str]:
result = []
string = ""
@@ -579,7 +610,7 @@ def format_cert_info(cert_info):
return result
def get_cert_dict(data):
def get_cert_dict(data: OpensshCertificate | None) -> dict[str, t.Any]:
if data is None:
return {}
@@ -590,7 +621,7 @@ def get_cert_dict(data):
return result
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
force=dict(type="bool", default=False),