Add type hints and type checking (#885)

* Enable basic type checking.

* Fix first errors.

* Add changelog fragment.

* Add types to module_utils and plugin_utils (without module backends).

* Add typing hints for acme_* modules.

* Add typing to X.509 certificate modules, and add more helpers.

* Add typing to remaining module backends.

* Add typing for action, filter, and lookup plugins.

* Bump ansible-core 2.19 beta requirement for typing.

* Add more typing definitions.

* Add typing to some unit tests.
This commit is contained in:
Felix Fontein
2025-05-11 18:00:11 +02:00
committed by GitHub
parent 82f0176773
commit f758d94fba
124 changed files with 4986 additions and 2662 deletions

View File

@@ -165,6 +165,7 @@ account_uri:
"""
import base64
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
@@ -180,7 +181,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec()
argument_spec.update_argspec(
terms_agreed=dict(type="bool", default=False),
@@ -204,24 +205,24 @@ def main():
),
)
argument_spec.update(
mutually_exclusive=(["new_account_key_src", "new_account_key_content"],),
required_if=(
mutually_exclusive=[("new_account_key_src", "new_account_key_content")],
required_if=[
# Make sure that for state == changed_key, one of
# new_account_key_src and new_account_key_content are specified
[
(
"state",
"changed_key",
["new_account_key_src", "new_account_key_content"],
True,
],
),
),
],
)
module = argument_spec.create_ansible_module(supports_check_mode=True)
backend = create_backend(module, True)
if module.params["external_account_binding"]:
# Make sure padding is there
key = module.params["external_account_binding"]["key"]
key: str = module.params["external_account_binding"]["key"]
if len(key) % 4 != 0:
key = key + ("=" * (4 - (len(key) % 4)))
# Make sure key is Base64 encoded
@@ -237,24 +238,25 @@ def main():
client = ACMEClient(module, backend)
account = ACMEAccount(client)
changed = False
state = module.params.get("state")
diff_before = {}
diff_after = {}
state: t.Literal["present", "absent", "changed_key"] = module.params["state"]
diff_before: dict[str, t.Any] = {}
diff_after: dict[str, t.Any] = {}
if state == "absent":
created, account_data = account.setup_account(allow_creation=False)
if account_data:
diff_before = dict(account_data)
diff_before["public_account_key"] = client.account_key_data["jwk"]
if client.account_key_data:
diff_before["public_account_key"] = client.account_key_data["jwk"]
if created:
raise AssertionError("Unwanted account creation")
if account_data is not None:
# Account is not yet deactivated
if not module.check_mode:
# Deactivate it
payload = {"status": "deactivated"}
deactivate_payload = {"status": "deactivated"}
result, info = client.send_signed_request(
client.account_uri,
payload,
t.cast(str, client.account_uri),
deactivate_payload,
error_msg="Failed to deactivate account",
expected_status_codes=[200],
)
@@ -278,13 +280,15 @@ def main():
diff_before = {}
else:
diff_before = dict(account_data)
diff_before["public_account_key"] = client.account_key_data["jwk"]
if client.account_key_data:
diff_before["public_account_key"] = client.account_key_data["jwk"]
updated = False
if not created:
updated, account_data = account.update_account(account_data, contact)
changed = created or updated
diff_after = dict(account_data)
diff_after["public_account_key"] = client.account_key_data["jwk"]
if client.account_key_data:
diff_after["public_account_key"] = client.account_key_data["jwk"]
elif state == "changed_key":
# Parse new account key
try:
@@ -306,7 +310,8 @@ def main():
msg="Account does not exist or is deactivated."
)
diff_before = dict(account_data)
diff_before["public_account_key"] = client.account_key_data["jwk"]
if client.account_key_data:
diff_before["public_account_key"] = client.account_key_data["jwk"]
# Now we can start the account key rollover
if not module.check_mode:
# Compose inner signed message
@@ -317,12 +322,12 @@ def main():
"jwk": new_key_data["jwk"],
"url": url,
}
payload = {
change_key_payload = {
"account": client.account_uri,
"newKey": new_key_data["jwk"], # specified in draft 12 and older
"oldKey": client.account_jwk, # specified in draft 13 and newer
}
data = client.sign_request(protected, payload, new_key_data)
data = client.sign_request(protected, change_key_payload, new_key_data)
# Send request and verify result
result, info = client.send_signed_request(
url,
@@ -332,8 +337,9 @@ def main():
)
if module._diff:
client.account_key_data = new_key_data
client.account_jws_header["alg"] = new_key_data["alg"]
diff_after = account.get_account_data()
if client.account_jws_header:
client.account_jws_header["alg"] = new_key_data["alg"]
diff_after = account.get_account_data() or {}
elif module._diff:
# Kind of fake diff_after
diff_after = dict(diff_before)

View File

@@ -204,6 +204,8 @@ order_uris:
version_added: 1.5.0
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
)
@@ -220,48 +222,55 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
)
def get_orders_list(module, client, orders_url):
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def get_orders_list(
module: AnsibleModule, client: ACMEClient, orders_url: str
) -> list[str]:
"""
Retrieves orders list (handles pagination).
Retrieves order URL list (handles pagination).
"""
orders = []
while orders_url:
orders: list[str] = []
next_orders_url: str | None = orders_url
while next_orders_url:
# Get part of orders list
res, info = client.get_request(
orders_url, parse_json_result=True, fail_on_error=True
next_orders_url, parse_json_result=True, fail_on_error=True
)
if not res.get("orders"):
if orders:
module.warn(
f"When retrieving orders list part {orders_url}, got empty result list"
f"When retrieving orders list part {next_orders_url}, got empty result list"
)
break
# Add order URLs to result list
orders.extend(res["orders"])
# Extract URL of next part of results list
new_orders_url = []
new_orders_url: list[str | None] = []
def f(link, relation):
def f(link: str, relation: str) -> None:
if relation == "next":
new_orders_url.append(link)
process_links(info, f)
new_orders_url.append(None)
previous_orders_url, orders_url = orders_url, new_orders_url.pop(0)
if orders_url == previous_orders_url:
previous_orders_url, next_orders_url = next_orders_url, new_orders_url.pop(0)
if next_orders_url == previous_orders_url:
# Prevent infinite loop
orders_url = None
next_orders_url = None
return orders
def get_order(client, order_url):
def get_order(client: ACMEClient, order_url: str) -> dict[str, t.Any]:
"""
Retrieve order data.
"""
return client.get_request(order_url, parse_json_result=True, fail_on_error=True)[0]
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec()
argument_spec.update_argspec(
retrieve_orders=dict(
@@ -282,16 +291,19 @@ def main():
)
if created:
raise AssertionError("Unwanted account creation")
result = {
result: dict[str, t.Any] = {
"changed": False,
"exists": client.account_uri is not None,
"account_uri": client.account_uri,
"exists": False,
"account_uri": None,
}
if client.account_uri is not None:
if client.account_uri is not None and account_data:
result["account_uri"] = client.account_uri
result["exists"] = True
# Make sure promised data is there
if "contact" not in account_data:
account_data["contact"] = []
account_data["public_account_key"] = client.account_key_data["jwk"]
if client.account_key_data:
account_data["public_account_key"] = client.account_key_data["jwk"]
result["account"] = account_data
# Retrieve orders list
if (

View File

@@ -94,6 +94,8 @@ renewal_info:
sample: '2024-04-29T01:17:10.236921+00:00'
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
ACMEClient,
create_backend,
@@ -104,15 +106,15 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_account=False)
argument_spec.update_argspec(
certificate_path=dict(type="path"),
certificate_content=dict(type="str"),
)
argument_spec.update(
required_one_of=(["certificate_path", "certificate_content"],),
mutually_exclusive=(["certificate_path", "certificate_content"],),
required_one_of=[("certificate_path", "certificate_content")],
mutually_exclusive=[("certificate_path", "certificate_content")],
)
module = argument_spec.create_ansible_module(supports_check_mode=True)
backend = create_backend(module, True)

View File

@@ -562,6 +562,7 @@ all_chains:
"""
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
@@ -592,6 +593,17 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
CertificateInformation,
CryptoBackend,
)
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
Authorization,
)
NO_CHALLENGE = "no challenge"
@@ -602,7 +614,7 @@ class ACMECertificateClient:
certificates.
"""
def __init__(self, module, backend):
def __init__(self, module: AnsibleModule, backend: CryptoBackend):
self.module = module
self.version = module.params["acme_version"]
self.challenge = module.params["challenge"]
@@ -618,9 +630,9 @@ class ACMECertificateClient:
self.account = ACMEAccount(self.client)
self.directory = self.client.directory
self.data = module.params["data"]
self.authorizations = None
self.authorizations: dict[str, Authorization] | None = None
self.cert_days = -1
self.order = None
self.order: Order | None = None
self.order_uri = self.data.get("order_uri") if self.data else None
self.all_chains = None
self.select_chain_matcher = []
@@ -662,7 +674,6 @@ class ACMECertificateClient:
contact.append("mailto:" + module.params["account_email"])
created, account_data = self.account.setup_account(
contact,
agreement=module.params.get("agreement"),
terms_agreed=module.params.get("terms_agreed"),
allow_creation=modify_account,
)
@@ -681,7 +692,7 @@ class ACMECertificateClient:
csr_filename=self.csr, csr_content=self.csr_content
)
def is_first_step(self):
def is_first_step(self) -> bool:
"""
Return True if this is the first execution of this module, i.e. if a
sufficient data object from a first run has not been provided.
@@ -692,7 +703,7 @@ class ACMECertificateClient:
# stored in self.order_uri by the constructor).
return self.order_uri is None
def _get_cert_info_or_none(self):
def _get_cert_info_or_none(self) -> CertificateInformation | None:
if self.module.params.get("dest"):
filename = self.module.params["dest"]
else:
@@ -701,7 +712,7 @@ class ACMECertificateClient:
return None
return self.client.backend.get_cert_information(cert_filename=filename)
def start_challenges(self):
def start_challenges(self) -> None:
"""
Create new authorizations for all identifiers of the CSR,
respectively start a new order for ACME v2.
@@ -733,13 +744,16 @@ class ACMECertificateClient:
self.authorizations.update(self.order.authorizations)
self.changed = True
def get_challenges_data(self, first_step):
def get_challenges_data(
self, first_step: bool
) -> tuple[dict[str, t.Any], dict[str, list[str]]]:
"""
Get challenge details for the chosen challenge type.
Return a tuple of generic challenge details, and specialized DNS challenge details.
"""
# Get general challenge data
data = {}
assert self.authorizations is not None
data: dict[str, t.Any] = {}
data_dns: dict[str, list[str]] = {}
for type_identifier, authz in self.authorizations.items():
identifier_type, identifier = split_identifier(type_identifier)
# Skip valid authentications: their challenges are already valid
@@ -747,7 +761,9 @@ class ACMECertificateClient:
if authz.status == "valid":
continue
# We drop the type from the key to preserve backwards compatibility
data[authz.identifier] = authz.get_challenge_data(self.client)
challenges = authz.get_challenge_data(self.client)
assert authz.identifier is not None
data[authz.identifier] = challenges
if (
first_step
and self.challenge is not None
@@ -756,10 +772,7 @@ class ACMECertificateClient:
raise ModuleFailException(
f"Found no challenge of type '{self.challenge}' for identifier {type_identifier}!"
)
# Get DNS challenge data
data_dns = {}
if self.challenge == "dns-01":
for identifier, challenges in data.items():
if self.challenge == "dns-01":
if self.challenge in challenges:
values = data_dns.get(challenges[self.challenge]["record"])
if values is None:
@@ -768,7 +781,7 @@ class ACMECertificateClient:
values.append(challenges[self.challenge]["resource_value"])
return data, data_dns
def finish_challenges(self):
def finish_challenges(self) -> None:
"""
Verify challenges for all identifiers of the CSR.
"""
@@ -777,6 +790,7 @@ class ACMECertificateClient:
# Step 1: obtain challenge information
# For ACME v2, we obtain the order object by fetching the
# order URI, and extract the information from there.
assert self.order_uri is not None
self.order = Order.from_url(self.client, self.order_uri)
self.order.load_authorizations(self.client)
self.authorizations.update(self.order.authorizations)
@@ -799,7 +813,9 @@ class ACMECertificateClient:
# Step 3: wait for authzs to validate
wait_for_validation(authzs_to_wait_for, self.client)
def download_alternate_chains(self, cert):
def download_alternate_chains(
self, cert: CertificateChain
) -> list[CertificateChain]:
alternate_chains = []
for alternate in cert.alternates:
try:
@@ -812,7 +828,9 @@ class ACMECertificateClient:
alternate_chains.append(alt_cert)
return alternate_chains
def find_matching_chain(self, chains):
def find_matching_chain(
self, chains: t.Iterable[CertificateChain]
) -> CertificateChain | None:
for criterium_idx, matcher in enumerate(self.select_chain_matcher):
for chain in chains:
if matcher.match(chain):
@@ -822,12 +840,13 @@ class ACMECertificateClient:
return chain
return None
def get_certificate(self):
def get_certificate(self) -> None:
"""
Request a new certificate and write it to the destination file.
First verifies whether all authorizations are valid; if not, aborts
with an error.
"""
assert self.authorizations is not None
for identifier_type, identifier in self.identifiers:
authz = self.authorizations.get(
normalize_combined_identifier(
@@ -844,7 +863,9 @@ class ACMECertificateClient:
module=self.module,
)
assert self.order is not None
self.order.finalize(self.client, pem_to_der(self.csr, self.csr_content))
assert self.order.certificate_uri is not None
cert = CertificateChain.download(self.client, self.order.certificate_uri)
if self.module.params["retrieve_all_alternates"] or self.select_chain_matcher:
# Retrieve alternate chains
@@ -887,12 +908,13 @@ class ACMECertificateClient:
):
self.changed = True
def deactivate_authzs(self):
def deactivate_authzs(self) -> None:
"""
Deactivates all valid authz's. Does not raise exceptions.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
https://tools.ietf.org/html/rfc8555#section-7.5.2
"""
assert self.authorizations is not None
for authz in self.authorizations.values():
try:
authz.deactivate(self.client)
@@ -905,7 +927,7 @@ class ACMECertificateClient:
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=True)
argument_spec.argument_spec["csr"]["aliases"] = ["src"]
argument_spec.update_argspec(
@@ -981,7 +1003,7 @@ def main():
else:
client = ACMECertificateClient(module, backend)
client.cert_days = cert_days
other = dict()
other: dict[str, t.Any] = {}
is_first_step = client.is_first_step()
if is_first_step:
# First run: start challenges / start new order
@@ -998,6 +1020,7 @@ def main():
client.deactivate_authzs()
data, data_dns = client.get_challenges_data(first_step=is_first_step)
auths = dict()
assert client.authorizations is not None
for k, v in client.authorizations.items():
# Remove "type:" from key
auths[v.identifier] = v.to_json()

View File

@@ -51,6 +51,8 @@ EXAMPLES = r"""
RETURN = """#"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
)
@@ -65,7 +67,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
from ansible_collections.community.crypto.plugins.module_utils.acme.orders import Order
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec()
argument_spec.update_argspec(
order_uri=dict(type="str", required=True),

View File

@@ -371,6 +371,8 @@ account_uri:
type: str
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend,
create_default_argspec,
@@ -383,7 +385,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=True)
argument_spec.update_argspec(
deactivate_authzs=dict(type="bool", default=True),

View File

@@ -317,6 +317,8 @@ selected_chain:
returned: always
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend,
create_default_argspec,
@@ -329,7 +331,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.acme.certificates import (
CertificateChain,
)
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=True)
argument_spec.update_argspec(
order_uri=dict(type="str", required=True),
@@ -375,6 +383,7 @@ def main():
or module.params["retrieve_all_alternates"]
)
changed = False
alternate_chains: list[CertificateChain] | None
if order.status == "valid":
# Step 2 and 3: download certificate(s) and chain(s)
cert, alternate_chains = client.download_certificate(

View File

@@ -357,6 +357,8 @@ authorizations_by_status:
returned: always
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend,
create_default_argspec,
@@ -369,7 +371,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=False)
argument_spec.update_argspec(
order_uri=dict(type="str", required=True),
@@ -381,8 +383,8 @@ def main():
try:
client = ACMECertificateClient(module, backend)
order = client.load_order()
authorizations_by_identifier = dict()
authorizations_by_status = {
authorizations_by_identifier: dict[str, dict[str, t.Any]] = {}
authorizations_by_status: dict[str, list[str]] = {
"pending": [],
"invalid": [],
"valid": [],
@@ -392,7 +394,8 @@ def main():
}
for identifier, authz in order.authorizations.items():
authorizations_by_identifier[identifier] = authz.to_json()
authorizations_by_status[authz.status].append(identifier)
if authz.status is not None:
authorizations_by_status[authz.status].append(identifier)
module.exit_json(
changed=False,
account_uri=client.client.account_uri,

View File

@@ -229,6 +229,8 @@ validating_challenges:
returned: always
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend,
create_default_argspec,
@@ -241,7 +243,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
Authorization,
)
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=False)
argument_spec.update_argspec(
order_uri=dict(type="str", required=True),
@@ -271,10 +279,12 @@ def main():
missing_challenge_authzs = [k for k, v in challenges.items() if v is None]
if missing_challenge_authzs:
missing_challenge_authzs = ", ".join(sorted(missing_challenge_authzs))
missing_challenge_authzs_str = ", ".join(
sorted(missing_challenge_authzs)
)
raise ModuleFailException(
"The challenge parameter must be supplied if there are pending authorizations."
f" The following authorizations are pending: {missing_challenge_authzs}"
f" The following authorizations are pending: {missing_challenge_authzs_str}"
)
bad_challenge_authzs = [
@@ -293,11 +303,13 @@ def main():
f"The following authorizations do not support the selected challenges: {authz_challenges_pairs}"
)
def is_pending(authz: Authorization) -> bool:
challenge_name = challenges[authz.combined_identifier]
challenge_obj = authz.find_challenge(challenge_name)
return challenge_obj is not None and challenge_obj.status == "pending"
really_pending_authzs = [
authz
for authz in pending_authzs
if authz.find_challenge(challenges[authz.combined_identifier]).status
== "pending"
authz for authz in pending_authzs if is_pending(authz)
]
# Step 4: validate pending authorizations
@@ -320,7 +332,7 @@ def main():
identifier_type=authz.identifier_type,
authz_url=authz.url,
challenge_type=challenge_type,
challenge_url=challenge.url,
challenge_url=challenge.url if challenge else None,
)
for authz, challenge_type, challenge in authzs_with_challenges_to_wait_for
],

View File

@@ -160,6 +160,7 @@ cert_id:
import os
import random
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
ACMEClient,
@@ -175,7 +176,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_account=False)
argument_spec.update_argspec(
certificate_path=dict(type="path"),
@@ -190,7 +191,7 @@ def main():
treat_parsing_error_as_non_existing=dict(type="bool", default=False),
)
argument_spec.update(
mutually_exclusive=(["certificate_path", "certificate_content"],),
mutually_exclusive=[("certificate_path", "certificate_content")],
)
module = argument_spec.create_ansible_module(supports_check_mode=True)
backend = create_backend(module, True)
@@ -203,7 +204,7 @@ def main():
supports_ari=False,
)
def complete(should_renew, **kwargs):
def complete(should_renew: bool, **kwargs) -> t.NoReturn:
result["should_renew"] = should_renew
result.update(kwargs)
module.exit_json(**result)

View File

@@ -110,6 +110,8 @@ EXAMPLES = r"""
RETURN = """#"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
)
@@ -129,7 +131,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(require_account_key=False)
argument_spec.update_argspec(
private_key_src=dict(type="path"),
@@ -139,22 +141,22 @@ def main():
revoke_reason=dict(type="int"),
)
argument_spec.update(
required_one_of=(
[
required_one_of=[
(
"account_key_src",
"account_key_content",
"private_key_src",
"private_key_content",
],
),
mutually_exclusive=(
[
),
],
mutually_exclusive=[
(
"account_key_src",
"account_key_content",
"private_key_src",
"private_key_content",
],
),
),
],
)
module = argument_spec.create_ansible_module()
backend = create_backend(module, False)
@@ -164,9 +166,9 @@ def main():
account = ACMEAccount(client)
# Load certificate
certificate = pem_to_der(module.params.get("certificate"))
certificate = nopad_b64(certificate)
certificate_b64 = nopad_b64(certificate)
# Construct payload
payload = {"certificate": certificate}
payload = {"certificate": certificate_b64}
if module.params.get("revoke_reason") is not None:
payload["reason"] = module.params.get("revoke_reason")
endpoint = client.directory["revokeCert"]

View File

@@ -149,6 +149,7 @@ regular_certificate:
import base64
import datetime
import ipaddress
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_text
@@ -173,10 +174,13 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
try:
import cryptography
import cryptography.hazmat.backends
import cryptography.hazmat.primitives.asymmetric.dh
import cryptography.hazmat.primitives.asymmetric.ec
import cryptography.hazmat.primitives.asymmetric.padding
import cryptography.hazmat.primitives.asymmetric.rsa
import cryptography.hazmat.primitives.asymmetric.utils
import cryptography.hazmat.primitives.asymmetric.x448
import cryptography.hazmat.primitives.asymmetric.x25519
import cryptography.hazmat.primitives.hashes
import cryptography.hazmat.primitives.serialization
import cryptography.x509
@@ -186,7 +190,7 @@ except ImportError:
# Convert byte string to ASN1 encoded octet string
def encode_octet_string(octet_string):
def encode_octet_string(octet_string: bytes) -> bytes:
if len(octet_string) >= 128:
raise ModuleFailException(
"Cannot handle octet strings with more than 128 bytes"
@@ -194,7 +198,7 @@ def encode_octet_string(octet_string):
return bytes([0x4, len(octet_string)]) + octet_string
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
challenge=dict(type="str", required=True, choices=["tls-alpn-01"]),
@@ -213,16 +217,16 @@ def main():
try:
# Get parameters
challenge = module.params["challenge"]
challenge_data = module.params["challenge_data"]
challenge: t.Literal["tls-alpn-01"] = module.params["challenge"]
challenge_data: dict[str, t.Any] = module.params["challenge_data"]
# Get hold of private key
private_key_content = module.params.get("private_key_content")
private_key_passphrase = module.params.get("private_key_passphrase")
if private_key_content is None:
private_key_content_str: str | None = module.params["private_key_content"]
private_key_passphrase: str | None = module.params["private_key_passphrase"]
if private_key_content_str is None:
private_key_content = read_file(module.params["private_key_src"])
else:
private_key_content = to_bytes(private_key_content)
private_key_content = to_bytes(private_key_content_str)
try:
private_key = (
cryptography.hazmat.primitives.serialization.load_pem_private_key(
@@ -236,6 +240,17 @@ def main():
)
except Exception as e:
raise ModuleFailException(f"Error while loading private key: {e}")
if isinstance(
private_key,
(
cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
),
):
raise ModuleFailException(
f"Cannot use private key type {type(private_key)}"
)
# Some common attributes
domain = to_text(challenge_data["resource"])
@@ -246,6 +261,7 @@ def main():
now = get_now_datetime(with_timezone=CRYPTOGRAPHY_TIMEZONE)
not_valid_before = now
not_valid_after = now + datetime.timedelta(days=10)
san: cryptography.x509.GeneralName
if identifier_type == "dns":
san = cryptography.x509.DNSName(identifier)
elif identifier_type == "ip":

View File

@@ -223,6 +223,8 @@ output_json:
- '...'
"""
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
ACMEClient,
@@ -235,7 +237,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(require_account_key=False)
argument_spec.update_argspec(
url=dict(type="str"),
@@ -246,17 +248,17 @@ def main():
fail_on_acme_error=dict(type="bool", default=True),
)
argument_spec.update(
required_if=(
["method", "get", ["url"]],
["method", "post", ["url", "content"]],
["method", "get", ["account_key_src", "account_key_content"], True],
["method", "post", ["account_key_src", "account_key_content"], True],
),
required_if=[
("method", "get", ["url"]),
("method", "post", ["url", "content"]),
("method", "get", ["account_key_src", "account_key_content"], True),
("method", "post", ["account_key_src", "account_key_content"], True),
],
)
module = argument_spec.create_ansible_module()
backend = create_backend(module, False)
result = dict()
result: dict[str, t.Any] = {}
changed = False
try:
# Get hold of ACMEClient and ACMEAccount objects (includes directory)

