mirror of
https://github.com/ansible-collections/community.crypto.git
synced 2026-05-06 21:33:00 +00:00
Add type hints and type checking (#885)
* Enable basic type checking. * Fix first errors. * Add changelog fragment. * Add types to module_utils and plugin_utils (without module backends). * Add typing hints for acme_* modules. * Add typing to X.509 certificate modules, and add more helpers. * Add typing to remaining module backends. * Add typing for action, filter, and lookup plugins. * Bump ansible-core 2.19 beta requirement for typing. * Add more typing definitions. * Add typing to some unit tests.
This commit is contained in:
@@ -165,6 +165,7 @@ account_uri:
|
||||
"""
|
||||
|
||||
import base64
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
@@ -180,7 +181,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec()
|
||||
argument_spec.update_argspec(
|
||||
terms_agreed=dict(type="bool", default=False),
|
||||
@@ -204,24 +205,24 @@ def main():
|
||||
),
|
||||
)
|
||||
argument_spec.update(
|
||||
mutually_exclusive=(["new_account_key_src", "new_account_key_content"],),
|
||||
required_if=(
|
||||
mutually_exclusive=[("new_account_key_src", "new_account_key_content")],
|
||||
required_if=[
|
||||
# Make sure that for state == changed_key, one of
|
||||
# new_account_key_src and new_account_key_content are specified
|
||||
[
|
||||
(
|
||||
"state",
|
||||
"changed_key",
|
||||
["new_account_key_src", "new_account_key_content"],
|
||||
True,
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
module = argument_spec.create_ansible_module(supports_check_mode=True)
|
||||
backend = create_backend(module, True)
|
||||
|
||||
if module.params["external_account_binding"]:
|
||||
# Make sure padding is there
|
||||
key = module.params["external_account_binding"]["key"]
|
||||
key: str = module.params["external_account_binding"]["key"]
|
||||
if len(key) % 4 != 0:
|
||||
key = key + ("=" * (4 - (len(key) % 4)))
|
||||
# Make sure key is Base64 encoded
|
||||
@@ -237,24 +238,25 @@ def main():
|
||||
client = ACMEClient(module, backend)
|
||||
account = ACMEAccount(client)
|
||||
changed = False
|
||||
state = module.params.get("state")
|
||||
diff_before = {}
|
||||
diff_after = {}
|
||||
state: t.Literal["present", "absent", "changed_key"] = module.params["state"]
|
||||
diff_before: dict[str, t.Any] = {}
|
||||
diff_after: dict[str, t.Any] = {}
|
||||
if state == "absent":
|
||||
created, account_data = account.setup_account(allow_creation=False)
|
||||
if account_data:
|
||||
diff_before = dict(account_data)
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
if client.account_key_data:
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
if created:
|
||||
raise AssertionError("Unwanted account creation")
|
||||
if account_data is not None:
|
||||
# Account is not yet deactivated
|
||||
if not module.check_mode:
|
||||
# Deactivate it
|
||||
payload = {"status": "deactivated"}
|
||||
deactivate_payload = {"status": "deactivated"}
|
||||
result, info = client.send_signed_request(
|
||||
client.account_uri,
|
||||
payload,
|
||||
t.cast(str, client.account_uri),
|
||||
deactivate_payload,
|
||||
error_msg="Failed to deactivate account",
|
||||
expected_status_codes=[200],
|
||||
)
|
||||
@@ -278,13 +280,15 @@ def main():
|
||||
diff_before = {}
|
||||
else:
|
||||
diff_before = dict(account_data)
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
if client.account_key_data:
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
updated = False
|
||||
if not created:
|
||||
updated, account_data = account.update_account(account_data, contact)
|
||||
changed = created or updated
|
||||
diff_after = dict(account_data)
|
||||
diff_after["public_account_key"] = client.account_key_data["jwk"]
|
||||
if client.account_key_data:
|
||||
diff_after["public_account_key"] = client.account_key_data["jwk"]
|
||||
elif state == "changed_key":
|
||||
# Parse new account key
|
||||
try:
|
||||
@@ -306,7 +310,8 @@ def main():
|
||||
msg="Account does not exist or is deactivated."
|
||||
)
|
||||
diff_before = dict(account_data)
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
if client.account_key_data:
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
# Now we can start the account key rollover
|
||||
if not module.check_mode:
|
||||
# Compose inner signed message
|
||||
@@ -317,12 +322,12 @@ def main():
|
||||
"jwk": new_key_data["jwk"],
|
||||
"url": url,
|
||||
}
|
||||
payload = {
|
||||
change_key_payload = {
|
||||
"account": client.account_uri,
|
||||
"newKey": new_key_data["jwk"], # specified in draft 12 and older
|
||||
"oldKey": client.account_jwk, # specified in draft 13 and newer
|
||||
}
|
||||
data = client.sign_request(protected, payload, new_key_data)
|
||||
data = client.sign_request(protected, change_key_payload, new_key_data)
|
||||
# Send request and verify result
|
||||
result, info = client.send_signed_request(
|
||||
url,
|
||||
@@ -332,8 +337,9 @@ def main():
|
||||
)
|
||||
if module._diff:
|
||||
client.account_key_data = new_key_data
|
||||
client.account_jws_header["alg"] = new_key_data["alg"]
|
||||
diff_after = account.get_account_data()
|
||||
if client.account_jws_header:
|
||||
client.account_jws_header["alg"] = new_key_data["alg"]
|
||||
diff_after = account.get_account_data() or {}
|
||||
elif module._diff:
|
||||
# Kind of fake diff_after
|
||||
diff_after = dict(diff_before)
|
||||
|
||||
@@ -204,6 +204,8 @@ order_uris:
|
||||
version_added: 1.5.0
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
)
|
||||
@@ -220,48 +222,55 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
def get_orders_list(module, client, orders_url):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def get_orders_list(
|
||||
module: AnsibleModule, client: ACMEClient, orders_url: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Retrieves orders list (handles pagination).
|
||||
Retrieves order URL list (handles pagination).
|
||||
"""
|
||||
orders = []
|
||||
while orders_url:
|
||||
orders: list[str] = []
|
||||
next_orders_url: str | None = orders_url
|
||||
while next_orders_url:
|
||||
# Get part of orders list
|
||||
res, info = client.get_request(
|
||||
orders_url, parse_json_result=True, fail_on_error=True
|
||||
next_orders_url, parse_json_result=True, fail_on_error=True
|
||||
)
|
||||
if not res.get("orders"):
|
||||
if orders:
|
||||
module.warn(
|
||||
f"When retrieving orders list part {orders_url}, got empty result list"
|
||||
f"When retrieving orders list part {next_orders_url}, got empty result list"
|
||||
)
|
||||
break
|
||||
# Add order URLs to result list
|
||||
orders.extend(res["orders"])
|
||||
# Extract URL of next part of results list
|
||||
new_orders_url = []
|
||||
new_orders_url: list[str | None] = []
|
||||
|
||||
def f(link, relation):
|
||||
def f(link: str, relation: str) -> None:
|
||||
if relation == "next":
|
||||
new_orders_url.append(link)
|
||||
|
||||
process_links(info, f)
|
||||
new_orders_url.append(None)
|
||||
previous_orders_url, orders_url = orders_url, new_orders_url.pop(0)
|
||||
if orders_url == previous_orders_url:
|
||||
previous_orders_url, next_orders_url = next_orders_url, new_orders_url.pop(0)
|
||||
if next_orders_url == previous_orders_url:
|
||||
# Prevent infinite loop
|
||||
orders_url = None
|
||||
next_orders_url = None
|
||||
return orders
|
||||
|
||||
|
||||
def get_order(client, order_url):
|
||||
def get_order(client: ACMEClient, order_url: str) -> dict[str, t.Any]:
|
||||
"""
|
||||
Retrieve order data.
|
||||
"""
|
||||
return client.get_request(order_url, parse_json_result=True, fail_on_error=True)[0]
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec()
|
||||
argument_spec.update_argspec(
|
||||
retrieve_orders=dict(
|
||||
@@ -282,16 +291,19 @@ def main():
|
||||
)
|
||||
if created:
|
||||
raise AssertionError("Unwanted account creation")
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"changed": False,
|
||||
"exists": client.account_uri is not None,
|
||||
"account_uri": client.account_uri,
|
||||
"exists": False,
|
||||
"account_uri": None,
|
||||
}
|
||||
if client.account_uri is not None:
|
||||
if client.account_uri is not None and account_data:
|
||||
result["account_uri"] = client.account_uri
|
||||
result["exists"] = True
|
||||
# Make sure promised data is there
|
||||
if "contact" not in account_data:
|
||||
account_data["contact"] = []
|
||||
account_data["public_account_key"] = client.account_key_data["jwk"]
|
||||
if client.account_key_data:
|
||||
account_data["public_account_key"] = client.account_key_data["jwk"]
|
||||
result["account"] = account_data
|
||||
# Retrieve orders list
|
||||
if (
|
||||
|
||||
@@ -94,6 +94,8 @@ renewal_info:
|
||||
sample: '2024-04-29T01:17:10.236921+00:00'
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
ACMEClient,
|
||||
create_backend,
|
||||
@@ -104,15 +106,15 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_account=False)
|
||||
argument_spec.update_argspec(
|
||||
certificate_path=dict(type="path"),
|
||||
certificate_content=dict(type="str"),
|
||||
)
|
||||
argument_spec.update(
|
||||
required_one_of=(["certificate_path", "certificate_content"],),
|
||||
mutually_exclusive=(["certificate_path", "certificate_content"],),
|
||||
required_one_of=[("certificate_path", "certificate_content")],
|
||||
mutually_exclusive=[("certificate_path", "certificate_content")],
|
||||
)
|
||||
module = argument_spec.create_ansible_module(supports_check_mode=True)
|
||||
backend = create_backend(module, True)
|
||||
|
||||
@@ -562,6 +562,7 @@ all_chains:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
@@ -592,6 +593,17 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
CertificateInformation,
|
||||
CryptoBackend,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
|
||||
Authorization,
|
||||
)
|
||||
|
||||
|
||||
NO_CHALLENGE = "no challenge"
|
||||
|
||||
|
||||
@@ -602,7 +614,7 @@ class ACMECertificateClient:
|
||||
certificates.
|
||||
"""
|
||||
|
||||
def __init__(self, module, backend):
|
||||
def __init__(self, module: AnsibleModule, backend: CryptoBackend):
|
||||
self.module = module
|
||||
self.version = module.params["acme_version"]
|
||||
self.challenge = module.params["challenge"]
|
||||
@@ -618,9 +630,9 @@ class ACMECertificateClient:
|
||||
self.account = ACMEAccount(self.client)
|
||||
self.directory = self.client.directory
|
||||
self.data = module.params["data"]
|
||||
self.authorizations = None
|
||||
self.authorizations: dict[str, Authorization] | None = None
|
||||
self.cert_days = -1
|
||||
self.order = None
|
||||
self.order: Order | None = None
|
||||
self.order_uri = self.data.get("order_uri") if self.data else None
|
||||
self.all_chains = None
|
||||
self.select_chain_matcher = []
|
||||
@@ -662,7 +674,6 @@ class ACMECertificateClient:
|
||||
contact.append("mailto:" + module.params["account_email"])
|
||||
created, account_data = self.account.setup_account(
|
||||
contact,
|
||||
agreement=module.params.get("agreement"),
|
||||
terms_agreed=module.params.get("terms_agreed"),
|
||||
allow_creation=modify_account,
|
||||
)
|
||||
@@ -681,7 +692,7 @@ class ACMECertificateClient:
|
||||
csr_filename=self.csr, csr_content=self.csr_content
|
||||
)
|
||||
|
||||
def is_first_step(self):
|
||||
def is_first_step(self) -> bool:
|
||||
"""
|
||||
Return True if this is the first execution of this module, i.e. if a
|
||||
sufficient data object from a first run has not been provided.
|
||||
@@ -692,7 +703,7 @@ class ACMECertificateClient:
|
||||
# stored in self.order_uri by the constructor).
|
||||
return self.order_uri is None
|
||||
|
||||
def _get_cert_info_or_none(self):
|
||||
def _get_cert_info_or_none(self) -> CertificateInformation | None:
|
||||
if self.module.params.get("dest"):
|
||||
filename = self.module.params["dest"]
|
||||
else:
|
||||
@@ -701,7 +712,7 @@ class ACMECertificateClient:
|
||||
return None
|
||||
return self.client.backend.get_cert_information(cert_filename=filename)
|
||||
|
||||
def start_challenges(self):
|
||||
def start_challenges(self) -> None:
|
||||
"""
|
||||
Create new authorizations for all identifiers of the CSR,
|
||||
respectively start a new order for ACME v2.
|
||||
@@ -733,13 +744,16 @@ class ACMECertificateClient:
|
||||
self.authorizations.update(self.order.authorizations)
|
||||
self.changed = True
|
||||
|
||||
def get_challenges_data(self, first_step):
|
||||
def get_challenges_data(
|
||||
self, first_step: bool
|
||||
) -> tuple[dict[str, t.Any], dict[str, list[str]]]:
|
||||
"""
|
||||
Get challenge details for the chosen challenge type.
|
||||
Return a tuple of generic challenge details, and specialized DNS challenge details.
|
||||
"""
|
||||
# Get general challenge data
|
||||
data = {}
|
||||
assert self.authorizations is not None
|
||||
data: dict[str, t.Any] = {}
|
||||
data_dns: dict[str, list[str]] = {}
|
||||
for type_identifier, authz in self.authorizations.items():
|
||||
identifier_type, identifier = split_identifier(type_identifier)
|
||||
# Skip valid authentications: their challenges are already valid
|
||||
@@ -747,7 +761,9 @@ class ACMECertificateClient:
|
||||
if authz.status == "valid":
|
||||
continue
|
||||
# We drop the type from the key to preserve backwards compatibility
|
||||
data[authz.identifier] = authz.get_challenge_data(self.client)
|
||||
challenges = authz.get_challenge_data(self.client)
|
||||
assert authz.identifier is not None
|
||||
data[authz.identifier] = challenges
|
||||
if (
|
||||
first_step
|
||||
and self.challenge is not None
|
||||
@@ -756,10 +772,7 @@ class ACMECertificateClient:
|
||||
raise ModuleFailException(
|
||||
f"Found no challenge of type '{self.challenge}' for identifier {type_identifier}!"
|
||||
)
|
||||
# Get DNS challenge data
|
||||
data_dns = {}
|
||||
if self.challenge == "dns-01":
|
||||
for identifier, challenges in data.items():
|
||||
if self.challenge == "dns-01":
|
||||
if self.challenge in challenges:
|
||||
values = data_dns.get(challenges[self.challenge]["record"])
|
||||
if values is None:
|
||||
@@ -768,7 +781,7 @@ class ACMECertificateClient:
|
||||
values.append(challenges[self.challenge]["resource_value"])
|
||||
return data, data_dns
|
||||
|
||||
def finish_challenges(self):
|
||||
def finish_challenges(self) -> None:
|
||||
"""
|
||||
Verify challenges for all identifiers of the CSR.
|
||||
"""
|
||||
@@ -777,6 +790,7 @@ class ACMECertificateClient:
|
||||
# Step 1: obtain challenge information
|
||||
# For ACME v2, we obtain the order object by fetching the
|
||||
# order URI, and extract the information from there.
|
||||
assert self.order_uri is not None
|
||||
self.order = Order.from_url(self.client, self.order_uri)
|
||||
self.order.load_authorizations(self.client)
|
||||
self.authorizations.update(self.order.authorizations)
|
||||
@@ -799,7 +813,9 @@ class ACMECertificateClient:
|
||||
# Step 3: wait for authzs to validate
|
||||
wait_for_validation(authzs_to_wait_for, self.client)
|
||||
|
||||
def download_alternate_chains(self, cert):
|
||||
def download_alternate_chains(
|
||||
self, cert: CertificateChain
|
||||
) -> list[CertificateChain]:
|
||||
alternate_chains = []
|
||||
for alternate in cert.alternates:
|
||||
try:
|
||||
@@ -812,7 +828,9 @@ class ACMECertificateClient:
|
||||
alternate_chains.append(alt_cert)
|
||||
return alternate_chains
|
||||
|
||||
def find_matching_chain(self, chains):
|
||||
def find_matching_chain(
|
||||
self, chains: t.Iterable[CertificateChain]
|
||||
) -> CertificateChain | None:
|
||||
for criterium_idx, matcher in enumerate(self.select_chain_matcher):
|
||||
for chain in chains:
|
||||
if matcher.match(chain):
|
||||
@@ -822,12 +840,13 @@ class ACMECertificateClient:
|
||||
return chain
|
||||
return None
|
||||
|
||||
def get_certificate(self):
|
||||
def get_certificate(self) -> None:
|
||||
"""
|
||||
Request a new certificate and write it to the destination file.
|
||||
First verifies whether all authorizations are valid; if not, aborts
|
||||
with an error.
|
||||
"""
|
||||
assert self.authorizations is not None
|
||||
for identifier_type, identifier in self.identifiers:
|
||||
authz = self.authorizations.get(
|
||||
normalize_combined_identifier(
|
||||
@@ -844,7 +863,9 @@ class ACMECertificateClient:
|
||||
module=self.module,
|
||||
)
|
||||
|
||||
assert self.order is not None
|
||||
self.order.finalize(self.client, pem_to_der(self.csr, self.csr_content))
|
||||
assert self.order.certificate_uri is not None
|
||||
cert = CertificateChain.download(self.client, self.order.certificate_uri)
|
||||
if self.module.params["retrieve_all_alternates"] or self.select_chain_matcher:
|
||||
# Retrieve alternate chains
|
||||
@@ -887,12 +908,13 @@ class ACMECertificateClient:
|
||||
):
|
||||
self.changed = True
|
||||
|
||||
def deactivate_authzs(self):
|
||||
def deactivate_authzs(self) -> None:
|
||||
"""
|
||||
Deactivates all valid authz's. Does not raise exceptions.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
https://tools.ietf.org/html/rfc8555#section-7.5.2
|
||||
"""
|
||||
assert self.authorizations is not None
|
||||
for authz in self.authorizations.values():
|
||||
try:
|
||||
authz.deactivate(self.client)
|
||||
@@ -905,7 +927,7 @@ class ACMECertificateClient:
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_certificate=True)
|
||||
argument_spec.argument_spec["csr"]["aliases"] = ["src"]
|
||||
argument_spec.update_argspec(
|
||||
@@ -981,7 +1003,7 @@ def main():
|
||||
else:
|
||||
client = ACMECertificateClient(module, backend)
|
||||
client.cert_days = cert_days
|
||||
other = dict()
|
||||
other: dict[str, t.Any] = {}
|
||||
is_first_step = client.is_first_step()
|
||||
if is_first_step:
|
||||
# First run: start challenges / start new order
|
||||
@@ -998,6 +1020,7 @@ def main():
|
||||
client.deactivate_authzs()
|
||||
data, data_dns = client.get_challenges_data(first_step=is_first_step)
|
||||
auths = dict()
|
||||
assert client.authorizations is not None
|
||||
for k, v in client.authorizations.items():
|
||||
# Remove "type:" from key
|
||||
auths[v.identifier] = v.to_json()
|
||||
|
||||
@@ -51,6 +51,8 @@ EXAMPLES = r"""
|
||||
|
||||
RETURN = """#"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
)
|
||||
@@ -65,7 +67,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.orders import Order
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec()
|
||||
argument_spec.update_argspec(
|
||||
order_uri=dict(type="str", required=True),
|
||||
|
||||
@@ -371,6 +371,8 @@ account_uri:
|
||||
type: str
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
create_backend,
|
||||
create_default_argspec,
|
||||
@@ -383,7 +385,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_certificate=True)
|
||||
argument_spec.update_argspec(
|
||||
deactivate_authzs=dict(type="bool", default=True),
|
||||
|
||||
@@ -317,6 +317,8 @@ selected_chain:
|
||||
returned: always
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
create_backend,
|
||||
create_default_argspec,
|
||||
@@ -329,7 +331,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.certificates import (
|
||||
CertificateChain,
|
||||
)
|
||||
|
||||
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_certificate=True)
|
||||
argument_spec.update_argspec(
|
||||
order_uri=dict(type="str", required=True),
|
||||
@@ -375,6 +383,7 @@ def main():
|
||||
or module.params["retrieve_all_alternates"]
|
||||
)
|
||||
changed = False
|
||||
alternate_chains: list[CertificateChain] | None
|
||||
if order.status == "valid":
|
||||
# Step 2 and 3: download certificate(s) and chain(s)
|
||||
cert, alternate_chains = client.download_certificate(
|
||||
|
||||
@@ -357,6 +357,8 @@ authorizations_by_status:
|
||||
returned: always
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
create_backend,
|
||||
create_default_argspec,
|
||||
@@ -369,7 +371,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_certificate=False)
|
||||
argument_spec.update_argspec(
|
||||
order_uri=dict(type="str", required=True),
|
||||
@@ -381,8 +383,8 @@ def main():
|
||||
try:
|
||||
client = ACMECertificateClient(module, backend)
|
||||
order = client.load_order()
|
||||
authorizations_by_identifier = dict()
|
||||
authorizations_by_status = {
|
||||
authorizations_by_identifier: dict[str, dict[str, t.Any]] = {}
|
||||
authorizations_by_status: dict[str, list[str]] = {
|
||||
"pending": [],
|
||||
"invalid": [],
|
||||
"valid": [],
|
||||
@@ -392,7 +394,8 @@ def main():
|
||||
}
|
||||
for identifier, authz in order.authorizations.items():
|
||||
authorizations_by_identifier[identifier] = authz.to_json()
|
||||
authorizations_by_status[authz.status].append(identifier)
|
||||
if authz.status is not None:
|
||||
authorizations_by_status[authz.status].append(identifier)
|
||||
module.exit_json(
|
||||
changed=False,
|
||||
account_uri=client.client.account_uri,
|
||||
|
||||
@@ -229,6 +229,8 @@ validating_challenges:
|
||||
returned: always
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
create_backend,
|
||||
create_default_argspec,
|
||||
@@ -241,7 +243,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
|
||||
Authorization,
|
||||
)
|
||||
|
||||
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_certificate=False)
|
||||
argument_spec.update_argspec(
|
||||
order_uri=dict(type="str", required=True),
|
||||
@@ -271,10 +279,12 @@ def main():
|
||||
|
||||
missing_challenge_authzs = [k for k, v in challenges.items() if v is None]
|
||||
if missing_challenge_authzs:
|
||||
missing_challenge_authzs = ", ".join(sorted(missing_challenge_authzs))
|
||||
missing_challenge_authzs_str = ", ".join(
|
||||
sorted(missing_challenge_authzs)
|
||||
)
|
||||
raise ModuleFailException(
|
||||
"The challenge parameter must be supplied if there are pending authorizations."
|
||||
f" The following authorizations are pending: {missing_challenge_authzs}"
|
||||
f" The following authorizations are pending: {missing_challenge_authzs_str}"
|
||||
)
|
||||
|
||||
bad_challenge_authzs = [
|
||||
@@ -293,11 +303,13 @@ def main():
|
||||
f"The following authorizations do not support the selected challenges: {authz_challenges_pairs}"
|
||||
)
|
||||
|
||||
def is_pending(authz: Authorization) -> bool:
|
||||
challenge_name = challenges[authz.combined_identifier]
|
||||
challenge_obj = authz.find_challenge(challenge_name)
|
||||
return challenge_obj is not None and challenge_obj.status == "pending"
|
||||
|
||||
really_pending_authzs = [
|
||||
authz
|
||||
for authz in pending_authzs
|
||||
if authz.find_challenge(challenges[authz.combined_identifier]).status
|
||||
== "pending"
|
||||
authz for authz in pending_authzs if is_pending(authz)
|
||||
]
|
||||
|
||||
# Step 4: validate pending authorizations
|
||||
@@ -320,7 +332,7 @@ def main():
|
||||
identifier_type=authz.identifier_type,
|
||||
authz_url=authz.url,
|
||||
challenge_type=challenge_type,
|
||||
challenge_url=challenge.url,
|
||||
challenge_url=challenge.url if challenge else None,
|
||||
)
|
||||
for authz, challenge_type, challenge in authzs_with_challenges_to_wait_for
|
||||
],
|
||||
|
||||
@@ -160,6 +160,7 @@ cert_id:
|
||||
|
||||
import os
|
||||
import random
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
ACMEClient,
|
||||
@@ -175,7 +176,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_account=False)
|
||||
argument_spec.update_argspec(
|
||||
certificate_path=dict(type="path"),
|
||||
@@ -190,7 +191,7 @@ def main():
|
||||
treat_parsing_error_as_non_existing=dict(type="bool", default=False),
|
||||
)
|
||||
argument_spec.update(
|
||||
mutually_exclusive=(["certificate_path", "certificate_content"],),
|
||||
mutually_exclusive=[("certificate_path", "certificate_content")],
|
||||
)
|
||||
module = argument_spec.create_ansible_module(supports_check_mode=True)
|
||||
backend = create_backend(module, True)
|
||||
@@ -203,7 +204,7 @@ def main():
|
||||
supports_ari=False,
|
||||
)
|
||||
|
||||
def complete(should_renew, **kwargs):
|
||||
def complete(should_renew: bool, **kwargs) -> t.NoReturn:
|
||||
result["should_renew"] = should_renew
|
||||
result.update(kwargs)
|
||||
module.exit_json(**result)
|
||||
|
||||
@@ -110,6 +110,8 @@ EXAMPLES = r"""
|
||||
|
||||
RETURN = """#"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
)
|
||||
@@ -129,7 +131,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(require_account_key=False)
|
||||
argument_spec.update_argspec(
|
||||
private_key_src=dict(type="path"),
|
||||
@@ -139,22 +141,22 @@ def main():
|
||||
revoke_reason=dict(type="int"),
|
||||
)
|
||||
argument_spec.update(
|
||||
required_one_of=(
|
||||
[
|
||||
required_one_of=[
|
||||
(
|
||||
"account_key_src",
|
||||
"account_key_content",
|
||||
"private_key_src",
|
||||
"private_key_content",
|
||||
],
|
||||
),
|
||||
mutually_exclusive=(
|
||||
[
|
||||
),
|
||||
],
|
||||
mutually_exclusive=[
|
||||
(
|
||||
"account_key_src",
|
||||
"account_key_content",
|
||||
"private_key_src",
|
||||
"private_key_content",
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
module = argument_spec.create_ansible_module()
|
||||
backend = create_backend(module, False)
|
||||
@@ -164,9 +166,9 @@ def main():
|
||||
account = ACMEAccount(client)
|
||||
# Load certificate
|
||||
certificate = pem_to_der(module.params.get("certificate"))
|
||||
certificate = nopad_b64(certificate)
|
||||
certificate_b64 = nopad_b64(certificate)
|
||||
# Construct payload
|
||||
payload = {"certificate": certificate}
|
||||
payload = {"certificate": certificate_b64}
|
||||
if module.params.get("revoke_reason") is not None:
|
||||
payload["reason"] = module.params.get("revoke_reason")
|
||||
endpoint = client.directory["revokeCert"]
|
||||
|
||||
@@ -149,6 +149,7 @@ regular_certificate:
|
||||
import base64
|
||||
import datetime
|
||||
import ipaddress
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
@@ -173,10 +174,13 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
try:
|
||||
import cryptography
|
||||
import cryptography.hazmat.backends
|
||||
import cryptography.hazmat.primitives.asymmetric.dh
|
||||
import cryptography.hazmat.primitives.asymmetric.ec
|
||||
import cryptography.hazmat.primitives.asymmetric.padding
|
||||
import cryptography.hazmat.primitives.asymmetric.rsa
|
||||
import cryptography.hazmat.primitives.asymmetric.utils
|
||||
import cryptography.hazmat.primitives.asymmetric.x448
|
||||
import cryptography.hazmat.primitives.asymmetric.x25519
|
||||
import cryptography.hazmat.primitives.hashes
|
||||
import cryptography.hazmat.primitives.serialization
|
||||
import cryptography.x509
|
||||
@@ -186,7 +190,7 @@ except ImportError:
|
||||
|
||||
|
||||
# Convert byte string to ASN1 encoded octet string
|
||||
def encode_octet_string(octet_string):
|
||||
def encode_octet_string(octet_string: bytes) -> bytes:
|
||||
if len(octet_string) >= 128:
|
||||
raise ModuleFailException(
|
||||
"Cannot handle octet strings with more than 128 bytes"
|
||||
@@ -194,7 +198,7 @@ def encode_octet_string(octet_string):
|
||||
return bytes([0x4, len(octet_string)]) + octet_string
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
challenge=dict(type="str", required=True, choices=["tls-alpn-01"]),
|
||||
@@ -213,16 +217,16 @@ def main():
|
||||
|
||||
try:
|
||||
# Get parameters
|
||||
challenge = module.params["challenge"]
|
||||
challenge_data = module.params["challenge_data"]
|
||||
challenge: t.Literal["tls-alpn-01"] = module.params["challenge"]
|
||||
challenge_data: dict[str, t.Any] = module.params["challenge_data"]
|
||||
|
||||
# Get hold of private key
|
||||
private_key_content = module.params.get("private_key_content")
|
||||
private_key_passphrase = module.params.get("private_key_passphrase")
|
||||
if private_key_content is None:
|
||||
private_key_content_str: str | None = module.params["private_key_content"]
|
||||
private_key_passphrase: str | None = module.params["private_key_passphrase"]
|
||||
if private_key_content_str is None:
|
||||
private_key_content = read_file(module.params["private_key_src"])
|
||||
else:
|
||||
private_key_content = to_bytes(private_key_content)
|
||||
private_key_content = to_bytes(private_key_content_str)
|
||||
try:
|
||||
private_key = (
|
||||
cryptography.hazmat.primitives.serialization.load_pem_private_key(
|
||||
@@ -236,6 +240,17 @@ def main():
|
||||
)
|
||||
except Exception as e:
|
||||
raise ModuleFailException(f"Error while loading private key: {e}")
|
||||
if isinstance(
|
||||
private_key,
|
||||
(
|
||||
cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey,
|
||||
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
|
||||
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
|
||||
),
|
||||
):
|
||||
raise ModuleFailException(
|
||||
f"Cannot use private key type {type(private_key)}"
|
||||
)
|
||||
|
||||
# Some common attributes
|
||||
domain = to_text(challenge_data["resource"])
|
||||
@@ -246,6 +261,7 @@ def main():
|
||||
now = get_now_datetime(with_timezone=CRYPTOGRAPHY_TIMEZONE)
|
||||
not_valid_before = now
|
||||
not_valid_after = now + datetime.timedelta(days=10)
|
||||
san: cryptography.x509.GeneralName
|
||||
if identifier_type == "dns":
|
||||
san = cryptography.x509.DNSName(identifier)
|
||||
elif identifier_type == "ip":
|
||||
|
||||
@@ -223,6 +223,8 @@ output_json:
|
||||
- '...'
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
ACMEClient,
|
||||
@@ -235,7 +237,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(require_account_key=False)
|
||||
argument_spec.update_argspec(
|
||||
url=dict(type="str"),
|
||||
@@ -246,17 +248,17 @@ def main():
|
||||
fail_on_acme_error=dict(type="bool", default=True),
|
||||
)
|
||||
argument_spec.update(
|
||||
required_if=(
|
||||
["method", "get", ["url"]],
|
||||
["method", "post", ["url", "content"]],
|
||||
["method", "get", ["account_key_src", "account_key_content"], True],
|
||||
["method", "post", ["account_key_src", "account_key_content"], True],
|
||||
),
|
||||
required_if=[
|
||||
("method", "get", ["url"]),
|
||||
("method", "post", ["url", "content"]),
|
||||
("method", "get", ["account_key_src", "account_key_content"], True),
|
||||
("method", "post", ["account_key_src", "account_key_content"], True),
|
||||
],
|
||||
)
|
||||
module = argument_spec.create_ansible_module()
|
||||
backend = create_backend(module, False)
|
||||
|
||||
result = dict()
|
||||
result: dict[str, t.Any] = {}
|
||||
changed = False
|
||||
try:
|
||||
# Get hold of ACMEClient and ACMEAccount objects (includes directory)
|
||||
|
||||
@@ -121,6 +121,7 @@ complete_chain:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
@@ -153,14 +154,18 @@ class Certificate:
|
||||
Stores PEM with parsed certificate.
|
||||
"""
|
||||
|
||||
def __init__(self, pem, cert):
|
||||
def __init__(self, pem: str, cert: cryptography.x509.Certificate) -> None:
|
||||
if not (pem.endswith("\n") or pem.endswith("\r")):
|
||||
pem = pem + "\n"
|
||||
self.pem = pem
|
||||
self.cert = cert
|
||||
|
||||
|
||||
def is_parent(module, cert, potential_parent):
|
||||
def is_parent(
|
||||
module: AnsibleModule,
|
||||
cert: Certificate,
|
||||
potential_parent: Certificate,
|
||||
) -> bool:
|
||||
"""
|
||||
Tests whether the given certificate has been issued by the potential parent certificate.
|
||||
"""
|
||||
@@ -173,6 +178,10 @@ def is_parent(module, cert, potential_parent):
|
||||
if isinstance(
|
||||
public_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey
|
||||
):
|
||||
if cert.cert.signature_hash_algorithm is None:
|
||||
raise AssertionError(
|
||||
"signature_hash_algorithm should be present for RSA certificates"
|
||||
)
|
||||
public_key.verify(
|
||||
cert.cert.signature,
|
||||
cert.cert.tbs_certificate_bytes,
|
||||
@@ -183,6 +192,10 @@ def is_parent(module, cert, potential_parent):
|
||||
public_key,
|
||||
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey,
|
||||
):
|
||||
if cert.cert.signature_hash_algorithm is None:
|
||||
raise AssertionError(
|
||||
"signature_hash_algorithm should be present for EC certificates"
|
||||
)
|
||||
public_key.verify(
|
||||
cert.cert.signature,
|
||||
cert.cert.tbs_certificate_bytes,
|
||||
@@ -213,11 +226,16 @@ def is_parent(module, cert, potential_parent):
|
||||
module.fail_json(msg=f"Unknown error on signature validation: {e}")
|
||||
|
||||
|
||||
def parse_PEM_list(module, text, source, fail_on_error=True):
|
||||
def parse_PEM_list(
|
||||
module: AnsibleModule,
|
||||
text: str,
|
||||
source: str | os.PathLike,
|
||||
fail_on_error: bool = True,
|
||||
) -> list[Certificate]:
|
||||
"""
|
||||
Parse concatenated PEM certificates. Return list of ``Certificate`` objects.
|
||||
"""
|
||||
result = []
|
||||
result: list[Certificate] = []
|
||||
for cert_pem in split_pem_list(text):
|
||||
# Try to load PEM certificate
|
||||
try:
|
||||
@@ -232,7 +250,9 @@ def parse_PEM_list(module, text, source, fail_on_error=True):
|
||||
return result
|
||||
|
||||
|
||||
def load_PEM_list(module, path, fail_on_error=True):
|
||||
def load_PEM_list(
|
||||
module: AnsibleModule, path: str | os.PathLike, fail_on_error: bool = True
|
||||
) -> list[Certificate]:
|
||||
"""
|
||||
Load concatenated PEM certificates from file. Return list of ``Certificate`` objects.
|
||||
"""
|
||||
@@ -258,13 +278,15 @@ class CertificateSet:
|
||||
Stores a set of certificates. Allows to search for parent (issuer of a certificate).
|
||||
"""
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.certificates = set()
|
||||
self.certificates_by_issuer = dict()
|
||||
self.certificate_by_cert = dict()
|
||||
self.certificates: set[Certificate] = set()
|
||||
self.certificates_by_issuer: dict[cryptography.x509.Name, list[Certificate]] = (
|
||||
{}
|
||||
)
|
||||
self.certificate_by_cert: dict[cryptography.x509.Certificate, Certificate] = {}
|
||||
|
||||
def _load_file(self, path):
|
||||
def _load_file(self, path: str | os.PathLike) -> None:
|
||||
certs = load_PEM_list(self.module, path, fail_on_error=False)
|
||||
for cert in certs:
|
||||
self.certificates.add(cert)
|
||||
@@ -273,7 +295,7 @@ class CertificateSet:
|
||||
self.certificates_by_issuer[cert.cert.subject].append(cert)
|
||||
self.certificate_by_cert[cert.cert] = cert
|
||||
|
||||
def load(self, path):
|
||||
def load(self, path: str | os.PathLike) -> None:
|
||||
"""
|
||||
Load lists of PEM certificates from a file or a directory.
|
||||
"""
|
||||
@@ -285,7 +307,7 @@ class CertificateSet:
|
||||
else:
|
||||
self._load_file(b_path)
|
||||
|
||||
def find_parent(self, cert):
|
||||
def find_parent(self, cert: Certificate) -> Certificate | None:
|
||||
"""
|
||||
Search for the parent (issuer) of a certificate. Return ``None`` if none was found.
|
||||
"""
|
||||
@@ -296,14 +318,18 @@ class CertificateSet:
|
||||
return None
|
||||
|
||||
|
||||
def format_cert(cert):
|
||||
def format_cert(cert: Certificate) -> str:
|
||||
"""
|
||||
Return human readable representation of certificate for error messages.
|
||||
"""
|
||||
return str(cert.cert)
|
||||
|
||||
|
||||
def check_cycle(module, occured_certificates, next):
|
||||
def check_cycle(
|
||||
module: AnsibleModule,
|
||||
occured_certificates: set[cryptography.x509.Certificate],
|
||||
next: Certificate,
|
||||
) -> None:
|
||||
"""
|
||||
Make sure that next is not in occured_certificates so far, and add it.
|
||||
"""
|
||||
@@ -313,7 +339,7 @@ def check_cycle(module, occured_certificates, next):
|
||||
occured_certificates.add(next_cert)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
input_chain=dict(type="str", required=True),
|
||||
@@ -354,10 +380,10 @@ def main():
|
||||
roots.load(path)
|
||||
|
||||
# Try to complete chain
|
||||
current = chain[-1]
|
||||
current: Certificate | None = chain[-1]
|
||||
completed = []
|
||||
occured_certificates = set([cert.cert for cert in chain])
|
||||
if current.cert in roots.certificate_by_cert:
|
||||
if current and current.cert in roots.certificate_by_cert:
|
||||
# Do not try to complete the chain when it is already ending with a root certificate
|
||||
current = None
|
||||
while current:
|
||||
|
||||
@@ -152,10 +152,13 @@ openssl:
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
CRYPTOGRAPHY_VERSION: str | None
|
||||
CRYPTOGRAPHY_IMP_ERR: str | None
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography.exceptions import UnsupportedAlgorithm
|
||||
@@ -165,10 +168,10 @@ try:
|
||||
# only got added in 0.2, so let's guard the import
|
||||
from cryptography.exceptions import InternalError as CryptographyInternalError
|
||||
except ImportError:
|
||||
CryptographyInternalError = Exception
|
||||
CryptographyInternalError = Exception # type: ignore
|
||||
except ImportError:
|
||||
UnsupportedAlgorithm = Exception
|
||||
CryptographyInternalError = Exception
|
||||
UnsupportedAlgorithm = Exception # type: ignore
|
||||
CryptographyInternalError = Exception # type: ignore
|
||||
HAS_CRYPTOGRAPHY = False
|
||||
CRYPTOGRAPHY_VERSION = None
|
||||
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
|
||||
@@ -201,8 +204,8 @@ CURVES = (
|
||||
)
|
||||
|
||||
|
||||
def add_crypto_information(module):
|
||||
result = {}
|
||||
def add_crypto_information(module: AnsibleModule) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
result["python_cryptography_installed"] = HAS_CRYPTOGRAPHY
|
||||
if not HAS_CRYPTOGRAPHY:
|
||||
result["python_cryptography_import_error"] = CRYPTOGRAPHY_IMP_ERR
|
||||
@@ -397,9 +400,9 @@ def add_crypto_information(module):
|
||||
return result
|
||||
|
||||
|
||||
def add_openssl_information(module):
|
||||
def add_openssl_information(module: AnsibleModule) -> dict[str, t.Any]:
|
||||
openssl_binary = module.get_bin_path("openssl")
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"openssl_present": openssl_binary is not None,
|
||||
}
|
||||
if openssl_binary is None:
|
||||
@@ -426,9 +429,9 @@ INFO_FUNCTIONS = (
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(argument_spec={}, supports_check_mode=True)
|
||||
result = {}
|
||||
result: dict[str, t.Any] = {}
|
||||
for fn in INFO_FUNCTIONS:
|
||||
result.update(fn(module))
|
||||
module.exit_json(**result)
|
||||
|
||||
@@ -550,6 +550,7 @@ import datetime
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
@@ -572,7 +573,7 @@ from ansible_collections.community.crypto.plugins.module_utils.io import write_f
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
|
||||
def validate_cert_expiry(cert_expiry):
|
||||
def validate_cert_expiry(cert_expiry: str) -> bool:
|
||||
search_string_partial = re.compile(
|
||||
r"^([0-9]+)-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01])\Z"
|
||||
)
|
||||
@@ -587,7 +588,7 @@ def validate_cert_expiry(cert_expiry):
|
||||
return False
|
||||
|
||||
|
||||
def calculate_cert_days(expires_after):
|
||||
def calculate_cert_days(expires_after: str | None) -> int:
|
||||
cert_days = 0
|
||||
if expires_after:
|
||||
expires_after_datetime = datetime.datetime.strptime(
|
||||
@@ -600,7 +601,9 @@ def calculate_cert_days(expires_after):
|
||||
# Populate the value of body[dict_param_name] with the JSON equivalent of
|
||||
# module parameter of param_name if that parameter is present, otherwise leave field
|
||||
# out of resulting dict
|
||||
def convert_module_param_to_json_bool(module, dict_param_name, param_name):
|
||||
def convert_module_param_to_json_bool(
|
||||
module: AnsibleModule, dict_param_name: str, param_name: str
|
||||
) -> dict[str, str]:
|
||||
body = {}
|
||||
if module.params[param_name] is not None:
|
||||
if module.params[param_name]:
|
||||
@@ -886,7 +889,7 @@ class EcsCertificate:
|
||||
return result
|
||||
|
||||
|
||||
def custom_fields_spec():
|
||||
def custom_fields_spec() -> dict[str, dict[str, str]]:
|
||||
return dict(
|
||||
text1=dict(type="str"),
|
||||
text2=dict(type="str"),
|
||||
@@ -926,7 +929,7 @@ def custom_fields_spec():
|
||||
)
|
||||
|
||||
|
||||
def ecs_certificate_argument_spec():
|
||||
def ecs_certificate_argument_spec() -> dict[str, dict[str, t.Any]]:
|
||||
return dict(
|
||||
backup=dict(type="bool", default=False),
|
||||
force=dict(type="bool", default=False),
|
||||
@@ -979,7 +982,7 @@ def ecs_certificate_argument_spec():
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
ecs_argument_spec = ecs_client_argument_spec()
|
||||
ecs_argument_spec.update(ecs_certificate_argument_spec())
|
||||
module = AnsibleModule(
|
||||
|
||||
@@ -218,6 +218,7 @@ ev_days_remaining:
|
||||
|
||||
import datetime
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.ecs.api import (
|
||||
@@ -228,7 +229,7 @@ from ansible_collections.community.crypto.plugins.module_utils.ecs.api import (
|
||||
)
|
||||
|
||||
|
||||
def calculate_days_remaining(expiry_date):
|
||||
def calculate_days_remaining(expiry_date: str | None) -> int | None:
|
||||
days_remaining = None
|
||||
if expiry_date:
|
||||
expiry_datetime = datetime.datetime.strptime(expiry_date, "%Y-%m-%dT%H:%M:%SZ")
|
||||
@@ -403,8 +404,8 @@ class EcsDomain:
|
||||
msg=f"Failed to request domain validation from Entrust (ECS) {e.message}"
|
||||
)
|
||||
|
||||
def dump(self):
|
||||
result = {
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {
|
||||
"changed": self.changed,
|
||||
"client_id": self.client_id,
|
||||
"domain_status": self.domain_status,
|
||||
@@ -436,7 +437,7 @@ class EcsDomain:
|
||||
return result
|
||||
|
||||
|
||||
def ecs_domain_argument_spec():
|
||||
def ecs_domain_argument_spec() -> dict[str, dict[str, t.Any]]:
|
||||
return dict(
|
||||
client_id=dict(type="int", default=1),
|
||||
domain_name=dict(type="str", required=True),
|
||||
@@ -447,7 +448,7 @@ def ecs_domain_argument_spec():
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
ecs_argument_spec = ecs_client_argument_spec()
|
||||
ecs_argument_spec.update(ecs_domain_argument_spec())
|
||||
module = AnsibleModule(
|
||||
|
||||
@@ -268,6 +268,7 @@ import atexit
|
||||
import base64
|
||||
import ssl
|
||||
import sys
|
||||
import typing as t
|
||||
from os.path import isfile
|
||||
from socket import create_connection, setdefaulttimeout, socket
|
||||
from ssl import (
|
||||
@@ -305,7 +306,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def send_starttls_packet(sock, server_type):
|
||||
def send_starttls_packet(sock: socket, server_type: t.Literal["mysql"]) -> None:
|
||||
if server_type == "mysql":
|
||||
ssl_request_packet = (
|
||||
b"\x20\x00\x00\x01\x85\xae\x7f\x00"
|
||||
@@ -321,7 +322,7 @@ def send_starttls_packet(sock, server_type):
|
||||
sock.send(ssl_request_packet)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
ca_cert=dict(type="path"),
|
||||
@@ -342,18 +343,18 @@ def main():
|
||||
),
|
||||
)
|
||||
|
||||
ca_cert = module.params.get("ca_cert")
|
||||
host = module.params.get("host")
|
||||
port = module.params.get("port")
|
||||
proxy_host = module.params.get("proxy_host")
|
||||
proxy_port = module.params.get("proxy_port")
|
||||
timeout = module.params.get("timeout")
|
||||
server_name = module.params.get("server_name")
|
||||
start_tls_server_type = module.params.get("starttls")
|
||||
ciphers = module.params.get("ciphers")
|
||||
asn1_base64 = module.params["asn1_base64"]
|
||||
tls_ctx_options = module.params["tls_ctx_options"]
|
||||
get_certificate_chain = module.params["get_certificate_chain"]
|
||||
ca_cert: str | None = module.params.get("ca_cert")
|
||||
host: str = module.params.get("host")
|
||||
port: int = module.params.get("port")
|
||||
proxy_host: str | None = module.params.get("proxy_host")
|
||||
proxy_port: int | None = module.params.get("proxy_port")
|
||||
timeout: int = module.params.get("timeout")
|
||||
server_name: str | None = module.params.get("server_name")
|
||||
start_tls_server_type: t.Literal["mysql"] | None = module.params.get("starttls")
|
||||
ciphers: list[str] | None = module.params.get("ciphers")
|
||||
asn1_base64: bool = module.params["asn1_base64"]
|
||||
tls_ctx_options: list[str | bytes | int] | None = module.params["tls_ctx_options"]
|
||||
get_certificate_chain: bool = module.params["get_certificate_chain"]
|
||||
|
||||
if get_certificate_chain and sys.version_info < (3, 10):
|
||||
module.fail_json(
|
||||
@@ -365,9 +366,9 @@ def main():
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
result = dict(
|
||||
changed=False,
|
||||
)
|
||||
result: dict[str, t.Any] = {
|
||||
"changed": False,
|
||||
}
|
||||
|
||||
if timeout:
|
||||
setdefaulttimeout(timeout)
|
||||
@@ -409,7 +410,7 @@ def main():
|
||||
|
||||
if tls_ctx_options is not None:
|
||||
# Clear default ctx options
|
||||
ctx.options = 0
|
||||
ctx.options = 0 # type: ignore
|
||||
|
||||
# For each item in the tls_ctx_options list
|
||||
for tls_ctx_option in tls_ctx_options:
|
||||
@@ -450,8 +451,10 @@ def main():
|
||||
)
|
||||
|
||||
tls_sock = ctx.wrap_socket(sock, server_hostname=server_name or host)
|
||||
cert = tls_sock.getpeercert(True)
|
||||
cert = DER_cert_to_PEM_cert(cert)
|
||||
cert_der = tls_sock.getpeercert(True)
|
||||
if cert_der is None:
|
||||
raise Exception("Unexpected error: no peer certificate has been returned")
|
||||
cert: str = DER_cert_to_PEM_cert(cert_der)
|
||||
|
||||
if get_certificate_chain:
|
||||
if sys.version_info < (3, 13):
|
||||
@@ -474,7 +477,7 @@ def main():
|
||||
# Python 3.13 do not return lists of byte strings, but lists of _ssl.Certificate objects. This is going to
|
||||
# be fixed by https://github.com/python/cpython/pull/118669. For now we convert the certificates ourselves
|
||||
# if they are not byte strings to work around this.
|
||||
def _convert_chain(chain):
|
||||
def _convert_chain(chain: list[bytes]) -> list[bytes]:
|
||||
return [
|
||||
(
|
||||
c
|
||||
@@ -514,13 +517,13 @@ def main():
|
||||
result["extensions"] = []
|
||||
for dotted_number, entry in cryptography_get_extensions_from_cert(x509).items():
|
||||
oid = cryptography.x509.oid.ObjectIdentifier(dotted_number)
|
||||
ext = {
|
||||
ext: dict[str, t.Any] = {
|
||||
"critical": entry["critical"],
|
||||
"asn1_data": entry["value"],
|
||||
"name": cryptography_oid_to_name(oid, short=True),
|
||||
}
|
||||
if not asn1_base64:
|
||||
ext["asn1_data"] = base64.b64decode(ext["asn1_data"])
|
||||
ext["asn1_data"] = base64.b64decode(entry["value"]) # type: ignore
|
||||
result["extensions"].append(ext)
|
||||
|
||||
result["issuer"] = {}
|
||||
|
||||
@@ -420,16 +420,13 @@ name:
|
||||
import os
|
||||
import re
|
||||
import stat
|
||||
import typing as t
|
||||
from base64 import b64decode
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
|
||||
|
||||
RETURN_CODE = 0
|
||||
STDOUT = 1
|
||||
STDERR = 2
|
||||
|
||||
# used to get <luks-name> out of lsblk output in format 'crypt <luks-name>'
|
||||
# regex takes care of any possible blank characters
|
||||
LUKS_NAME_REGEX = re.compile(r"^crypt\s+([^\s]*)\s*$")
|
||||
@@ -456,7 +453,7 @@ LUKS2_HEADER_OFFSETS = [
|
||||
LUKS2_HEADER2 = b"SKUL\xba\xbe"
|
||||
|
||||
|
||||
def wipe_luks_headers(device):
|
||||
def wipe_luks_headers(device: str) -> None:
|
||||
wipe_offsets = []
|
||||
with open(device, "rb") as f:
|
||||
# f.seek(0)
|
||||
@@ -478,12 +475,12 @@ def wipe_luks_headers(device):
|
||||
|
||||
class Handler:
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self._module = module
|
||||
self._lsblk_bin = self._module.get_bin_path("lsblk", True)
|
||||
self._passphrase_encoding = module.params["passphrase_encoding"]
|
||||
|
||||
def get_passphrase_from_module_params(self, parameter_name):
|
||||
def get_passphrase_from_module_params(self, parameter_name: str) -> bytes | None:
|
||||
passphrase = self._module.params[parameter_name]
|
||||
if passphrase is None:
|
||||
return None
|
||||
@@ -496,88 +493,91 @@ class Handler:
|
||||
f"Error while base64-decoding '{parameter_name}': {exc}"
|
||||
)
|
||||
|
||||
def _run_command(self, command, data=None):
|
||||
def _run_command(
|
||||
self, command: list[str], data: bytes | None = None
|
||||
) -> tuple[int, str, str]:
|
||||
return self._module.run_command(command, data=data, binary_data=True)
|
||||
|
||||
def get_device_by_uuid(self, uuid):
|
||||
def get_device_by_uuid(self, uuid: str | None) -> str | None:
|
||||
"""Returns the device that holds UUID passed by user"""
|
||||
self._blkid_bin = self._module.get_bin_path("blkid", True)
|
||||
uuid = self._module.params["uuid"]
|
||||
if uuid is None:
|
||||
return None
|
||||
result = self._run_command([self._blkid_bin, "--uuid", uuid])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, stdout, dummy = self._run_command([self._blkid_bin, "--uuid", uuid])
|
||||
if rc != 0:
|
||||
return None
|
||||
return result[STDOUT].strip()
|
||||
return stdout.strip()
|
||||
|
||||
def get_device_by_label(self, label):
|
||||
def get_device_by_label(self, label: str) -> str | None:
|
||||
"""Returns the device that holds label passed by user"""
|
||||
self._blkid_bin = self._module.get_bin_path("blkid", True)
|
||||
label = self._module.params["label"]
|
||||
if label is None:
|
||||
return None
|
||||
result = self._run_command([self._blkid_bin, "--label", label])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, stdout, dummy = self._run_command([self._blkid_bin, "--label", label])
|
||||
if rc != 0:
|
||||
return None
|
||||
return result[STDOUT].strip()
|
||||
return stdout.strip()
|
||||
|
||||
def generate_luks_name(self, device):
|
||||
def generate_luks_name(self, device: str) -> str:
|
||||
"""Generate name for luks based on device UUID ('luks-<UUID>').
|
||||
Raises ValueError when obtaining of UUID fails.
|
||||
"""
|
||||
result = self._run_command([self._lsblk_bin, "-n", device, "-o", "UUID"])
|
||||
rc, stdout, stderr = self._run_command(
|
||||
[self._lsblk_bin, "-n", device, "-o", "UUID"]
|
||||
)
|
||||
|
||||
if result[RETURN_CODE] != 0:
|
||||
raise ValueError(
|
||||
f"Error while generating LUKS name for {device}: {result[STDERR]}"
|
||||
)
|
||||
dev_uuid = result[STDOUT].strip()
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while generating LUKS name for {device}: {stderr}")
|
||||
dev_uuid = stdout.strip()
|
||||
return f"luks-{dev_uuid}"
|
||||
|
||||
|
||||
class CryptHandler(Handler):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CryptHandler, self).__init__(module)
|
||||
self._cryptsetup_bin = self._module.get_bin_path("cryptsetup", True)
|
||||
|
||||
def get_container_name_by_device(self, device):
|
||||
def get_container_name_by_device(self, device: str) -> str | None:
|
||||
"""obtain LUKS container name based on the device where it is located
|
||||
return None if not found
|
||||
raise ValueError if lsblk command fails
|
||||
"""
|
||||
result = self._run_command([self._lsblk_bin, device, "-nlo", "type,name"])
|
||||
if result[RETURN_CODE] != 0:
|
||||
raise ValueError(
|
||||
f"Error while obtaining LUKS name for {device}: {result[STDERR]}"
|
||||
)
|
||||
rc, stdout, stderr = self._run_command(
|
||||
[self._lsblk_bin, device, "-nlo", "type,name"]
|
||||
)
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while obtaining LUKS name for {device}: {stderr}")
|
||||
|
||||
for line in result[STDOUT].splitlines(False):
|
||||
for line in stdout.splitlines(False):
|
||||
m = LUKS_NAME_REGEX.match(line)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return None
|
||||
|
||||
def get_container_device_by_name(self, name):
|
||||
def get_container_device_by_name(self, name: str) -> str | None:
|
||||
"""obtain device name based on the LUKS container name
|
||||
return None if not found
|
||||
raise ValueError if lsblk command fails
|
||||
"""
|
||||
# apparently each device can have only one LUKS container on it
|
||||
result = self._run_command([self._cryptsetup_bin, "status", name])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, stdout, dummy = self._run_command([self._cryptsetup_bin, "status", name])
|
||||
if rc != 0:
|
||||
return None
|
||||
|
||||
m = LUKS_DEVICE_REGEX.search(result[STDOUT])
|
||||
m = LUKS_DEVICE_REGEX.search(stdout)
|
||||
if not m:
|
||||
return None
|
||||
device = m.group(1)
|
||||
return device
|
||||
|
||||
def is_luks(self, device):
|
||||
def is_luks(self, device: str) -> bool:
|
||||
"""check if the LUKS container does exist"""
|
||||
result = self._run_command([self._cryptsetup_bin, "isLuks", device])
|
||||
return result[RETURN_CODE] == 0
|
||||
rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "isLuks", device])
|
||||
return rc == 0
|
||||
|
||||
def get_luks_type(self, device):
|
||||
def get_luks_type(self, device: str) -> t.Literal["luks1", "luks2"] | None:
|
||||
"""get the luks type of a device"""
|
||||
if self.is_luks(device):
|
||||
with open(device, "rb") as f:
|
||||
@@ -589,16 +589,18 @@ class CryptHandler(Handler):
|
||||
return "luks1"
|
||||
return None
|
||||
|
||||
def is_luks_slot_set(self, device, keyslot):
|
||||
def is_luks_slot_set(self, device: str, keyslot: int) -> bool:
|
||||
"""check if a keyslot is set"""
|
||||
result = self._run_command([self._cryptsetup_bin, "luksDump", device])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, stdout, dummy = self._run_command(
|
||||
[self._cryptsetup_bin, "luksDump", device]
|
||||
)
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while dumping LUKS header from {device}")
|
||||
result_luks1 = f"Key Slot {keyslot}: ENABLED" in result[STDOUT]
|
||||
result_luks2 = f" {keyslot}: luks2" in result[STDOUT]
|
||||
result_luks1 = f"Key Slot {keyslot}: ENABLED" in stdout
|
||||
result_luks2 = f" {keyslot}: luks2" in stdout
|
||||
return result_luks1 or result_luks2
|
||||
|
||||
def _add_pbkdf_options(self, options, pbkdf):
|
||||
def _add_pbkdf_options(self, options: list[str], pbkdf: dict[str, t.Any]) -> None:
|
||||
if pbkdf["iteration_time"] is not None:
|
||||
options.extend(["--iter-time", str(int(pbkdf["iteration_time"] * 1000))])
|
||||
if pbkdf["iteration_count"] is not None:
|
||||
@@ -612,16 +614,16 @@ class CryptHandler(Handler):
|
||||
|
||||
def run_luks_create(
|
||||
self,
|
||||
device,
|
||||
keyfile,
|
||||
passphrase,
|
||||
keyslot,
|
||||
keysize,
|
||||
cipher,
|
||||
hash_,
|
||||
sector_size,
|
||||
pbkdf,
|
||||
):
|
||||
device: str,
|
||||
keyfile: str | None,
|
||||
passphrase: bytes | None,
|
||||
keyslot: int | None,
|
||||
keysize: int | None,
|
||||
cipher: str | None,
|
||||
hash_: str | None,
|
||||
sector_size: str | None,
|
||||
pbkdf: dict[str, t.Any] | None,
|
||||
) -> None:
|
||||
# create a new luks container; use batch mode to auto confirm
|
||||
luks_type = self._module.params["type"]
|
||||
label = self._module.params["label"]
|
||||
@@ -653,23 +655,23 @@ class CryptHandler(Handler):
|
||||
else:
|
||||
args.append("-")
|
||||
|
||||
result = self._run_command(args, data=passphrase)
|
||||
if result[RETURN_CODE] != 0:
|
||||
raise ValueError(f"Error while creating LUKS on {device}: {result[STDERR]}")
|
||||
rc, dummy, stderr = self._run_command(args, data=passphrase)
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while creating LUKS on {device}: {stderr}")
|
||||
|
||||
def run_luks_open(
|
||||
self,
|
||||
device,
|
||||
keyfile,
|
||||
passphrase,
|
||||
perf_same_cpu_crypt,
|
||||
perf_submit_from_crypt_cpus,
|
||||
perf_no_read_workqueue,
|
||||
perf_no_write_workqueue,
|
||||
persistent,
|
||||
allow_discards,
|
||||
name,
|
||||
):
|
||||
device: str,
|
||||
keyfile: str | None,
|
||||
passphrase: bytes | None,
|
||||
perf_same_cpu_crypt: bool,
|
||||
perf_submit_from_crypt_cpus: bool,
|
||||
perf_no_read_workqueue: bool,
|
||||
perf_no_write_workqueue: bool,
|
||||
persistent: bool,
|
||||
allow_discards: bool,
|
||||
name: str,
|
||||
) -> None:
|
||||
args = [self._cryptsetup_bin]
|
||||
if keyfile:
|
||||
args.extend(["--key-file", keyfile])
|
||||
@@ -689,27 +691,27 @@ class CryptHandler(Handler):
|
||||
args.extend(["--allow-discards"])
|
||||
args.extend(["open", "--type", "luks", device, name])
|
||||
|
||||
result = self._run_command(args, data=passphrase)
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, dummy, stderr = self._run_command(args, data=passphrase)
|
||||
if rc != 0:
|
||||
raise ValueError(
|
||||
f"Error while opening LUKS container on {device}: {result[STDERR]}"
|
||||
f"Error while opening LUKS container on {device}: {stderr}"
|
||||
)
|
||||
|
||||
def run_luks_close(self, name):
|
||||
result = self._run_command([self._cryptsetup_bin, "close", name])
|
||||
if result[RETURN_CODE] != 0:
|
||||
def run_luks_close(self, name: str) -> None:
|
||||
rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "close", name])
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while closing LUKS container {name}")
|
||||
|
||||
def run_luks_remove(self, device):
|
||||
def run_luks_remove(self, device: str) -> None:
|
||||
wipefs_bin = self._module.get_bin_path("wipefs", True)
|
||||
|
||||
name = self.get_container_name_by_device(device)
|
||||
if name is not None:
|
||||
self.run_luks_close(name)
|
||||
result = self._run_command([wipefs_bin, "--all", device])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, dummy, stderr = self._run_command([wipefs_bin, "--all", device])
|
||||
if rc != 0:
|
||||
raise ValueError(
|
||||
f"Error while wiping LUKS container signatures for {device}: {result[STDERR]}"
|
||||
f"Error while wiping LUKS container signatures for {device}: {stderr}"
|
||||
)
|
||||
|
||||
# For LUKS2, sometimes both `cryptsetup erase` and `wipefs` do **not**
|
||||
@@ -724,14 +726,14 @@ class CryptHandler(Handler):
|
||||
|
||||
def run_luks_add_key(
|
||||
self,
|
||||
device,
|
||||
keyfile,
|
||||
passphrase,
|
||||
new_keyfile,
|
||||
new_passphrase,
|
||||
new_keyslot,
|
||||
pbkdf,
|
||||
):
|
||||
device: str,
|
||||
keyfile: str | None,
|
||||
passphrase: bytes | None,
|
||||
new_keyfile: str | None,
|
||||
new_passphrase: bytes | None,
|
||||
new_keyslot: int | None,
|
||||
pbkdf: dict[str, t.Any] | None,
|
||||
) -> None:
|
||||
"""Add new key from a keyfile or passphrase to given 'device';
|
||||
authentication done using 'keyfile' or 'passphrase'.
|
||||
Raises ValueError when command fails.
|
||||
@@ -746,36 +748,47 @@ class CryptHandler(Handler):
|
||||
|
||||
if keyfile:
|
||||
args.extend(["--key-file", keyfile])
|
||||
else:
|
||||
elif passphrase is not None:
|
||||
args.extend(["--key-file", "-", "--keyfile-size", str(len(passphrase))])
|
||||
data.append(passphrase)
|
||||
else:
|
||||
raise ValueError("Need passphrase or keyfile")
|
||||
|
||||
if new_keyfile:
|
||||
args.append(new_keyfile)
|
||||
else:
|
||||
elif new_passphrase is not None:
|
||||
args.append("-")
|
||||
data.append(new_passphrase)
|
||||
else:
|
||||
raise ValueError("Need new passphrase or new keyfile")
|
||||
|
||||
result = self._run_command(args, data=b"".join(data) or None)
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, dummy, stderr = self._run_command(args, data=b"".join(data) or None)
|
||||
if rc != 0:
|
||||
raise ValueError(
|
||||
f"Error while adding new LUKS keyslot to {device}: {result[STDERR]}"
|
||||
f"Error while adding new LUKS keyslot to {device}: {stderr}"
|
||||
)
|
||||
|
||||
def run_luks_remove_key(
|
||||
self, device, keyfile, passphrase, keyslot, force_remove_last_key=False
|
||||
):
|
||||
self,
|
||||
device: str,
|
||||
keyfile: str | None,
|
||||
passphrase: bytes | None,
|
||||
keyslot: int | None,
|
||||
force_remove_last_key: bool = False,
|
||||
) -> None:
|
||||
"""Remove key from given device
|
||||
Raises ValueError when command fails
|
||||
"""
|
||||
if not force_remove_last_key:
|
||||
result = self._run_command([self._cryptsetup_bin, "luksDump", device])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, stdout, dummy = self._run_command(
|
||||
[self._cryptsetup_bin, "luksDump", device]
|
||||
)
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while dumping LUKS header from {device}")
|
||||
keyslot_count = 0
|
||||
keyslot_area = False
|
||||
keyslot_re = re.compile(r"^Key Slot [0-9]+: ENABLED")
|
||||
for line in result[STDOUT].splitlines():
|
||||
for line in stdout.splitlines():
|
||||
if line.startswith("Keyslots:"):
|
||||
keyslot_area = True
|
||||
elif line.startswith(" "):
|
||||
@@ -808,13 +821,17 @@ class CryptHandler(Handler):
|
||||
# Since we supply -q no passphrase is needed
|
||||
args = [self._cryptsetup_bin, "luksKillSlot", device, "-q", str(keyslot)]
|
||||
passphrase = None
|
||||
result = self._run_command(args, data=passphrase)
|
||||
if result[RETURN_CODE] != 0:
|
||||
raise ValueError(
|
||||
f"Error while removing LUKS key from {device}: {result[STDERR]}"
|
||||
)
|
||||
rc, dummy, stderr = self._run_command(args, data=passphrase)
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while removing LUKS key from {device}: {stderr}")
|
||||
|
||||
def luks_test_key(self, device, keyfile, passphrase, keyslot=None):
|
||||
def luks_test_key(
|
||||
self,
|
||||
device: str,
|
||||
keyfile: str | None,
|
||||
passphrase: bytes | None,
|
||||
keyslot: int | None = None,
|
||||
) -> bool:
|
||||
"""Check whether the keyfile or passphrase works.
|
||||
Raises ValueError when command fails.
|
||||
"""
|
||||
@@ -830,42 +847,37 @@ class CryptHandler(Handler):
|
||||
if keyslot is not None:
|
||||
args.extend(["--key-slot", str(keyslot)])
|
||||
|
||||
result = self._run_command(args, data=data)
|
||||
if result[RETURN_CODE] == 0:
|
||||
rc, stdout, stderr = self._run_command(args, data=data)
|
||||
if rc == 0:
|
||||
return True
|
||||
for output in (STDOUT, STDERR):
|
||||
if "No key available with this passphrase" in result[output]:
|
||||
for output in (stdout, stderr):
|
||||
if "No key available with this passphrase" in output:
|
||||
return False
|
||||
if "No usable keyslot is available." in result[output]:
|
||||
if "No usable keyslot is available." in output:
|
||||
return False
|
||||
|
||||
# This check is necessary due to cryptsetup in version 2.0.3 not printing 'No usable keyslot is available'
|
||||
# when using the --key-slot parameter in combination with --test-passphrase
|
||||
if (
|
||||
result[RETURN_CODE] == 1
|
||||
and keyslot is not None
|
||||
and result[STDOUT] == ""
|
||||
and result[STDERR] == ""
|
||||
):
|
||||
if rc == 1 and keyslot is not None and stdout == "" and stderr == "":
|
||||
return False
|
||||
|
||||
raise ValueError(
|
||||
f"Error while testing whether keyslot exists on {device}: {result[STDERR]}"
|
||||
f"Error while testing whether keyslot exists on {device}: {stderr}"
|
||||
)
|
||||
|
||||
|
||||
class ConditionsHandler(Handler):
|
||||
|
||||
def __init__(self, module, crypthandler):
|
||||
def __init__(self, module: AnsibleModule, crypthandler: CryptHandler) -> None:
|
||||
super(ConditionsHandler, self).__init__(module)
|
||||
self._crypthandler = crypthandler
|
||||
self.device = self.get_device_name()
|
||||
|
||||
def get_device_name(self):
|
||||
device = self._module.params.get("device")
|
||||
label = self._module.params.get("label")
|
||||
uuid = self._module.params.get("uuid")
|
||||
name = self._module.params.get("name")
|
||||
def get_device_name(self) -> str | None:
|
||||
device: str | None = self._module.params.get("device")
|
||||
label: str | None = self._module.params.get("label")
|
||||
uuid: str | None = self._module.params.get("uuid")
|
||||
name: str | None = self._module.params.get("name")
|
||||
|
||||
if device is None and label is not None:
|
||||
device = self.get_device_by_label(label)
|
||||
@@ -876,7 +888,7 @@ class ConditionsHandler(Handler):
|
||||
|
||||
return device
|
||||
|
||||
def luks_create(self):
|
||||
def luks_create(self) -> bool:
|
||||
return (
|
||||
self.device is not None
|
||||
and (
|
||||
@@ -887,7 +899,7 @@ class ConditionsHandler(Handler):
|
||||
and not self._crypthandler.is_luks(self.device)
|
||||
)
|
||||
|
||||
def opened_luks_name(self):
|
||||
def opened_luks_name(self, device: str) -> str | None:
|
||||
"""If luks is already opened, return its name.
|
||||
If 'name' parameter is specified and differs
|
||||
from obtained value, fail.
|
||||
@@ -897,7 +909,7 @@ class ConditionsHandler(Handler):
|
||||
return None
|
||||
|
||||
# try to obtain luks name - it may be already opened
|
||||
name = self._crypthandler.get_container_name_by_device(self.device)
|
||||
name = self._crypthandler.get_container_name_by_device(device)
|
||||
|
||||
if name is None:
|
||||
# container is not open
|
||||
@@ -917,7 +929,7 @@ class ConditionsHandler(Handler):
|
||||
# container is opened and the names match
|
||||
return name
|
||||
|
||||
def luks_open(self):
|
||||
def luks_open(self) -> bool:
|
||||
if (
|
||||
(
|
||||
self._module.params["keyfile"] is None
|
||||
@@ -929,13 +941,13 @@ class ConditionsHandler(Handler):
|
||||
# conditions for open not fulfilled
|
||||
return False
|
||||
|
||||
name = self.opened_luks_name()
|
||||
name = self.opened_luks_name(self.device)
|
||||
|
||||
if name is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
def luks_close(self):
|
||||
def luks_close(self) -> bool:
|
||||
if (
|
||||
self._module.params["name"] is None and self.device is None
|
||||
) or self._module.params["state"] != "closed":
|
||||
@@ -948,15 +960,17 @@ class ConditionsHandler(Handler):
|
||||
luks_is_open = name is not None
|
||||
|
||||
if self._module.params["name"] is not None:
|
||||
self.device = self._crypthandler.get_container_device_by_name(
|
||||
device = self._crypthandler.get_container_device_by_name(
|
||||
self._module.params["name"]
|
||||
)
|
||||
# successfully getting device based on name means that luks is open
|
||||
luks_is_open = self.device is not None
|
||||
luks_is_open = device is not None
|
||||
if device is not None:
|
||||
self.device = device
|
||||
|
||||
return luks_is_open
|
||||
|
||||
def luks_add_key(self):
|
||||
def luks_add_key(self) -> bool:
|
||||
if (
|
||||
self.device is None
|
||||
or (
|
||||
@@ -995,7 +1009,7 @@ class ConditionsHandler(Handler):
|
||||
|
||||
return not key_present
|
||||
|
||||
def luks_remove_key(self):
|
||||
def luks_remove_key(self) -> bool:
|
||||
if self.device is None or (
|
||||
self._module.params["remove_keyfile"] is None
|
||||
and self._module.params["remove_passphrase"] is None
|
||||
@@ -1037,14 +1051,16 @@ class ConditionsHandler(Handler):
|
||||
self.get_passphrase_from_module_params("remove_passphrase"),
|
||||
)
|
||||
|
||||
def luks_remove(self):
|
||||
def luks_remove(self) -> bool:
|
||||
return (
|
||||
self.device is not None
|
||||
and self._module.params["state"] == "absent"
|
||||
and self._crypthandler.is_luks(self.device)
|
||||
)
|
||||
|
||||
def validate_keyslot(self, param, luks_type):
|
||||
def validate_keyslot(
|
||||
self, param: str, luks_type: t.Literal["luks1", "luks2"] | None
|
||||
) -> None:
|
||||
if self._module.params[param] is not None:
|
||||
if luks_type is None and param == "keyslot":
|
||||
if 8 <= self._module.params[param] <= 31:
|
||||
@@ -1066,7 +1082,7 @@ class ConditionsHandler(Handler):
|
||||
)
|
||||
|
||||
|
||||
def run_module():
|
||||
def run_module() -> t.NoReturn:
|
||||
# available arguments/parameters that a user can pass
|
||||
module_args = dict(
|
||||
state=dict(
|
||||
@@ -1122,7 +1138,7 @@ def run_module():
|
||||
]
|
||||
|
||||
# seed the result dict in the object
|
||||
result = dict(changed=False, name=None)
|
||||
result: dict[str, t.Any] = {"changed": False, "name": None}
|
||||
|
||||
module = AnsibleModule(
|
||||
argument_spec=module_args,
|
||||
@@ -1142,19 +1158,26 @@ def run_module():
|
||||
except Exception as e:
|
||||
module.fail_json(msg=str(e))
|
||||
|
||||
crypt = CryptHandler(module)
|
||||
conditions = ConditionsHandler(module, crypt)
|
||||
|
||||
# conditions not allowed to run
|
||||
if module.params["label"] is not None and module.params["type"] == "luks1":
|
||||
module.fail_json(msg="You cannot combine type luks1 with the label option.")
|
||||
|
||||
crypt = CryptHandler(module)
|
||||
try:
|
||||
conditions = ConditionsHandler(module, crypt)
|
||||
except ValueError as exc:
|
||||
module.fail_json(msg=str(exc))
|
||||
|
||||
if (
|
||||
module.params["keyslot"] is not None
|
||||
or module.params["new_keyslot"] is not None
|
||||
or module.params["remove_keyslot"] is not None
|
||||
):
|
||||
luks_type = crypt.get_luks_type(conditions.get_device_name())
|
||||
luks_type = (
|
||||
crypt.get_luks_type(conditions.device)
|
||||
if conditions.device is not None
|
||||
else None
|
||||
)
|
||||
if luks_type is None and module.params["type"] is not None:
|
||||
luks_type = module.params["type"]
|
||||
for param in ["keyslot", "new_keyslot", "remove_keyslot"]:
|
||||
@@ -1175,6 +1198,7 @@ def run_module():
|
||||
|
||||
# luks create
|
||||
if conditions.luks_create():
|
||||
assert conditions.device # ensured in conditions.luks_create()
|
||||
if not module.check_mode:
|
||||
try:
|
||||
crypt.run_luks_create(
|
||||
@@ -1196,11 +1220,13 @@ def run_module():
|
||||
|
||||
# luks open
|
||||
|
||||
name = conditions.opened_luks_name()
|
||||
if name is not None:
|
||||
result["name"] = name
|
||||
if conditions.device is not None:
|
||||
name = conditions.opened_luks_name(conditions.device)
|
||||
if name is not None:
|
||||
result["name"] = name
|
||||
|
||||
if conditions.luks_open():
|
||||
assert conditions.device # ensured in conditions.luks_open()
|
||||
name = module.params["name"]
|
||||
if name is None:
|
||||
try:
|
||||
@@ -1237,6 +1263,8 @@ def run_module():
|
||||
module.fail_json(msg=f"luks_device error: {e}")
|
||||
else:
|
||||
name = module.params["name"]
|
||||
if name is None:
|
||||
module.fail_json(msg="Cannot determine name to close device")
|
||||
if not module.check_mode:
|
||||
try:
|
||||
crypt.run_luks_close(name)
|
||||
@@ -1249,6 +1277,7 @@ def run_module():
|
||||
|
||||
# luks add key
|
||||
if conditions.luks_add_key():
|
||||
assert conditions.device # ensured in conditions.luks_add_key()
|
||||
if not module.check_mode:
|
||||
try:
|
||||
crypt.run_luks_add_key(
|
||||
@@ -1268,6 +1297,7 @@ def run_module():
|
||||
|
||||
# luks remove key
|
||||
if conditions.luks_remove_key():
|
||||
assert conditions.device # ensured in conditions.luks_remove_key()
|
||||
if not module.check_mode:
|
||||
try:
|
||||
last_key = module.params["force_remove_last_key"]
|
||||
@@ -1286,6 +1316,7 @@ def run_module():
|
||||
|
||||
# luks remove
|
||||
if conditions.luks_remove():
|
||||
assert conditions.device # ensured in conditions.luks_remove()
|
||||
if not module.check_mode:
|
||||
try:
|
||||
crypt.run_luks_remove(conditions.device)
|
||||
@@ -1299,7 +1330,7 @@ def run_module():
|
||||
module.exit_json(**result)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
run_module()
|
||||
|
||||
|
||||
|
||||
@@ -284,6 +284,7 @@ info:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.common import (
|
||||
@@ -302,47 +303,58 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
|
||||
|
||||
class Certificate(OpensshModule):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(Certificate, self).__init__(module)
|
||||
self.ssh_keygen = KeygenCommand(self.module)
|
||||
|
||||
self.identifier = self.module.params["identifier"] or ""
|
||||
self.options = self.module.params["options"] or []
|
||||
self.path = self.module.params["path"]
|
||||
self.pkcs11_provider = self.module.params["pkcs11_provider"]
|
||||
self.principals = self.module.params["principals"] or []
|
||||
self.public_key = self.module.params["public_key"]
|
||||
self.regenerate = (
|
||||
self.identifier: str = self.module.params["identifier"] or ""
|
||||
self.options: list[str] = self.module.params["options"] or []
|
||||
self.path: str = self.module.params["path"]
|
||||
self.pkcs11_provider: str | None = self.module.params["pkcs11_provider"]
|
||||
self.principals: list[str] = self.module.params["principals"] or []
|
||||
self.public_key: str | None = self.module.params["public_key"]
|
||||
self.regenerate: t.Literal[
|
||||
"never",
|
||||
"fail",
|
||||
"partial_idempotence",
|
||||
"full_idempotence",
|
||||
"always",
|
||||
] = (
|
||||
self.module.params["regenerate"]
|
||||
if not self.module.params["force"]
|
||||
else "always"
|
||||
)
|
||||
self.serial_number = self.module.params["serial_number"]
|
||||
self.signature_algorithm = self.module.params["signature_algorithm"]
|
||||
self.signing_key = self.module.params["signing_key"]
|
||||
self.state = self.module.params["state"]
|
||||
self.type = self.module.params["type"]
|
||||
self.use_agent = self.module.params["use_agent"]
|
||||
self.valid_at = self.module.params["valid_at"]
|
||||
self.ignore_timestamps = self.module.params["ignore_timestamps"]
|
||||
self.serial_number: int | None = self.module.params["serial_number"]
|
||||
self.signature_algorithm: (
|
||||
t.Literal["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"] | None
|
||||
) = self.module.params["signature_algorithm"]
|
||||
self.signing_key: str | None = self.module.params["signing_key"]
|
||||
self.state: t.Literal["absent", "present"] = self.module.params["state"]
|
||||
self.type: t.Literal["host", "user"] | None = self.module.params["type"]
|
||||
self.use_agent: bool = self.module.params["use_agent"]
|
||||
self.valid_at: str | None = self.module.params["valid_at"]
|
||||
self.ignore_timestamps: bool = self.module.params["ignore_timestamps"]
|
||||
|
||||
self._check_if_base_dir(self.path)
|
||||
|
||||
if self.state == "present":
|
||||
self._validate_parameters()
|
||||
|
||||
self.data = None
|
||||
self.original_data = None
|
||||
self.data: OpensshCertificate | None = None
|
||||
self.original_data: OpensshCertificate | None = None
|
||||
if self._exists():
|
||||
self._load_certificate()
|
||||
|
||||
self.time_parameters = None
|
||||
self.time_parameters: OpensshCertificateTimeParameters | None = None
|
||||
if self.state == "present":
|
||||
self._set_time_parameters()
|
||||
|
||||
def _validate_parameters(self):
|
||||
def _validate_parameters(self) -> None:
|
||||
for path in (self.public_key, self.signing_key):
|
||||
self._check_if_base_dir(path)
|
||||
if (
|
||||
path is not None
|
||||
): # should never be None, but the type checker doesn't know
|
||||
self._check_if_base_dir(path)
|
||||
|
||||
if self.options and self.type == "host":
|
||||
self.module.fail_json(
|
||||
@@ -352,7 +364,7 @@ class Certificate(OpensshModule):
|
||||
if self.use_agent:
|
||||
self._use_agent_available()
|
||||
|
||||
def _use_agent_available(self):
|
||||
def _use_agent_available(self) -> None:
|
||||
ssh_version = self._get_ssh_version()
|
||||
if not ssh_version:
|
||||
self.module.fail_json(msg="Failed to determine ssh version")
|
||||
@@ -362,10 +374,10 @@ class Certificate(OpensshModule):
|
||||
+ f" Your version is: {ssh_version}"
|
||||
)
|
||||
|
||||
def _exists(self):
|
||||
def _exists(self) -> bool:
|
||||
return os.path.exists(self.path)
|
||||
|
||||
def _load_certificate(self):
|
||||
def _load_certificate(self) -> None:
|
||||
try:
|
||||
self.original_data = OpensshCertificate.load(self.path)
|
||||
except (TypeError, ValueError) as e:
|
||||
@@ -373,7 +385,7 @@ class Certificate(OpensshModule):
|
||||
self.module.fail_json(msg=f"Unable to read existing certificate: {e}")
|
||||
self.module.warn(f"Unable to read existing certificate: {e}")
|
||||
|
||||
def _set_time_parameters(self):
|
||||
def _set_time_parameters(self) -> None:
|
||||
try:
|
||||
self.time_parameters = OpensshCertificateTimeParameters(
|
||||
valid_from=self.module.params["valid_from"],
|
||||
@@ -382,7 +394,7 @@ class Certificate(OpensshModule):
|
||||
except ValueError as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self) -> None:
|
||||
if self.state == "present":
|
||||
if self._should_generate():
|
||||
self._generate()
|
||||
@@ -391,7 +403,7 @@ class Certificate(OpensshModule):
|
||||
if self._exists():
|
||||
self._remove()
|
||||
|
||||
def _should_generate(self):
|
||||
def _should_generate(self) -> bool:
|
||||
if self.regenerate == "never":
|
||||
return self.original_data is None
|
||||
elif self.regenerate == "fail":
|
||||
@@ -408,7 +420,13 @@ class Certificate(OpensshModule):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _is_fully_valid(self):
|
||||
def _is_fully_valid(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
if self.public_key is None:
|
||||
raise AssertionError("Contract violation public_key not provided")
|
||||
if self.signing_key is None:
|
||||
raise AssertionError("Contract violation signing_key not provided")
|
||||
return self._is_partially_valid() and all(
|
||||
[
|
||||
self._compare_options() if self.original_data.type == "user" else True,
|
||||
@@ -420,7 +438,9 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _is_partially_valid(self):
|
||||
def _is_partially_valid(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
return all(
|
||||
[
|
||||
set(self.original_data.principals) == set(self.principals),
|
||||
@@ -439,7 +459,9 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _compare_time_parameters(self):
|
||||
def _compare_time_parameters(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
try:
|
||||
original_time_parameters = OpensshCertificateTimeParameters(
|
||||
valid_from=self.original_data.valid_after,
|
||||
@@ -458,7 +480,9 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _compare_options(self):
|
||||
def _compare_options(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
try:
|
||||
critical_options, extensions = parse_option_list(self.options)
|
||||
except ValueError as e:
|
||||
@@ -471,13 +495,13 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _get_key_fingerprint(self, path):
|
||||
def _get_key_fingerprint(self, path: str) -> str:
|
||||
private_key_content = self.ssh_keygen.get_private_key(path, check_rc=True)[1]
|
||||
return PrivateKey.from_string(private_key_content).fingerprint
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _generate(self):
|
||||
def _generate(self) -> None:
|
||||
try:
|
||||
temp_certificate = self._generate_temp_certificate()
|
||||
self._safe_secure_move([(temp_certificate, self.path)])
|
||||
@@ -491,7 +515,14 @@ class Certificate(OpensshModule):
|
||||
except (TypeError, ValueError) as e:
|
||||
self.module.fail_json(msg=f"Unable to read new certificate: {e}")
|
||||
|
||||
def _generate_temp_certificate(self):
|
||||
def _generate_temp_certificate(self) -> str:
|
||||
if self.public_key is None:
|
||||
raise AssertionError("Contract violation public_key not provided")
|
||||
if self.signing_key is None:
|
||||
raise AssertionError("Contract violation signing_key not provided")
|
||||
if self.time_parameters is None:
|
||||
raise AssertionError("Contract violation time_parameters not provided")
|
||||
|
||||
key_copy = os.path.join(self.module.tmpdir, os.path.basename(self.public_key))
|
||||
|
||||
try:
|
||||
@@ -523,14 +554,14 @@ class Certificate(OpensshModule):
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _remove(self):
|
||||
def _remove(self) -> None:
|
||||
try:
|
||||
os.remove(self.path)
|
||||
except OSError as e:
|
||||
self.module.fail_json(msg=f"Unable to remove existing certificate: {e}")
|
||||
|
||||
@property
|
||||
def _result(self):
|
||||
def _result(self) -> dict[str, t.Any]:
|
||||
if self.state != "present":
|
||||
return {}
|
||||
|
||||
@@ -546,14 +577,14 @@ class Certificate(OpensshModule):
|
||||
}
|
||||
|
||||
@property
|
||||
def diff(self):
|
||||
def diff(self) -> dict[str, t.Any]:
|
||||
return {
|
||||
"before": get_cert_dict(self.original_data),
|
||||
"after": get_cert_dict(self.data),
|
||||
}
|
||||
|
||||
|
||||
def format_cert_info(cert_info):
|
||||
def format_cert_info(cert_info: str) -> list[str]:
|
||||
result = []
|
||||
string = ""
|
||||
|
||||
@@ -579,7 +610,7 @@ def format_cert_info(cert_info):
|
||||
return result
|
||||
|
||||
|
||||
def get_cert_dict(data):
|
||||
def get_cert_dict(data: OpensshCertificate | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return {}
|
||||
|
||||
@@ -590,7 +621,7 @@ def get_cert_dict(data):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
force=dict(type="bool", default=False),
|
||||
|
||||
@@ -198,13 +198,15 @@ comment:
|
||||
sample: test@comment
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.keypair_backend import (
|
||||
select_backend,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
|
||||
@@ -239,6 +239,7 @@ csr:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
@@ -256,9 +257,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import (
|
||||
CertificateSigningRequestBackend,
|
||||
)
|
||||
|
||||
|
||||
class CertificateSigningRequestModule(OpenSSLObject):
|
||||
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend
|
||||
) -> None:
|
||||
super(CertificateSigningRequestModule, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -269,11 +279,11 @@ class CertificateSigningRequestModule(OpenSSLObject):
|
||||
self.return_content = module.params["return_content"]
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup_file: str | None = None
|
||||
|
||||
self.module_backend.set_existing(load_file_if_exists(self.path, module))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate the certificate signing request."""
|
||||
if self.force or self.module_backend.needs_regeneration():
|
||||
if not self.check_mode:
|
||||
@@ -292,13 +302,13 @@ class CertificateSigningRequestModule(OpenSSLObject):
|
||||
file_args, self.changed
|
||||
)
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
self.module_backend.set_existing(None)
|
||||
if self.backup and not self.check_mode:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
super(CertificateSigningRequestModule, self).remove(module)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = self.module_backend.dump(include_csr=self.return_content)
|
||||
result.update(
|
||||
@@ -312,7 +322,7 @@ class CertificateSigningRequestModule(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = get_csr_argument_spec()
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
|
||||
@@ -308,6 +308,7 @@ authority_cert_serial_number:
|
||||
sample: 12345
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -318,7 +319,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path"),
|
||||
@@ -335,11 +336,15 @@ def main():
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
if module.params["content"] is not None:
|
||||
data = module.params["content"].encode("utf-8")
|
||||
content: str | None = module.params["content"]
|
||||
path: str | None = module.params["path"]
|
||||
if content is not None:
|
||||
data = content.encode("utf-8")
|
||||
else:
|
||||
if path is None:
|
||||
module.fail_json(msg="One of content and path must be provided")
|
||||
try:
|
||||
with open(module.params["path"], "rb") as f:
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
except (IOError, OSError) as e:
|
||||
module.fail_json(msg=f"Error while reading CSR file from disk: {e}")
|
||||
|
||||
@@ -127,6 +127,8 @@ csr:
|
||||
type: str
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
)
|
||||
@@ -136,8 +138,17 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import (
|
||||
CertificateSigningRequestBackend,
|
||||
)
|
||||
|
||||
|
||||
class CertificateSigningRequestModule:
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend
|
||||
) -> None:
|
||||
self.check_mode = module.check_mode
|
||||
self.module = module
|
||||
self.module_backend = module_backend
|
||||
@@ -145,13 +156,13 @@ class CertificateSigningRequestModule:
|
||||
if module.params["content"] is not None:
|
||||
self.module_backend.set_existing(module.params["content"].encode("utf-8"))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate the certificate signing request."""
|
||||
if self.module_backend.needs_regeneration():
|
||||
self.module_backend.generate_csr()
|
||||
self.changed = True
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = self.module_backend.dump(include_csr=True)
|
||||
result.update(
|
||||
@@ -162,7 +173,7 @@ class CertificateSigningRequestModule:
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = get_csr_argument_spec()
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
|
||||
@@ -132,6 +132,7 @@ import abc
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_native
|
||||
@@ -173,23 +174,23 @@ class DHParameterError(Exception):
|
||||
|
||||
class DHParameterBase:
|
||||
|
||||
def __init__(self, module):
|
||||
self.state = module.params["state"]
|
||||
self.path = module.params["path"]
|
||||
self.size = module.params["size"]
|
||||
self.force = module.params["force"]
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.state: t.Literal["absent", "present"] = module.params["state"]
|
||||
self.path: str = module.params["path"]
|
||||
self.size: int = module.params["size"]
|
||||
self.force: bool = module.params["force"]
|
||||
self.changed = False
|
||||
self.return_content = module.params["return_content"]
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def _do_generate(self, module):
|
||||
def _do_generate(self, module: AnsibleModule) -> None:
|
||||
"""Actually generate the DH params."""
|
||||
pass
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate DH params."""
|
||||
changed = False
|
||||
|
||||
@@ -206,7 +207,7 @@ class DHParameterBase:
|
||||
|
||||
self.changed = changed
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
try:
|
||||
@@ -215,28 +216,27 @@ class DHParameterBase:
|
||||
except OSError as exc:
|
||||
module.fail_json(msg=str(exc))
|
||||
|
||||
def check(self, module):
|
||||
def check(self, module: AnsibleModule) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
if self.force:
|
||||
return False
|
||||
return self._check_params_valid(module) and self._check_fs_attributes(module)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_params_valid(self, module):
|
||||
def _check_params_valid(self, module: AnsibleModule) -> bool:
|
||||
"""Check if the params are in the correct state"""
|
||||
pass
|
||||
|
||||
def _check_fs_attributes(self, module):
|
||||
def _check_fs_attributes(self, module: AnsibleModule) -> bool:
|
||||
"""Checks (and changes if not in check mode!) fs attributes"""
|
||||
file_args = module.load_file_common_arguments(module.params)
|
||||
if module.check_file_absent_if_check_mode(file_args["path"]):
|
||||
return False
|
||||
return not module.set_fs_attributes_if_different(file_args, False)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"size": self.size,
|
||||
"filename": self.path,
|
||||
"changed": self.changed,
|
||||
@@ -252,25 +252,24 @@ class DHParameterBase:
|
||||
|
||||
class DHParameterAbsent(DHParameterBase):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(DHParameterAbsent, self).__init__(module)
|
||||
|
||||
def _do_generate(self, module):
|
||||
def _do_generate(self, module: AnsibleModule) -> None:
|
||||
"""Actually generate the DH params."""
|
||||
pass
|
||||
|
||||
def _check_params_valid(self, module):
|
||||
def _check_params_valid(self, module: AnsibleModule) -> bool:
|
||||
"""Check if the params are in the correct state"""
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
class DHParameterOpenSSL(DHParameterBase):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(DHParameterOpenSSL, self).__init__(module)
|
||||
self.openssl_bin = module.get_bin_path("openssl", True)
|
||||
|
||||
def _do_generate(self, module):
|
||||
def _do_generate(self, module: AnsibleModule) -> None:
|
||||
"""Actually generate the DH params."""
|
||||
# create a tempfile
|
||||
fd, tmpsrc = tempfile.mkstemp()
|
||||
@@ -288,7 +287,7 @@ class DHParameterOpenSSL(DHParameterBase):
|
||||
except Exception as e:
|
||||
module.fail_json(msg=f"Failed to write to file {self.path}: {str(e)}")
|
||||
|
||||
def _check_params_valid(self, module):
|
||||
def _check_params_valid(self, module: AnsibleModule) -> bool:
|
||||
"""Check if the params are in the correct state"""
|
||||
command = [
|
||||
self.openssl_bin,
|
||||
@@ -321,10 +320,10 @@ class DHParameterOpenSSL(DHParameterBase):
|
||||
|
||||
class DHParameterCryptography(DHParameterBase):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(DHParameterCryptography, self).__init__(module)
|
||||
|
||||
def _do_generate(self, module):
|
||||
def _do_generate(self, module: AnsibleModule) -> None:
|
||||
"""Actually generate the DH params."""
|
||||
# Generate parameters
|
||||
params = cryptography.hazmat.primitives.asymmetric.dh.generate_parameters(
|
||||
@@ -341,7 +340,7 @@ class DHParameterCryptography(DHParameterBase):
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
write_file(module, result)
|
||||
|
||||
def _check_params_valid(self, module):
|
||||
def _check_params_valid(self, module: AnsibleModule) -> bool:
|
||||
"""Check if the params are in the correct state"""
|
||||
# Load parameters
|
||||
try:
|
||||
@@ -357,7 +356,7 @@ class DHParameterCryptography(DHParameterBase):
|
||||
return bits == self.size
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
"""Main function"""
|
||||
|
||||
module = AnsibleModule(
|
||||
@@ -383,6 +382,7 @@ def main():
|
||||
msg=f"The directory '{base_dir}' does not exist or the file is not a directory",
|
||||
)
|
||||
|
||||
dhparam: DHParameterOpenSSL | DHParameterCryptography | DHParameterAbsent
|
||||
if module.params["state"] == "present":
|
||||
backend = module.params["select_crypto_backend"]
|
||||
if backend == "auto":
|
||||
|
||||
@@ -280,6 +280,7 @@ import itertools
|
||||
import os
|
||||
import stat
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
@@ -296,7 +297,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
|
||||
OpenSSLObject,
|
||||
load_certificate,
|
||||
load_privatekey,
|
||||
load_certificate_issuer_privatekey,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
|
||||
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
|
||||
@@ -320,6 +321,7 @@ except ImportError:
|
||||
|
||||
CRYPTOGRAPHY_COMPATIBILITY2022_ERR = None
|
||||
try:
|
||||
import cryptography.x509
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.serialization.pkcs12 import PBES
|
||||
|
||||
@@ -333,8 +335,22 @@ except Exception:
|
||||
else:
|
||||
CRYPTOGRAPHY_HAS_COMPATIBILITY2022 = True
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..module_utils.crypto.cryptography_support import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
)
|
||||
|
||||
def load_certificate_set(filename):
|
||||
PKCS12 = tuple[
|
||||
t.Union[CertificateIssuerPrivateKeyTypes, None],
|
||||
t.Union[cryptography.x509.Certificate, None],
|
||||
list[cryptography.x509.Certificate],
|
||||
t.Union[bytes, None],
|
||||
]
|
||||
|
||||
|
||||
def load_certificate_set(
|
||||
filename: str | os.PathLike,
|
||||
) -> list[cryptography.x509.Certificate]:
|
||||
"""
|
||||
Load list of concatenated PEM files, and return a list of parsed certificates.
|
||||
"""
|
||||
@@ -351,70 +367,80 @@ class PkcsError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class Pkcs(OpenSSLObject):
|
||||
def __init__(self, module, iter_size_default=2048):
|
||||
path: str
|
||||
|
||||
def __init__(self, module: AnsibleModule, iter_size_default: int = 2048) -> None:
|
||||
super(Pkcs, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
module.params["force"],
|
||||
module.check_mode,
|
||||
)
|
||||
self.action = module.params["action"]
|
||||
self.other_certificates = module.params["other_certificates"]
|
||||
self.other_certificates_parse_all = module.params[
|
||||
self.action: t.Literal["export", "parse"] = module.params["action"]
|
||||
self.other_certificates: list[cryptography.x509.Certificate] = []
|
||||
self.other_certificates_str: list[str] | None = module.params[
|
||||
"other_certificates"
|
||||
]
|
||||
self.other_certificates_parse_all: bool = module.params[
|
||||
"other_certificates_parse_all"
|
||||
]
|
||||
self.other_certificates_content = module.params["other_certificates_content"]
|
||||
self.certificate_path = module.params["certificate_path"]
|
||||
self.certificate_content = module.params["certificate_content"]
|
||||
self.friendly_name = module.params["friendly_name"]
|
||||
self.iter_size = module.params["iter_size"] or iter_size_default
|
||||
self.maciter_size = module.params["maciter_size"] or 1
|
||||
self.encryption_level = module.params["encryption_level"]
|
||||
self.other_certificates_content: list[str] | None = module.params[
|
||||
"other_certificates_content"
|
||||
]
|
||||
self.certificate_path: str | None = module.params["certificate_path"]
|
||||
certificate_content: str | None = module.params["certificate_content"]
|
||||
self.friendly_name: str | None = module.params["friendly_name"]
|
||||
self.iter_size: int = module.params["iter_size"] or iter_size_default
|
||||
self.maciter_size: int = module.params["maciter_size"] or 1
|
||||
self.encryption_level: t.Literal["auto", "compatibility2022"] = module.params[
|
||||
"encryption_level"
|
||||
]
|
||||
self.passphrase = module.params["passphrase"]
|
||||
self.pkcs12 = None
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
self.pkcs12_bytes = None
|
||||
self.return_content = module.params["return_content"]
|
||||
self.src = module.params["src"]
|
||||
self.pkcs12: PKCS12 | None = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
self.pkcs12_bytes: bytes | None = None
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
self.src: str | None = module.params["src"]
|
||||
|
||||
if module.params["mode"] is None:
|
||||
module.params["mode"] = "0400"
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
self.certificate_content: bytes | None = None
|
||||
if self.certificate_path is not None:
|
||||
try:
|
||||
with open(self.certificate_path, "rb") as fh:
|
||||
self.certificate_content = fh.read()
|
||||
except (IOError, OSError) as exc:
|
||||
raise PkcsError(exc)
|
||||
elif self.certificate_content is not None:
|
||||
self.certificate_content = to_bytes(self.certificate_content)
|
||||
elif certificate_content is not None:
|
||||
self.certificate_content = to_bytes(certificate_content)
|
||||
|
||||
self.privatekey_content: bytes | None = None
|
||||
if self.privatekey_path is not None:
|
||||
try:
|
||||
with open(self.privatekey_path, "rb") as fh:
|
||||
self.privatekey_content = fh.read()
|
||||
except (IOError, OSError) as exc:
|
||||
raise PkcsError(exc)
|
||||
elif self.privatekey_content is not None:
|
||||
self.privatekey_content = to_bytes(self.privatekey_content)
|
||||
elif privatekey_content is not None:
|
||||
self.privatekey_content = to_bytes(privatekey_content)
|
||||
|
||||
if self.other_certificates:
|
||||
if self.other_certificates_str:
|
||||
if self.other_certificates_parse_all:
|
||||
filenames = list(self.other_certificates)
|
||||
self.other_certificates = []
|
||||
for other_cert_bundle in filenames:
|
||||
for other_cert_bundle in self.other_certificates_str:
|
||||
self.other_certificates.extend(
|
||||
load_certificate_set(other_cert_bundle)
|
||||
)
|
||||
else:
|
||||
self.other_certificates = [
|
||||
load_certificate(other_cert)
|
||||
for other_cert in self.other_certificates
|
||||
for other_cert in self.other_certificates_str
|
||||
]
|
||||
elif self.other_certificates_content:
|
||||
certs = self.other_certificates_content
|
||||
@@ -430,40 +456,42 @@ class Pkcs(OpenSSLObject):
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_bytes(self, module):
|
||||
def generate_bytes(self, module: AnsibleModule) -> bytes:
|
||||
"""Generate PKCS#12 file archive."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
|
||||
bytes | None,
|
||||
bytes | None,
|
||||
list[bytes],
|
||||
bytes | None,
|
||||
]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_bytes(self, pkcs12_content):
|
||||
def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _dump_privatekey(self, pkcs12):
|
||||
def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _dump_certificate(self, pkcs12):
|
||||
def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _dump_other_certificates(self, pkcs12):
|
||||
def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_friendly_name(self, pkcs12):
|
||||
pass
|
||||
|
||||
def check(self, module, perms_required=True):
|
||||
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
|
||||
state_and_perms = super(Pkcs, self).check(module, perms_required)
|
||||
|
||||
def _check_pkey_passphrase():
|
||||
def _check_pkey_passphrase() -> bool:
|
||||
if self.privatekey_passphrase:
|
||||
try:
|
||||
load_privatekey(
|
||||
None,
|
||||
load_certificate_issuer_privatekey(
|
||||
content=self.privatekey_content,
|
||||
passphrase=self.privatekey_passphrase,
|
||||
)
|
||||
@@ -476,6 +504,7 @@ class Pkcs(OpenSSLObject):
|
||||
|
||||
if os.path.exists(self.path) and module.params["action"] == "export":
|
||||
self.generate_bytes(module) # ignore result
|
||||
assert self.pkcs12 is not None
|
||||
self.src = self.path
|
||||
try:
|
||||
(
|
||||
@@ -524,7 +553,7 @@ class Pkcs(OpenSSLObject):
|
||||
return False
|
||||
elif (
|
||||
module.params["action"] == "parse"
|
||||
and os.path.exists(self.src)
|
||||
and os.path.exists(self.src or "")
|
||||
and os.path.exists(self.path)
|
||||
):
|
||||
try:
|
||||
@@ -548,10 +577,10 @@ class Pkcs(OpenSSLObject):
|
||||
|
||||
return _check_pkey_passphrase()
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"filename": self.path,
|
||||
}
|
||||
if self.privatekey_path:
|
||||
@@ -567,13 +596,20 @@ class Pkcs(OpenSSLObject):
|
||||
|
||||
return result
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
super(Pkcs, self).remove(module)
|
||||
|
||||
def parse(self):
|
||||
def parse(self) -> tuple[
|
||||
bytes | None,
|
||||
bytes | None,
|
||||
list[bytes],
|
||||
bytes | None,
|
||||
]:
|
||||
"""Read PKCS#12 file."""
|
||||
if self.src is None:
|
||||
raise AssertionError("Contract violation: src is None")
|
||||
|
||||
try:
|
||||
with open(self.src, "rb") as pkcs12_fh:
|
||||
@@ -582,10 +618,13 @@ class Pkcs(OpenSSLObject):
|
||||
except IOError as exc:
|
||||
raise PkcsError(exc)
|
||||
|
||||
def generate(self):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
# Empty method because OpenSSLObject wants this
|
||||
pass
|
||||
|
||||
def write(self, module, content, mode=None):
|
||||
def write(
|
||||
self, module: AnsibleModule, content: bytes, mode: int | str | None = None
|
||||
) -> None:
|
||||
"""Write the PKCS#12 file."""
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
@@ -595,7 +634,7 @@ class Pkcs(OpenSSLObject):
|
||||
|
||||
|
||||
class PkcsCryptography(Pkcs):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(PkcsCryptography, self).__init__(module, iter_size_default=50000)
|
||||
if (
|
||||
self.encryption_level == "compatibility2022"
|
||||
@@ -607,13 +646,12 @@ class PkcsCryptography(Pkcs):
|
||||
exception=CRYPTOGRAPHY_COMPATIBILITY2022_ERR,
|
||||
)
|
||||
|
||||
def generate_bytes(self, module):
|
||||
def generate_bytes(self, module: AnsibleModule) -> bytes:
|
||||
"""Generate PKCS#12 file archive."""
|
||||
pkey = None
|
||||
if self.privatekey_content:
|
||||
try:
|
||||
pkey = load_privatekey(
|
||||
None,
|
||||
pkey = load_certificate_issuer_privatekey(
|
||||
content=self.privatekey_content,
|
||||
passphrase=self.privatekey_passphrase,
|
||||
)
|
||||
@@ -631,6 +669,7 @@ class PkcsCryptography(Pkcs):
|
||||
# Store fake object which can be used to retrieve the components back
|
||||
self.pkcs12 = (pkey, cert, self.other_certificates, friendly_name)
|
||||
|
||||
encryption: serialization.KeySerializationEncryption
|
||||
if not self.passphrase:
|
||||
encryption = serialization.NoEncryption()
|
||||
elif self.encryption_level == "compatibility2022":
|
||||
@@ -654,7 +693,12 @@ class PkcsCryptography(Pkcs):
|
||||
encryption,
|
||||
)
|
||||
|
||||
def parse_bytes(self, pkcs12_content):
|
||||
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
|
||||
bytes | None,
|
||||
bytes | None,
|
||||
list[bytes],
|
||||
bytes | None,
|
||||
]:
|
||||
try:
|
||||
private_key, certificate, additional_certificates, friendly_name = (
|
||||
parse_pkcs12(pkcs12_content, self.passphrase)
|
||||
@@ -683,11 +727,7 @@ class PkcsCryptography(Pkcs):
|
||||
except ValueError as exc:
|
||||
raise PkcsError(exc)
|
||||
|
||||
# The following methods will get self.pkcs12 passed, which is computed as:
|
||||
#
|
||||
# self.pkcs12 = (pkey, cert, self.other_certificates, self.friendly_name)
|
||||
|
||||
def _dump_privatekey(self, pkcs12):
|
||||
def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
|
||||
return (
|
||||
pkcs12[0].private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
@@ -698,27 +738,27 @@ class PkcsCryptography(Pkcs):
|
||||
else None
|
||||
)
|
||||
|
||||
def _dump_certificate(self, pkcs12):
|
||||
def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
|
||||
return pkcs12[1].public_bytes(serialization.Encoding.PEM) if pkcs12[1] else None
|
||||
|
||||
def _dump_other_certificates(self, pkcs12):
|
||||
def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
|
||||
return [
|
||||
other_cert.public_bytes(serialization.Encoding.PEM)
|
||||
for other_cert in pkcs12[2]
|
||||
]
|
||||
|
||||
def _get_friendly_name(self, pkcs12):
|
||||
def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
|
||||
return pkcs12[3]
|
||||
|
||||
|
||||
def select_backend(module):
|
||||
def select_backend(module: AnsibleModule) -> Pkcs:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return PkcsCryptography(module)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = dict(
|
||||
action=dict(type="str", default="export", choices=["export", "parse"]),
|
||||
other_certificates=dict(
|
||||
|
||||
@@ -155,6 +155,7 @@ privatekey:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
@@ -172,9 +173,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey import (
|
||||
PrivateKeyBackend,
|
||||
)
|
||||
|
||||
|
||||
class PrivateKeyModule(OpenSSLObject):
|
||||
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, module_backend: PrivateKeyBackend
|
||||
) -> None:
|
||||
super(PrivateKeyModule, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -182,19 +192,19 @@ class PrivateKeyModule(OpenSSLObject):
|
||||
module.check_mode,
|
||||
)
|
||||
self.module_backend = module_backend
|
||||
self.return_content = module.params["return_content"]
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
if self.force:
|
||||
module_backend.regenerate = "always"
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: str | None = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
if module.params["mode"] is None:
|
||||
module.params["mode"] = "0600"
|
||||
|
||||
module_backend.set_existing(load_file_if_exists(self.path, module))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate a keypair."""
|
||||
|
||||
if self.module_backend.needs_regeneration():
|
||||
@@ -228,13 +238,13 @@ class PrivateKeyModule(OpenSSLObject):
|
||||
file_args, self.changed
|
||||
)
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
self.module_backend.set_existing(None)
|
||||
if self.backup and not self.check_mode:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
super(PrivateKeyModule, self).remove(module)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
result = self.module_backend.dump(include_key=self.return_content)
|
||||
@@ -246,7 +256,7 @@ class PrivateKeyModule(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
|
||||
argument_spec = get_privatekey_argument_spec()
|
||||
argument_spec.argument_spec.update(
|
||||
|
||||
@@ -60,6 +60,7 @@ backup_file:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
@@ -77,8 +78,17 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey_convert import (
|
||||
PrivateKeyConvertBackend,
|
||||
)
|
||||
|
||||
|
||||
class PrivateKeyConvertModule(OpenSSLObject):
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, module_backend: PrivateKeyConvertBackend
|
||||
) -> None:
|
||||
super(PrivateKeyConvertModule, self).__init__(
|
||||
module.params["dest_path"],
|
||||
"present",
|
||||
@@ -87,8 +97,8 @@ class PrivateKeyConvertModule(OpenSSLObject):
|
||||
)
|
||||
self.module_backend = module_backend
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
module.params["path"] = module.params["dest_path"]
|
||||
if module.params["mode"] is None:
|
||||
@@ -96,12 +106,14 @@ class PrivateKeyConvertModule(OpenSSLObject):
|
||||
|
||||
module_backend.set_existing_destination(load_file_if_exists(self.path, module))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Do conversion."""
|
||||
|
||||
if self.module_backend.needs_conversion():
|
||||
# Convert
|
||||
privatekey_data = self.module_backend.get_private_key_data()
|
||||
if privatekey_data is None:
|
||||
raise AssertionError("Contract violation: privatekey_data is None")
|
||||
if not self.check_mode:
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
@@ -116,7 +128,7 @@ class PrivateKeyConvertModule(OpenSSLObject):
|
||||
file_args, self.changed
|
||||
)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
result = self.module_backend.dump()
|
||||
@@ -127,7 +139,7 @@ class PrivateKeyConvertModule(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
|
||||
argument_spec = get_privatekey_argument_spec()
|
||||
argument_spec.argument_spec.update(
|
||||
|
||||
@@ -200,6 +200,7 @@ private_data:
|
||||
type: dict
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -212,7 +213,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path"),
|
||||
@@ -243,7 +244,7 @@ def main():
|
||||
data = f.read()
|
||||
except (IOError, OSError) as e:
|
||||
module.fail_json(
|
||||
msg=f"Error while reading private key file from disk: {e}", **result
|
||||
msg=f"Error while reading private key file from disk: {e}", **result # type: ignore
|
||||
)
|
||||
|
||||
result["can_load_key"] = True
|
||||
@@ -261,10 +262,10 @@ def main():
|
||||
module.exit_json(**result)
|
||||
except PrivateKeyParseError as exc:
|
||||
result.update(exc.result)
|
||||
module.fail_json(msg=exc.error_message, **result)
|
||||
module.fail_json(msg=exc.error_message, **result) # type: ignore
|
||||
except PrivateKeyConsistencyError as exc:
|
||||
result.update(exc.result)
|
||||
module.fail_json(msg=exc.error_message, **result)
|
||||
module.fail_json(msg=exc.error_message, **result) # type: ignore
|
||||
except OpenSSLObjectError as exc:
|
||||
module.fail_json(msg=str(exc))
|
||||
|
||||
|
||||
@@ -186,6 +186,7 @@ publickey:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -218,6 +219,12 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PrivateKeyTypes,
|
||||
PublicKeyTypes,
|
||||
)
|
||||
|
||||
|
||||
class PublicKeyError(OpenSSLObjectError):
|
||||
pass
|
||||
@@ -225,7 +232,7 @@ class PublicKeyError(OpenSSLObjectError):
|
||||
|
||||
class PublicKey(OpenSSLObject):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(PublicKey, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -233,27 +240,29 @@ class PublicKey(OpenSSLObject):
|
||||
module.check_mode,
|
||||
)
|
||||
self.module = module
|
||||
self.format = module.params["format"]
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
if self.privatekey_content is not None:
|
||||
self.privatekey_content = self.privatekey_content.encode("utf-8")
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.privatekey = None
|
||||
self.publickey_bytes = None
|
||||
self.return_content = module.params["return_content"]
|
||||
self.fingerprint = {}
|
||||
self.format: t.Literal["OpenSSH", "PEM"] = module.params["format"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
if privatekey_content is not None:
|
||||
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
|
||||
else:
|
||||
self.privatekey_content = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
self.privatekey: PrivateKeyTypes | None = None
|
||||
self.publickey_bytes: bytes | None = None
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
self.fingerprint: dict[str, str] = {}
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
self.diff_before = self._get_info(None)
|
||||
self.diff_after = self._get_info(None)
|
||||
|
||||
def _get_info(self, data):
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return dict()
|
||||
result = dict(can_parse_key=False)
|
||||
return {}
|
||||
result = {"can_parse_key": False}
|
||||
try:
|
||||
result.update(
|
||||
get_publickey_info(
|
||||
@@ -267,7 +276,7 @@ class PublicKey(OpenSSLObject):
|
||||
pass
|
||||
return result
|
||||
|
||||
def _create_publickey(self, module):
|
||||
def _create_publickey(self, module: AnsibleModule) -> bytes:
|
||||
self.privatekey = load_privatekey(
|
||||
path=self.privatekey_path,
|
||||
content=self.privatekey_content,
|
||||
@@ -284,10 +293,12 @@ class PublicKey(OpenSSLObject):
|
||||
crypto_serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate the public key."""
|
||||
|
||||
if self.privatekey_content is None and not os.path.exists(self.privatekey_path):
|
||||
if self.privatekey_path is not None and not os.path.exists(
|
||||
self.privatekey_path
|
||||
):
|
||||
raise PublicKeyError(
|
||||
f"The private key {self.privatekey_path} does not exist"
|
||||
)
|
||||
@@ -320,17 +331,18 @@ class PublicKey(OpenSSLObject):
|
||||
elif module.set_fs_attributes_if_different(file_args, False):
|
||||
self.changed = True
|
||||
|
||||
def check(self, module, perms_required=True):
|
||||
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
|
||||
state_and_perms = super(PublicKey, self).check(module, perms_required)
|
||||
|
||||
def _check_privatekey():
|
||||
if self.privatekey_content is None and not os.path.exists(
|
||||
def _check_privatekey() -> bool:
|
||||
if self.privatekey_path is not None and not os.path.exists(
|
||||
self.privatekey_path
|
||||
):
|
||||
return False
|
||||
|
||||
current_publickey: PublicKeyTypes
|
||||
try:
|
||||
with open(self.path, "rb") as public_key_fh:
|
||||
publickey_content = public_key_fh.read()
|
||||
@@ -369,15 +381,15 @@ class PublicKey(OpenSSLObject):
|
||||
|
||||
return _check_privatekey()
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
super(PublicKey, self).remove(module)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"privatekey": self.privatekey_path,
|
||||
"filename": self.path,
|
||||
"format": self.format,
|
||||
@@ -403,7 +415,7 @@ class PublicKey(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
|
||||
@@ -152,6 +152,7 @@ public_data:
|
||||
returned: When RV(type=DSA) or RV(type=ECC)
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -163,7 +164,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path"),
|
||||
@@ -191,7 +192,7 @@ def main():
|
||||
data = f.read()
|
||||
except (IOError, OSError) as e:
|
||||
module.fail_json(
|
||||
msg=f"Error while reading public key file from disk: {e}", **result
|
||||
msg=f"Error while reading public key file from disk: {e}", **result # type: ignore
|
||||
)
|
||||
|
||||
module_backend = select_backend(module, data)
|
||||
@@ -201,7 +202,7 @@ def main():
|
||||
module.exit_json(**result)
|
||||
except PublicKeyParseError as exc:
|
||||
result.update(exc.result)
|
||||
module.fail_json(msg=exc.error_message, **result)
|
||||
module.fail_json(msg=exc.error_message, **result) # type: ignore
|
||||
except OpenSSLObjectError as exc:
|
||||
module.fail_json(msg=str(exc))
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ signature:
|
||||
|
||||
import base64
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
|
||||
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
|
||||
@@ -132,7 +133,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im
|
||||
|
||||
class SignatureBase(OpenSSLObject):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SignatureBase, self).__init__(
|
||||
path=module.params["path"],
|
||||
state="present",
|
||||
@@ -140,32 +141,35 @@ class SignatureBase(OpenSSLObject):
|
||||
check_mode=module.check_mode,
|
||||
)
|
||||
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
if self.privatekey_content is not None:
|
||||
self.privatekey_content = self.privatekey_content.encode("utf-8")
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.module = module
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
if privatekey_content is not None:
|
||||
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
|
||||
else:
|
||||
self.privatekey_content = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
|
||||
def generate(self):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
# Empty method because OpenSSLObject wants this
|
||||
pass
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
# Empty method because OpenSSLObject wants this
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
# Implementation with using cryptography
|
||||
class SignatureCryptography(SignatureBase):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SignatureCryptography, self).__init__(module)
|
||||
|
||||
def run(self):
|
||||
def run(self) -> dict[str, t.Any]:
|
||||
_padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
|
||||
_hash = cryptography.hazmat.primitives.hashes.SHA256()
|
||||
|
||||
result = dict()
|
||||
result: dict[str, t.Any] = {}
|
||||
|
||||
try:
|
||||
with open(self.path, "rb") as f:
|
||||
@@ -223,7 +227,7 @@ class SignatureCryptography(SignatureBase):
|
||||
raise OpenSSLObjectError(e)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
privatekey_path=dict(type="path"),
|
||||
|
||||
@@ -88,6 +88,7 @@ valid:
|
||||
|
||||
import base64
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
|
||||
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
|
||||
@@ -121,7 +122,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im
|
||||
|
||||
class SignatureInfoBase(OpenSSLObject):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SignatureInfoBase, self).__init__(
|
||||
path=module.params["path"],
|
||||
state="present",
|
||||
@@ -129,32 +130,35 @@ class SignatureInfoBase(OpenSSLObject):
|
||||
check_mode=module.check_mode,
|
||||
)
|
||||
|
||||
self.signature = module.params["signature"]
|
||||
self.certificate_path = module.params["certificate_path"]
|
||||
self.certificate_content = module.params["certificate_content"]
|
||||
if self.certificate_content is not None:
|
||||
self.certificate_content = self.certificate_content.encode("utf-8")
|
||||
self.module = module
|
||||
self.signature: str = module.params["signature"]
|
||||
self.certificate_path: str | None = module.params["certificate_path"]
|
||||
certificate_content: str | None = module.params["certificate_content"]
|
||||
if certificate_content is not None:
|
||||
self.certificate_content: bytes | None = certificate_content.encode("utf-8")
|
||||
else:
|
||||
self.certificate_content = None
|
||||
|
||||
def generate(self):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
# Empty method because OpenSSLObject wants this
|
||||
pass
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
# Empty method because OpenSSLObject wants this
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
# Implementation with using cryptography
|
||||
class SignatureInfoCryptography(SignatureInfoBase):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SignatureInfoCryptography, self).__init__(module)
|
||||
|
||||
def run(self):
|
||||
def run(self) -> dict[str, t.Any]:
|
||||
_padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
|
||||
_hash = cryptography.hazmat.primitives.hashes.SHA256()
|
||||
|
||||
result = dict()
|
||||
result: dict[str, t.Any] = {}
|
||||
|
||||
try:
|
||||
with open(self.path, "rb") as f:
|
||||
@@ -228,7 +232,7 @@ class SignatureInfoCryptography(SignatureInfoBase):
|
||||
raise OpenSSLObjectError(e)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
certificate_path=dict(type="path"),
|
||||
|
||||
@@ -224,6 +224,7 @@ certificate:
|
||||
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
@@ -257,8 +258,15 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
|
||||
CertificateBackend,
|
||||
)
|
||||
|
||||
|
||||
class CertificateAbsent(OpenSSLObject):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CertificateAbsent, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -266,19 +274,19 @@ class CertificateAbsent(OpenSSLObject):
|
||||
module.check_mode,
|
||||
)
|
||||
self.module = module
|
||||
self.return_content = module.params["return_content"]
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
pass
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
super(CertificateAbsent, self).remove(module)
|
||||
|
||||
def dump(self, check_mode=False):
|
||||
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
|
||||
result = {
|
||||
"changed": self.changed,
|
||||
"filename": self.path,
|
||||
@@ -296,7 +304,7 @@ class CertificateAbsent(OpenSSLObject):
|
||||
class GenericCertificate(OpenSSLObject):
|
||||
"""Retrieve a certificate using the given module backend."""
|
||||
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(self, module: AnsibleModule, module_backend: CertificateBackend):
|
||||
super(GenericCertificate, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -311,7 +319,7 @@ class GenericCertificate(OpenSSLObject):
|
||||
self.module_backend = module_backend
|
||||
self.module_backend.set_existing(load_file_if_exists(self.path, module))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
if self.module_backend.needs_regeneration():
|
||||
if not self.check_mode:
|
||||
self.module_backend.generate_certificate()
|
||||
@@ -329,14 +337,14 @@ class GenericCertificate(OpenSSLObject):
|
||||
file_args, self.changed
|
||||
)
|
||||
|
||||
def check(self, module, perms_required=True):
|
||||
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
return (
|
||||
super(GenericCertificate, self).check(module, perms_required)
|
||||
and not self.module_backend.needs_regeneration()
|
||||
)
|
||||
|
||||
def dump(self, check_mode=False):
|
||||
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
|
||||
result = self.module_backend.dump(include_certificate=self.return_content)
|
||||
result.update(
|
||||
{
|
||||
@@ -349,7 +357,7 @@ class GenericCertificate(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = get_certificate_argument_spec()
|
||||
add_acme_provider_to_argument_spec(argument_spec)
|
||||
add_entrust_provider_to_argument_spec(argument_spec)
|
||||
@@ -363,13 +371,14 @@ def main():
|
||||
return_content=dict(type="bool", default=False),
|
||||
)
|
||||
)
|
||||
argument_spec.required_if.append(["state", "present", ["provider"]])
|
||||
argument_spec.required_if.append(("state", "present", ["provider"]))
|
||||
module = argument_spec.create_ansible_module(
|
||||
add_file_common_args=True,
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
try:
|
||||
certificate: GenericCertificate | CertificateAbsent
|
||||
if module.params["state"] == "absent":
|
||||
certificate = CertificateAbsent(module)
|
||||
|
||||
@@ -389,7 +398,13 @@ def main():
|
||||
)
|
||||
|
||||
provider = module.params["provider"]
|
||||
provider_map = {
|
||||
provider_map: dict[
|
||||
str,
|
||||
type[AcmeCertificateProvider]
|
||||
| type[EntrustCertificateProvider]
|
||||
| type[OwnCACertificateProvider]
|
||||
| type[SelfSignedCertificateProvider],
|
||||
] = {
|
||||
"acme": AcmeCertificateProvider,
|
||||
"entrust": EntrustCertificateProvider,
|
||||
"ownca": OwnCACertificateProvider,
|
||||
|
||||
@@ -106,6 +106,7 @@ backup_file:
|
||||
|
||||
import base64
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
@@ -142,8 +143,12 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def parse_certificate(input, strict=False):
|
||||
input_format = "pem" if identify_pem_format(input) else "der"
|
||||
def parse_certificate(
|
||||
input: bytes, strict: bool = False
|
||||
) -> tuple[bytes, t.Literal["pem", "der"], str | None]:
|
||||
input_format: t.Literal["pem", "der"] = (
|
||||
"pem" if identify_pem_format(input) else "der"
|
||||
)
|
||||
if input_format == "pem":
|
||||
pems = split_pem_list(to_text(input))
|
||||
if len(pems) > 1 and strict:
|
||||
@@ -162,7 +167,7 @@ def parse_certificate(input, strict=False):
|
||||
|
||||
|
||||
class X509CertificateConvertModule(OpenSSLObject):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(X509CertificateConvertModule, self).__init__(
|
||||
module.params["dest_path"],
|
||||
"present",
|
||||
@@ -170,9 +175,9 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
module.check_mode,
|
||||
)
|
||||
|
||||
self.src_path = module.params["src_path"]
|
||||
self.src_content = module.params["src_content"]
|
||||
self.src_content_base64 = module.params["src_content_base64"]
|
||||
self.src_path: str | None = module.params["src_path"]
|
||||
self.src_content: str | None = module.params["src_content"]
|
||||
self.src_content_base64: bool = module.params["src_content_base64"]
|
||||
if self.src_content is not None:
|
||||
self.input = to_bytes(self.src_content)
|
||||
if self.src_content_base64:
|
||||
@@ -181,6 +186,8 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
except Exception as exc:
|
||||
module.fail_json(msg=f"Cannot Base64 decode src_content: {exc}")
|
||||
else:
|
||||
if self.src_path is None:
|
||||
module.fail_json(msg="One of src_path and src_content must be provided")
|
||||
try:
|
||||
with open(self.src_path, "rb") as f:
|
||||
self.input = f.read()
|
||||
@@ -189,8 +196,8 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
msg=f"Failure while reading file {self.src_path}: {exc}"
|
||||
)
|
||||
|
||||
self.format = module.params["format"]
|
||||
self.strict = module.params["strict"]
|
||||
self.format: t.Literal["pem", "der"] = module.params["format"]
|
||||
self.strict: bool = module.params["strict"]
|
||||
self.wanted_pem_type = "CERTIFICATE"
|
||||
|
||||
try:
|
||||
@@ -203,8 +210,8 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
if module.params["verify_cert_parsable"]:
|
||||
self.verify_cert_parsable(module)
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
module.params["path"] = self.path
|
||||
|
||||
@@ -221,7 +228,7 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def verify_cert_parsable(self, module):
|
||||
def verify_cert_parsable(self, module: AnsibleModule) -> None:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
@@ -230,7 +237,7 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
except Exception as exc:
|
||||
module.fail_json(msg=f"Error while parsing certificate: {exc}")
|
||||
|
||||
def needs_conversion(self):
|
||||
def needs_conversion(self) -> bool:
|
||||
if self.dest_content is None or self.dest_content_format is None:
|
||||
return True
|
||||
if self.dest_content_format != self.format:
|
||||
@@ -241,7 +248,7 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_dest_certificate(self):
|
||||
def get_dest_certificate(self) -> bytes:
|
||||
if self.format == "der":
|
||||
return self.input
|
||||
data = to_bytes(base64.b64encode(self.input))
|
||||
@@ -250,7 +257,7 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
lines.append(to_bytes(f"{PEM_END_START}{self.wanted_pem_type}{PEM_END}\n"))
|
||||
return b"\n".join(lines)
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Do conversion."""
|
||||
if self.needs_conversion():
|
||||
# Convert
|
||||
@@ -269,18 +276,18 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
file_args, self.changed
|
||||
)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = dict(
|
||||
changed=self.changed,
|
||||
)
|
||||
result: dict[str, t.Any] = {
|
||||
"changed": self.changed,
|
||||
}
|
||||
if self.backup_file:
|
||||
result["backup_file"] = self.backup_file
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = dict(
|
||||
src_path=dict(type="path"),
|
||||
src_content=dict(type="str"),
|
||||
|
||||
@@ -390,8 +390,10 @@ issuer_uri:
|
||||
version_added: 2.9.0
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
)
|
||||
@@ -406,7 +408,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path"),
|
||||
@@ -424,18 +426,22 @@ def main():
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
if module.params["content"] is not None:
|
||||
data = module.params["content"].encode("utf-8")
|
||||
content: str | None = module.params["content"]
|
||||
path: str | None = module.params["path"]
|
||||
if content is not None:
|
||||
data = content.encode("utf-8")
|
||||
else:
|
||||
if path is None:
|
||||
module.fail_json(msg="One of path and content must be provided")
|
||||
try:
|
||||
with open(module.params["path"], "rb") as f:
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
except (IOError, OSError) as e:
|
||||
module.fail_json(msg=f"Error while reading certificate file from disk: {e}")
|
||||
|
||||
module_backend = select_backend(module, data)
|
||||
|
||||
valid_at = module.params["valid_at"]
|
||||
valid_at: dict[str, t.Any] = module.params["valid_at"]
|
||||
if valid_at:
|
||||
for k, v in valid_at.items():
|
||||
if not isinstance(v, (str, bytes)):
|
||||
@@ -443,13 +449,11 @@ def main():
|
||||
msg=f"The value for valid_at.{k} must be of type string (got {type(v)})"
|
||||
)
|
||||
valid_at[k] = get_relative_time_option(
|
||||
v, f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE
|
||||
to_text(v), f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE
|
||||
)
|
||||
|
||||
try:
|
||||
result = module_backend.get_info(
|
||||
der_support_enabled=module.params["content"] is None
|
||||
)
|
||||
result = module_backend.get_info(der_support_enabled=content is None)
|
||||
|
||||
not_before = module_backend.get_not_before()
|
||||
not_after = module_backend.get_not_after()
|
||||
|
||||
@@ -118,6 +118,8 @@ certificate:
|
||||
type: str
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
)
|
||||
@@ -139,23 +141,31 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
|
||||
CertificateBackend,
|
||||
)
|
||||
|
||||
|
||||
class GenericCertificate:
|
||||
"""Retrieve a certificate using the given module backend."""
|
||||
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(self, module: AnsibleModule, module_backend: CertificateBackend):
|
||||
self.check_mode = module.check_mode
|
||||
self.module = module
|
||||
self.module_backend = module_backend
|
||||
self.changed = False
|
||||
if module.params["content"] is not None:
|
||||
self.module_backend.set_existing(module.params["content"].encode("utf-8"))
|
||||
content: str | None = module.params["content"]
|
||||
if content is not None:
|
||||
self.module_backend.set_existing(content.encode("utf-8"))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
if self.module_backend.needs_regeneration():
|
||||
self.module_backend.generate_certificate()
|
||||
self.changed = True
|
||||
|
||||
def dump(self, check_mode=False):
|
||||
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
|
||||
result = self.module_backend.dump(include_certificate=True)
|
||||
result.update(
|
||||
{
|
||||
@@ -165,7 +175,7 @@ class GenericCertificate:
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = get_certificate_argument_spec()
|
||||
argument_spec.argument_spec["provider"]["required"] = True
|
||||
add_entrust_provider_to_argument_spec(argument_spec)
|
||||
@@ -182,7 +192,12 @@ def main():
|
||||
|
||||
try:
|
||||
provider = module.params["provider"]
|
||||
provider_map = {
|
||||
provider_map: dict[
|
||||
str,
|
||||
type[EntrustCertificateProvider]
|
||||
| type[OwnCACertificateProvider]
|
||||
| type[SelfSignedCertificateProvider],
|
||||
] = {
|
||||
"entrust": EntrustCertificateProvider,
|
||||
"ownca": OwnCACertificateProvider,
|
||||
"selfsigned": SelfSignedCertificateProvider,
|
||||
|
||||
@@ -425,6 +425,7 @@ crl:
|
||||
|
||||
import base64
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_text
|
||||
@@ -463,7 +464,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
|
||||
OpenSSLObject,
|
||||
load_certificate,
|
||||
load_privatekey,
|
||||
load_certificate_issuer_privatekey,
|
||||
parse_name_field,
|
||||
parse_ordered_name_field,
|
||||
select_message_digest,
|
||||
@@ -495,6 +496,9 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
|
||||
class CRLError(OpenSSLObjectError):
|
||||
pass
|
||||
@@ -502,7 +506,7 @@ class CRLError(OpenSSLObjectError):
|
||||
|
||||
class CRL(OpenSSLObject):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CRL, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -510,53 +514,69 @@ class CRL(OpenSSLObject):
|
||||
module.check_mode,
|
||||
)
|
||||
|
||||
self.format = module.params["format"]
|
||||
self.format: t.Literal["pem", "der"] = module.params["format"]
|
||||
|
||||
self.update = module.params["crl_mode"] == "update"
|
||||
self.ignore_timestamps = module.params["ignore_timestamps"]
|
||||
self.return_content = module.params["return_content"]
|
||||
self.name_encoding = module.params["name_encoding"]
|
||||
self.serial_numbers_format = module.params["serial_numbers"]
|
||||
self.crl_content = None
|
||||
self.update: bool = module.params["crl_mode"] == "update"
|
||||
self.ignore_timestamps: bool = module.params["ignore_timestamps"]
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params[
|
||||
"name_encoding"
|
||||
]
|
||||
self.serial_numbers_format: t.Literal["integer", "hex-octets"] = module.params[
|
||||
"serial_numbers"
|
||||
]
|
||||
self.crl_content: bytes | None = None
|
||||
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
if self.privatekey_content is not None:
|
||||
self.privatekey_content = self.privatekey_content.encode("utf-8")
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
if privatekey_content is not None:
|
||||
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
|
||||
else:
|
||||
self.privatekey_content = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
|
||||
try:
|
||||
if module.params["issuer_ordered"]:
|
||||
issuer_ordered: list[dict[str, t.Any]] | None = module.params[
|
||||
"issuer_ordered"
|
||||
]
|
||||
issuer: dict[str, list[str | bytes] | bytes | str] | None = module.params[
|
||||
"issuer"
|
||||
]
|
||||
if issuer_ordered:
|
||||
self.issuer_ordered = True
|
||||
self.issuer = parse_ordered_name_field(
|
||||
module.params["issuer_ordered"], "issuer_ordered"
|
||||
)
|
||||
self.issuer = parse_ordered_name_field(issuer_ordered, "issuer_ordered")
|
||||
else:
|
||||
self.issuer_ordered = False
|
||||
self.issuer = parse_name_field(module.params["issuer"], "issuer")
|
||||
self.issuer = (
|
||||
parse_name_field(issuer, "issuer") if issuer is not None else []
|
||||
)
|
||||
except (TypeError, ValueError) as exc:
|
||||
module.fail_json(msg=str(exc))
|
||||
|
||||
self.last_update = get_relative_time_option(
|
||||
self.last_update: datetime.datetime = get_relative_time_option(
|
||||
module.params["last_update"],
|
||||
"last_update",
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
self.next_update = get_relative_time_option(
|
||||
self.next_update: datetime.datetime | None = get_relative_time_option(
|
||||
module.params["next_update"],
|
||||
"next_update",
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
|
||||
self.digest = select_message_digest(module.params["digest"])
|
||||
if self.digest is None:
|
||||
digest = select_message_digest(module.params["digest"])
|
||||
if digest is None:
|
||||
raise CRLError(f'The digest "{module.params["digest"]}" is not supported')
|
||||
self.digest = digest
|
||||
|
||||
self.module = module
|
||||
|
||||
self.revoked_certificates = []
|
||||
for i, rc in enumerate(module.params["revoked_certificates"]):
|
||||
result = {
|
||||
revoked_certificates: list[dict[str, t.Any]] = module.params[
|
||||
"revoked_certificates"
|
||||
]
|
||||
for i, rc in enumerate(revoked_certificates):
|
||||
result: dict[str, t.Any] = {
|
||||
"serial_number": None,
|
||||
"revocation_date": None,
|
||||
"issuer": None,
|
||||
@@ -567,21 +587,25 @@ class CRL(OpenSSLObject):
|
||||
"invalidity_date_critical": False,
|
||||
}
|
||||
path_prefix = f"revoked_certificates[{i}]."
|
||||
if rc["path"] is not None or rc["content"] is not None:
|
||||
path: str | None = rc["path"]
|
||||
content_str: str | None = rc["content"]
|
||||
if path is not None or content_str is not None:
|
||||
# Load certificate from file or content
|
||||
try:
|
||||
if rc["content"] is not None:
|
||||
rc["content"] = rc["content"].encode("utf-8")
|
||||
cert = load_certificate(rc["path"], content=rc["content"])
|
||||
content: bytes | None = None
|
||||
if content_str is not None:
|
||||
content = content_str.encode("utf-8")
|
||||
rc["content"] = content
|
||||
cert = load_certificate(path, content=content)
|
||||
result["serial_number"] = cert.serial_number
|
||||
except OpenSSLObjectError as e:
|
||||
if rc["content"] is not None:
|
||||
if content_str is not None:
|
||||
module.fail_json(
|
||||
msg=f"Cannot parse certificate from {path_prefix}content: {e}"
|
||||
)
|
||||
else:
|
||||
module.fail_json(
|
||||
msg=f'Cannot read certificate "{rc["path"]}" from {path_prefix}path: {e}'
|
||||
msg=f'Cannot read certificate "{path}" from {path_prefix}path: {e}'
|
||||
)
|
||||
else:
|
||||
# Specify serial_number (and potentially issuer) directly
|
||||
@@ -611,11 +635,11 @@ class CRL(OpenSSLObject):
|
||||
result["invalidity_date_critical"] = rc["invalidity_date_critical"]
|
||||
self.revoked_certificates.append(result)
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
try:
|
||||
self.privatekey = load_privatekey(
|
||||
self.privatekey = load_certificate_issuer_privatekey(
|
||||
path=self.privatekey_path,
|
||||
content=self.privatekey_content,
|
||||
passphrase=self.privatekey_passphrase,
|
||||
@@ -643,7 +667,7 @@ class CRL(OpenSSLObject):
|
||||
|
||||
self.diff_after = self.diff_before = self._get_info(data)
|
||||
|
||||
def _parse_serial_number(self, value, index):
|
||||
def _parse_serial_number(self, value: t.Any, index: int) -> int:
|
||||
if self.serial_numbers_format == "integer":
|
||||
try:
|
||||
return check_type_int(value)
|
||||
@@ -662,22 +686,42 @@ class CRL(OpenSSLObject):
|
||||
f"Unexpected value {self.serial_numbers_format} of serial_numbers"
|
||||
)
|
||||
|
||||
def _get_info(self, data):
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return dict()
|
||||
return {}
|
||||
try:
|
||||
result = get_crl_info(self.module, data)
|
||||
result["can_parse_crl"] = True
|
||||
return result
|
||||
except Exception:
|
||||
return dict(can_parse_crl=False)
|
||||
return {"can_parse_crl": False}
|
||||
|
||||
def remove(self):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
if self.backup:
|
||||
self.backup_file = self.module.backup_local(self.path)
|
||||
super(CRL, self).remove(self.module)
|
||||
|
||||
def _compress_entry(self, entry):
|
||||
def _compress_entry(self, entry: dict[str, t.Any]) -> (
|
||||
tuple[
|
||||
int | None,
|
||||
tuple[str, ...] | None,
|
||||
bool,
|
||||
int | None,
|
||||
bool,
|
||||
datetime.datetime | None,
|
||||
bool,
|
||||
]
|
||||
| tuple[
|
||||
int | None,
|
||||
datetime.datetime | None,
|
||||
tuple[str, ...] | None,
|
||||
bool,
|
||||
int | None,
|
||||
bool,
|
||||
datetime.datetime | None,
|
||||
bool,
|
||||
]
|
||||
):
|
||||
issuer = None
|
||||
if entry["issuer"] is not None:
|
||||
# Normalize to IDNA. If this is used-provided, it was already converted to
|
||||
@@ -713,7 +757,12 @@ class CRL(OpenSSLObject):
|
||||
entry["invalidity_date_critical"],
|
||||
)
|
||||
|
||||
def check(self, module, perms_required=True, ignore_conversion=True):
|
||||
def check(
|
||||
self,
|
||||
module: AnsibleModule,
|
||||
perms_required: bool = True,
|
||||
ignore_conversion: bool = True,
|
||||
) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
|
||||
state_and_perms = super(CRL, self).check(self.module, perms_required)
|
||||
@@ -743,10 +792,13 @@ class CRL(OpenSSLObject):
|
||||
]
|
||||
is_issuer = [(sub.oid, sub.value) for sub in self.crl.issuer]
|
||||
if not self.issuer_ordered:
|
||||
want_issuer = set(want_issuer)
|
||||
is_issuer = set(is_issuer)
|
||||
if want_issuer != is_issuer:
|
||||
return False
|
||||
want_issuer_set = set(want_issuer)
|
||||
is_issuer_set = set(is_issuer)
|
||||
if want_issuer_set != is_issuer_set:
|
||||
return False
|
||||
else:
|
||||
if want_issuer != is_issuer:
|
||||
return False
|
||||
|
||||
old_entries = [
|
||||
self._compress_entry(cryptography_decode_revoked_certificate(cert))
|
||||
@@ -769,7 +821,7 @@ class CRL(OpenSSLObject):
|
||||
|
||||
return True
|
||||
|
||||
def _generate_crl(self):
|
||||
def _generate_crl(self) -> bytes:
|
||||
crl = CertificateRevocationListBuilder()
|
||||
|
||||
try:
|
||||
@@ -787,7 +839,8 @@ class CRL(OpenSSLObject):
|
||||
raise CRLError(e)
|
||||
|
||||
crl = set_last_update(crl, self.last_update)
|
||||
crl = set_next_update(crl, self.next_update)
|
||||
if self.next_update is not None:
|
||||
crl = set_next_update(crl, self.next_update)
|
||||
|
||||
if self.update and self.crl:
|
||||
new_entries = set(
|
||||
@@ -799,22 +852,26 @@ class CRL(OpenSSLObject):
|
||||
)
|
||||
if decoded_entry not in new_entries:
|
||||
crl = crl.add_revoked_certificate(entry)
|
||||
for entry in self.revoked_certificates:
|
||||
for revoked_entry in self.revoked_certificates:
|
||||
revoked_cert = RevokedCertificateBuilder()
|
||||
revoked_cert = revoked_cert.serial_number(entry["serial_number"])
|
||||
revoked_cert = set_revocation_date(revoked_cert, entry["revocation_date"])
|
||||
if entry["issuer"] is not None:
|
||||
revoked_cert = revoked_cert.serial_number(revoked_entry["serial_number"])
|
||||
revoked_cert = set_revocation_date(
|
||||
revoked_cert, revoked_entry["revocation_date"]
|
||||
)
|
||||
if revoked_entry["issuer"] is not None:
|
||||
revoked_cert = revoked_cert.add_extension(
|
||||
x509.CertificateIssuer(entry["issuer"]), entry["issuer_critical"]
|
||||
x509.CertificateIssuer(revoked_entry["issuer"]),
|
||||
revoked_entry["issuer_critical"],
|
||||
)
|
||||
if entry["reason"] is not None:
|
||||
if revoked_entry["reason"] is not None:
|
||||
revoked_cert = revoked_cert.add_extension(
|
||||
x509.CRLReason(entry["reason"]), entry["reason_critical"]
|
||||
x509.CRLReason(revoked_entry["reason"]),
|
||||
revoked_entry["reason_critical"],
|
||||
)
|
||||
if entry["invalidity_date"] is not None:
|
||||
if revoked_entry["invalidity_date"] is not None:
|
||||
revoked_cert = revoked_cert.add_extension(
|
||||
x509.InvalidityDate(entry["invalidity_date"]),
|
||||
entry["invalidity_date_critical"],
|
||||
x509.InvalidityDate(revoked_entry["invalidity_date"]),
|
||||
revoked_entry["invalidity_date_critical"],
|
||||
)
|
||||
crl = crl.add_revoked_certificate(revoked_cert.build())
|
||||
|
||||
@@ -827,7 +884,7 @@ class CRL(OpenSSLObject):
|
||||
else:
|
||||
return self.crl.public_bytes(Encoding.DER)
|
||||
|
||||
def generate(self):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
result = None
|
||||
if (
|
||||
not self.check(self.module, perms_required=False, ignore_conversion=True)
|
||||
@@ -861,7 +918,7 @@ class CRL(OpenSSLObject):
|
||||
elif self.module.set_fs_attributes_if_different(file_args, False):
|
||||
self.changed = True
|
||||
|
||||
def dump(self, check_mode=False):
|
||||
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
|
||||
result = {
|
||||
"changed": self.changed,
|
||||
"filename": self.path,
|
||||
@@ -879,37 +936,53 @@ class CRL(OpenSSLObject):
|
||||
|
||||
if check_mode:
|
||||
result["last_update"] = self.last_update.strftime(TIMESTAMP_FORMAT)
|
||||
result["next_update"] = self.next_update.strftime(TIMESTAMP_FORMAT)
|
||||
result["next_update"] = (
|
||||
self.next_update.strftime(TIMESTAMP_FORMAT)
|
||||
if self.next_update is not None
|
||||
else None
|
||||
)
|
||||
# result['digest'] = cryptography_oid_to_name(self.crl.signature_algorithm_oid)
|
||||
result["digest"] = self.module.params["digest"]
|
||||
result["issuer_ordered"] = self.issuer
|
||||
result["issuer"] = {}
|
||||
issuer: dict[str, str | bytes] = {}
|
||||
result["issuer"] = issuer
|
||||
for k, v in self.issuer:
|
||||
result["issuer"][k] = v
|
||||
result["revoked_certificates"] = []
|
||||
issuer[k] = v
|
||||
revoked_certificates: list[dict[str, t.Any]] = []
|
||||
result["revoked_certificates"] = revoked_certificates
|
||||
for entry in self.revoked_certificates:
|
||||
result["revoked_certificates"].append(
|
||||
revoked_certificates.append(
|
||||
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
|
||||
)
|
||||
elif self.crl:
|
||||
result["last_update"] = get_last_update(self.crl).strftime(TIMESTAMP_FORMAT)
|
||||
result["next_update"] = get_next_update(self.crl).strftime(TIMESTAMP_FORMAT)
|
||||
next_update = get_next_update(self.crl)
|
||||
result["next_update"] = (
|
||||
next_update.strftime(TIMESTAMP_FORMAT)
|
||||
if next_update is not None
|
||||
else None
|
||||
)
|
||||
result["digest"] = cryptography_oid_to_name(
|
||||
cryptography_get_signature_algorithm_oid_from_crl(self.crl)
|
||||
)
|
||||
issuer = []
|
||||
issuer_list: list[list[str]] = []
|
||||
for attribute in self.crl.issuer:
|
||||
issuer.append(
|
||||
[cryptography_oid_to_name(attribute.oid), attribute.value]
|
||||
issuer_list.append(
|
||||
[
|
||||
cryptography_oid_to_name(attribute.oid),
|
||||
to_text(attribute.value),
|
||||
]
|
||||
)
|
||||
result["issuer_ordered"] = issuer
|
||||
result["issuer"] = {}
|
||||
for k, v in issuer:
|
||||
result["issuer"][k] = v
|
||||
result["revoked_certificates"] = []
|
||||
result["issuer_ordered"] = issuer_list
|
||||
issuer = {}
|
||||
result["issuer"] = issuer
|
||||
for k, v in issuer_list:
|
||||
issuer[k] = v
|
||||
revoked_certificates = []
|
||||
result["revoked_certificates"] = revoked_certificates
|
||||
for cert in self.crl:
|
||||
entry = cryptography_decode_revoked_certificate(cert)
|
||||
result["revoked_certificates"].append(
|
||||
revoked_certificates.append(
|
||||
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
|
||||
)
|
||||
|
||||
@@ -923,7 +996,7 @@ class CRL(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
state=dict(type="str", default="present", choices=["present", "absent"]),
|
||||
@@ -1015,14 +1088,14 @@ def main():
|
||||
)
|
||||
module.exit_json(**result)
|
||||
|
||||
crl.generate()
|
||||
crl.generate(module)
|
||||
else:
|
||||
if module.check_mode:
|
||||
result = crl.dump(check_mode=True)
|
||||
result["changed"] = os.path.exists(module.params["path"])
|
||||
module.exit_json(**result)
|
||||
|
||||
crl.remove()
|
||||
crl.remove(module)
|
||||
|
||||
result = crl.dump()
|
||||
module.exit_json(**result)
|
||||
|
||||
@@ -100,7 +100,9 @@ last_update:
|
||||
type: str
|
||||
sample: '20190413202428Z'
|
||||
next_update:
|
||||
description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
|
||||
description:
|
||||
- The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
|
||||
- Will be C(none) if no such timestamp is present.
|
||||
returned: success
|
||||
type: str
|
||||
sample: '20190413202428Z'
|
||||
@@ -172,6 +174,7 @@ revoked_certificates:
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -185,7 +188,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path"),
|
||||
@@ -200,25 +203,30 @@ def main():
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
if module.params["content"] is None:
|
||||
content: str | None = module.params["content"]
|
||||
path: str | None = module.params["path"]
|
||||
if content is None:
|
||||
if path is None:
|
||||
module.fail_json(msg="One of content and path must be provided")
|
||||
try:
|
||||
with open(module.params["path"], "rb") as f:
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
except (IOError, OSError) as e:
|
||||
module.fail_json(msg=f"Error while reading CRL file from disk: {e}")
|
||||
else:
|
||||
data = module.params["content"].encode("utf-8")
|
||||
data = content.encode("utf-8")
|
||||
if not identify_pem_format(data):
|
||||
try:
|
||||
data = base64.b64decode(module.params["content"])
|
||||
data = base64.b64decode(content)
|
||||
except (binascii.Error, TypeError) as e:
|
||||
module.fail_json(msg=f"Error while Base64 decoding content: {e}")
|
||||
|
||||
list_revoked_certificates: bool = module.params["list_revoked_certificates"]
|
||||
try:
|
||||
result = get_crl_info(
|
||||
module,
|
||||
data,
|
||||
list_revoked_certificates=module.params["list_revoked_certificates"],
|
||||
list_revoked_certificates=list_revoked_certificates,
|
||||
)
|
||||
module.exit_json(**result)
|
||||
except OpenSSLObjectError as e:
|
||||
|
||||
Reference in New Issue
Block a user