mirror of
https://github.com/ansible-collections/community.crypto.git
synced 2026-05-07 13:53:06 +00:00
Add type hints and type checking (#885)
* Enable basic type checking. * Fix first errors. * Add changelog fragment. * Add types to module_utils and plugin_utils (without module backends). * Add typing hints for acme_* modules. * Add typing to X.509 certificate modules, and add more helpers. * Add typing to remaining module backends. * Add typing for action, filter, and lookup plugins. * Bump ansible-core 2.19 beta requirement for typing. * Add more typing definitions. * Add typing to some unit tests.
This commit is contained in:
@@ -284,6 +284,7 @@ info:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.common import (
|
||||
@@ -302,47 +303,58 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
|
||||
|
||||
class Certificate(OpensshModule):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(Certificate, self).__init__(module)
|
||||
self.ssh_keygen = KeygenCommand(self.module)
|
||||
|
||||
self.identifier = self.module.params["identifier"] or ""
|
||||
self.options = self.module.params["options"] or []
|
||||
self.path = self.module.params["path"]
|
||||
self.pkcs11_provider = self.module.params["pkcs11_provider"]
|
||||
self.principals = self.module.params["principals"] or []
|
||||
self.public_key = self.module.params["public_key"]
|
||||
self.regenerate = (
|
||||
self.identifier: str = self.module.params["identifier"] or ""
|
||||
self.options: list[str] = self.module.params["options"] or []
|
||||
self.path: str = self.module.params["path"]
|
||||
self.pkcs11_provider: str | None = self.module.params["pkcs11_provider"]
|
||||
self.principals: list[str] = self.module.params["principals"] or []
|
||||
self.public_key: str | None = self.module.params["public_key"]
|
||||
self.regenerate: t.Literal[
|
||||
"never",
|
||||
"fail",
|
||||
"partial_idempotence",
|
||||
"full_idempotence",
|
||||
"always",
|
||||
] = (
|
||||
self.module.params["regenerate"]
|
||||
if not self.module.params["force"]
|
||||
else "always"
|
||||
)
|
||||
self.serial_number = self.module.params["serial_number"]
|
||||
self.signature_algorithm = self.module.params["signature_algorithm"]
|
||||
self.signing_key = self.module.params["signing_key"]
|
||||
self.state = self.module.params["state"]
|
||||
self.type = self.module.params["type"]
|
||||
self.use_agent = self.module.params["use_agent"]
|
||||
self.valid_at = self.module.params["valid_at"]
|
||||
self.ignore_timestamps = self.module.params["ignore_timestamps"]
|
||||
self.serial_number: int | None = self.module.params["serial_number"]
|
||||
self.signature_algorithm: (
|
||||
t.Literal["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"] | None
|
||||
) = self.module.params["signature_algorithm"]
|
||||
self.signing_key: str | None = self.module.params["signing_key"]
|
||||
self.state: t.Literal["absent", "present"] = self.module.params["state"]
|
||||
self.type: t.Literal["host", "user"] | None = self.module.params["type"]
|
||||
self.use_agent: bool = self.module.params["use_agent"]
|
||||
self.valid_at: str | None = self.module.params["valid_at"]
|
||||
self.ignore_timestamps: bool = self.module.params["ignore_timestamps"]
|
||||
|
||||
self._check_if_base_dir(self.path)
|
||||
|
||||
if self.state == "present":
|
||||
self._validate_parameters()
|
||||
|
||||
self.data = None
|
||||
self.original_data = None
|
||||
self.data: OpensshCertificate | None = None
|
||||
self.original_data: OpensshCertificate | None = None
|
||||
if self._exists():
|
||||
self._load_certificate()
|
||||
|
||||
self.time_parameters = None
|
||||
self.time_parameters: OpensshCertificateTimeParameters | None = None
|
||||
if self.state == "present":
|
||||
self._set_time_parameters()
|
||||
|
||||
def _validate_parameters(self):
|
||||
def _validate_parameters(self) -> None:
|
||||
for path in (self.public_key, self.signing_key):
|
||||
self._check_if_base_dir(path)
|
||||
if (
|
||||
path is not None
|
||||
): # should never be None, but the type checker doesn't know
|
||||
self._check_if_base_dir(path)
|
||||
|
||||
if self.options and self.type == "host":
|
||||
self.module.fail_json(
|
||||
@@ -352,7 +364,7 @@ class Certificate(OpensshModule):
|
||||
if self.use_agent:
|
||||
self._use_agent_available()
|
||||
|
||||
def _use_agent_available(self):
|
||||
def _use_agent_available(self) -> None:
|
||||
ssh_version = self._get_ssh_version()
|
||||
if not ssh_version:
|
||||
self.module.fail_json(msg="Failed to determine ssh version")
|
||||
@@ -362,10 +374,10 @@ class Certificate(OpensshModule):
|
||||
+ f" Your version is: {ssh_version}"
|
||||
)
|
||||
|
||||
def _exists(self):
|
||||
def _exists(self) -> bool:
|
||||
return os.path.exists(self.path)
|
||||
|
||||
def _load_certificate(self):
|
||||
def _load_certificate(self) -> None:
|
||||
try:
|
||||
self.original_data = OpensshCertificate.load(self.path)
|
||||
except (TypeError, ValueError) as e:
|
||||
@@ -373,7 +385,7 @@ class Certificate(OpensshModule):
|
||||
self.module.fail_json(msg=f"Unable to read existing certificate: {e}")
|
||||
self.module.warn(f"Unable to read existing certificate: {e}")
|
||||
|
||||
def _set_time_parameters(self):
|
||||
def _set_time_parameters(self) -> None:
|
||||
try:
|
||||
self.time_parameters = OpensshCertificateTimeParameters(
|
||||
valid_from=self.module.params["valid_from"],
|
||||
@@ -382,7 +394,7 @@ class Certificate(OpensshModule):
|
||||
except ValueError as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self) -> None:
|
||||
if self.state == "present":
|
||||
if self._should_generate():
|
||||
self._generate()
|
||||
@@ -391,7 +403,7 @@ class Certificate(OpensshModule):
|
||||
if self._exists():
|
||||
self._remove()
|
||||
|
||||
def _should_generate(self):
|
||||
def _should_generate(self) -> bool:
|
||||
if self.regenerate == "never":
|
||||
return self.original_data is None
|
||||
elif self.regenerate == "fail":
|
||||
@@ -408,7 +420,13 @@ class Certificate(OpensshModule):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _is_fully_valid(self):
|
||||
def _is_fully_valid(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
if self.public_key is None:
|
||||
raise AssertionError("Contract violation public_key not provided")
|
||||
if self.signing_key is None:
|
||||
raise AssertionError("Contract violation signing_key not provided")
|
||||
return self._is_partially_valid() and all(
|
||||
[
|
||||
self._compare_options() if self.original_data.type == "user" else True,
|
||||
@@ -420,7 +438,9 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _is_partially_valid(self):
|
||||
def _is_partially_valid(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
return all(
|
||||
[
|
||||
set(self.original_data.principals) == set(self.principals),
|
||||
@@ -439,7 +459,9 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _compare_time_parameters(self):
|
||||
def _compare_time_parameters(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
try:
|
||||
original_time_parameters = OpensshCertificateTimeParameters(
|
||||
valid_from=self.original_data.valid_after,
|
||||
@@ -458,7 +480,9 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _compare_options(self):
|
||||
def _compare_options(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
try:
|
||||
critical_options, extensions = parse_option_list(self.options)
|
||||
except ValueError as e:
|
||||
@@ -471,13 +495,13 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _get_key_fingerprint(self, path):
|
||||
def _get_key_fingerprint(self, path: str) -> str:
|
||||
private_key_content = self.ssh_keygen.get_private_key(path, check_rc=True)[1]
|
||||
return PrivateKey.from_string(private_key_content).fingerprint
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _generate(self):
|
||||
def _generate(self) -> None:
|
||||
try:
|
||||
temp_certificate = self._generate_temp_certificate()
|
||||
self._safe_secure_move([(temp_certificate, self.path)])
|
||||
@@ -491,7 +515,14 @@ class Certificate(OpensshModule):
|
||||
except (TypeError, ValueError) as e:
|
||||
self.module.fail_json(msg=f"Unable to read new certificate: {e}")
|
||||
|
||||
def _generate_temp_certificate(self):
|
||||
def _generate_temp_certificate(self) -> str:
|
||||
if self.public_key is None:
|
||||
raise AssertionError("Contract violation public_key not provided")
|
||||
if self.signing_key is None:
|
||||
raise AssertionError("Contract violation signing_key not provided")
|
||||
if self.time_parameters is None:
|
||||
raise AssertionError("Contract violation time_parameters not provided")
|
||||
|
||||
key_copy = os.path.join(self.module.tmpdir, os.path.basename(self.public_key))
|
||||
|
||||
try:
|
||||
@@ -523,14 +554,14 @@ class Certificate(OpensshModule):
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _remove(self):
|
||||
def _remove(self) -> None:
|
||||
try:
|
||||
os.remove(self.path)
|
||||
except OSError as e:
|
||||
self.module.fail_json(msg=f"Unable to remove existing certificate: {e}")
|
||||
|
||||
@property
|
||||
def _result(self):
|
||||
def _result(self) -> dict[str, t.Any]:
|
||||
if self.state != "present":
|
||||
return {}
|
||||
|
||||
@@ -546,14 +577,14 @@ class Certificate(OpensshModule):
|
||||
}
|
||||
|
||||
@property
|
||||
def diff(self):
|
||||
def diff(self) -> dict[str, t.Any]:
|
||||
return {
|
||||
"before": get_cert_dict(self.original_data),
|
||||
"after": get_cert_dict(self.data),
|
||||
}
|
||||
|
||||
|
||||
def format_cert_info(cert_info):
|
||||
def format_cert_info(cert_info: str) -> list[str]:
|
||||
result = []
|
||||
string = ""
|
||||
|
||||
@@ -579,7 +610,7 @@ def format_cert_info(cert_info):
|
||||
return result
|
||||
|
||||
|
||||
def get_cert_dict(data):
|
||||
def get_cert_dict(data: OpensshCertificate | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return {}
|
||||
|
||||
@@ -590,7 +621,7 @@ def get_cert_dict(data):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
force=dict(type="bool", default=False),
|
||||
|
||||
Reference in New Issue
Block a user