Remove no longer needed backend abstractions. (#912)

This commit is contained in:
Felix Fontein
2025-06-01 09:07:06 +02:00
committed by GitHub
parent d1a8137d91
commit 576a06b5b2
9 changed files with 759 additions and 1062 deletions

View File

@@ -9,7 +9,6 @@
from __future__ import annotations
import abc
import binascii
import typing as t
@@ -69,189 +68,13 @@ except ImportError:
TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
class CertificateInfoRetrieval:
cert: x509.Certificate
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
# content must be a bytes string
self.module = module
self.content = content
@abc.abstractmethod
def _get_der_bytes(self) -> bytes:
pass
@abc.abstractmethod
def _get_signature_algorithm(self) -> str:
pass
@abc.abstractmethod
def _get_subject_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_issuer_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_version(self) -> int | str:
pass
@abc.abstractmethod
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
pass
@abc.abstractmethod
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def get_not_before(self) -> datetime.datetime:
pass
@abc.abstractmethod
def get_not_after(self) -> datetime.datetime:
pass
@abc.abstractmethod
def _get_public_key_pem(self) -> bytes:
pass
@abc.abstractmethod
def _get_public_key_object(self) -> PublicKeyTypes:
pass
@abc.abstractmethod
def _get_subject_key_identifier(self) -> bytes | None:
pass
@abc.abstractmethod
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
pass
@abc.abstractmethod
def _get_serial_number(self) -> int:
pass
@abc.abstractmethod
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
pass
@abc.abstractmethod
def _get_ocsp_uri(self) -> str | None:
pass
@abc.abstractmethod
def _get_issuer_uri(self) -> str | None:
pass
def get_info(
self, *, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.cert = load_certificate(
content=self.content,
der_support_enabled=der_support_enabled,
)
result["signature_algorithm"] = self._get_signature_algorithm()
subject = self._get_subject_ordered()
issuer = self._get_issuer_ordered()
result["subject"] = {}
for k, v in subject:
result["subject"][k] = v
result["subject_ordered"] = subject
result["issuer"] = {}
for k, v in issuer:
result["issuer"][k] = v
result["issuer_ordered"] = issuer
result["version"] = self._get_version()
result["key_usage"], result["key_usage_critical"] = self._get_key_usage()
result["extended_key_usage"], result["extended_key_usage_critical"] = (
self._get_extended_key_usage()
)
result["basic_constraints"], result["basic_constraints_critical"] = (
self._get_basic_constraints()
)
result["ocsp_must_staple"], result["ocsp_must_staple_critical"] = (
self._get_ocsp_must_staple()
)
result["subject_alt_name"], result["subject_alt_name_critical"] = (
self._get_subject_alt_name()
)
not_before = self.get_not_before()
not_after = self.get_not_after()
result["not_before"] = not_before.strftime(TIMESTAMP_FORMAT)
result["not_after"] = not_after.strftime(TIMESTAMP_FORMAT)
result["expired"] = not_after < get_now_datetime(
with_timezone=CRYPTOGRAPHY_TIMEZONE
)
result["public_key"] = to_text(self._get_public_key_pem())
public_key_info = get_publickey_info(
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
result.update(
{
"public_key_type": public_key_info["type"],
"public_key_data": public_key_info["public_data"],
"public_key_fingerprints": public_key_info["fingerprints"],
}
)
result["fingerprints"] = get_fingerprint_of_bytes(
self._get_der_bytes(), prefer_one=prefer_one_fingerprint
)
ski_bytes = self._get_subject_key_identifier()
if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
else:
ski = None
result["subject_key_identifier"] = ski
aki_bytes, aci, acsn = self._get_authority_key_identifier()
if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
else:
aki = None
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
result["serial_number"] = self._get_serial_number()
result["extensions_by_oid"] = self._get_all_extensions()
result["ocsp_uri"] = self._get_ocsp_uri()
result["issuer_uri"] = self._get_issuer_uri()
return result
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
"""Validate the supplied cert, using the cryptography backend"""
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
super().__init__(module=module, content=content)
self.name_encoding = module.params.get("name_encoding", "ignore")
def _get_der_bytes(self) -> bytes:
@@ -464,6 +287,93 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
pass
return None
def get_info(
self, *, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.cert = load_certificate(
content=self.content,
der_support_enabled=der_support_enabled,
)
result["signature_algorithm"] = self._get_signature_algorithm()
subject = self._get_subject_ordered()
issuer = self._get_issuer_ordered()
result["subject"] = {}
for k, v in subject:
result["subject"][k] = v
result["subject_ordered"] = subject
result["issuer"] = {}
for k, v in issuer:
result["issuer"][k] = v
result["issuer_ordered"] = issuer
result["version"] = self._get_version()
result["key_usage"], result["key_usage_critical"] = self._get_key_usage()
result["extended_key_usage"], result["extended_key_usage_critical"] = (
self._get_extended_key_usage()
)
result["basic_constraints"], result["basic_constraints_critical"] = (
self._get_basic_constraints()
)
result["ocsp_must_staple"], result["ocsp_must_staple_critical"] = (
self._get_ocsp_must_staple()
)
result["subject_alt_name"], result["subject_alt_name_critical"] = (
self._get_subject_alt_name()
)
not_before = self.get_not_before()
not_after = self.get_not_after()
result["not_before"] = not_before.strftime(TIMESTAMP_FORMAT)
result["not_after"] = not_after.strftime(TIMESTAMP_FORMAT)
result["expired"] = not_after < get_now_datetime(
with_timezone=CRYPTOGRAPHY_TIMEZONE
)
result["public_key"] = to_text(self._get_public_key_pem())
public_key_info = get_publickey_info(
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
result.update(
{
"public_key_type": public_key_info["type"],
"public_key_data": public_key_info["public_data"],
"public_key_fingerprints": public_key_info["fingerprints"],
}
)
result["fingerprints"] = get_fingerprint_of_bytes(
self._get_der_bytes(), prefer_one=prefer_one_fingerprint
)
ski_bytes = self._get_subject_key_identifier()
if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
else:
ski = None
result["subject_key_identifier"] = ski
aki_bytes, aci, acsn = self._get_authority_key_identifier()
if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
else:
aki = None
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
result["serial_number"] = self._get_serial_number()
result["extensions_by_oid"] = self._get_all_extensions()
result["ocsp_uri"] = self._get_ocsp_uri()
result["issuer_uri"] = self._get_issuer_uri()
return result
def get_certificate_info(
*,
@@ -471,7 +381,7 @@ def get_certificate_info(
content: bytes,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CertificateInfoRetrievalCryptography(module=module, content=content)
info = CertificateInfoRetrieval(module=module, content=content)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
@@ -481,7 +391,7 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateInfoRetrievalCryptography(module=module, content=content)
return CertificateInfoRetrieval(module=module, content=content)
__all__ = ("CertificateInfoRetrieval", "get_certificate_info", "select_backend")

View File

@@ -8,7 +8,6 @@
from __future__ import annotations
import abc
import binascii
import typing as t
@@ -86,7 +85,57 @@ class CertificateSigningRequestError(OpenSSLObjectError):
# - module.fail_json(msg: str, **kwargs)
class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def parse_crl_distribution_points(
*, module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
) -> list[cryptography.x509.DistributionPoint]:
result = []
for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
try:
full_name = None
relative_name = None
crl_issuer = None
reasons = None
if parse_crl_distribution_point["full_name"] is not None:
if not parse_crl_distribution_point["full_name"]:
raise OpenSSLObjectError("full_name must not be empty")
full_name = [
cryptography_get_name(name, what="full name")
for name in parse_crl_distribution_point["full_name"]
]
if parse_crl_distribution_point["relative_name"] is not None:
if not parse_crl_distribution_point["relative_name"]:
raise OpenSSLObjectError("relative_name must not be empty")
relative_name = cryptography_parse_relative_distinguished_name(
parse_crl_distribution_point["relative_name"]
)
if parse_crl_distribution_point["crl_issuer"] is not None:
if not parse_crl_distribution_point["crl_issuer"]:
raise OpenSSLObjectError("crl_issuer must not be empty")
crl_issuer = [
cryptography_get_name(name, what="CRL issuer")
for name in parse_crl_distribution_point["crl_issuer"]
]
if parse_crl_distribution_point["reasons"] is not None:
reasons_list = []
for reason in parse_crl_distribution_point["reasons"]:
reasons_list.append(REVOCATION_REASON_MAP[reason])
reasons = frozenset(reasons_list)
result.append(
cryptography.x509.DistributionPoint(
full_name=full_name,
relative_name=relative_name,
crl_issuer=crl_issuer,
reasons=reasons,
)
)
except (OpenSSLObjectError, ValueError) as e:
raise OpenSSLObjectError(
f"Error while parsing CRL distribution point #{index}: {e}"
) from e
return result
class CertificateSigningRequestBackend:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.digest: str = module.params["digest"]
@@ -214,6 +263,14 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self.diff_before = self._get_info(data=None)
self.diff_after = self._get_info(data=None)
crl_distribution_points: list[dict[str, t.Any]] | None = module.params[
"crl_distribution_points"
]
if crl_distribution_points:
self.crl_distribution_points = parse_crl_distribution_points(
module=module, crl_distribution_points=crl_distribution_points
)
def _get_info(self, *, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
@@ -229,147 +286,6 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
except Exception:
return {"can_parse_csr": False}
@abc.abstractmethod
def generate_csr(self) -> None:
"""(Re-)Generate CSR."""
@abc.abstractmethod
def get_csr_data(self) -> bytes:
"""Return bytes for self.csr."""
def set_existing(self, *, csr_bytes: bytes | None) -> None:
"""Set existing CSR bytes. None indicates that the CSR does not exist."""
self.existing_csr_bytes = csr_bytes
self.diff_after = self.diff_before = self._get_info(
data=self.existing_csr_bytes
)
def has_existing(self) -> bool:
"""Query whether an existing CSR is/has been there."""
return self.existing_csr_bytes is not None
def _ensure_private_key_loaded(self) -> None:
"""Load the provided private key into self.privatekey."""
if self.privatekey is not None:
return
try:
self.privatekey = load_certificate_issuer_privatekey(
path=self.privatekey_path,
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
)
except OpenSSLBadPassphraseError as exc:
raise CertificateSigningRequestError(exc) from exc
@abc.abstractmethod
def _check_csr(self) -> bool:
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
def needs_regeneration(self) -> bool:
"""Check whether a regeneration is necessary."""
if self.existing_csr_bytes is None:
return True
try:
self.existing_csr = load_certificate_request(
content=self.existing_csr_bytes,
)
except Exception:
return True
self._ensure_private_key_loaded()
return not self._check_csr()
def dump(self, *, include_csr: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
"subject": self.subject,
"subjectAltName": self.subject_alt_name,
"keyUsage": self.key_usage,
"extendedKeyUsage": self.extended_key_usage,
"basicConstraints": self.basic_constraints,
"ocspMustStaple": self.ocsp_must_staple,
"name_constraints_permitted": self.name_constraints_permitted,
"name_constraints_excluded": self.name_constraints_excluded,
}
# Get hold of CSR bytes
csr_bytes = self.existing_csr_bytes
if self.csr is not None:
csr_bytes = self.get_csr_data()
self.diff_after = self._get_info(data=csr_bytes)
if include_csr:
# Store result
result["csr"] = csr_bytes.decode("utf-8") if csr_bytes else None
result["diff"] = {
"before": self.diff_before,
"after": self.diff_after,
}
return result
def parse_crl_distribution_points(
*, module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
) -> list[cryptography.x509.DistributionPoint]:
result = []
for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
try:
full_name = None
relative_name = None
crl_issuer = None
reasons = None
if parse_crl_distribution_point["full_name"] is not None:
if not parse_crl_distribution_point["full_name"]:
raise OpenSSLObjectError("full_name must not be empty")
full_name = [
cryptography_get_name(name, what="full name")
for name in parse_crl_distribution_point["full_name"]
]
if parse_crl_distribution_point["relative_name"] is not None:
if not parse_crl_distribution_point["relative_name"]:
raise OpenSSLObjectError("relative_name must not be empty")
relative_name = cryptography_parse_relative_distinguished_name(
parse_crl_distribution_point["relative_name"]
)
if parse_crl_distribution_point["crl_issuer"] is not None:
if not parse_crl_distribution_point["crl_issuer"]:
raise OpenSSLObjectError("crl_issuer must not be empty")
crl_issuer = [
cryptography_get_name(name, what="CRL issuer")
for name in parse_crl_distribution_point["crl_issuer"]
]
if parse_crl_distribution_point["reasons"] is not None:
reasons_list = []
for reason in parse_crl_distribution_point["reasons"]:
reasons_list.append(REVOCATION_REASON_MAP[reason])
reasons = frozenset(reasons_list)
result.append(
cryptography.x509.DistributionPoint(
full_name=full_name,
relative_name=relative_name,
crl_issuer=crl_issuer,
reasons=reasons,
)
)
except (OpenSSLObjectError, ValueError) as e:
raise OpenSSLObjectError(
f"Error while parsing CRL distribution point #{index}: {e}"
) from e
return result
# Implementation with using cryptography
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
def __init__(self, *, module: AnsibleModule) -> None:
super().__init__(module=module)
crl_distribution_points: list[dict[str, t.Any]] | None = module.params[
"crl_distribution_points"
]
if crl_distribution_points:
self.crl_distribution_points = parse_crl_distribution_points(
module=module, crl_distribution_points=crl_distribution_points
)
def generate_csr(self) -> None:
"""(Re-)Generate CSR."""
self._ensure_private_key_loaded()
@@ -542,6 +458,30 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
cryptography.hazmat.primitives.serialization.Encoding.PEM
)
def set_existing(self, *, csr_bytes: bytes | None) -> None:
"""Set existing CSR bytes. None indicates that the CSR does not exist."""
self.existing_csr_bytes = csr_bytes
self.diff_after = self.diff_before = self._get_info(
data=self.existing_csr_bytes
)
def has_existing(self) -> bool:
"""Query whether an existing CSR is/has been there."""
return self.existing_csr_bytes is not None
def _ensure_private_key_loaded(self) -> None:
"""Load the provided private key into self.privatekey."""
if self.privatekey is not None:
return
try:
self.privatekey = load_certificate_issuer_privatekey(
path=self.privatekey_path,
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
)
except OpenSSLBadPassphraseError as exc:
raise CertificateSigningRequestError(exc) from exc
def _check_csr(self) -> bool:
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
if self.existing_csr is None:
@@ -795,14 +735,55 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
and _check_signature(self.existing_csr)
)
def needs_regeneration(self) -> bool:
"""Check whether a regeneration is necessary."""
if self.existing_csr_bytes is None:
return True
try:
self.existing_csr = load_certificate_request(
content=self.existing_csr_bytes,
)
except Exception:
return True
self._ensure_private_key_loaded()
return not self._check_csr()
def dump(self, *, include_csr: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
"subject": self.subject,
"subjectAltName": self.subject_alt_name,
"keyUsage": self.key_usage,
"extendedKeyUsage": self.extended_key_usage,
"basicConstraints": self.basic_constraints,
"ocspMustStaple": self.ocsp_must_staple,
"name_constraints_permitted": self.name_constraints_permitted,
"name_constraints_excluded": self.name_constraints_excluded,
}
# Get hold of CSR bytes
csr_bytes = self.existing_csr_bytes
if self.csr is not None:
csr_bytes = self.get_csr_data()
self.diff_after = self._get_info(data=csr_bytes)
if include_csr:
# Store result
result["csr"] = csr_bytes.decode("utf-8") if csr_bytes else None
result["diff"] = {
"before": self.diff_before,
"after": self.diff_after,
}
return result
def select_backend(
module: AnsibleModule,
) -> CertificateSigningRequestCryptographyBackend:
) -> CertificateSigningRequestBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateSigningRequestCryptographyBackend(module=module)
return CertificateSigningRequestBackend(module=module)
def get_csr_argument_spec() -> ArgumentSpec:

View File

@@ -9,7 +9,6 @@
from __future__ import annotations
import abc
import binascii
import typing as t
@@ -60,7 +59,7 @@ except ImportError:
TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CSRInfoRetrieval(metaclass=abc.ABCMeta):
class CSRInfoRetrieval:
csr: x509.CertificateSigningRequest
def __init__(
@@ -69,139 +68,6 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
self.module = module
self.content = content
self.validate_signature = validate_signature
@abc.abstractmethod
def _get_subject_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
pass
@abc.abstractmethod
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_public_key_pem(self) -> bytes:
pass
@abc.abstractmethod
def _get_public_key_object(self) -> CertificatePublicKeyTypes:
pass
@abc.abstractmethod
def _get_subject_key_identifier(self) -> bytes | None:
pass
@abc.abstractmethod
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
pass
@abc.abstractmethod
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
pass
@abc.abstractmethod
def _is_signature_valid(self) -> bool:
pass
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.csr = load_certificate_request(
content=self.content,
)
subject = self._get_subject_ordered()
result["subject"] = {}
for k, v in subject:
result["subject"][k] = v
result["subject_ordered"] = subject
result["key_usage"], result["key_usage_critical"] = self._get_key_usage()
result["extended_key_usage"], result["extended_key_usage_critical"] = (
self._get_extended_key_usage()
)
result["basic_constraints"], result["basic_constraints_critical"] = (
self._get_basic_constraints()
)
result["ocsp_must_staple"], result["ocsp_must_staple_critical"] = (
self._get_ocsp_must_staple()
)
result["subject_alt_name"], result["subject_alt_name_critical"] = (
self._get_subject_alt_name()
)
(
result["name_constraints_permitted"],
result["name_constraints_excluded"],
result["name_constraints_critical"],
) = self._get_name_constraints()
result["public_key"] = to_text(self._get_public_key_pem())
public_key_info = get_publickey_info(
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
result.update(
{
"public_key_type": public_key_info["type"],
"public_key_data": public_key_info["public_data"],
"public_key_fingerprints": public_key_info["fingerprints"],
}
)
ski_bytes = self._get_subject_key_identifier()
ski = None
if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
result["subject_key_identifier"] = ski
aki_bytes, aci, acsn = self._get_authority_key_identifier()
aki = None
if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
result["extensions_by_oid"] = self._get_all_extensions()
result["signature_valid"] = self._is_signature_valid()
if self.validate_signature and not result["signature_valid"]:
self.module.fail_json(msg="CSR signature is invalid!", **result)
return result
class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
"""Validate the supplied CSR, using the cryptography backend"""
def __init__(
self, *, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
super().__init__(
module=module, content=content, validate_signature=validate_signature
)
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get(
"name_encoding", "ignore"
)
@@ -371,6 +237,74 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
def _is_signature_valid(self) -> bool:
return self.csr.is_signature_valid
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.csr = load_certificate_request(
content=self.content,
)
subject = self._get_subject_ordered()
result["subject"] = {}
for k, v in subject:
result["subject"][k] = v
result["subject_ordered"] = subject
result["key_usage"], result["key_usage_critical"] = self._get_key_usage()
result["extended_key_usage"], result["extended_key_usage_critical"] = (
self._get_extended_key_usage()
)
result["basic_constraints"], result["basic_constraints_critical"] = (
self._get_basic_constraints()
)
result["ocsp_must_staple"], result["ocsp_must_staple_critical"] = (
self._get_ocsp_must_staple()
)
result["subject_alt_name"], result["subject_alt_name_critical"] = (
self._get_subject_alt_name()
)
(
result["name_constraints_permitted"],
result["name_constraints_excluded"],
result["name_constraints_critical"],
) = self._get_name_constraints()
result["public_key"] = to_text(self._get_public_key_pem())
public_key_info = get_publickey_info(
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
result.update(
{
"public_key_type": public_key_info["type"],
"public_key_data": public_key_info["public_data"],
"public_key_fingerprints": public_key_info["fingerprints"],
}
)
ski_bytes = self._get_subject_key_identifier()
ski = None
if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
result["subject_key_identifier"] = ski
aki_bytes, aci, acsn = self._get_authority_key_identifier()
aki = None
if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
result["extensions_by_oid"] = self._get_all_extensions()
result["signature_valid"] = self._is_signature_valid()
if self.validate_signature and not result["signature_valid"]:
self.module.fail_json(msg="CSR signature is invalid!", **result)
return result
def get_csr_info(
*,
@@ -379,7 +313,7 @@ def get_csr_info(
validate_signature: bool = True,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CSRInfoRetrievalCryptography(
info = CSRInfoRetrieval(
module=module, content=content, validate_signature=validate_signature
)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
@@ -391,7 +325,7 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CSRInfoRetrievalCryptography(
return CSRInfoRetrieval(
module=module, content=content, validate_signature=validate_signature
)

View File

@@ -8,7 +8,6 @@
from __future__ import annotations
import abc
import base64
import traceback
import typing as t
@@ -79,7 +78,56 @@ class PrivateKeyError(OpenSSLObjectError):
# - module.fail_json(msg: str, **kwargs)
class PrivateKeyBackend(metaclass=abc.ABCMeta):
class _Curve:
def __init__(
self,
*,
name: str,
ectype: str,
deprecated: bool,
) -> None:
self.name = name
self.ectype = ectype
self.deprecated = deprecated
def _get_ec_class(
self, *, module: GeneralAnsibleModule
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
ecclass: (
type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve] | None
) = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype)
if ecclass is None:
module.fail_json(
msg=f"Your cryptography version does not support {self.ectype}"
)
return ecclass
def create(
self, *, size: int, module: GeneralAnsibleModule
) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
ecclass = self._get_ec_class(module=module)
return ecclass()
def verify(
self,
*,
privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
module: GeneralAnsibleModule,
) -> bool:
ecclass = self._get_ec_class(module=module)
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
class PrivateKeyBackend:
def _add_curve(
self,
name: str,
ectype: str,
*,
deprecated: bool = False,
) -> None:
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
def __init__(self, *, module: GeneralAnsibleModule) -> None:
self.module = module
self.type: t.Literal[
@@ -107,6 +155,27 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
self.diff_before = self._get_info(data=None)
self.diff_after = self._get_info(data=None)
self.curves: dict[str, _Curve] = {}
self._add_curve("secp224r1", "SECP224R1")
self._add_curve("secp256k1", "SECP256K1")
self._add_curve("secp256r1", "SECP256R1")
self._add_curve("secp384r1", "SECP384R1")
self._add_curve("secp521r1", "SECP521R1")
self._add_curve("secp192r1", "SECP192R1", deprecated=True)
self._add_curve("sect163k1", "SECT163K1", deprecated=True)
self._add_curve("sect163r2", "SECT163R2", deprecated=True)
self._add_curve("sect233k1", "SECT233K1", deprecated=True)
self._add_curve("sect233r1", "SECT233R1", deprecated=True)
self._add_curve("sect283k1", "SECT283K1", deprecated=True)
self._add_curve("sect283r1", "SECT283R1", deprecated=True)
self._add_curve("sect409k1", "SECT409K1", deprecated=True)
self._add_curve("sect409r1", "SECT409R1", deprecated=True)
self._add_curve("sect571k1", "SECT571K1", deprecated=True)
self._add_curve("sect571r1", "SECT571R1", deprecated=True)
self._add_curve("brainpoolP256r1", "BrainpoolP256R1", deprecated=True)
self._add_curve("brainpoolP384r1", "BrainpoolP384R1", deprecated=True)
self._add_curve("brainpoolP512r1", "BrainpoolP512R1", deprecated=True)
def _get_info(self, *, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
@@ -129,9 +198,61 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
pass
return result
@abc.abstractmethod
def _get_wanted_format(self) -> t.Literal["pkcs1", "pkcs8", "raw"]:
if self.format not in ("auto", "auto_ignore"):
return self.format # type: ignore
if self.type in ("X25519", "X448", "Ed25519", "Ed448"):
return "pkcs8"
return "pkcs1"
def generate_private_key(self) -> None:
"""(Re-)Generate private key."""
try:
if self.type == "RSA":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
public_exponent=65537, # OpenSSL always uses this
key_size=self.size,
)
)
if self.type == "DSA":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.dsa.generate_private_key(
key_size=self.size
)
)
if self.type == "X25519":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.generate()
)
if self.type == "X448":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.generate()
)
if self.type == "Ed25519":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.generate()
)
if self.type == "Ed448":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate()
)
if self.type == "ECC" and self.curve in self.curves:
if self.curves[self.curve].deprecated:
self.module.warn(
f"Elliptic curves of type {self.curve} should not be used for new keys!"
)
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
curve=self.curves[self.curve].create(
size=self.size, module=self.module
),
)
)
except cryptography.exceptions.UnsupportedAlgorithm:
self.module.fail_json(
msg=f"Cryptography backend does not support the algorithm required for {self.type}"
)
def convert_private_key(self) -> None:
"""Convert existing private key (self.existing_private_key) to new private key (self.private_key).
@@ -143,9 +264,68 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
self._ensure_existing_private_key_loaded()
self.private_key = self.existing_private_key
@abc.abstractmethod
def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key."""
"""Return bytes for self.private_key"""
if self.private_key is None:
raise AssertionError("private_key not set")
# Select export format and encoding
try:
export_format_txt = self._get_wanted_format()
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
if export_format_txt == "pkcs1":
# "TraditionalOpenSSL" format is PKCS1
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
)
elif export_format_txt == "pkcs8":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
)
elif export_format_txt == "raw":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
)
export_encoding = (
cryptography.hazmat.primitives.serialization.Encoding.Raw
)
else:
# pylint does not notice that all possible values for export_format_txt have been covered.
raise AssertionError("Can never be reached") # pragma: no cover
except AttributeError:
self.module.fail_json(
msg=f'Cryptography backend does not support the selected output format "{self.format}"'
)
# Select key encryption
encryption_algorithm: (
cryptography.hazmat.primitives.serialization.KeySerializationEncryption
) = cryptography.hazmat.primitives.serialization.NoEncryption()
if self.cipher and self.passphrase:
if self.cipher == "auto":
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
to_bytes(self.passphrase)
)
else:
self.module.fail_json(
msg='Cryptography backend can only use "auto" for cipher option.'
)
# Serialize key
try:
return self.private_key.private_bytes(
encoding=export_encoding,
format=export_format,
encryption_algorithm=encryption_algorithm,
)
except ValueError:
self.module.fail_json(
msg=f'Cryptography backend cannot serialize the private key in the required format "{self.format}"'
)
except Exception:
self.module.fail_json(
msg=f'Error while serializing the private key in the required format "{self.format}"',
exception=traceback.format_exc(),
)
def set_existing(self, *, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
@@ -158,21 +338,136 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
"""Query whether an existing private key is/has been there."""
return self.existing_private_key_bytes is not None
@abc.abstractmethod
def _check_passphrase(self) -> bool:
"""Check whether provided passphrase matches, assuming self.existing_private_key_bytes has been populated."""
def _load_privatekey(self) -> PrivateKeyTypes:
data = self.existing_private_key_bytes
if data is None:
raise AssertionError("existing_private_key_bytes not set")
try:
# Interpret bytes depending on format.
key_format = identify_private_key_format(data)
if key_format == "raw":
if len(data) == 56:
return cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.from_private_bytes(
data
)
if len(data) == 57:
return cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.from_private_bytes(
data
)
if len(data) == 32:
if self.type == "X25519":
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
)
if self.type == "Ed25519":
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
)
try:
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
)
except Exception:
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
)
raise PrivateKeyError("Cannot load raw key")
return cryptography.hazmat.primitives.serialization.load_pem_private_key(
data,
None if self.passphrase is None else to_bytes(self.passphrase),
)
except Exception as e:
raise PrivateKeyError(e) from e
@abc.abstractmethod
def _ensure_existing_private_key_loaded(self) -> None:
"""Make sure that self.existing_private_key is populated from self.existing_private_key_bytes."""
if self.existing_private_key is None and self.has_existing():
self.existing_private_key = self._load_privatekey()
def _check_passphrase(self) -> bool:
"""Check whether provided passphrase matches, assuming self.existing_private_key_bytes has been populated."""
if self.existing_private_key_bytes is None:
raise AssertionError("existing_private_key_bytes not set")
try:
key_format = identify_private_key_format(self.existing_private_key_bytes)
if key_format == "raw":
# Raw keys cannot be encrypted. To avoid incompatibilities, we try to
# actually load the key (and return False when this fails).
self._load_privatekey()
# Loading the key succeeded. Only return True when no passphrase was
# provided.
return self.passphrase is None
return bool(
cryptography.hazmat.primitives.serialization.load_pem_private_key(
self.existing_private_key_bytes,
None if self.passphrase is None else to_bytes(self.passphrase),
)
)
except Exception:
return False
@abc.abstractmethod
def _check_size_and_type(self) -> bool:
"""Check whether provided size and type matches, assuming self.existing_private_key has been populated."""
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey,
):
return (
self.type == "RSA" and self.size == self.existing_private_key.key_size
)
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey,
):
return (
self.type == "DSA" and self.size == self.existing_private_key.key_size
)
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
):
return self.type == "X25519"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
):
return self.type == "X448"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
):
return self.type == "Ed25519"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey,
):
return self.type == "Ed448"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
):
if self.type != "ECC":
return False
if self.curve not in self.curves:
return False
return self.curves[self.curve].verify(
privatekey=self.existing_private_key, module=self.module
)
return False
@abc.abstractmethod
def _check_format(self) -> bool:
"""Check whether the key file format, assuming self.existing_private_key and self.existing_private_key_bytes has been populated."""
if self.existing_private_key_bytes is None:
raise AssertionError("existing_private_key_bytes not set")
if self.format == "auto_ignore":
return True
try:
key_format = identify_private_key_format(self.existing_private_key_bytes)
return key_format == self._get_wanted_format()
except Exception:
return False
def needs_regeneration(self) -> bool:
"""Check whether a regeneration is necessary."""
@@ -272,333 +567,11 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
return result
class _Curve:
def __init__(
self,
*,
name: str,
ectype: str,
deprecated: bool,
) -> None:
self.name = name
self.ectype = ectype
self.deprecated = deprecated
def _get_ec_class(
self, *, module: GeneralAnsibleModule
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
ecclass: (
type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve] | None
) = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype)
if ecclass is None:
module.fail_json(
msg=f"Your cryptography version does not support {self.ectype}"
)
return ecclass
def create(
self, *, size: int, module: GeneralAnsibleModule
) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
ecclass = self._get_ec_class(module=module)
return ecclass()
def verify(
self,
*,
privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
module: GeneralAnsibleModule,
) -> bool:
ecclass = self._get_ec_class(module=module)
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
# Implementation with using cryptography
class PrivateKeyCryptographyBackend(PrivateKeyBackend):
def _add_curve(
self,
name: str,
ectype: str,
*,
deprecated: bool = False,
) -> None:
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
def __init__(self, module: GeneralAnsibleModule) -> None:
super().__init__(module=module)
self.curves: dict[str, _Curve] = {}
self._add_curve("secp224r1", "SECP224R1")
self._add_curve("secp256k1", "SECP256K1")
self._add_curve("secp256r1", "SECP256R1")
self._add_curve("secp384r1", "SECP384R1")
self._add_curve("secp521r1", "SECP521R1")
self._add_curve("secp192r1", "SECP192R1", deprecated=True)
self._add_curve("sect163k1", "SECT163K1", deprecated=True)
self._add_curve("sect163r2", "SECT163R2", deprecated=True)
self._add_curve("sect233k1", "SECT233K1", deprecated=True)
self._add_curve("sect233r1", "SECT233R1", deprecated=True)
self._add_curve("sect283k1", "SECT283K1", deprecated=True)
self._add_curve("sect283r1", "SECT283R1", deprecated=True)
self._add_curve("sect409k1", "SECT409K1", deprecated=True)
self._add_curve("sect409r1", "SECT409R1", deprecated=True)
self._add_curve("sect571k1", "SECT571K1", deprecated=True)
self._add_curve("sect571r1", "SECT571R1", deprecated=True)
self._add_curve("brainpoolP256r1", "BrainpoolP256R1", deprecated=True)
self._add_curve("brainpoolP384r1", "BrainpoolP384R1", deprecated=True)
self._add_curve("brainpoolP512r1", "BrainpoolP512R1", deprecated=True)
def _get_wanted_format(self) -> t.Literal["pkcs1", "pkcs8", "raw"]:
if self.format not in ("auto", "auto_ignore"):
return self.format # type: ignore
if self.type in ("X25519", "X448", "Ed25519", "Ed448"):
return "pkcs8"
return "pkcs1"
def generate_private_key(self) -> None:
"""(Re-)Generate private key."""
try:
if self.type == "RSA":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
public_exponent=65537, # OpenSSL always uses this
key_size=self.size,
)
)
if self.type == "DSA":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.dsa.generate_private_key(
key_size=self.size
)
)
if self.type == "X25519":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.generate()
)
if self.type == "X448":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.generate()
)
if self.type == "Ed25519":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.generate()
)
if self.type == "Ed448":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate()
)
if self.type == "ECC" and self.curve in self.curves:
if self.curves[self.curve].deprecated:
self.module.warn(
f"Elliptic curves of type {self.curve} should not be used for new keys!"
)
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
curve=self.curves[self.curve].create(
size=self.size, module=self.module
),
)
)
except cryptography.exceptions.UnsupportedAlgorithm:
self.module.fail_json(
msg=f"Cryptography backend does not support the algorithm required for {self.type}"
)
def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key"""
if self.private_key is None:
raise AssertionError("private_key not set")
# Select export format and encoding
try:
export_format_txt = self._get_wanted_format()
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
if export_format_txt == "pkcs1":
# "TraditionalOpenSSL" format is PKCS1
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
)
elif export_format_txt == "pkcs8":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
)
elif export_format_txt == "raw":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
)
export_encoding = (
cryptography.hazmat.primitives.serialization.Encoding.Raw
)
else:
# pylint does not notice that all possible values for export_format_txt have been covered.
raise AssertionError("Can never be reached") # pragma: no cover
except AttributeError:
self.module.fail_json(
msg=f'Cryptography backend does not support the selected output format "{self.format}"'
)
# Select key encryption
encryption_algorithm: (
cryptography.hazmat.primitives.serialization.KeySerializationEncryption
) = cryptography.hazmat.primitives.serialization.NoEncryption()
if self.cipher and self.passphrase:
if self.cipher == "auto":
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
to_bytes(self.passphrase)
)
else:
self.module.fail_json(
msg='Cryptography backend can only use "auto" for cipher option.'
)
# Serialize key
try:
return self.private_key.private_bytes(
encoding=export_encoding,
format=export_format,
encryption_algorithm=encryption_algorithm,
)
except ValueError:
self.module.fail_json(
msg=f'Cryptography backend cannot serialize the private key in the required format "{self.format}"'
)
except Exception:
self.module.fail_json(
msg=f'Error while serializing the private key in the required format "{self.format}"',
exception=traceback.format_exc(),
)
def _load_privatekey(self) -> PrivateKeyTypes:
data = self.existing_private_key_bytes
if data is None:
raise AssertionError("existing_private_key_bytes not set")
try:
# Interpret bytes depending on format.
key_format = identify_private_key_format(data)
if key_format == "raw":
if len(data) == 56:
return cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.from_private_bytes(
data
)
if len(data) == 57:
return cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.from_private_bytes(
data
)
if len(data) == 32:
if self.type == "X25519":
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
)
if self.type == "Ed25519":
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
)
try:
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
)
except Exception:
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
)
raise PrivateKeyError("Cannot load raw key")
return cryptography.hazmat.primitives.serialization.load_pem_private_key(
data,
None if self.passphrase is None else to_bytes(self.passphrase),
)
except Exception as e:
raise PrivateKeyError(e) from e
def _ensure_existing_private_key_loaded(self) -> None:
if self.existing_private_key is None and self.has_existing():
self.existing_private_key = self._load_privatekey()
def _check_passphrase(self) -> bool:
if self.existing_private_key_bytes is None:
raise AssertionError("existing_private_key_bytes not set")
try:
key_format = identify_private_key_format(self.existing_private_key_bytes)
if key_format == "raw":
# Raw keys cannot be encrypted. To avoid incompatibilities, we try to
# actually load the key (and return False when this fails).
self._load_privatekey()
# Loading the key succeeded. Only return True when no passphrase was
# provided.
return self.passphrase is None
return bool(
cryptography.hazmat.primitives.serialization.load_pem_private_key(
self.existing_private_key_bytes,
None if self.passphrase is None else to_bytes(self.passphrase),
)
)
except Exception:
return False
def _check_size_and_type(self) -> bool:
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey,
):
return (
self.type == "RSA" and self.size == self.existing_private_key.key_size
)
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey,
):
return (
self.type == "DSA" and self.size == self.existing_private_key.key_size
)
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
):
return self.type == "X25519"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
):
return self.type == "X448"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
):
return self.type == "Ed25519"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey,
):
return self.type == "Ed448"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
):
if self.type != "ECC":
return False
if self.curve not in self.curves:
return False
return self.curves[self.curve].verify(
privatekey=self.existing_private_key, module=self.module
)
return False
def _check_format(self) -> bool:
if self.existing_private_key_bytes is None:
raise AssertionError("existing_private_key_bytes not set")
if self.format == "auto_ignore":
return True
try:
key_format = identify_private_key_format(self.existing_private_key_bytes)
return key_format == self._get_wanted_format()
except Exception:
return False
def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyCryptographyBackend(module=module)
return PrivateKeyBackend(module=module)
def get_privatekey_argument_spec() -> ArgumentSpec:

View File

@@ -7,7 +7,6 @@
from __future__ import annotations
import abc
import traceback
import typing as t
@@ -68,7 +67,7 @@ class PrivateKeyError(OpenSSLObjectError):
# - module.fail_json(msg: str, **kwargs)
class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
class PrivateKeyConvertBackend:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.src_path: str | None = module.params["src_path"]
@@ -88,61 +87,6 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
self.dest_private_key: PrivateKeyTypes | None = None
self.dest_private_key_bytes: bytes | None = None
@abc.abstractmethod
def get_private_key_data(self) -> bytes:
"""Return bytes for self.src_private_key in output format."""
def set_existing_destination(self, *, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.dest_private_key_bytes = privatekey_bytes
def has_existing_destination(self) -> bool:
"""Query whether an existing private key is/has been there."""
return self.dest_private_key_bytes is not None
@abc.abstractmethod
def _load_private_key(
self,
*,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
) -> tuple[str, PrivateKeyTypes]:
"""Check whether data can be loaded as a private key with the provided passphrase. Return tuple (type, private_key)."""
def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
dummy, self.src_private_key = self._load_private_key(
data=self.src_private_key_bytes, passphrase=self.src_passphrase
)
if not self.has_existing_destination():
return True
assert self.dest_private_key_bytes is not None
try:
key_format, self.dest_private_key = self._load_private_key(
data=self.dest_private_key_bytes,
passphrase=self.dest_passphrase,
current_hint=self.src_private_key,
)
except Exception:
return True
return key_format != self.format or not cryptography_compare_private_keys(
self.dest_private_key, self.src_private_key
)
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
return {}
# Implementation with using cryptography
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def __init__(self, *, module: AnsibleModule) -> None:
super().__init__(module=module)
def get_private_key_data(self) -> bytes:
"""Return bytes for self.src_private_key in output format"""
if self.src_private_key is None:
@@ -202,6 +146,14 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
exception=traceback.format_exc(),
)
def set_existing_destination(self, *, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.dest_private_key_bytes = privatekey_bytes
def has_existing_destination(self) -> bool:
"""Query whether an existing private key is/has been there."""
return self.dest_private_key_bytes is not None
def _load_private_key(
self,
*,
@@ -209,6 +161,7 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
) -> tuple[str, PrivateKeyTypes]:
"""Check whether data can be loaded as a private key with the provided passphrase. Return tuple (type, private_key)."""
try:
# Interpret bytes depending on format.
key_format = identify_private_key_format(data)
@@ -275,12 +228,39 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
except Exception as e:
raise PrivateKeyError(e) from e
def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
dummy, self.src_private_key = self._load_private_key(
data=self.src_private_key_bytes, passphrase=self.src_passphrase
)
if not self.has_existing_destination():
return True
assert self.dest_private_key_bytes is not None
try:
key_format, self.dest_private_key = self._load_private_key(
data=self.dest_private_key_bytes,
passphrase=self.dest_passphrase,
current_hint=self.src_private_key,
)
except Exception:
return True
return key_format != self.format or not cryptography_compare_private_keys(
self.dest_private_key, self.src_private_key
)
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
return {}
def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyConvertCryptographyBackend(module=module)
return PrivateKeyConvertBackend(module=module)
def get_privatekey_argument_spec() -> ArgumentSpec:

View File

@@ -9,7 +9,6 @@
from __future__ import annotations
import abc
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_text
@@ -207,7 +206,7 @@ class PrivateKeyParseError(OpenSSLObjectError):
self.result = result
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
class PrivateKeyInfoRetrieval:
key: PrivateKeyTypes
def __init__(
@@ -225,21 +224,28 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
self.return_private_key_data = return_private_key_data
self.check_consistency = check_consistency
@abc.abstractmethod
def _get_public_key(self, *, binary: bool) -> bytes:
pass
return self.key.public_key().public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
@abc.abstractmethod
def _get_key_info(
self, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
pass
return _get_cryptography_private_key_info(
self.key, need_private_key_data=need_private_key_data
)
@abc.abstractmethod
def _is_key_consistent(
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
pass
return _is_cryptography_key_consistent(
self.key,
key_public_data=key_public_data,
key_private_data=key_private_data,
warn_func=self.module.warn,
)
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {
@@ -288,38 +294,6 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
return result
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
"""Validate the supplied private key, using the cryptography backend"""
def __init__(
self, *, module: GeneralAnsibleModule, content: bytes, **kwargs
) -> None:
super().__init__(module=module, content=content, **kwargs)
def _get_public_key(self, *, binary: bool) -> bytes:
return self.key.public_key().public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(
self, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return _get_cryptography_private_key_info(
self.key, need_private_key_data=need_private_key_data
)
def _is_key_consistent(
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
return _is_cryptography_key_consistent(
self.key,
key_public_data=key_public_data,
key_private_data=key_private_data,
warn_func=self.module.warn,
)
def get_privatekey_info(
*,
module: GeneralAnsibleModule,
@@ -328,7 +302,7 @@ def get_privatekey_info(
return_private_key_data: bool = False,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PrivateKeyInfoRetrievalCryptography(
info = PrivateKeyInfoRetrieval(
module=module,
content=content,
passphrase=passphrase,
@@ -348,7 +322,7 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyInfoRetrievalCryptography(
return PrivateKeyInfoRetrieval(
module=module,
content=content,
passphrase=passphrase,

View File

@@ -7,7 +7,6 @@
from __future__ import annotations
import abc
import typing as t
from ansible_collections.community.crypto.plugins.module_utils._crypto.basic import (
@@ -105,7 +104,7 @@ class PublicKeyParseError(OpenSSLObjectError):
self.result = result
class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
class PublicKeyInfoRetrieval:
def __init__(
self,
*,
@@ -118,13 +117,18 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
self.content = content
self.key = key
@abc.abstractmethod
def _get_public_key(self, binary: bool) -> bytes:
pass
if self.key is None:
raise AssertionError("key must be set")
return self.key.public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
@abc.abstractmethod
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
pass
if self.key is None:
raise AssertionError("key must be set")
return _get_cryptography_public_key_info(self.key)
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
@@ -147,32 +151,6 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
return result
class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
"""Validate the supplied public key, using the cryptography backend"""
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> None:
super().__init__(module=module, content=content, key=key)
def _get_public_key(self, binary: bool) -> bytes:
if self.key is None:
raise AssertionError("key must be set")
return self.key.public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
if self.key is None:
raise AssertionError("key must be set")
return _get_cryptography_public_key_info(self.key)
def get_publickey_info(
*,
module: GeneralAnsibleModule,
@@ -180,7 +158,7 @@ def get_publickey_info(
key: PublicKeyTypes | None = None,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PublicKeyInfoRetrievalCryptography(module=module, content=content, key=key)
info = PublicKeyInfoRetrieval(module=module, content=content, key=key)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
@@ -193,7 +171,7 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PublicKeyInfoRetrievalCryptography(module=module, content=content, key=key)
return PublicKeyInfoRetrieval(module=module, content=content, key=key)
__all__ = (

View File

@@ -273,7 +273,6 @@ pkcs12:
version_added: "1.0.0"
"""
import abc
import base64
import itertools
import os
@@ -367,7 +366,7 @@ class PkcsError(OpenSSLObjectError):
class Pkcs(OpenSSLObject):
path: str
def __init__(self, module: AnsibleModule, iter_size_default: int = 2048) -> None:
def __init__(self, module: AnsibleModule, iter_size_default: int = 50000) -> None:
super().__init__(
path=module.params["path"],
state=module.params["state"],
@@ -451,34 +450,119 @@ class Pkcs(OpenSSLObject):
load_certificate(content=to_bytes(other_cert)) for other_cert in certs
]
@abc.abstractmethod
if (
self.encryption_level == "compatibility2022"
and not CRYPTOGRAPHY_HAS_COMPATIBILITY2022
):
module.fail_json(
msg="The installed cryptography version does not support encryption_level = compatibility2022."
" You need cryptography >= 38.0.0 and support for SHA1",
exception=CRYPTOGRAPHY_COMPATIBILITY2022_ERR,
)
def generate_bytes(self, module: AnsibleModule) -> bytes:
"""Generate PKCS#12 file archive."""
pkey = None
if self.privatekey_content:
try:
pkey = load_certificate_issuer_privatekey(
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
)
except OpenSSLBadPassphraseError as exc:
raise PkcsError(exc) from exc
cert = None
if self.certificate_content:
cert = load_certificate(content=self.certificate_content)
friendly_name = (
to_bytes(self.friendly_name) if self.friendly_name is not None else None
)
# Store fake object which can be used to retrieve the components back
self.pkcs12 = (pkey, cert, self.other_certificates, friendly_name)
encryption: serialization.KeySerializationEncryption
if not self.passphrase:
encryption = serialization.NoEncryption()
elif self.encryption_level == "compatibility2022":
encryption = (
serialization.PrivateFormat.PKCS12.encryption_builder()
.kdf_rounds(self.iter_size)
.key_cert_algorithm(PBES.PBESv1SHA1And3KeyTripleDESCBC)
.hmac_hash(hashes.SHA1())
.build(to_bytes(self.passphrase))
)
else:
encryption = serialization.BestAvailableEncryption(
to_bytes(self.passphrase)
)
return serialize_key_and_certificates(
friendly_name,
pkey,
cert,
self.other_certificates,
encryption,
)
@abc.abstractmethod
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
pass
try:
private_key, certificate, additional_certificates, friendly_name = (
parse_pkcs12(pkcs12_content, passphrase=self.passphrase)
)
pkey = None
if private_key is not None:
pkey = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
crt = None
if certificate is not None:
crt = certificate.public_bytes(serialization.Encoding.PEM)
other_certs = []
if additional_certificates is not None:
other_certs = [
other_cert.public_bytes(serialization.Encoding.PEM)
for other_cert in additional_certificates
]
return (pkey, crt, other_certs, friendly_name)
except ValueError as exc:
raise PkcsError(exc) from exc
@abc.abstractmethod
def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
pass
return (
pkcs12[0].private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
if pkcs12[0]
else None
)
@abc.abstractmethod
def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
pass
return pkcs12[1].public_bytes(serialization.Encoding.PEM) if pkcs12[1] else None
@abc.abstractmethod
def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
pass
return [
other_cert.public_bytes(serialization.Encoding.PEM)
for other_cert in pkcs12[2]
]
@abc.abstractmethod
def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
pass
return pkcs12[3]
def check(self, module: AnsibleModule, *, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
@@ -628,129 +712,11 @@ class Pkcs(OpenSSLObject):
self.pkcs12_bytes = content
class PkcsCryptography(Pkcs):
def __init__(self, module: AnsibleModule) -> None:
super().__init__(module, iter_size_default=50000)
if (
self.encryption_level == "compatibility2022"
and not CRYPTOGRAPHY_HAS_COMPATIBILITY2022
):
module.fail_json(
msg="The installed cryptography version does not support encryption_level = compatibility2022."
" You need cryptography >= 38.0.0 and support for SHA1",
exception=CRYPTOGRAPHY_COMPATIBILITY2022_ERR,
)
def generate_bytes(self, module: AnsibleModule) -> bytes:
"""Generate PKCS#12 file archive."""
pkey = None
if self.privatekey_content:
try:
pkey = load_certificate_issuer_privatekey(
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
)
except OpenSSLBadPassphraseError as exc:
raise PkcsError(exc) from exc
cert = None
if self.certificate_content:
cert = load_certificate(content=self.certificate_content)
friendly_name = (
to_bytes(self.friendly_name) if self.friendly_name is not None else None
)
# Store fake object which can be used to retrieve the components back
self.pkcs12 = (pkey, cert, self.other_certificates, friendly_name)
encryption: serialization.KeySerializationEncryption
if not self.passphrase:
encryption = serialization.NoEncryption()
elif self.encryption_level == "compatibility2022":
encryption = (
serialization.PrivateFormat.PKCS12.encryption_builder()
.kdf_rounds(self.iter_size)
.key_cert_algorithm(PBES.PBESv1SHA1And3KeyTripleDESCBC)
.hmac_hash(hashes.SHA1())
.build(to_bytes(self.passphrase))
)
else:
encryption = serialization.BestAvailableEncryption(
to_bytes(self.passphrase)
)
return serialize_key_and_certificates(
friendly_name,
pkey,
cert,
self.other_certificates,
encryption,
)
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
try:
private_key, certificate, additional_certificates, friendly_name = (
parse_pkcs12(pkcs12_content, passphrase=self.passphrase)
)
pkey = None
if private_key is not None:
pkey = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
crt = None
if certificate is not None:
crt = certificate.public_bytes(serialization.Encoding.PEM)
other_certs = []
if additional_certificates is not None:
other_certs = [
other_cert.public_bytes(serialization.Encoding.PEM)
for other_cert in additional_certificates
]
return (pkey, crt, other_certs, friendly_name)
except ValueError as exc:
raise PkcsError(exc) from exc
def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
return (
pkcs12[0].private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
if pkcs12[0]
else None
)
def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
return pkcs12[1].public_bytes(serialization.Encoding.PEM) if pkcs12[1] else None
def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
return [
other_cert.public_bytes(serialization.Encoding.PEM)
for other_cert in pkcs12[2]
]
def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
return pkcs12[3]
def select_backend(module: AnsibleModule) -> Pkcs:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PkcsCryptography(module)
return Pkcs(module)
def main() -> t.NoReturn: