Remove no longer needed backend abstractions. (#912)

This commit is contained in:
Felix Fontein
2025-06-01 09:07:06 +02:00
committed by GitHub
parent d1a8137d91
commit 576a06b5b2
9 changed files with 759 additions and 1062 deletions

View File

@@ -9,7 +9,6 @@
from __future__ import annotations
import abc
import binascii
import typing as t
@@ -69,189 +68,13 @@ except ImportError:
TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
class CertificateInfoRetrieval:
cert: x509.Certificate
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
# content must be a bytes string
self.module = module
self.content = content
@abc.abstractmethod
def _get_der_bytes(self) -> bytes:
pass
@abc.abstractmethod
def _get_signature_algorithm(self) -> str:
pass
@abc.abstractmethod
def _get_subject_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_issuer_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_version(self) -> int | str:
pass
@abc.abstractmethod
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
pass
@abc.abstractmethod
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def get_not_before(self) -> datetime.datetime:
pass
@abc.abstractmethod
def get_not_after(self) -> datetime.datetime:
pass
@abc.abstractmethod
def _get_public_key_pem(self) -> bytes:
pass
@abc.abstractmethod
def _get_public_key_object(self) -> PublicKeyTypes:
pass
@abc.abstractmethod
def _get_subject_key_identifier(self) -> bytes | None:
pass
@abc.abstractmethod
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
pass
@abc.abstractmethod
def _get_serial_number(self) -> int:
pass
@abc.abstractmethod
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
pass
@abc.abstractmethod
def _get_ocsp_uri(self) -> str | None:
pass
@abc.abstractmethod
def _get_issuer_uri(self) -> str | None:
pass
def get_info(
self, *, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.cert = load_certificate(
content=self.content,
der_support_enabled=der_support_enabled,
)
result["signature_algorithm"] = self._get_signature_algorithm()
subject = self._get_subject_ordered()
issuer = self._get_issuer_ordered()
result["subject"] = {}
for k, v in subject:
result["subject"][k] = v
result["subject_ordered"] = subject
result["issuer"] = {}
for k, v in issuer:
result["issuer"][k] = v
result["issuer_ordered"] = issuer
result["version"] = self._get_version()
result["key_usage"], result["key_usage_critical"] = self._get_key_usage()
result["extended_key_usage"], result["extended_key_usage_critical"] = (
self._get_extended_key_usage()
)
result["basic_constraints"], result["basic_constraints_critical"] = (
self._get_basic_constraints()
)
result["ocsp_must_staple"], result["ocsp_must_staple_critical"] = (
self._get_ocsp_must_staple()
)
result["subject_alt_name"], result["subject_alt_name_critical"] = (
self._get_subject_alt_name()
)
not_before = self.get_not_before()
not_after = self.get_not_after()
result["not_before"] = not_before.strftime(TIMESTAMP_FORMAT)
result["not_after"] = not_after.strftime(TIMESTAMP_FORMAT)
result["expired"] = not_after < get_now_datetime(
with_timezone=CRYPTOGRAPHY_TIMEZONE
)
result["public_key"] = to_text(self._get_public_key_pem())
public_key_info = get_publickey_info(
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
result.update(
{
"public_key_type": public_key_info["type"],
"public_key_data": public_key_info["public_data"],
"public_key_fingerprints": public_key_info["fingerprints"],
}
)
result["fingerprints"] = get_fingerprint_of_bytes(
self._get_der_bytes(), prefer_one=prefer_one_fingerprint
)
ski_bytes = self._get_subject_key_identifier()
if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
else:
ski = None
result["subject_key_identifier"] = ski
aki_bytes, aci, acsn = self._get_authority_key_identifier()
if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
else:
aki = None
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
result["serial_number"] = self._get_serial_number()
result["extensions_by_oid"] = self._get_all_extensions()
result["ocsp_uri"] = self._get_ocsp_uri()
result["issuer_uri"] = self._get_issuer_uri()
return result
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
"""Validate the supplied cert, using the cryptography backend"""
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
super().__init__(module=module, content=content)
self.name_encoding = module.params.get("name_encoding", "ignore")
def _get_der_bytes(self) -> bytes:
@@ -464,6 +287,93 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
pass
return None
def get_info(
self, *, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.cert = load_certificate(
content=self.content,
der_support_enabled=der_support_enabled,
)
result["signature_algorithm"] = self._get_signature_algorithm()
subject = self._get_subject_ordered()
issuer = self._get_issuer_ordered()
result["subject"] = {}
for k, v in subject:
result["subject"][k] = v
result["subject_ordered"] = subject
result["issuer"] = {}
for k, v in issuer:
result["issuer"][k] = v
result["issuer_ordered"] = issuer
result["version"] = self._get_version()
result["key_usage"], result["key_usage_critical"] = self._get_key_usage()
result["extended_key_usage"], result["extended_key_usage_critical"] = (
self._get_extended_key_usage()
)
result["basic_constraints"], result["basic_constraints_critical"] = (
self._get_basic_constraints()
)
result["ocsp_must_staple"], result["ocsp_must_staple_critical"] = (
self._get_ocsp_must_staple()
)
result["subject_alt_name"], result["subject_alt_name_critical"] = (
self._get_subject_alt_name()
)
not_before = self.get_not_before()
not_after = self.get_not_after()
result["not_before"] = not_before.strftime(TIMESTAMP_FORMAT)
result["not_after"] = not_after.strftime(TIMESTAMP_FORMAT)
result["expired"] = not_after < get_now_datetime(
with_timezone=CRYPTOGRAPHY_TIMEZONE
)
result["public_key"] = to_text(self._get_public_key_pem())
public_key_info = get_publickey_info(
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
result.update(
{
"public_key_type": public_key_info["type"],
"public_key_data": public_key_info["public_data"],
"public_key_fingerprints": public_key_info["fingerprints"],
}
)
result["fingerprints"] = get_fingerprint_of_bytes(
self._get_der_bytes(), prefer_one=prefer_one_fingerprint
)
ski_bytes = self._get_subject_key_identifier()
if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
else:
ski = None
result["subject_key_identifier"] = ski
aki_bytes, aci, acsn = self._get_authority_key_identifier()
if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
else:
aki = None
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
result["serial_number"] = self._get_serial_number()
result["extensions_by_oid"] = self._get_all_extensions()
result["ocsp_uri"] = self._get_ocsp_uri()
result["issuer_uri"] = self._get_issuer_uri()
return result
def get_certificate_info(
*,
@@ -471,7 +381,7 @@ def get_certificate_info(
content: bytes,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CertificateInfoRetrievalCryptography(module=module, content=content)
info = CertificateInfoRetrieval(module=module, content=content)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
@@ -481,7 +391,7 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateInfoRetrievalCryptography(module=module, content=content)
return CertificateInfoRetrieval(module=module, content=content)
__all__ = ("CertificateInfoRetrieval", "get_certificate_info", "select_backend")