View File

@@ -121,6 +121,7 @@ complete_chain:
"""
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes
@@ -153,14 +154,18 @@ class Certificate:
Stores PEM with parsed certificate.
"""
def __init__(self, pem, cert):
def __init__(self, pem: str, cert: cryptography.x509.Certificate) -> None:
if not (pem.endswith("\n") or pem.endswith("\r")):
pem = pem + "\n"
self.pem = pem
self.cert = cert
def is_parent(module, cert, potential_parent):
def is_parent(
module: AnsibleModule,
cert: Certificate,
potential_parent: Certificate,
) -> bool:
"""
Tests whether the given certificate has been issued by the potential parent certificate.
"""
@@ -173,6 +178,10 @@ def is_parent(module, cert, potential_parent):
if isinstance(
public_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey
):
if cert.cert.signature_hash_algorithm is None:
raise AssertionError(
"signature_hash_algorithm should be present for RSA certificates"
)
public_key.verify(
cert.cert.signature,
cert.cert.tbs_certificate_bytes,
@@ -183,6 +192,10 @@ def is_parent(module, cert, potential_parent):
public_key,
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey,
):
if cert.cert.signature_hash_algorithm is None:
raise AssertionError(
"signature_hash_algorithm should be present for EC certificates"
)
public_key.verify(
cert.cert.signature,
cert.cert.tbs_certificate_bytes,
@@ -213,11 +226,16 @@ def is_parent(module, cert, potential_parent):
module.fail_json(msg=f"Unknown error on signature validation: {e}")
def parse_PEM_list(module, text, source, fail_on_error=True):
def parse_PEM_list(
module: AnsibleModule,
text: str,
source: str | os.PathLike,
fail_on_error: bool = True,
) -> list[Certificate]:
"""
Parse concatenated PEM certificates. Return list of ``Certificate`` objects.
"""
result = []
result: list[Certificate] = []
for cert_pem in split_pem_list(text):
# Try to load PEM certificate
try:
@@ -232,7 +250,9 @@ def parse_PEM_list(module, text, source, fail_on_error=True):
return result
def load_PEM_list(module, path, fail_on_error=True):
def load_PEM_list(
module: AnsibleModule, path: str | os.PathLike, fail_on_error: bool = True
) -> list[Certificate]:
"""
Load concatenated PEM certificates from file. Return list of ``Certificate`` objects.
"""
@@ -258,13 +278,15 @@ class CertificateSet:
Stores a set of certificates. Allows to search for parent (issuer of a certificate).
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.certificates = set()
self.certificates_by_issuer = dict()
self.certificate_by_cert = dict()
self.certificates: set[Certificate] = set()
self.certificates_by_issuer: dict[cryptography.x509.Name, list[Certificate]] = (
{}
)
self.certificate_by_cert: dict[cryptography.x509.Certificate, Certificate] = {}
def _load_file(self, path):
def _load_file(self, path: str | os.PathLike) -> None:
certs = load_PEM_list(self.module, path, fail_on_error=False)
for cert in certs:
self.certificates.add(cert)
@@ -273,7 +295,7 @@ class CertificateSet:
self.certificates_by_issuer[cert.cert.subject].append(cert)
self.certificate_by_cert[cert.cert] = cert
def load(self, path):
def load(self, path: str | os.PathLike) -> None:
"""
Load lists of PEM certificates from a file or a directory.
"""
@@ -285,7 +307,7 @@ class CertificateSet:
else:
self._load_file(b_path)
def find_parent(self, cert):
def find_parent(self, cert: Certificate) -> Certificate | None:
"""
Search for the parent (issuer) of a certificate. Return ``None`` if none was found.
"""
@@ -296,14 +318,18 @@ class CertificateSet:
return None
def format_cert(cert):
def format_cert(cert: Certificate) -> str:
"""
Return human readable representation of certificate for error messages.
"""
return str(cert.cert)
def check_cycle(module, occured_certificates, next):
def check_cycle(
module: AnsibleModule,
occured_certificates: set[cryptography.x509.Certificate],
next: Certificate,
) -> None:
"""
Make sure that next is not in occured_certificates so far, and add it.
"""
@@ -313,7 +339,7 @@ def check_cycle(module, occured_certificates, next):
occured_certificates.add(next_cert)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
input_chain=dict(type="str", required=True),
@@ -354,10 +380,10 @@ def main():
roots.load(path)
# Try to complete chain
current = chain[-1]
current: Certificate | None = chain[-1]
completed = []
occured_certificates = set([cert.cert for cert in chain])
if current.cert in roots.certificate_by_cert:
if current and current.cert in roots.certificate_by_cert:
# Do not try to complete the chain when it is already ending with a root certificate
current = None
while current:

View File

@@ -152,10 +152,13 @@ openssl:
"""
import traceback
import typing as t
from ansible.module_utils.basic import AnsibleModule
CRYPTOGRAPHY_VERSION: str | None
CRYPTOGRAPHY_IMP_ERR: str | None
try:
import cryptography
from cryptography.exceptions import UnsupportedAlgorithm
@@ -165,10 +168,10 @@ try:
# only got added in 0.2, so let's guard the import
from cryptography.exceptions import InternalError as CryptographyInternalError
except ImportError:
CryptographyInternalError = Exception
CryptographyInternalError = Exception # type: ignore
except ImportError:
UnsupportedAlgorithm = Exception
CryptographyInternalError = Exception
UnsupportedAlgorithm = Exception # type: ignore
CryptographyInternalError = Exception # type: ignore
HAS_CRYPTOGRAPHY = False
CRYPTOGRAPHY_VERSION = None
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -201,8 +204,8 @@ CURVES = (
)
def add_crypto_information(module):
result = {}
def add_crypto_information(module: AnsibleModule) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
result["python_cryptography_installed"] = HAS_CRYPTOGRAPHY
if not HAS_CRYPTOGRAPHY:
result["python_cryptography_import_error"] = CRYPTOGRAPHY_IMP_ERR
@@ -397,9 +400,9 @@ def add_crypto_information(module):
return result
def add_openssl_information(module):
def add_openssl_information(module: AnsibleModule) -> dict[str, t.Any]:
openssl_binary = module.get_bin_path("openssl")
result = {
result: dict[str, t.Any] = {
"openssl_present": openssl_binary is not None,
}
if openssl_binary is None:
@@ -426,9 +429,9 @@ INFO_FUNCTIONS = (
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(argument_spec={}, supports_check_mode=True)
result = {}
result: dict[str, t.Any] = {}
for fn in INFO_FUNCTIONS:
result.update(fn(module))
module.exit_json(**result)

View File

@@ -550,6 +550,7 @@ import datetime
import os
import re
import time
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes
@@ -572,7 +573,7 @@ from ansible_collections.community.crypto.plugins.module_utils.io import write_f
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
def validate_cert_expiry(cert_expiry):
def validate_cert_expiry(cert_expiry: str) -> bool:
search_string_partial = re.compile(
r"^([0-9]+)-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01])\Z"
)
@@ -587,7 +588,7 @@ def validate_cert_expiry(cert_expiry):
return False
def calculate_cert_days(expires_after):
def calculate_cert_days(expires_after: str | None) -> int:
cert_days = 0
if expires_after:
expires_after_datetime = datetime.datetime.strptime(
@@ -600,7 +601,9 @@ def calculate_cert_days(expires_after):
# Populate the value of body[dict_param_name] with the JSON equivalent of
# module parameter of param_name if that parameter is present, otherwise leave field
# out of resulting dict
def convert_module_param_to_json_bool(module, dict_param_name, param_name):
def convert_module_param_to_json_bool(
module: AnsibleModule, dict_param_name: str, param_name: str
) -> dict[str, str]:
body = {}
if module.params[param_name] is not None:
if module.params[param_name]:
@@ -886,7 +889,7 @@ class EcsCertificate:
return result
def custom_fields_spec():
def custom_fields_spec() -> dict[str, dict[str, str]]:
return dict(
text1=dict(type="str"),
text2=dict(type="str"),
@@ -926,7 +929,7 @@ def custom_fields_spec():
)
def ecs_certificate_argument_spec():
def ecs_certificate_argument_spec() -> dict[str, dict[str, t.Any]]:
return dict(
backup=dict(type="bool", default=False),
force=dict(type="bool", default=False),
@@ -979,7 +982,7 @@ def ecs_certificate_argument_spec():
)
def main():
def main() -> t.NoReturn:
ecs_argument_spec = ecs_client_argument_spec()
ecs_argument_spec.update(ecs_certificate_argument_spec())
module = AnsibleModule(

View File

@@ -218,6 +218,7 @@ ev_days_remaining:
import datetime
import time
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.ecs.api import (
@@ -228,7 +229,7 @@ from ansible_collections.community.crypto.plugins.module_utils.ecs.api import (
)
def calculate_days_remaining(expiry_date):
def calculate_days_remaining(expiry_date: str | None) -> int | None:
days_remaining = None
if expiry_date:
expiry_datetime = datetime.datetime.strptime(expiry_date, "%Y-%m-%dT%H:%M:%SZ")
@@ -403,8 +404,8 @@ class EcsDomain:
msg=f"Failed to request domain validation from Entrust (ECS) {e.message}"
)
def dump(self):
result = {
def dump(self) -> dict[str, t.Any]:
result: dict[str, t.Any] = {
"changed": self.changed,
"client_id": self.client_id,
"domain_status": self.domain_status,
@@ -436,7 +437,7 @@ class EcsDomain:
return result
def ecs_domain_argument_spec():
def ecs_domain_argument_spec() -> dict[str, dict[str, t.Any]]:
return dict(
client_id=dict(type="int", default=1),
domain_name=dict(type="str", required=True),
@@ -447,7 +448,7 @@ def ecs_domain_argument_spec():
)
def main():
def main() -> t.NoReturn:
ecs_argument_spec = ecs_client_argument_spec()
ecs_argument_spec.update(ecs_domain_argument_spec())
module = AnsibleModule(

View File

@@ -268,6 +268,7 @@ import atexit
import base64
import ssl
import sys
import typing as t
from os.path import isfile
from socket import create_connection, setdefaulttimeout, socket
from ssl import (
@@ -305,7 +306,7 @@ except ImportError:
pass
def send_starttls_packet(sock, server_type):
def send_starttls_packet(sock: socket, server_type: t.Literal["mysql"]) -> None:
if server_type == "mysql":
ssl_request_packet = (
b"\x20\x00\x00\x01\x85\xae\x7f\x00"
@@ -321,7 +322,7 @@ def send_starttls_packet(sock, server_type):
sock.send(ssl_request_packet)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
ca_cert=dict(type="path"),
@@ -342,18 +343,18 @@ def main():
),
)
ca_cert = module.params.get("ca_cert")
host = module.params.get("host")
port = module.params.get("port")
proxy_host = module.params.get("proxy_host")
proxy_port = module.params.get("proxy_port")
timeout = module.params.get("timeout")
server_name = module.params.get("server_name")
start_tls_server_type = module.params.get("starttls")
ciphers = module.params.get("ciphers")
asn1_base64 = module.params["asn1_base64"]
tls_ctx_options = module.params["tls_ctx_options"]
get_certificate_chain = module.params["get_certificate_chain"]
ca_cert: str | None = module.params.get("ca_cert")
host: str = module.params.get("host")
port: int = module.params.get("port")
proxy_host: str | None = module.params.get("proxy_host")
proxy_port: int | None = module.params.get("proxy_port")
timeout: int = module.params.get("timeout")
server_name: str | None = module.params.get("server_name")
start_tls_server_type: t.Literal["mysql"] | None = module.params.get("starttls")
ciphers: list[str] | None = module.params.get("ciphers")
asn1_base64: bool = module.params["asn1_base64"]
tls_ctx_options: list[str | bytes | int] | None = module.params["tls_ctx_options"]
get_certificate_chain: bool = module.params["get_certificate_chain"]
if get_certificate_chain and sys.version_info < (3, 10):
module.fail_json(
@@ -365,9 +366,9 @@ def main():
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
result = dict(
changed=False,
)
result: dict[str, t.Any] = {
"changed": False,
}
if timeout:
setdefaulttimeout(timeout)
@@ -409,7 +410,7 @@ def main():
if tls_ctx_options is not None:
# Clear default ctx options
ctx.options = 0
ctx.options = 0 # type: ignore
# For each item in the tls_ctx_options list
for tls_ctx_option in tls_ctx_options:
@@ -450,8 +451,10 @@ def main():
)
tls_sock = ctx.wrap_socket(sock, server_hostname=server_name or host)
cert = tls_sock.getpeercert(True)
cert = DER_cert_to_PEM_cert(cert)
cert_der = tls_sock.getpeercert(True)
if cert_der is None:
raise Exception("Unexpected error: no peer certificate has been returned")
cert: str = DER_cert_to_PEM_cert(cert_der)
if get_certificate_chain:
if sys.version_info < (3, 13):
@@ -474,7 +477,7 @@ def main():
# Python 3.13 do not return lists of byte strings, but lists of _ssl.Certificate objects. This is going to
# be fixed by https://github.com/python/cpython/pull/118669. For now we convert the certificates ourselves
# if they are not byte strings to work around this.
def _convert_chain(chain):
def _convert_chain(chain: list[bytes]) -> list[bytes]:
return [
(
c
@@ -514,13 +517,13 @@ def main():
result["extensions"] = []
for dotted_number, entry in cryptography_get_extensions_from_cert(x509).items():
oid = cryptography.x509.oid.ObjectIdentifier(dotted_number)
ext = {
ext: dict[str, t.Any] = {
"critical": entry["critical"],
"asn1_data": entry["value"],
"name": cryptography_oid_to_name(oid, short=True),
}
if not asn1_base64:
ext["asn1_data"] = base64.b64decode(ext["asn1_data"])
ext["asn1_data"] = base64.b64decode(entry["value"]) # type: ignore
result["extensions"].append(ext)
result["issuer"] = {}

View File

@@ -420,16 +420,13 @@ name:
import os
import re
import stat
import typing as t
from base64 import b64decode
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_native
RETURN_CODE = 0
STDOUT = 1
STDERR = 2
# used to get <luks-name> out of lsblk output in format 'crypt <luks-name>'
# regex takes care of any possible blank characters
LUKS_NAME_REGEX = re.compile(r"^crypt\s+([^\s]*)\s*$")
@@ -456,7 +453,7 @@ LUKS2_HEADER_OFFSETS = [
LUKS2_HEADER2 = b"SKUL\xba\xbe"
def wipe_luks_headers(device):
def wipe_luks_headers(device: str) -> None:
wipe_offsets = []
with open(device, "rb") as f:
# f.seek(0)
@@ -478,12 +475,12 @@ def wipe_luks_headers(device):
class Handler:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self._module = module
self._lsblk_bin = self._module.get_bin_path("lsblk", True)
self._passphrase_encoding = module.params["passphrase_encoding"]
def get_passphrase_from_module_params(self, parameter_name):
def get_passphrase_from_module_params(self, parameter_name: str) -> bytes | None:
passphrase = self._module.params[parameter_name]
if passphrase is None:
return None
@@ -496,88 +493,91 @@ class Handler:
f"Error while base64-decoding '{parameter_name}': {exc}"
)
def _run_command(self, command, data=None):
def _run_command(
self, command: list[str], data: bytes | None = None
) -> tuple[int, str, str]:
return self._module.run_command(command, data=data, binary_data=True)
def get_device_by_uuid(self, uuid):
def get_device_by_uuid(self, uuid: str | None) -> str | None:
"""Returns the device that holds UUID passed by user"""
self._blkid_bin = self._module.get_bin_path("blkid", True)
uuid = self._module.params["uuid"]
if uuid is None:
return None
result = self._run_command([self._blkid_bin, "--uuid", uuid])
if result[RETURN_CODE] != 0:
rc, stdout, dummy = self._run_command([self._blkid_bin, "--uuid", uuid])
if rc != 0:
return None
return result[STDOUT].strip()
return stdout.strip()
def get_device_by_label(self, label):
def get_device_by_label(self, label: str) -> str | None:
"""Returns the device that holds label passed by user"""
self._blkid_bin = self._module.get_bin_path("blkid", True)
label = self._module.params["label"]
if label is None:
return None
result = self._run_command([self._blkid_bin, "--label", label])
if result[RETURN_CODE] != 0:
rc, stdout, dummy = self._run_command([self._blkid_bin, "--label", label])
if rc != 0:
return None
return result[STDOUT].strip()
return stdout.strip()
def generate_luks_name(self, device):
def generate_luks_name(self, device: str) -> str:
"""Generate name for luks based on device UUID ('luks-<UUID>').
Raises ValueError when obtaining of UUID fails.
"""
result = self._run_command([self._lsblk_bin, "-n", device, "-o", "UUID"])
rc, stdout, stderr = self._run_command(
[self._lsblk_bin, "-n", device, "-o", "UUID"]
)
if result[RETURN_CODE] != 0:
raise ValueError(
f"Error while generating LUKS name for {device}: {result[STDERR]}"
)
dev_uuid = result[STDOUT].strip()
if rc != 0:
raise ValueError(f"Error while generating LUKS name for {device}: {stderr}")
dev_uuid = stdout.strip()
return f"luks-{dev_uuid}"
class CryptHandler(Handler):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(CryptHandler, self).__init__(module)
self._cryptsetup_bin = self._module.get_bin_path("cryptsetup", True)
def get_container_name_by_device(self, device):
def get_container_name_by_device(self, device: str) -> str | None:
"""obtain LUKS container name based on the device where it is located
return None if not found
raise ValueError if lsblk command fails
"""
result = self._run_command([self._lsblk_bin, device, "-nlo", "type,name"])
if result[RETURN_CODE] != 0:
raise ValueError(
f"Error while obtaining LUKS name for {device}: {result[STDERR]}"
)
rc, stdout, stderr = self._run_command(
[self._lsblk_bin, device, "-nlo", "type,name"]
)
if rc != 0:
raise ValueError(f"Error while obtaining LUKS name for {device}: {stderr}")
for line in result[STDOUT].splitlines(False):
for line in stdout.splitlines(False):
m = LUKS_NAME_REGEX.match(line)
if m:
return m.group(1)
return None
def get_container_device_by_name(self, name):
def get_container_device_by_name(self, name: str) -> str | None:
"""obtain device name based on the LUKS container name
return None if not found
raise ValueError if lsblk command fails
"""
# apparently each device can have only one LUKS container on it
result = self._run_command([self._cryptsetup_bin, "status", name])
if result[RETURN_CODE] != 0:
rc, stdout, dummy = self._run_command([self._cryptsetup_bin, "status", name])
if rc != 0:
return None
m = LUKS_DEVICE_REGEX.search(result[STDOUT])
m = LUKS_DEVICE_REGEX.search(stdout)
if not m:
return None
device = m.group(1)
return device
def is_luks(self, device):
def is_luks(self, device: str) -> bool:
"""check if the LUKS container does exist"""
result = self._run_command([self._cryptsetup_bin, "isLuks", device])
return result[RETURN_CODE] == 0
rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "isLuks", device])
return rc == 0
def get_luks_type(self, device):
def get_luks_type(self, device: str) -> t.Literal["luks1", "luks2"] | None:
"""get the luks type of a device"""
if self.is_luks(device):
with open(device, "rb") as f:
@@ -589,16 +589,18 @@ class CryptHandler(Handler):
return "luks1"
return None
def is_luks_slot_set(self, device, keyslot):
def is_luks_slot_set(self, device: str, keyslot: int) -> bool:
"""check if a keyslot is set"""
result = self._run_command([self._cryptsetup_bin, "luksDump", device])
if result[RETURN_CODE] != 0:
rc, stdout, dummy = self._run_command(
[self._cryptsetup_bin, "luksDump", device]
)
if rc != 0:
raise ValueError(f"Error while dumping LUKS header from {device}")
result_luks1 = f"Key Slot {keyslot}: ENABLED" in result[STDOUT]
result_luks2 = f" {keyslot}: luks2" in result[STDOUT]
result_luks1 = f"Key Slot {keyslot}: ENABLED" in stdout
result_luks2 = f" {keyslot}: luks2" in stdout
return result_luks1 or result_luks2
def _add_pbkdf_options(self, options, pbkdf):
def _add_pbkdf_options(self, options: list[str], pbkdf: dict[str, t.Any]) -> None:
if pbkdf["iteration_time"] is not None:
options.extend(["--iter-time", str(int(pbkdf["iteration_time"] * 1000))])
if pbkdf["iteration_count"] is not None:
@@ -612,16 +614,16 @@ class CryptHandler(Handler):
def run_luks_create(
self,
device,
keyfile,
passphrase,
keyslot,
keysize,
cipher,
hash_,
sector_size,
pbkdf,
):
device: str,
keyfile: str | None,
passphrase: bytes | None,
keyslot: int | None,
keysize: int | None,
cipher: str | None,
hash_: str | None,
sector_size: str | None,
pbkdf: dict[str, t.Any] | None,
) -> None:
# create a new luks container; use batch mode to auto confirm
luks_type = self._module.params["type"]
label = self._module.params["label"]
@@ -653,23 +655,23 @@ class CryptHandler(Handler):
else:
args.append("-")
result = self._run_command(args, data=passphrase)
if result[RETURN_CODE] != 0:
raise ValueError(f"Error while creating LUKS on {device}: {result[STDERR]}")
rc, dummy, stderr = self._run_command(args, data=passphrase)
if rc != 0:
raise ValueError(f"Error while creating LUKS on {device}: {stderr}")
def run_luks_open(
self,
device,
keyfile,
passphrase,
perf_same_cpu_crypt,
perf_submit_from_crypt_cpus,
perf_no_read_workqueue,
perf_no_write_workqueue,
persistent,
allow_discards,
name,
):
device: str,
keyfile: str | None,
passphrase: bytes | None,
perf_same_cpu_crypt: bool,
perf_submit_from_crypt_cpus: bool,
perf_no_read_workqueue: bool,
perf_no_write_workqueue: bool,
persistent: bool,
allow_discards: bool,
name: str,
) -> None:
args = [self._cryptsetup_bin]
if keyfile:
args.extend(["--key-file", keyfile])
@@ -689,27 +691,27 @@ class CryptHandler(Handler):
args.extend(["--allow-discards"])
args.extend(["open", "--type", "luks", device, name])
result = self._run_command(args, data=passphrase)
if result[RETURN_CODE] != 0:
rc, dummy, stderr = self._run_command(args, data=passphrase)
if rc != 0:
raise ValueError(
f"Error while opening LUKS container on {device}: {result[STDERR]}"
f"Error while opening LUKS container on {device}: {stderr}"
)
def run_luks_close(self, name):
result = self._run_command([self._cryptsetup_bin, "close", name])
if result[RETURN_CODE] != 0:
def run_luks_close(self, name: str) -> None:
rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "close", name])
if rc != 0:
raise ValueError(f"Error while closing LUKS container {name}")
def run_luks_remove(self, device):
def run_luks_remove(self, device: str) -> None:
wipefs_bin = self._module.get_bin_path("wipefs", True)
name = self.get_container_name_by_device(device)
if name is not None:
self.run_luks_close(name)
result = self._run_command([wipefs_bin, "--all", device])
if result[RETURN_CODE] != 0:
rc, dummy, stderr = self._run_command([wipefs_bin, "--all", device])
if rc != 0:
raise ValueError(
f"Error while wiping LUKS container signatures for {device}: {result[STDERR]}"
f"Error while wiping LUKS container signatures for {device}: {stderr}"
)
# For LUKS2, sometimes both `cryptsetup erase` and `wipefs` do **not**
@@ -724,14 +726,14 @@ class CryptHandler(Handler):
def run_luks_add_key(
self,
device,
keyfile,
passphrase,
new_keyfile,
new_passphrase,
new_keyslot,
pbkdf,
):
device: str,
keyfile: str | None,
passphrase: bytes | None,
new_keyfile: str | None,
new_passphrase: bytes | None,
new_keyslot: int | None,
pbkdf: dict[str, t.Any] | None,
) -> None:
"""Add new key from a keyfile or passphrase to given 'device';
authentication done using 'keyfile' or 'passphrase'.
Raises ValueError when command fails.
@@ -746,36 +748,47 @@ class CryptHandler(Handler):
if keyfile:
args.extend(["--key-file", keyfile])
else:
elif passphrase is not None:
args.extend(["--key-file", "-", "--keyfile-size", str(len(passphrase))])
data.append(passphrase)
else:
raise ValueError("Need passphrase or keyfile")
if new_keyfile:
args.append(new_keyfile)
else:
elif new_passphrase is not None:
args.append("-")
data.append(new_passphrase)
else:
raise ValueError("Need new passphrase or new keyfile")
result = self._run_command(args, data=b"".join(data) or None)
if result[RETURN_CODE] != 0:
rc, dummy, stderr = self._run_command(args, data=b"".join(data) or None)
if rc != 0:
raise ValueError(
f"Error while adding new LUKS keyslot to {device}: {result[STDERR]}"
f"Error while adding new LUKS keyslot to {device}: {stderr}"
)
def run_luks_remove_key(
self, device, keyfile, passphrase, keyslot, force_remove_last_key=False
):
self,
device: str,
keyfile: str | None,
passphrase: bytes | None,
keyslot: int | None,
force_remove_last_key: bool = False,
) -> None:
"""Remove key from given device
Raises ValueError when command fails
"""
if not force_remove_last_key:
result = self._run_command([self._cryptsetup_bin, "luksDump", device])
if result[RETURN_CODE] != 0:
rc, stdout, dummy = self._run_command(
[self._cryptsetup_bin, "luksDump", device]
)
if rc != 0:
raise ValueError(f"Error while dumping LUKS header from {device}")
keyslot_count = 0
keyslot_area = False
keyslot_re = re.compile(r"^Key Slot [0-9]+: ENABLED")
for line in result[STDOUT].splitlines():
for line in stdout.splitlines():
if line.startswith("Keyslots:"):
keyslot_area = True
elif line.startswith(" "):
@@ -808,13 +821,17 @@ class CryptHandler(Handler):
# Since we supply -q no passphrase is needed
args = [self._cryptsetup_bin, "luksKillSlot", device, "-q", str(keyslot)]
passphrase = None
result = self._run_command(args, data=passphrase)
if result[RETURN_CODE] != 0:
raise ValueError(
f"Error while removing LUKS key from {device}: {result[STDERR]}"
)
rc, dummy, stderr = self._run_command(args, data=passphrase)
if rc != 0:
raise ValueError(f"Error while removing LUKS key from {device}: {stderr}")
def luks_test_key(self, device, keyfile, passphrase, keyslot=None):
def luks_test_key(
self,
device: str,
keyfile: str | None,
passphrase: bytes | None,
keyslot: int | None = None,
) -> bool:
"""Check whether the keyfile or passphrase works.
Raises ValueError when command fails.
"""
@@ -830,42 +847,37 @@ class CryptHandler(Handler):
if keyslot is not None:
args.extend(["--key-slot", str(keyslot)])
result = self._run_command(args, data=data)
if result[RETURN_CODE] == 0:
rc, stdout, stderr = self._run_command(args, data=data)
if rc == 0:
return True
for output in (STDOUT, STDERR):
if "No key available with this passphrase" in result[output]:
for output in (stdout, stderr):
if "No key available with this passphrase" in output:
return False
if "No usable keyslot is available." in result[output]:
if "No usable keyslot is available." in output:
return False
# This check is necessary due to cryptsetup in version 2.0.3 not printing 'No usable keyslot is available'
# when using the --key-slot parameter in combination with --test-passphrase
if (
result[RETURN_CODE] == 1
and keyslot is not None
and result[STDOUT] == ""
and result[STDERR] == ""
):
if rc == 1 and keyslot is not None and stdout == "" and stderr == "":
return False
raise ValueError(
f"Error while testing whether keyslot exists on {device}: {result[STDERR]}"
f"Error while testing whether keyslot exists on {device}: {stderr}"
)
class ConditionsHandler(Handler):
def __init__(self, module, crypthandler):
def __init__(self, module: AnsibleModule, crypthandler: CryptHandler) -> None:
super(ConditionsHandler, self).__init__(module)
self._crypthandler = crypthandler
self.device = self.get_device_name()
def get_device_name(self):
device = self._module.params.get("device")
label = self._module.params.get("label")
uuid = self._module.params.get("uuid")
name = self._module.params.get("name")
def get_device_name(self) -> str | None:
device: str | None = self._module.params.get("device")
label: str | None = self._module.params.get("label")
uuid: str | None = self._module.params.get("uuid")
name: str | None = self._module.params.get("name")
if device is None and label is not None:
device = self.get_device_by_label(label)
@@ -876,7 +888,7 @@ class ConditionsHandler(Handler):
return device
def luks_create(self):
def luks_create(self) -> bool:
return (
self.device is not None
and (
@@ -887,7 +899,7 @@ class ConditionsHandler(Handler):
and not self._crypthandler.is_luks(self.device)
)
def opened_luks_name(self):
def opened_luks_name(self, device: str) -> str | None:
"""If luks is already opened, return its name.
If 'name' parameter is specified and differs
from obtained value, fail.
@@ -897,7 +909,7 @@ class ConditionsHandler(Handler):
return None
# try to obtain luks name - it may be already opened
name = self._crypthandler.get_container_name_by_device(self.device)
name = self._crypthandler.get_container_name_by_device(device)
if name is None:
# container is not open
@@ -917,7 +929,7 @@ class ConditionsHandler(Handler):
# container is opened and the names match
return name
def luks_open(self):
def luks_open(self) -> bool:
if (
(
self._module.params["keyfile"] is None
@@ -929,13 +941,13 @@ class ConditionsHandler(Handler):
# conditions for open not fulfilled
return False
name = self.opened_luks_name()
name = self.opened_luks_name(self.device)
if name is None:
return True
return False
def luks_close(self):
def luks_close(self) -> bool:
if (
self._module.params["name"] is None and self.device is None
) or self._module.params["state"] != "closed":
@@ -948,15 +960,17 @@ class ConditionsHandler(Handler):
luks_is_open = name is not None
if self._module.params["name"] is not None:
self.device = self._crypthandler.get_container_device_by_name(
device = self._crypthandler.get_container_device_by_name(
self._module.params["name"]
)
# successfully getting device based on name means that luks is open
luks_is_open = self.device is not None
luks_is_open = device is not None
if device is not None:
self.device = device
return luks_is_open
def luks_add_key(self):
def luks_add_key(self) -> bool:
if (
self.device is None
or (
@@ -995,7 +1009,7 @@ class ConditionsHandler(Handler):
return not key_present
def luks_remove_key(self):
def luks_remove_key(self) -> bool:
if self.device is None or (
self._module.params["remove_keyfile"] is None
and self._module.params["remove_passphrase"] is None
@@ -1037,14 +1051,16 @@ class ConditionsHandler(Handler):
self.get_passphrase_from_module_params("remove_passphrase"),
)
def luks_remove(self):
def luks_remove(self) -> bool:
return (
self.device is not None
and self._module.params["state"] == "absent"
and self._crypthandler.is_luks(self.device)
)
def validate_keyslot(self, param, luks_type):
def validate_keyslot(
self, param: str, luks_type: t.Literal["luks1", "luks2"] | None
) -> None:
if self._module.params[param] is not None:
if luks_type is None and param == "keyslot":
if 8 <= self._module.params[param] <= 31:
@@ -1066,7 +1082,7 @@ class ConditionsHandler(Handler):
)
def run_module():
def run_module() -> t.NoReturn:
# available arguments/parameters that a user can pass
module_args = dict(
state=dict(
@@ -1122,7 +1138,7 @@ def run_module():
]
# seed the result dict in the object
result = dict(changed=False, name=None)
result: dict[str, t.Any] = {"changed": False, "name": None}
module = AnsibleModule(
argument_spec=module_args,
@@ -1142,19 +1158,26 @@ def run_module():
except Exception as e:
module.fail_json(msg=str(e))
crypt = CryptHandler(module)
conditions = ConditionsHandler(module, crypt)
# conditions not allowed to run
if module.params["label"] is not None and module.params["type"] == "luks1":
module.fail_json(msg="You cannot combine type luks1 with the label option.")
crypt = CryptHandler(module)
try:
conditions = ConditionsHandler(module, crypt)
except ValueError as exc:
module.fail_json(msg=str(exc))
if (
module.params["keyslot"] is not None
or module.params["new_keyslot"] is not None
or module.params["remove_keyslot"] is not None
):
luks_type = crypt.get_luks_type(conditions.get_device_name())
luks_type = (
crypt.get_luks_type(conditions.device)
if conditions.device is not None
else None
)
if luks_type is None and module.params["type"] is not None:
luks_type = module.params["type"]
for param in ["keyslot", "new_keyslot", "remove_keyslot"]:
@@ -1175,6 +1198,7 @@ def run_module():
# luks create
if conditions.luks_create():
assert conditions.device # ensured in conditions.luks_create()
if not module.check_mode:
try:
crypt.run_luks_create(
@@ -1196,11 +1220,13 @@ def run_module():
# luks open
name = conditions.opened_luks_name()
if name is not None:
result["name"] = name
if conditions.device is not None:
name = conditions.opened_luks_name(conditions.device)
if name is not None:
result["name"] = name
if conditions.luks_open():
assert conditions.device # ensured in conditions.luks_open()
name = module.params["name"]
if name is None:
try:
@@ -1237,6 +1263,8 @@ def run_module():
module.fail_json(msg=f"luks_device error: {e}")
else:
name = module.params["name"]
if name is None:
module.fail_json(msg="Cannot determine name to close device")
if not module.check_mode:
try:
crypt.run_luks_close(name)
@@ -1249,6 +1277,7 @@ def run_module():
# luks add key
if conditions.luks_add_key():
assert conditions.device # ensured in conditions.luks_add_key()
if not module.check_mode:
try:
crypt.run_luks_add_key(
@@ -1268,6 +1297,7 @@ def run_module():
# luks remove key
if conditions.luks_remove_key():
assert conditions.device # ensured in conditions.luks_remove_key()
if not module.check_mode:
try:
last_key = module.params["force_remove_last_key"]
@@ -1286,6 +1316,7 @@ def run_module():
# luks remove
if conditions.luks_remove():
assert conditions.device # ensured in conditions.luks_remove()
if not module.check_mode:
try:
crypt.run_luks_remove(conditions.device)
@@ -1299,7 +1330,7 @@ def run_module():
module.exit_json(**result)
def main():
def main() -> t.NoReturn:
run_module()

View File

@@ -284,6 +284,7 @@ info:
"""
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.common import (
@@ -302,47 +303,58 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
class Certificate(OpensshModule):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(Certificate, self).__init__(module)
self.ssh_keygen = KeygenCommand(self.module)
self.identifier = self.module.params["identifier"] or ""
self.options = self.module.params["options"] or []
self.path = self.module.params["path"]
self.pkcs11_provider = self.module.params["pkcs11_provider"]
self.principals = self.module.params["principals"] or []
self.public_key = self.module.params["public_key"]
self.regenerate = (
self.identifier: str = self.module.params["identifier"] or ""
self.options: list[str] = self.module.params["options"] or []
self.path: str = self.module.params["path"]
self.pkcs11_provider: str | None = self.module.params["pkcs11_provider"]
self.principals: list[str] = self.module.params["principals"] or []
self.public_key: str | None = self.module.params["public_key"]
self.regenerate: t.Literal[
"never",
"fail",
"partial_idempotence",
"full_idempotence",
"always",
] = (
self.module.params["regenerate"]
if not self.module.params["force"]
else "always"
)
self.serial_number = self.module.params["serial_number"]
self.signature_algorithm = self.module.params["signature_algorithm"]
self.signing_key = self.module.params["signing_key"]
self.state = self.module.params["state"]
self.type = self.module.params["type"]
self.use_agent = self.module.params["use_agent"]
self.valid_at = self.module.params["valid_at"]
self.ignore_timestamps = self.module.params["ignore_timestamps"]
self.serial_number: int | None = self.module.params["serial_number"]
self.signature_algorithm: (
t.Literal["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"] | None
) = self.module.params["signature_algorithm"]
self.signing_key: str | None = self.module.params["signing_key"]
self.state: t.Literal["absent", "present"] = self.module.params["state"]
self.type: t.Literal["host", "user"] | None = self.module.params["type"]
self.use_agent: bool = self.module.params["use_agent"]
self.valid_at: str | None = self.module.params["valid_at"]
self.ignore_timestamps: bool = self.module.params["ignore_timestamps"]
self._check_if_base_dir(self.path)
if self.state == "present":
self._validate_parameters()
self.data = None
self.original_data = None
self.data: OpensshCertificate | None = None
self.original_data: OpensshCertificate | None = None
if self._exists():
self._load_certificate()
self.time_parameters = None
self.time_parameters: OpensshCertificateTimeParameters | None = None
if self.state == "present":
self._set_time_parameters()
def _validate_parameters(self):
def _validate_parameters(self) -> None:
for path in (self.public_key, self.signing_key):
self._check_if_base_dir(path)
if (
path is not None
): # should never be None, but the type checker doesn't know
self._check_if_base_dir(path)
if self.options and self.type == "host":
self.module.fail_json(
@@ -352,7 +364,7 @@ class Certificate(OpensshModule):
if self.use_agent:
self._use_agent_available()
def _use_agent_available(self):
def _use_agent_available(self) -> None:
ssh_version = self._get_ssh_version()
if not ssh_version:
self.module.fail_json(msg="Failed to determine ssh version")
@@ -362,10 +374,10 @@ class Certificate(OpensshModule):
+ f" Your version is: {ssh_version}"
)
def _exists(self):
def _exists(self) -> bool:
return os.path.exists(self.path)
def _load_certificate(self):
def _load_certificate(self) -> None:
try:
self.original_data = OpensshCertificate.load(self.path)
except (TypeError, ValueError) as e:
@@ -373,7 +385,7 @@ class Certificate(OpensshModule):
self.module.fail_json(msg=f"Unable to read existing certificate: {e}")
self.module.warn(f"Unable to read existing certificate: {e}")
def _set_time_parameters(self):
def _set_time_parameters(self) -> None:
try:
self.time_parameters = OpensshCertificateTimeParameters(
valid_from=self.module.params["valid_from"],
@@ -382,7 +394,7 @@ class Certificate(OpensshModule):
except ValueError as e:
self.module.fail_json(msg=str(e))
def _execute(self):
def _execute(self) -> None:
if self.state == "present":
if self._should_generate():
self._generate()
@@ -391,7 +403,7 @@ class Certificate(OpensshModule):
if self._exists():
self._remove()
def _should_generate(self):
def _should_generate(self) -> bool:
if self.regenerate == "never":
return self.original_data is None
elif self.regenerate == "fail":
@@ -408,7 +420,13 @@ class Certificate(OpensshModule):
else:
return True
def _is_fully_valid(self):
def _is_fully_valid(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
if self.public_key is None:
raise AssertionError("Contract violation public_key not provided")
if self.signing_key is None:
raise AssertionError("Contract violation signing_key not provided")
return self._is_partially_valid() and all(
[
self._compare_options() if self.original_data.type == "user" else True,
@@ -420,7 +438,9 @@ class Certificate(OpensshModule):
]
)
def _is_partially_valid(self):
def _is_partially_valid(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
return all(
[
set(self.original_data.principals) == set(self.principals),
@@ -439,7 +459,9 @@ class Certificate(OpensshModule):
]
)
def _compare_time_parameters(self):
def _compare_time_parameters(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
try:
original_time_parameters = OpensshCertificateTimeParameters(
valid_from=self.original_data.valid_after,
@@ -458,7 +480,9 @@ class Certificate(OpensshModule):
]
)
def _compare_options(self):
def _compare_options(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
try:
critical_options, extensions = parse_option_list(self.options)
except ValueError as e:
@@ -471,13 +495,13 @@ class Certificate(OpensshModule):
]
)
def _get_key_fingerprint(self, path):
def _get_key_fingerprint(self, path: str) -> str:
private_key_content = self.ssh_keygen.get_private_key(path, check_rc=True)[1]
return PrivateKey.from_string(private_key_content).fingerprint
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _generate(self):
def _generate(self) -> None:
try:
temp_certificate = self._generate_temp_certificate()
self._safe_secure_move([(temp_certificate, self.path)])
@@ -491,7 +515,14 @@ class Certificate(OpensshModule):
except (TypeError, ValueError) as e:
self.module.fail_json(msg=f"Unable to read new certificate: {e}")
def _generate_temp_certificate(self):
def _generate_temp_certificate(self) -> str:
if self.public_key is None:
raise AssertionError("Contract violation public_key not provided")
if self.signing_key is None:
raise AssertionError("Contract violation signing_key not provided")
if self.time_parameters is None:
raise AssertionError("Contract violation time_parameters not provided")
key_copy = os.path.join(self.module.tmpdir, os.path.basename(self.public_key))
try:
@@ -523,14 +554,14 @@ class Certificate(OpensshModule):
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _remove(self):
def _remove(self) -> None:
try:
os.remove(self.path)
except OSError as e:
self.module.fail_json(msg=f"Unable to remove existing certificate: {e}")
@property
def _result(self):
def _result(self) -> dict[str, t.Any]:
if self.state != "present":
return {}
@@ -546,14 +577,14 @@ class Certificate(OpensshModule):
}
@property
def diff(self):
def diff(self) -> dict[str, t.Any]:
return {
"before": get_cert_dict(self.original_data),
"after": get_cert_dict(self.data),
}
def format_cert_info(cert_info):
def format_cert_info(cert_info: str) -> list[str]:
result = []
string = ""
@@ -579,7 +610,7 @@ def format_cert_info(cert_info):
return result
def get_cert_dict(data):
def get_cert_dict(data: OpensshCertificate | None) -> dict[str, t.Any]:
if data is None:
return {}
@@ -590,7 +621,7 @@ def get_cert_dict(data):
return result
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
force=dict(type="bool", default=False),

View File

@@ -198,13 +198,15 @@ comment:
sample: test@comment
"""
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.keypair_backend import (
select_backend,
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(

View File

@@ -239,6 +239,7 @@ csr:
"""
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
@@ -256,9 +257,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import (
CertificateSigningRequestBackend,
)
class CertificateSigningRequestModule(OpenSSLObject):
def __init__(self, module, module_backend):
def __init__(
self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend
) -> None:
super(CertificateSigningRequestModule, self).__init__(
module.params["path"],
module.params["state"],
@@ -269,11 +279,11 @@ class CertificateSigningRequestModule(OpenSSLObject):
self.return_content = module.params["return_content"]
self.backup = module.params["backup"]
self.backup_file = None
self.backup_file: str | None = None
self.module_backend.set_existing(load_file_if_exists(self.path, module))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Generate the certificate signing request."""
if self.force or self.module_backend.needs_regeneration():
if not self.check_mode:
@@ -292,13 +302,13 @@ class CertificateSigningRequestModule(OpenSSLObject):
file_args, self.changed
)
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
self.module_backend.set_existing(None)
if self.backup and not self.check_mode:
self.backup_file = module.backup_local(self.path)
super(CertificateSigningRequestModule, self).remove(module)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = self.module_backend.dump(include_csr=self.return_content)
result.update(
@@ -312,7 +322,7 @@ class CertificateSigningRequestModule(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_csr_argument_spec()
argument_spec.argument_spec.update(
dict(

View File

@@ -308,6 +308,7 @@ authority_cert_serial_number:
sample: 12345
"""
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -318,7 +319,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path"),
@@ -335,11 +336,15 @@ def main():
supports_check_mode=True,
)
if module.params["content"] is not None:
data = module.params["content"].encode("utf-8")
content: str | None = module.params["content"]
path: str | None = module.params["path"]
if content is not None:
data = content.encode("utf-8")
else:
if path is None:
module.fail_json(msg="One of content and path must be provided")
try:
with open(module.params["path"], "rb") as f:
with open(path, "rb") as f:
data = f.read()
except (IOError, OSError) as e:
module.fail_json(msg=f"Error while reading CSR file from disk: {e}")

View File

@@ -127,6 +127,8 @@ csr:
type: str
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
)
@@ -136,8 +138,17 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import (
CertificateSigningRequestBackend,
)
class CertificateSigningRequestModule:
def __init__(self, module, module_backend):
def __init__(
self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend
) -> None:
self.check_mode = module.check_mode
self.module = module
self.module_backend = module_backend
@@ -145,13 +156,13 @@ class CertificateSigningRequestModule:
if module.params["content"] is not None:
self.module_backend.set_existing(module.params["content"].encode("utf-8"))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Generate the certificate signing request."""
if self.module_backend.needs_regeneration():
self.module_backend.generate_csr()
self.changed = True
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = self.module_backend.dump(include_csr=True)
result.update(
@@ -162,7 +173,7 @@ class CertificateSigningRequestModule:
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_csr_argument_spec()
argument_spec.argument_spec.update(
dict(

View File

@@ -132,6 +132,7 @@ import abc
import os
import re
import tempfile
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_native
@@ -173,23 +174,23 @@ class DHParameterError(Exception):
class DHParameterBase:
def __init__(self, module):
self.state = module.params["state"]
self.path = module.params["path"]
self.size = module.params["size"]
self.force = module.params["force"]
def __init__(self, module: AnsibleModule) -> None:
self.state: t.Literal["absent", "present"] = module.params["state"]
self.path: str = module.params["path"]
self.size: int = module.params["size"]
self.force: bool = module.params["force"]
self.changed = False
self.return_content = module.params["return_content"]
self.return_content: bool = module.params["return_content"]
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
@abc.abstractmethod
def _do_generate(self, module):
def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params."""
pass
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Generate DH params."""
changed = False
@@ -206,7 +207,7 @@ class DHParameterBase:
self.changed = changed
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
if self.backup:
self.backup_file = module.backup_local(self.path)
try:
@@ -215,28 +216,27 @@ class DHParameterBase:
except OSError as exc:
module.fail_json(msg=str(exc))
def check(self, module):
def check(self, module: AnsibleModule) -> bool:
"""Ensure the resource is in its desired state."""
if self.force:
return False
return self._check_params_valid(module) and self._check_fs_attributes(module)
@abc.abstractmethod
def _check_params_valid(self, module):
def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state"""
pass
def _check_fs_attributes(self, module):
def _check_fs_attributes(self, module: AnsibleModule) -> bool:
"""Checks (and changes if not in check mode!) fs attributes"""
file_args = module.load_file_common_arguments(module.params)
if module.check_file_absent_if_check_mode(file_args["path"]):
return False
return not module.set_fs_attributes_if_different(file_args, False)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = {
result: dict[str, t.Any] = {
"size": self.size,
"filename": self.path,
"changed": self.changed,
@@ -252,25 +252,24 @@ class DHParameterBase:
class DHParameterAbsent(DHParameterBase):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(DHParameterAbsent, self).__init__(module)
def _do_generate(self, module):
def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params."""
pass
def _check_params_valid(self, module):
def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state"""
pass
return False
class DHParameterOpenSSL(DHParameterBase):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(DHParameterOpenSSL, self).__init__(module)
self.openssl_bin = module.get_bin_path("openssl", True)
def _do_generate(self, module):
def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params."""
# create a tempfile
fd, tmpsrc = tempfile.mkstemp()
@@ -288,7 +287,7 @@ class DHParameterOpenSSL(DHParameterBase):
except Exception as e:
module.fail_json(msg=f"Failed to write to file {self.path}: {str(e)}")
def _check_params_valid(self, module):
def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state"""
command = [
self.openssl_bin,
@@ -321,10 +320,10 @@ class DHParameterOpenSSL(DHParameterBase):
class DHParameterCryptography(DHParameterBase):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(DHParameterCryptography, self).__init__(module)
def _do_generate(self, module):
def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params."""
# Generate parameters
params = cryptography.hazmat.primitives.asymmetric.dh.generate_parameters(
@@ -341,7 +340,7 @@ class DHParameterCryptography(DHParameterBase):
self.backup_file = module.backup_local(self.path)
write_file(module, result)
def _check_params_valid(self, module):
def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state"""
# Load parameters
try:
@@ -357,7 +356,7 @@ class DHParameterCryptography(DHParameterBase):
return bits == self.size
def main():
def main() -> t.NoReturn:
"""Main function"""
module = AnsibleModule(
@@ -383,6 +382,7 @@ def main():
msg=f"The directory '{base_dir}' does not exist or the file is not a directory",
)
dhparam: DHParameterOpenSSL | DHParameterCryptography | DHParameterAbsent
if module.params["state"] == "present":
backend = module.params["select_crypto_backend"]
if backend == "auto":

View File

@@ -280,6 +280,7 @@ import itertools
import os
import stat
import traceback
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_native
@@ -296,7 +297,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
OpenSSLObject,
load_certificate,
load_privatekey,
load_certificate_issuer_privatekey,
)
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -320,6 +321,7 @@ except ImportError:
CRYPTOGRAPHY_COMPATIBILITY2022_ERR = None
try:
import cryptography.x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization.pkcs12 import PBES
@@ -333,8 +335,22 @@ except Exception:
else:
CRYPTOGRAPHY_HAS_COMPATIBILITY2022 = True
if t.TYPE_CHECKING:
from ..module_utils.crypto.cryptography_support import (
CertificateIssuerPrivateKeyTypes,
)
def load_certificate_set(filename):
PKCS12 = tuple[
t.Union[CertificateIssuerPrivateKeyTypes, None],
t.Union[cryptography.x509.Certificate, None],
list[cryptography.x509.Certificate],
t.Union[bytes, None],
]
def load_certificate_set(
filename: str | os.PathLike,
) -> list[cryptography.x509.Certificate]:
"""
Load list of concatenated PEM files, and return a list of parsed certificates.
"""
@@ -351,70 +367,80 @@ class PkcsError(OpenSSLObjectError):
class Pkcs(OpenSSLObject):
def __init__(self, module, iter_size_default=2048):
path: str
def __init__(self, module: AnsibleModule, iter_size_default: int = 2048) -> None:
super(Pkcs, self).__init__(
module.params["path"],
module.params["state"],
module.params["force"],
module.check_mode,
)
self.action = module.params["action"]
self.other_certificates = module.params["other_certificates"]
self.other_certificates_parse_all = module.params[
self.action: t.Literal["export", "parse"] = module.params["action"]
self.other_certificates: list[cryptography.x509.Certificate] = []
self.other_certificates_str: list[str] | None = module.params[
"other_certificates"
]
self.other_certificates_parse_all: bool = module.params[
"other_certificates_parse_all"
]
self.other_certificates_content = module.params["other_certificates_content"]
self.certificate_path = module.params["certificate_path"]
self.certificate_content = module.params["certificate_content"]
self.friendly_name = module.params["friendly_name"]
self.iter_size = module.params["iter_size"] or iter_size_default
self.maciter_size = module.params["maciter_size"] or 1
self.encryption_level = module.params["encryption_level"]
self.other_certificates_content: list[str] | None = module.params[
"other_certificates_content"
]
self.certificate_path: str | None = module.params["certificate_path"]
certificate_content: str | None = module.params["certificate_content"]
self.friendly_name: str | None = module.params["friendly_name"]
self.iter_size: int = module.params["iter_size"] or iter_size_default
self.maciter_size: int = module.params["maciter_size"] or 1
self.encryption_level: t.Literal["auto", "compatibility2022"] = module.params[
"encryption_level"
]
self.passphrase = module.params["passphrase"]
self.pkcs12 = None
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
self.pkcs12_bytes = None
self.return_content = module.params["return_content"]
self.src = module.params["src"]
self.pkcs12: PKCS12 | None = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
self.pkcs12_bytes: bytes | None = None
self.return_content: bool = module.params["return_content"]
self.src: str | None = module.params["src"]
if module.params["mode"] is None:
module.params["mode"] = "0400"
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
self.certificate_content: bytes | None = None
if self.certificate_path is not None:
try:
with open(self.certificate_path, "rb") as fh:
self.certificate_content = fh.read()
except (IOError, OSError) as exc:
raise PkcsError(exc)
elif self.certificate_content is not None:
self.certificate_content = to_bytes(self.certificate_content)
elif certificate_content is not None:
self.certificate_content = to_bytes(certificate_content)
self.privatekey_content: bytes | None = None
if self.privatekey_path is not None:
try:
with open(self.privatekey_path, "rb") as fh:
self.privatekey_content = fh.read()
except (IOError, OSError) as exc:
raise PkcsError(exc)
elif self.privatekey_content is not None:
self.privatekey_content = to_bytes(self.privatekey_content)
elif privatekey_content is not None:
self.privatekey_content = to_bytes(privatekey_content)
if self.other_certificates:
if self.other_certificates_str:
if self.other_certificates_parse_all:
filenames = list(self.other_certificates)
self.other_certificates = []
for other_cert_bundle in filenames:
for other_cert_bundle in self.other_certificates_str:
self.other_certificates.extend(
load_certificate_set(other_cert_bundle)
)
else:
self.other_certificates = [
load_certificate(other_cert)
for other_cert in self.other_certificates
for other_cert in self.other_certificates_str
]
elif self.other_certificates_content:
certs = self.other_certificates_content
@@ -430,40 +456,42 @@ class Pkcs(OpenSSLObject):
]
@abc.abstractmethod
def generate_bytes(self, module):
def generate_bytes(self, module: AnsibleModule) -> bytes:
"""Generate PKCS#12 file archive."""
@abc.abstractmethod
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
pass
@abc.abstractmethod
def parse_bytes(self, pkcs12_content):
def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
pass
@abc.abstractmethod
def _dump_privatekey(self, pkcs12):
def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
pass
@abc.abstractmethod
def _dump_certificate(self, pkcs12):
def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
pass
@abc.abstractmethod
def _dump_other_certificates(self, pkcs12):
def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
pass
@abc.abstractmethod
def _get_friendly_name(self, pkcs12):
pass
def check(self, module, perms_required=True):
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
state_and_perms = super(Pkcs, self).check(module, perms_required)
def _check_pkey_passphrase():
def _check_pkey_passphrase() -> bool:
if self.privatekey_passphrase:
try:
load_privatekey(
None,
load_certificate_issuer_privatekey(
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
)
@@ -476,6 +504,7 @@ class Pkcs(OpenSSLObject):
if os.path.exists(self.path) and module.params["action"] == "export":
self.generate_bytes(module) # ignore result
assert self.pkcs12 is not None
self.src = self.path
try:
(
@@ -524,7 +553,7 @@ class Pkcs(OpenSSLObject):
return False
elif (
module.params["action"] == "parse"
and os.path.exists(self.src)
and os.path.exists(self.src or "")
and os.path.exists(self.path)
):
try:
@@ -548,10 +577,10 @@ class Pkcs(OpenSSLObject):
return _check_pkey_passphrase()
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = {
result: dict[str, t.Any] = {
"filename": self.path,
}
if self.privatekey_path:
@@ -567,13 +596,20 @@ class Pkcs(OpenSSLObject):
return result
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
if self.backup:
self.backup_file = module.backup_local(self.path)
super(Pkcs, self).remove(module)
def parse(self):
def parse(self) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
"""Read PKCS#12 file."""
if self.src is None:
raise AssertionError("Contract violation: src is None")
try:
with open(self.src, "rb") as pkcs12_fh:
@@ -582,10 +618,13 @@ class Pkcs(OpenSSLObject):
except IOError as exc:
raise PkcsError(exc)
def generate(self):
def generate(self, module: AnsibleModule) -> None:
# Empty method because OpenSSLObject wants this
pass
def write(self, module, content, mode=None):
def write(
self, module: AnsibleModule, content: bytes, mode: int | str | None = None
) -> None:
"""Write the PKCS#12 file."""
if self.backup:
self.backup_file = module.backup_local(self.path)
@@ -595,7 +634,7 @@ class Pkcs(OpenSSLObject):
class PkcsCryptography(Pkcs):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(PkcsCryptography, self).__init__(module, iter_size_default=50000)
if (
self.encryption_level == "compatibility2022"
@@ -607,13 +646,12 @@ class PkcsCryptography(Pkcs):
exception=CRYPTOGRAPHY_COMPATIBILITY2022_ERR,
)
def generate_bytes(self, module):
def generate_bytes(self, module: AnsibleModule) -> bytes:
"""Generate PKCS#12 file archive."""
pkey = None
if self.privatekey_content:
try:
pkey = load_privatekey(
None,
pkey = load_certificate_issuer_privatekey(
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
)
@@ -631,6 +669,7 @@ class PkcsCryptography(Pkcs):
# Store fake object which can be used to retrieve the components back
self.pkcs12 = (pkey, cert, self.other_certificates, friendly_name)
encryption: serialization.KeySerializationEncryption
if not self.passphrase:
encryption = serialization.NoEncryption()
elif self.encryption_level == "compatibility2022":
@@ -654,7 +693,12 @@ class PkcsCryptography(Pkcs):
encryption,
)
def parse_bytes(self, pkcs12_content):
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
try:
private_key, certificate, additional_certificates, friendly_name = (
parse_pkcs12(pkcs12_content, self.passphrase)
@@ -683,11 +727,7 @@ class PkcsCryptography(Pkcs):
except ValueError as exc:
raise PkcsError(exc)
# The following methods will get self.pkcs12 passed, which is computed as:
#
# self.pkcs12 = (pkey, cert, self.other_certificates, self.friendly_name)
def _dump_privatekey(self, pkcs12):
def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
return (
pkcs12[0].private_bytes(
encoding=serialization.Encoding.PEM,
@@ -698,27 +738,27 @@ class PkcsCryptography(Pkcs):
else None
)
def _dump_certificate(self, pkcs12):
def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
return pkcs12[1].public_bytes(serialization.Encoding.PEM) if pkcs12[1] else None
def _dump_other_certificates(self, pkcs12):
def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
return [
other_cert.public_bytes(serialization.Encoding.PEM)
for other_cert in pkcs12[2]
]
def _get_friendly_name(self, pkcs12):
def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
return pkcs12[3]
def select_backend(module):
def select_backend(module: AnsibleModule) -> Pkcs:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PkcsCryptography(module)
def main():
def main() -> t.NoReturn:
argument_spec = dict(
action=dict(type="str", default="export", choices=["export", "parse"]),
other_certificates=dict(

View File

@@ -155,6 +155,7 @@ privatekey:
"""
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
@@ -172,9 +173,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey import (
PrivateKeyBackend,
)
class PrivateKeyModule(OpenSSLObject):
def __init__(self, module, module_backend):
def __init__(
self, module: AnsibleModule, module_backend: PrivateKeyBackend
) -> None:
super(PrivateKeyModule, self).__init__(
module.params["path"],
module.params["state"],
@@ -182,19 +192,19 @@ class PrivateKeyModule(OpenSSLObject):
module.check_mode,
)
self.module_backend = module_backend
self.return_content = module.params["return_content"]
self.return_content: bool = module.params["return_content"]
if self.force:
module_backend.regenerate = "always"
self.backup = module.params["backup"]
self.backup_file = None
self.backup: str | None = module.params["backup"]
self.backup_file: str | None = None
if module.params["mode"] is None:
module.params["mode"] = "0600"
module_backend.set_existing(load_file_if_exists(self.path, module))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Generate a keypair."""
if self.module_backend.needs_regeneration():
@@ -228,13 +238,13 @@ class PrivateKeyModule(OpenSSLObject):
file_args, self.changed
)
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
self.module_backend.set_existing(None)
if self.backup and not self.check_mode:
self.backup_file = module.backup_local(self.path)
super(PrivateKeyModule, self).remove(module)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = self.module_backend.dump(include_key=self.return_content)
@@ -246,7 +256,7 @@ class PrivateKeyModule(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_privatekey_argument_spec()
argument_spec.argument_spec.update(

View File

@@ -60,6 +60,7 @@ backup_file:
"""
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
@@ -77,8 +78,17 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey_convert import (
PrivateKeyConvertBackend,
)
class PrivateKeyConvertModule(OpenSSLObject):
def __init__(self, module, module_backend):
def __init__(
self, module: AnsibleModule, module_backend: PrivateKeyConvertBackend
) -> None:
super(PrivateKeyConvertModule, self).__init__(
module.params["dest_path"],
"present",
@@ -87,8 +97,8 @@ class PrivateKeyConvertModule(OpenSSLObject):
)
self.module_backend = module_backend
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
module.params["path"] = module.params["dest_path"]
if module.params["mode"] is None:
@@ -96,12 +106,14 @@ class PrivateKeyConvertModule(OpenSSLObject):
module_backend.set_existing_destination(load_file_if_exists(self.path, module))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Do conversion."""
if self.module_backend.needs_conversion():
# Convert
privatekey_data = self.module_backend.get_private_key_data()
if privatekey_data is None:
raise AssertionError("Contract violation: privatekey_data is None")
if not self.check_mode:
if self.backup:
self.backup_file = module.backup_local(self.path)
@@ -116,7 +128,7 @@ class PrivateKeyConvertModule(OpenSSLObject):
file_args, self.changed
)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = self.module_backend.dump()
@@ -127,7 +139,7 @@ class PrivateKeyConvertModule(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_privatekey_argument_spec()
argument_spec.argument_spec.update(

View File

@@ -200,6 +200,7 @@ private_data:
type: dict
"""
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -212,7 +213,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path"),
@@ -243,7 +244,7 @@ def main():
data = f.read()
except (IOError, OSError) as e:
module.fail_json(
msg=f"Error while reading private key file from disk: {e}", **result
msg=f"Error while reading private key file from disk: {e}", **result # type: ignore
)
result["can_load_key"] = True
@@ -261,10 +262,10 @@ def main():
module.exit_json(**result)
except PrivateKeyParseError as exc:
result.update(exc.result)
module.fail_json(msg=exc.error_message, **result)
module.fail_json(msg=exc.error_message, **result) # type: ignore
except PrivateKeyConsistencyError as exc:
result.update(exc.result)
module.fail_json(msg=exc.error_message, **result)
module.fail_json(msg=exc.error_message, **result) # type: ignore
except OpenSSLObjectError as exc:
module.fail_json(msg=str(exc))

View File

@@ -186,6 +186,7 @@ publickey:
"""
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -218,6 +219,12 @@ try:
except ImportError:
pass
if t.TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
PublicKeyTypes,
)
class PublicKeyError(OpenSSLObjectError):
pass
@@ -225,7 +232,7 @@ class PublicKeyError(OpenSSLObjectError):
class PublicKey(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(PublicKey, self).__init__(
module.params["path"],
module.params["state"],
@@ -233,27 +240,29 @@ class PublicKey(OpenSSLObject):
module.check_mode,
)
self.module = module
self.format = module.params["format"]
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
if self.privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.privatekey = None
self.publickey_bytes = None
self.return_content = module.params["return_content"]
self.fingerprint = {}
self.format: t.Literal["OpenSSH", "PEM"] = module.params["format"]
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
if privatekey_content is not None:
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.privatekey: PrivateKeyTypes | None = None
self.publickey_bytes: bytes | None = None
self.return_content: bool = module.params["return_content"]
self.fingerprint: dict[str, str] = {}
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
def _get_info(self, data):
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return dict()
result = dict(can_parse_key=False)
return {}
result = {"can_parse_key": False}
try:
result.update(
get_publickey_info(
@@ -267,7 +276,7 @@ class PublicKey(OpenSSLObject):
pass
return result
def _create_publickey(self, module):
def _create_publickey(self, module: AnsibleModule) -> bytes:
self.privatekey = load_privatekey(
path=self.privatekey_path,
content=self.privatekey_content,
@@ -284,10 +293,12 @@ class PublicKey(OpenSSLObject):
crypto_serialization.PublicFormat.SubjectPublicKeyInfo,
)
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Generate the public key."""
if self.privatekey_content is None and not os.path.exists(self.privatekey_path):
if self.privatekey_path is not None and not os.path.exists(
self.privatekey_path
):
raise PublicKeyError(
f"The private key {self.privatekey_path} does not exist"
)
@@ -320,17 +331,18 @@ class PublicKey(OpenSSLObject):
elif module.set_fs_attributes_if_different(file_args, False):
self.changed = True
def check(self, module, perms_required=True):
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
state_and_perms = super(PublicKey, self).check(module, perms_required)
def _check_privatekey():
if self.privatekey_content is None and not os.path.exists(
def _check_privatekey() -> bool:
if self.privatekey_path is not None and not os.path.exists(
self.privatekey_path
):
return False
current_publickey: PublicKeyTypes
try:
with open(self.path, "rb") as public_key_fh:
publickey_content = public_key_fh.read()
@@ -369,15 +381,15 @@ class PublicKey(OpenSSLObject):
return _check_privatekey()
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
if self.backup:
self.backup_file = module.backup_local(self.path)
super(PublicKey, self).remove(module)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = {
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
"filename": self.path,
"format": self.format,
@@ -403,7 +415,7 @@ class PublicKey(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(

View File

@@ -152,6 +152,7 @@ public_data:
returned: When RV(type=DSA) or RV(type=ECC)
"""
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -163,7 +164,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path"),
@@ -191,7 +192,7 @@ def main():
data = f.read()
except (IOError, OSError) as e:
module.fail_json(
msg=f"Error while reading public key file from disk: {e}", **result
msg=f"Error while reading public key file from disk: {e}", **result # type: ignore
)
module_backend = select_backend(module, data)
@@ -201,7 +202,7 @@ def main():
module.exit_json(**result)
except PublicKeyParseError as exc:
result.update(exc.result)
module.fail_json(msg=exc.error_message, **result)
module.fail_json(msg=exc.error_message, **result) # type: ignore
except OpenSSLObjectError as exc:
module.fail_json(msg=str(exc))

View File

@@ -99,6 +99,7 @@ signature:
import base64
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -132,7 +133,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im
class SignatureBase(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(SignatureBase, self).__init__(
path=module.params["path"],
state="present",
@@ -140,32 +141,35 @@ class SignatureBase(OpenSSLObject):
check_mode=module.check_mode,
)
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
if self.privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.module = module
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
if privatekey_content is not None:
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
def generate(self):
def generate(self, module: AnsibleModule) -> None:
# Empty method because OpenSSLObject wants this
pass
def dump(self):
def dump(self) -> dict[str, t.Any]:
# Empty method because OpenSSLObject wants this
pass
return {}
# Implementation with using cryptography
class SignatureCryptography(SignatureBase):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(SignatureCryptography, self).__init__(module)
def run(self):
def run(self) -> dict[str, t.Any]:
_padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
_hash = cryptography.hazmat.primitives.hashes.SHA256()
result = dict()
result: dict[str, t.Any] = {}
try:
with open(self.path, "rb") as f:
@@ -223,7 +227,7 @@ class SignatureCryptography(SignatureBase):
raise OpenSSLObjectError(e)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
privatekey_path=dict(type="path"),

View File

@@ -88,6 +88,7 @@ valid:
import base64
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -121,7 +122,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im
class SignatureInfoBase(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(SignatureInfoBase, self).__init__(
path=module.params["path"],
state="present",
@@ -129,32 +130,35 @@ class SignatureInfoBase(OpenSSLObject):
check_mode=module.check_mode,
)
self.signature = module.params["signature"]
self.certificate_path = module.params["certificate_path"]
self.certificate_content = module.params["certificate_content"]
if self.certificate_content is not None:
self.certificate_content = self.certificate_content.encode("utf-8")
self.module = module
self.signature: str = module.params["signature"]
self.certificate_path: str | None = module.params["certificate_path"]
certificate_content: str | None = module.params["certificate_content"]
if certificate_content is not None:
self.certificate_content: bytes | None = certificate_content.encode("utf-8")
else:
self.certificate_content = None
def generate(self):
def generate(self, module: AnsibleModule) -> None:
# Empty method because OpenSSLObject wants this
pass
def dump(self):
def dump(self) -> dict[str, t.Any]:
# Empty method because OpenSSLObject wants this
pass
return {}
# Implementation with using cryptography
class SignatureInfoCryptography(SignatureInfoBase):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(SignatureInfoCryptography, self).__init__(module)
def run(self):
def run(self) -> dict[str, t.Any]:
_padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
_hash = cryptography.hazmat.primitives.hashes.SHA256()
result = dict()
result: dict[str, t.Any] = {}
try:
with open(self.path, "rb") as f:
@@ -228,7 +232,7 @@ class SignatureInfoCryptography(SignatureInfoBase):
raise OpenSSLObjectError(e)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
certificate_path=dict(type="path"),

View File

@@ -224,6 +224,7 @@ certificate:
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
@@ -257,8 +258,15 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
CertificateBackend,
)
class CertificateAbsent(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(CertificateAbsent, self).__init__(
module.params["path"],
module.params["state"],
@@ -266,19 +274,19 @@ class CertificateAbsent(OpenSSLObject):
module.check_mode,
)
self.module = module
self.return_content = module.params["return_content"]
self.backup = module.params["backup"]
self.backup_file = None
self.return_content: bool = module.params["return_content"]
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
pass
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
if self.backup:
self.backup_file = module.backup_local(self.path)
super(CertificateAbsent, self).remove(module)
def dump(self, check_mode=False):
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = {
"changed": self.changed,
"filename": self.path,
@@ -296,7 +304,7 @@ class CertificateAbsent(OpenSSLObject):
class GenericCertificate(OpenSSLObject):
"""Retrieve a certificate using the given module backend."""
def __init__(self, module, module_backend):
def __init__(self, module: AnsibleModule, module_backend: CertificateBackend):
super(GenericCertificate, self).__init__(
module.params["path"],
module.params["state"],
@@ -311,7 +319,7 @@ class GenericCertificate(OpenSSLObject):
self.module_backend = module_backend
self.module_backend.set_existing(load_file_if_exists(self.path, module))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
if self.module_backend.needs_regeneration():
if not self.check_mode:
self.module_backend.generate_certificate()
@@ -329,14 +337,14 @@ class GenericCertificate(OpenSSLObject):
file_args, self.changed
)
def check(self, module, perms_required=True):
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
return (
super(GenericCertificate, self).check(module, perms_required)
and not self.module_backend.needs_regeneration()
)
def dump(self, check_mode=False):
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = self.module_backend.dump(include_certificate=self.return_content)
result.update(
{
@@ -349,7 +357,7 @@ class GenericCertificate(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_certificate_argument_spec()
add_acme_provider_to_argument_spec(argument_spec)
add_entrust_provider_to_argument_spec(argument_spec)
@@ -363,13 +371,14 @@ def main():
return_content=dict(type="bool", default=False),
)
)
argument_spec.required_if.append(["state", "present", ["provider"]])
argument_spec.required_if.append(("state", "present", ["provider"]))
module = argument_spec.create_ansible_module(
add_file_common_args=True,
supports_check_mode=True,
)
try:
certificate: GenericCertificate | CertificateAbsent
if module.params["state"] == "absent":
certificate = CertificateAbsent(module)
@@ -389,7 +398,13 @@ def main():
)
provider = module.params["provider"]
provider_map = {
provider_map: dict[
str,
type[AcmeCertificateProvider]
| type[EntrustCertificateProvider]
| type[OwnCACertificateProvider]
| type[SelfSignedCertificateProvider],
] = {
"acme": AcmeCertificateProvider,
"entrust": EntrustCertificateProvider,
"ownca": OwnCACertificateProvider,

View File

@@ -106,6 +106,7 @@ backup_file:
import base64
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_text
@@ -142,8 +143,12 @@ except ImportError:
pass
def parse_certificate(input, strict=False):
input_format = "pem" if identify_pem_format(input) else "der"
def parse_certificate(
input: bytes, strict: bool = False
) -> tuple[bytes, t.Literal["pem", "der"], str | None]:
input_format: t.Literal["pem", "der"] = (
"pem" if identify_pem_format(input) else "der"
)
if input_format == "pem":
pems = split_pem_list(to_text(input))
if len(pems) > 1 and strict:
@@ -162,7 +167,7 @@ def parse_certificate(input, strict=False):
class X509CertificateConvertModule(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(X509CertificateConvertModule, self).__init__(
module.params["dest_path"],
"present",
@@ -170,9 +175,9 @@ class X509CertificateConvertModule(OpenSSLObject):
module.check_mode,
)
self.src_path = module.params["src_path"]
self.src_content = module.params["src_content"]
self.src_content_base64 = module.params["src_content_base64"]
self.src_path: str | None = module.params["src_path"]
self.src_content: str | None = module.params["src_content"]
self.src_content_base64: bool = module.params["src_content_base64"]
if self.src_content is not None:
self.input = to_bytes(self.src_content)
if self.src_content_base64:
@@ -181,6 +186,8 @@ class X509CertificateConvertModule(OpenSSLObject):
except Exception as exc:
module.fail_json(msg=f"Cannot Base64 decode src_content: {exc}")
else:
if self.src_path is None:
module.fail_json(msg="One of src_path and src_content must be provided")
try:
with open(self.src_path, "rb") as f:
self.input = f.read()
@@ -189,8 +196,8 @@ class X509CertificateConvertModule(OpenSSLObject):
msg=f"Failure while reading file {self.src_path}: {exc}"
)
self.format = module.params["format"]
self.strict = module.params["strict"]
self.format: t.Literal["pem", "der"] = module.params["format"]
self.strict: bool = module.params["strict"]
self.wanted_pem_type = "CERTIFICATE"
try:
@@ -203,8 +210,8 @@ class X509CertificateConvertModule(OpenSSLObject):
if module.params["verify_cert_parsable"]:
self.verify_cert_parsable(module)
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
module.params["path"] = self.path
@@ -221,7 +228,7 @@ class X509CertificateConvertModule(OpenSSLObject):
except Exception:
pass
def verify_cert_parsable(self, module):
def verify_cert_parsable(self, module: AnsibleModule) -> None:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
@@ -230,7 +237,7 @@ class X509CertificateConvertModule(OpenSSLObject):
except Exception as exc:
module.fail_json(msg=f"Error while parsing certificate: {exc}")
def needs_conversion(self):
def needs_conversion(self) -> bool:
if self.dest_content is None or self.dest_content_format is None:
return True
if self.dest_content_format != self.format:
@@ -241,7 +248,7 @@ class X509CertificateConvertModule(OpenSSLObject):
return True
return False
def get_dest_certificate(self):
def get_dest_certificate(self) -> bytes:
if self.format == "der":
return self.input
data = to_bytes(base64.b64encode(self.input))
@@ -250,7 +257,7 @@ class X509CertificateConvertModule(OpenSSLObject):
lines.append(to_bytes(f"{PEM_END_START}{self.wanted_pem_type}{PEM_END}\n"))
return b"\n".join(lines)
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Do conversion."""
if self.needs_conversion():
# Convert
@@ -269,18 +276,18 @@ class X509CertificateConvertModule(OpenSSLObject):
file_args, self.changed
)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = dict(
changed=self.changed,
)
result: dict[str, t.Any] = {
"changed": self.changed,
}
if self.backup_file:
result["backup_file"] = self.backup_file
return result
def main():
def main() -> t.NoReturn:
argument_spec = dict(
src_path=dict(type="path"),
src_content=dict(type="str"),

View File

@@ -390,8 +390,10 @@ issuer_uri:
version_added: 2.9.0
"""
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
)
@@ -406,7 +408,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path"),
@@ -424,18 +426,22 @@ def main():
supports_check_mode=True,
)
if module.params["content"] is not None:
data = module.params["content"].encode("utf-8")
content: str | None = module.params["content"]
path: str | None = module.params["path"]
if content is not None:
data = content.encode("utf-8")
else:
if path is None:
module.fail_json(msg="One of path and content must be provided")
try:
with open(module.params["path"], "rb") as f:
with open(path, "rb") as f:
data = f.read()
except (IOError, OSError) as e:
module.fail_json(msg=f"Error while reading certificate file from disk: {e}")
module_backend = select_backend(module, data)
valid_at = module.params["valid_at"]
valid_at: dict[str, t.Any] = module.params["valid_at"]
if valid_at:
for k, v in valid_at.items():
if not isinstance(v, (str, bytes)):
@@ -443,13 +449,11 @@ def main():
msg=f"The value for valid_at.{k} must be of type string (got {type(v)})"
)
valid_at[k] = get_relative_time_option(
v, f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE
to_text(v), f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE
)
try:
result = module_backend.get_info(
der_support_enabled=module.params["content"] is None
)
result = module_backend.get_info(der_support_enabled=content is None)
not_before = module_backend.get_not_before()
not_after = module_backend.get_not_after()

View File

@@ -118,6 +118,8 @@ certificate:
type: str
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
)
@@ -139,23 +141,31 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
CertificateBackend,
)
class GenericCertificate:
"""Retrieve a certificate using the given module backend."""
def __init__(self, module, module_backend):
def __init__(self, module: AnsibleModule, module_backend: CertificateBackend):
self.check_mode = module.check_mode
self.module = module
self.module_backend = module_backend
self.changed = False
if module.params["content"] is not None:
self.module_backend.set_existing(module.params["content"].encode("utf-8"))
content: str | None = module.params["content"]
if content is not None:
self.module_backend.set_existing(content.encode("utf-8"))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
if self.module_backend.needs_regeneration():
self.module_backend.generate_certificate()
self.changed = True
def dump(self, check_mode=False):
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = self.module_backend.dump(include_certificate=True)
result.update(
{
@@ -165,7 +175,7 @@ class GenericCertificate:
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_certificate_argument_spec()
argument_spec.argument_spec["provider"]["required"] = True
add_entrust_provider_to_argument_spec(argument_spec)
@@ -182,7 +192,12 @@ def main():
try:
provider = module.params["provider"]
provider_map = {
provider_map: dict[
str,
type[EntrustCertificateProvider]
| type[OwnCACertificateProvider]
| type[SelfSignedCertificateProvider],
] = {
"entrust": EntrustCertificateProvider,
"ownca": OwnCACertificateProvider,
"selfsigned": SelfSignedCertificateProvider,

View File

@@ -425,6 +425,7 @@ crl:
import base64
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_text
@@ -463,7 +464,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
OpenSSLObject,
load_certificate,
load_privatekey,
load_certificate_issuer_privatekey,
parse_name_field,
parse_ordered_name_field,
select_message_digest,
@@ -495,6 +496,9 @@ try:
except ImportError:
pass
if t.TYPE_CHECKING:
import datetime
class CRLError(OpenSSLObjectError):
pass
@@ -502,7 +506,7 @@ class CRLError(OpenSSLObjectError):
class CRL(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(CRL, self).__init__(
module.params["path"],
module.params["state"],
@@ -510,53 +514,69 @@ class CRL(OpenSSLObject):
module.check_mode,
)
self.format = module.params["format"]
self.format: t.Literal["pem", "der"] = module.params["format"]
self.update = module.params["crl_mode"] == "update"
self.ignore_timestamps = module.params["ignore_timestamps"]
self.return_content = module.params["return_content"]
self.name_encoding = module.params["name_encoding"]
self.serial_numbers_format = module.params["serial_numbers"]
self.crl_content = None
self.update: bool = module.params["crl_mode"] == "update"
self.ignore_timestamps: bool = module.params["ignore_timestamps"]
self.return_content: bool = module.params["return_content"]
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params[
"name_encoding"
]
self.serial_numbers_format: t.Literal["integer", "hex-octets"] = module.params[
"serial_numbers"
]
self.crl_content: bytes | None = None
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
if self.privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
if privatekey_content is not None:
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
try:
if module.params["issuer_ordered"]:
issuer_ordered: list[dict[str, t.Any]] | None = module.params[
"issuer_ordered"
]
issuer: dict[str, list[str | bytes] | bytes | str] | None = module.params[
"issuer"
]
if issuer_ordered:
self.issuer_ordered = True
self.issuer = parse_ordered_name_field(
module.params["issuer_ordered"], "issuer_ordered"
)
self.issuer = parse_ordered_name_field(issuer_ordered, "issuer_ordered")
else:
self.issuer_ordered = False
self.issuer = parse_name_field(module.params["issuer"], "issuer")
self.issuer = (
parse_name_field(issuer, "issuer") if issuer is not None else []
)
except (TypeError, ValueError) as exc:
module.fail_json(msg=str(exc))
self.last_update = get_relative_time_option(
self.last_update: datetime.datetime = get_relative_time_option(
module.params["last_update"],
"last_update",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.next_update = get_relative_time_option(
self.next_update: datetime.datetime | None = get_relative_time_option(
module.params["next_update"],
"next_update",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["digest"])
if self.digest is None:
digest = select_message_digest(module.params["digest"])
if digest is None:
raise CRLError(f'The digest "{module.params["digest"]}" is not supported')
self.digest = digest
self.module = module
self.revoked_certificates = []
for i, rc in enumerate(module.params["revoked_certificates"]):
result = {
revoked_certificates: list[dict[str, t.Any]] = module.params[
"revoked_certificates"
]
for i, rc in enumerate(revoked_certificates):
result: dict[str, t.Any] = {
"serial_number": None,
"revocation_date": None,
"issuer": None,
@@ -567,21 +587,25 @@ class CRL(OpenSSLObject):
"invalidity_date_critical": False,
}
path_prefix = f"revoked_certificates[{i}]."
if rc["path"] is not None or rc["content"] is not None:
path: str | None = rc["path"]
content_str: str | None = rc["content"]
if path is not None or content_str is not None:
# Load certificate from file or content
try:
if rc["content"] is not None:
rc["content"] = rc["content"].encode("utf-8")
cert = load_certificate(rc["path"], content=rc["content"])
content: bytes | None = None
if content_str is not None:
content = content_str.encode("utf-8")
rc["content"] = content
cert = load_certificate(path, content=content)
result["serial_number"] = cert.serial_number
except OpenSSLObjectError as e:
if rc["content"] is not None:
if content_str is not None:
module.fail_json(
msg=f"Cannot parse certificate from {path_prefix}content: {e}"
)
else:
module.fail_json(
msg=f'Cannot read certificate "{rc["path"]}" from {path_prefix}path: {e}'
msg=f'Cannot read certificate "{path}" from {path_prefix}path: {e}'
)
else:
# Specify serial_number (and potentially issuer) directly
@@ -611,11 +635,11 @@ class CRL(OpenSSLObject):
result["invalidity_date_critical"] = rc["invalidity_date_critical"]
self.revoked_certificates.append(result)
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
try:
self.privatekey = load_privatekey(
self.privatekey = load_certificate_issuer_privatekey(
path=self.privatekey_path,
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
@@ -643,7 +667,7 @@ class CRL(OpenSSLObject):
self.diff_after = self.diff_before = self._get_info(data)
def _parse_serial_number(self, value, index):
def _parse_serial_number(self, value: t.Any, index: int) -> int:
if self.serial_numbers_format == "integer":
try:
return check_type_int(value)
@@ -662,22 +686,42 @@ class CRL(OpenSSLObject):
f"Unexpected value {self.serial_numbers_format} of serial_numbers"
)
def _get_info(self, data):
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return dict()
return {}
try:
result = get_crl_info(self.module, data)
result["can_parse_crl"] = True
return result
except Exception:
return dict(can_parse_crl=False)
return {"can_parse_crl": False}
def remove(self):
def remove(self, module: AnsibleModule) -> None:
if self.backup:
self.backup_file = self.module.backup_local(self.path)
super(CRL, self).remove(self.module)
def _compress_entry(self, entry):
def _compress_entry(self, entry: dict[str, t.Any]) -> (
tuple[
int | None,
tuple[str, ...] | None,
bool,
int | None,
bool,
datetime.datetime | None,
bool,
]
| tuple[
int | None,
datetime.datetime | None,
tuple[str, ...] | None,
bool,
int | None,
bool,
datetime.datetime | None,
bool,
]
):
issuer = None
if entry["issuer"] is not None:
# Normalize to IDNA. If this is used-provided, it was already converted to
@@ -713,7 +757,12 @@ class CRL(OpenSSLObject):
entry["invalidity_date_critical"],
)
def check(self, module, perms_required=True, ignore_conversion=True):
def check(
self,
module: AnsibleModule,
perms_required: bool = True,
ignore_conversion: bool = True,
) -> bool:
"""Ensure the resource is in its desired state."""
state_and_perms = super(CRL, self).check(self.module, perms_required)
@@ -743,10 +792,13 @@ class CRL(OpenSSLObject):
]
is_issuer = [(sub.oid, sub.value) for sub in self.crl.issuer]
if not self.issuer_ordered:
want_issuer = set(want_issuer)
is_issuer = set(is_issuer)
if want_issuer != is_issuer:
return False
want_issuer_set = set(want_issuer)
is_issuer_set = set(is_issuer)
if want_issuer_set != is_issuer_set:
return False
else:
if want_issuer != is_issuer:
return False
old_entries = [
self._compress_entry(cryptography_decode_revoked_certificate(cert))
@@ -769,7 +821,7 @@ class CRL(OpenSSLObject):
return True
def _generate_crl(self):
def _generate_crl(self) -> bytes:
crl = CertificateRevocationListBuilder()
try:
@@ -787,7 +839,8 @@ class CRL(OpenSSLObject):
raise CRLError(e)
crl = set_last_update(crl, self.last_update)
crl = set_next_update(crl, self.next_update)
if self.next_update is not None:
crl = set_next_update(crl, self.next_update)
if self.update and self.crl:
new_entries = set(
@@ -799,22 +852,26 @@ class CRL(OpenSSLObject):
)
if decoded_entry not in new_entries:
crl = crl.add_revoked_certificate(entry)
for entry in self.revoked_certificates:
for revoked_entry in self.revoked_certificates:
revoked_cert = RevokedCertificateBuilder()
revoked_cert = revoked_cert.serial_number(entry["serial_number"])
revoked_cert = set_revocation_date(revoked_cert, entry["revocation_date"])
if entry["issuer"] is not None:
revoked_cert = revoked_cert.serial_number(revoked_entry["serial_number"])
revoked_cert = set_revocation_date(
revoked_cert, revoked_entry["revocation_date"]
)
if revoked_entry["issuer"] is not None:
revoked_cert = revoked_cert.add_extension(
x509.CertificateIssuer(entry["issuer"]), entry["issuer_critical"]
x509.CertificateIssuer(revoked_entry["issuer"]),
revoked_entry["issuer_critical"],
)
if entry["reason"] is not None:
if revoked_entry["reason"] is not None:
revoked_cert = revoked_cert.add_extension(
x509.CRLReason(entry["reason"]), entry["reason_critical"]
x509.CRLReason(revoked_entry["reason"]),
revoked_entry["reason_critical"],
)
if entry["invalidity_date"] is not None:
if revoked_entry["invalidity_date"] is not None:
revoked_cert = revoked_cert.add_extension(
x509.InvalidityDate(entry["invalidity_date"]),
entry["invalidity_date_critical"],
x509.InvalidityDate(revoked_entry["invalidity_date"]),
revoked_entry["invalidity_date_critical"],
)
crl = crl.add_revoked_certificate(revoked_cert.build())
@@ -827,7 +884,7 @@ class CRL(OpenSSLObject):
else:
return self.crl.public_bytes(Encoding.DER)
def generate(self):
def generate(self, module: AnsibleModule) -> None:
result = None
if (
not self.check(self.module, perms_required=False, ignore_conversion=True)
@@ -861,7 +918,7 @@ class CRL(OpenSSLObject):
elif self.module.set_fs_attributes_if_different(file_args, False):
self.changed = True
def dump(self, check_mode=False):
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = {
"changed": self.changed,
"filename": self.path,
@@ -879,37 +936,53 @@ class CRL(OpenSSLObject):
if check_mode:
result["last_update"] = self.last_update.strftime(TIMESTAMP_FORMAT)
result["next_update"] = self.next_update.strftime(TIMESTAMP_FORMAT)
result["next_update"] = (
self.next_update.strftime(TIMESTAMP_FORMAT)
if self.next_update is not None
else None
)
# result['digest'] = cryptography_oid_to_name(self.crl.signature_algorithm_oid)
result["digest"] = self.module.params["digest"]
result["issuer_ordered"] = self.issuer
result["issuer"] = {}
issuer: dict[str, str | bytes] = {}
result["issuer"] = issuer
for k, v in self.issuer:
result["issuer"][k] = v
result["revoked_certificates"] = []
issuer[k] = v
revoked_certificates: list[dict[str, t.Any]] = []
result["revoked_certificates"] = revoked_certificates
for entry in self.revoked_certificates:
result["revoked_certificates"].append(
revoked_certificates.append(
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
)
elif self.crl:
result["last_update"] = get_last_update(self.crl).strftime(TIMESTAMP_FORMAT)
result["next_update"] = get_next_update(self.crl).strftime(TIMESTAMP_FORMAT)
next_update = get_next_update(self.crl)
result["next_update"] = (
next_update.strftime(TIMESTAMP_FORMAT)
if next_update is not None
else None
)
result["digest"] = cryptography_oid_to_name(
cryptography_get_signature_algorithm_oid_from_crl(self.crl)
)
issuer = []
issuer_list: list[list[str]] = []
for attribute in self.crl.issuer:
issuer.append(
[cryptography_oid_to_name(attribute.oid), attribute.value]
issuer_list.append(
[
cryptography_oid_to_name(attribute.oid),
to_text(attribute.value),
]
)
result["issuer_ordered"] = issuer
result["issuer"] = {}
for k, v in issuer:
result["issuer"][k] = v
result["revoked_certificates"] = []
result["issuer_ordered"] = issuer_list
issuer = {}
result["issuer"] = issuer
for k, v in issuer_list:
issuer[k] = v
revoked_certificates = []
result["revoked_certificates"] = revoked_certificates
for cert in self.crl:
entry = cryptography_decode_revoked_certificate(cert)
result["revoked_certificates"].append(
revoked_certificates.append(
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
)
@@ -923,7 +996,7 @@ class CRL(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
state=dict(type="str", default="present", choices=["present", "absent"]),
@@ -1015,14 +1088,14 @@ def main():
)
module.exit_json(**result)
crl.generate()
crl.generate(module)
else:
if module.check_mode:
result = crl.dump(check_mode=True)
result["changed"] = os.path.exists(module.params["path"])
module.exit_json(**result)
crl.remove()
crl.remove(module)
result = crl.dump()
module.exit_json(**result)

View File

@@ -100,7 +100,9 @@ last_update:
type: str
sample: '20190413202428Z'
next_update:
description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
description:
- The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
- Will be C(none) if no such timestamp is present.
returned: success
type: str
sample: '20190413202428Z'
@@ -172,6 +174,7 @@ revoked_certificates:
import base64
import binascii
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -185,7 +188,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path"),
@@ -200,25 +203,30 @@ def main():
supports_check_mode=True,
)
if module.params["content"] is None:
content: str | None = module.params["content"]
path: str | None = module.params["path"]
if content is None:
if path is None:
module.fail_json(msg="One of content and path must be provided")
try:
with open(module.params["path"], "rb") as f:
with open(path, "rb") as f:
data = f.read()
except (IOError, OSError) as e:
module.fail_json(msg=f"Error while reading CRL file from disk: {e}")
else:
data = module.params["content"].encode("utf-8")
data = content.encode("utf-8")
if not identify_pem_format(data):
try:
data = base64.b64decode(module.params["content"])
data = base64.b64decode(content)
except (binascii.Error, TypeError) as e:
module.fail_json(msg=f"Error while Base64 decoding content: {e}")
list_revoked_certificates: bool = module.params["list_revoked_certificates"]
try:
result = get_crl_info(
module,
data,
list_revoked_certificates=module.params["list_revoked_certificates"],
list_revoked_certificates=list_revoked_certificates,
)
module.exit_json(**result)
except OpenSSLObjectError as e: