Code refactoring (#889)

* Add __all__ to all module and plugin utils.

* Convert quite a few positional args to keyword args.

* Avoid Python 3.8+ syntax.
This commit is contained in:
Felix Fontein
2025-05-16 06:55:57 +02:00
committed by GitHub
parent a5a4e022ba
commit 44bcc8cebc
101 changed files with 1510 additions and 748 deletions

View File

@@ -63,7 +63,7 @@ class CertificateError(OpenSSLObjectError):
class CertificateBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.force: bool = module.params["force"]
@@ -104,7 +104,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return {}
try:
result = get_certificate_info(
self.module, data, prefer_one_fingerprint=True
module=self.module, content=data, prefer_one_fingerprint=True
)
result["can_parse_certificate"] = True
return result
@@ -289,6 +289,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -330,7 +331,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return True
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
@@ -372,7 +373,7 @@ class CertificateProvider(metaclass=abc.ABCMeta):
def select_backend(
module: AnsibleModule, provider: CertificateProvider
*, module: AnsibleModule, provider: CertificateProvider
) -> CertificateBackend:
provider.validate_module_args(module)
@@ -415,3 +416,11 @@ def get_certificate_argument_spec() -> ArgumentSpec:
["privatekey_path", "privatekey_content"],
],
)
__all__ = (
"CertificateError",
"CertificateBackend",
"CertificateProvider",
"get_certificate_argument_spec",
)

View File

@@ -29,8 +29,8 @@ if t.TYPE_CHECKING:
class AcmeCertificateBackend(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(AcmeCertificateBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(AcmeCertificateBackend, self).__init__(module=module)
self.accountkey_path: str = module.params["acme_accountkey_path"]
self.challenge_path: str = module.params["acme_challenge_path"]
self.use_chain: bool = module.params["acme_chain"]
@@ -102,8 +102,10 @@ class AcmeCertificateBackend(CertificateBackend):
raise AssertionError("Contract violation: cert_bytes is None")
return self.cert_bytes
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(AcmeCertificateBackend, self).dump(include_certificate)
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
result = super(AcmeCertificateBackend, self).dump(
include_certificate=include_certificate
)
result["accountkey"] = self.accountkey_path
return result
@@ -123,7 +125,7 @@ class AcmeCertificateProvider(CertificateProvider):
return False
def create_backend(self, module: AnsibleModule) -> AcmeCertificateBackend:
return AcmeCertificateBackend(module)
return AcmeCertificateBackend(module=module)
def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -138,3 +140,10 @@ def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
),
)
)
__all__ = (
"AcmeCertificateBackend",
"AcmeCertificateProvider",
"add_acme_provider_to_argument_spec",
)

View File

@@ -50,12 +50,12 @@ except ImportError:
class EntrustCertificateBackend(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(EntrustCertificateBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(EntrustCertificateBackend, self).__init__(module=module)
self.trackingId = None
self.notAfter = get_relative_time_option(
module.params["entrust_not_after"],
"entrust_not_after",
input_name="entrust_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
@@ -159,6 +159,7 @@ class EntrustCertificateBackend(CertificateBackend):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -229,7 +230,7 @@ class EntrustCertificateProvider(CertificateProvider):
return False
def create_backend(self, module: AnsibleModule) -> EntrustCertificateBackend:
return EntrustCertificateBackend(module)
return EntrustCertificateBackend(module=module)
def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -281,3 +282,10 @@ def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
],
)
)
__all__ = (
"EntrustCertificateBackend",
"EntrustCertificateProvider",
"add_entrust_provider_to_argument_spec",
)

View File

@@ -70,7 +70,7 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
# content must be a bytes string
self.module = module
self.content = content
@@ -158,11 +158,10 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
pass
def get_info(
self, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
self, *, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.cert = load_certificate(
None,
content=self.content,
der_support_enabled=der_support_enabled,
)
@@ -204,7 +203,7 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
result["public_key"] = to_native(self._get_public_key_pem())
public_key_info = get_publickey_info(
self.module,
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
@@ -249,8 +248,10 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
"""Validate the supplied cert, using the cryptography backend"""
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
super(CertificateInfoRetrievalCryptography, self).__init__(module, content)
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
super(CertificateInfoRetrievalCryptography, self).__init__(
module=module, content=content
)
self.name_encoding = module.params.get("name_encoding", "ignore")
def _get_der_bytes(self) -> bytes:
@@ -465,16 +466,22 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
def get_certificate_info(
module: GeneralAnsibleModule, content: bytes, prefer_one_fingerprint: bool = False
*,
module: GeneralAnsibleModule,
content: bytes,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CertificateInfoRetrievalCryptography(module, content)
info = CertificateInfoRetrievalCryptography(module=module, content=content)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
module: GeneralAnsibleModule, content: bytes
*, module: GeneralAnsibleModule, content: bytes
) -> CertificateInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateInfoRetrievalCryptography(module, content)
return CertificateInfoRetrievalCryptography(module=module, content=content)
__all__ = ("CertificateInfoRetrieval", "get_certificate_info", "select_backend")

View File

@@ -62,8 +62,8 @@ except ImportError:
class OwnCACertificateBackendCryptography(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(OwnCACertificateBackendCryptography, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(OwnCACertificateBackendCryptography, self).__init__(module=module)
self.create_subject_key_identifier: t.Literal[
"create_if_not_provided", "always_create", "never_create"
@@ -73,12 +73,12 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
]
self.notBefore = get_relative_time_option(
module.params["ownca_not_before"],
"ownca_not_before",
input_name="ownca_not_before",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.notAfter = get_relative_time_option(
module.params["ownca_not_after"],
"ownca_not_after",
input_name="ownca_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["ownca_digest"])
@@ -220,6 +220,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -233,7 +234,8 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
# Check whether certificate is signed by CA certificate
if not cryptography_verify_certificate_signature(
self.existing_certificate, self.ca_cert.public_key()
certificate=self.existing_certificate,
signer_public_key=self.ca_cert.public_key(),
):
return True
@@ -270,9 +272,9 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
result = super(OwnCACertificateBackendCryptography, self).dump(
include_certificate
include_certificate=include_certificate
)
result.update(
{
@@ -339,7 +341,7 @@ class OwnCACertificateProvider(CertificateProvider):
def create_backend(
self, module: AnsibleModule
) -> OwnCACertificateBackendCryptography:
return OwnCACertificateBackendCryptography(module)
return OwnCACertificateBackendCryptography(module=module)
def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -369,3 +371,10 @@ def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
["ownca_privatekey_path", "ownca_privatekey_content"],
]
)
__all__ = (
"OwnCACertificateBackendCryptography",
"OwnCACertificateProvider",
"add_ownca_provider_to_argument_spec",
)

View File

@@ -58,20 +58,20 @@ except ImportError:
class SelfSignedCertificateBackendCryptography(CertificateBackend):
privatekey: CertificateIssuerPrivateKeyTypes
def __init__(self, module: AnsibleModule) -> None:
super(SelfSignedCertificateBackendCryptography, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(SelfSignedCertificateBackendCryptography, self).__init__(module=module)
self.create_subject_key_identifier: t.Literal[
"create_if_not_provided", "always_create", "never_create"
] = module.params["selfsigned_create_subject_key_identifier"]
self.notBefore = get_relative_time_option(
module.params["selfsigned_not_before"],
"selfsigned_not_before",
input_name="selfsigned_not_before",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.notAfter = get_relative_time_option(
module.params["selfsigned_not_after"],
"selfsigned_not_after",
input_name="selfsigned_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["selfsigned_digest"])
@@ -162,6 +162,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -177,15 +178,16 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
# Check whether certificate is signed by private key
if not cryptography_verify_certificate_signature(
self.existing_certificate, self.privatekey.public_key()
certificate=self.existing_certificate,
signer_public_key=self.privatekey.public_key(),
):
return True
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
result = super(SelfSignedCertificateBackendCryptography, self).dump(
include_certificate
include_certificate=include_certificate
)
if self.module.check_mode:
@@ -239,7 +241,7 @@ class SelfSignedCertificateProvider(CertificateProvider):
def create_backend(
self, module: AnsibleModule
) -> SelfSignedCertificateBackendCryptography:
return SelfSignedCertificateBackendCryptography(module)
return SelfSignedCertificateBackendCryptography(module=module)
def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -261,3 +263,10 @@ def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> Non
),
)
)
__all__ = (
"SelfSignedCertificateBackendCryptography",
"SelfSignedCertificateProvider",
"add_selfsigned_provider_to_argument_spec",
)

View File

@@ -55,6 +55,7 @@ except ImportError:
class CRLInfoRetrieval:
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes,
list_revoked_certificates: bool = True,
@@ -113,12 +114,20 @@ class CRLInfoRetrieval:
def get_crl_info(
module: GeneralAnsibleModule, content: bytes, list_revoked_certificates: bool = True
*,
module: GeneralAnsibleModule,
content: bytes,
list_revoked_certificates: bool = True,
) -> dict[str, t.Any]:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
info = CRLInfoRetrieval(
module, content, list_revoked_certificates=list_revoked_certificates
module=module,
content=content,
list_revoked_certificates=list_revoked_certificates,
)
return info.get_info()
__all__ = ("CRLInfoRetrieval", "get_crl_info")

View File

@@ -87,7 +87,7 @@ class CertificateSigningRequestError(OpenSSLObjectError):
class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.digest: str = module.params["digest"]
self.privatekey_path: str | None = module.params["privatekey_path"]
@@ -158,7 +158,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
try:
if module.params["subject"]:
self.subject = self.subject + parse_name_field(
module.params["subject"], "subject"
module.params["subject"], name_field_name="subject"
)
if module.params["subject_ordered"]:
if self.subject:
@@ -166,7 +166,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
"subject_ordered cannot be combined with any other subject field"
)
self.subject = parse_ordered_name_field(
module.params["subject_ordered"], "subject_ordered"
module.params["subject_ordered"], name_field_name="subject_ordered"
)
self.ordered_subject = True
except ValueError as exc:
@@ -205,16 +205,16 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self.existing_csr: cryptography.x509.CertificateSigningRequest | None = None
self.existing_csr_bytes: bytes | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
self.diff_before = self._get_info(data=None)
self.diff_after = self._get_info(data=None)
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
def _get_info(self, *, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
try:
result = get_csr_info(
self.module,
data,
module=self.module,
content=data,
validate_signature=False,
prefer_one_fingerprint=True,
)
@@ -231,10 +231,12 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def get_csr_data(self) -> bytes:
"""Return bytes for self.csr."""
def set_existing(self, csr_bytes: bytes | None) -> None:
def set_existing(self, *, csr_bytes: bytes | None) -> None:
"""Set existing CSR bytes. None indicates that the CSR does not exist."""
self.existing_csr_bytes = csr_bytes
self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes)
self.diff_after = self.diff_before = self._get_info(
data=self.existing_csr_bytes
)
def has_existing(self) -> bool:
"""Query whether an existing CSR is/has been there."""
@@ -263,7 +265,6 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
return True
try:
self.existing_csr = load_certificate_request(
None,
content=self.existing_csr_bytes,
)
except Exception:
@@ -271,7 +272,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self._ensure_private_key_loaded()
return not self._check_csr()
def dump(self, include_csr: bool) -> dict[str, t.Any]:
def dump(self, *, include_csr: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
@@ -288,7 +289,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
csr_bytes = self.existing_csr_bytes
if self.csr is not None:
csr_bytes = self.get_csr_data()
self.diff_after = self._get_info(csr_bytes)
self.diff_after = self._get_info(data=csr_bytes)
if include_csr:
# Store result
result["csr"] = csr_bytes.decode("utf-8") if csr_bytes else None
@@ -301,7 +302,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def parse_crl_distribution_points(
module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
*, module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
) -> list[cryptography.x509.DistributionPoint]:
result = []
for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
@@ -314,7 +315,7 @@ def parse_crl_distribution_points(
if not parse_crl_distribution_point["full_name"]:
raise OpenSSLObjectError("full_name must not be empty")
full_name = [
cryptography_get_name(name, "full name")
cryptography_get_name(name, what="full name")
for name in parse_crl_distribution_point["full_name"]
]
if parse_crl_distribution_point["relative_name"] is not None:
@@ -327,7 +328,7 @@ def parse_crl_distribution_points(
if not parse_crl_distribution_point["crl_issuer"]:
raise OpenSSLObjectError("crl_issuer must not be empty")
crl_issuer = [
cryptography_get_name(name, "CRL issuer")
cryptography_get_name(name, what="CRL issuer")
for name in parse_crl_distribution_point["crl_issuer"]
]
if parse_crl_distribution_point["reasons"] is not None:
@@ -352,8 +353,10 @@ def parse_crl_distribution_points(
# Implementation with using cryptography
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
def __init__(self, module: AnsibleModule) -> None:
super(CertificateSigningRequestCryptographyBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(CertificateSigningRequestCryptographyBackend, self).__init__(
module=module
)
if self.version != 1:
module.warn(
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
@@ -364,7 +367,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
]
if crl_distribution_points:
self.crl_distribution_points = parse_crl_distribution_points(
module, crl_distribution_points
module=module, crl_distribution_points=crl_distribution_points
)
def generate_csr(self) -> None:
@@ -431,12 +434,16 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
csr = csr.add_extension(
cryptography.x509.NameConstraints(
[
cryptography_get_name(name, "name constraints permitted")
cryptography_get_name(
name, what="name constraints permitted"
)
for name in self.name_constraints_permitted
]
or None,
[
cryptography_get_name(name, "name constraints excluded")
cryptography_get_name(
name, what="name constraints excluded"
)
for name in self.name_constraints_excluded
]
or None,
@@ -473,7 +480,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
issuers = None
if self.authority_cert_issuer is not None:
issuers = [
cryptography_get_name(n, "authority cert issuer")
cryptography_get_name(n, what="authority cert issuer")
for n in self.authority_cert_issuer
]
csr = csr.add_extension(
@@ -679,11 +686,15 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else []
)
nc_perm = [
to_text(cryptography_get_name(altname, "name constraints permitted"))
to_text(
cryptography_get_name(altname, what="name constraints permitted")
)
for altname in self.name_constraints_permitted
]
nc_excl = [
to_text(cryptography_get_name(altname, "name constraints excluded"))
to_text(
cryptography_get_name(altname, what="name constraints excluded")
)
for altname in self.name_constraints_excluded
]
if set(nc_perm) != set(current_nc_perm) or set(nc_excl) != set(
@@ -731,7 +742,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
csr_aci = None
if self.authority_cert_issuer is not None:
aci = [
to_text(cryptography_get_name(n, "authority cert issuer"))
to_text(cryptography_get_name(n, what="authority cert issuer"))
for n in self.authority_cert_issuer
]
if ext.value.authority_cert_issuer is not None:
@@ -798,7 +809,7 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateSigningRequestCryptographyBackend(module)
return CertificateSigningRequestCryptographyBackend(module=module)
def get_csr_argument_spec() -> ArgumentSpec:
@@ -903,3 +914,11 @@ def get_csr_argument_spec() -> ArgumentSpec:
["privatekey_path", "privatekey_content"],
],
)
__all__ = (
"CertificateSigningRequestError",
"CertificateSigningRequestBackend",
"select_backend",
"get_csr_argument_spec",
)

View File

@@ -62,7 +62,7 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CSRInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
self, *, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
self.module = module
self.content = content
@@ -122,10 +122,9 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
def _is_signature_valid(self) -> bool:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.csr = load_certificate_request(
None,
content=self.content,
)
@@ -156,7 +155,7 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
result["public_key"] = to_native(self._get_public_key_pem())
public_key_info = get_publickey_info(
self.module,
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
@@ -196,10 +195,10 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
"""Validate the supplied CSR, using the cryptography backend"""
def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
self, *, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
super(CSRInfoRetrievalCryptography, self).__init__(
module, content, validate_signature
module=module, content=content, validate_signature=validate_signature
)
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get(
"name_encoding", "ignore"
@@ -372,23 +371,27 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
def get_csr_info(
*,
module: GeneralAnsibleModule,
content: bytes,
validate_signature: bool = True,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
module=module, content=content, validate_signature=validate_signature
)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
*, module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
) -> CSRInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
module=module, content=content, validate_signature=validate_signature
)
__all__ = ("CSRInfoRetrieval", "get_csr_info", "select_backend")

View File

@@ -80,7 +80,7 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyBackend(metaclass=abc.ABCMeta):
def __init__(self, module: GeneralAnsibleModule) -> None:
def __init__(self, *, module: GeneralAnsibleModule) -> None:
self.module = module
self.type: t.Literal[
"DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"
@@ -104,18 +104,18 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
self.existing_private_key: PrivateKeyTypes | None = None
self.existing_private_key_bytes: bytes | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
self.diff_before = self._get_info(data=None)
self.diff_after = self._get_info(data=None)
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
def _get_info(self, *, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
result: dict[str, t.Any] = {"can_parse_key": False}
try:
result.update(
get_privatekey_info(
self.module,
data,
module=self.module,
content=data,
passphrase=self.passphrase,
return_private_key_data=False,
prefer_one_fingerprint=True,
@@ -148,11 +148,11 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key."""
def set_existing(self, privatekey_bytes: bytes | None) -> None:
def set_existing(self, *, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.existing_private_key_bytes = privatekey_bytes
self.diff_after = self.diff_before = self._get_info(
self.existing_private_key_bytes
data=self.existing_private_key_bytes
)
def has_existing(self) -> bool:
@@ -235,7 +235,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
return get_fingerprint_of_privatekey(self.existing_private_key)
return None
def dump(self, include_key: bool) -> dict[str, t.Any]:
def dump(self, *, include_key: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
if not self.private_key:
@@ -255,7 +255,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
pk_bytes = self.existing_private_key_bytes
if self.private_key is not None:
pk_bytes = self.get_private_key_data()
self.diff_after = self._get_info(pk_bytes)
self.diff_after = self._get_info(data=pk_bytes)
if include_key:
# Store result
if pk_bytes:
@@ -276,6 +276,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
class _Curve:
def __init__(
self,
*,
name: str,
ectype: str,
deprecated: bool,
@@ -285,7 +286,7 @@ class _Curve:
self.deprecated = deprecated
def _get_ec_class(
self, module: GeneralAnsibleModule
self, *, module: GeneralAnsibleModule
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore
if ecclass is None:
@@ -295,17 +296,18 @@ class _Curve:
return ecclass
def create(
self, size: int, module: GeneralAnsibleModule
self, *, size: int, module: GeneralAnsibleModule
) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
ecclass = self._get_ec_class(module)
ecclass = self._get_ec_class(module=module)
return ecclass()
def verify(
self,
*,
privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
module: GeneralAnsibleModule,
) -> bool:
ecclass = self._get_ec_class(module)
ecclass = self._get_ec_class(module=module)
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
@@ -316,6 +318,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
self,
name: str,
ectype: str,
*,
deprecated: bool = False,
) -> None:
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
@@ -575,7 +578,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
if self.curve not in self.curves:
return False
return self.curves[self.curve].verify(
self.existing_private_key, module=self.module
privatekey=self.existing_private_key, module=self.module
)
return False
@@ -596,7 +599,7 @@ def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyCryptographyBackend(module)
return PrivateKeyCryptographyBackend(module=module)
def get_privatekey_argument_spec() -> ArgumentSpec:
@@ -661,3 +664,11 @@ def get_privatekey_argument_spec() -> ArgumentSpec:
("type", "ECC", ["curve"]),
],
)
__all__ = (
"PrivateKeyError",
"PrivateKeyBackend",
"select_backend",
"get_privatekey_argument_spec",
)

View File

@@ -69,7 +69,7 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.src_path: str | None = module.params["src_path"]
self.src_content: str | None = module.params["src_content"]
@@ -79,7 +79,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
self.src_private_key: PrivateKeyTypes | None = None
if self.src_path is not None:
self.src_private_key_bytes = load_file(self.src_path, module)
self.src_private_key_bytes = load_file(path=self.src_path, module=module)
else:
if self.src_content is None:
raise AssertionError("src_content is None")
@@ -93,7 +93,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
"""Return bytes for self.src_private_key in output format."""
pass
def set_existing_destination(self, privatekey_bytes: bytes | None) -> None:
def set_existing_destination(self, *, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.dest_private_key_bytes = privatekey_bytes
@@ -104,6 +104,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
@abc.abstractmethod
def _load_private_key(
self,
*,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
@@ -113,7 +114,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
dummy, self.src_private_key = self._load_private_key(
self.src_private_key_bytes, self.src_passphrase
data=self.src_private_key_bytes, passphrase=self.src_passphrase
)
if not self.has_existing_destination():
@@ -122,8 +123,8 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
try:
format, self.dest_private_key = self._load_private_key(
self.dest_private_key_bytes,
self.dest_passphrase,
data=self.dest_private_key_bytes,
passphrase=self.dest_passphrase,
current_hint=self.src_private_key,
)
except Exception:
@@ -140,7 +141,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
# Implementation with using cryptography
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module)
def get_private_key_data(self) -> bytes:
@@ -201,6 +202,7 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def _load_private_key(
self,
*,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
@@ -276,7 +278,7 @@ def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyConvertCryptographyBackend(module)
return PrivateKeyConvertCryptographyBackend(module=module)
def get_privatekey_argument_spec() -> ArgumentSpec:
@@ -295,3 +297,11 @@ def get_privatekey_argument_spec() -> ArgumentSpec:
["src_path", "src_content"],
],
)
__all__ = (
"PrivateKeyError",
"PrivateKeyConvertBackend",
"select_backend",
"get_privatekey_argument_spec",
)

View File

@@ -60,7 +60,7 @@ SIGNATURE_TEST_DATA = b"1234"
def _get_cryptography_private_key_info(
key: PrivateKeyTypes, need_private_key_data: bool = False
key: PrivateKeyTypes, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
key_type, key_public_data = _get_cryptography_public_key_info(key.public_key())
key_private_data: dict[str, t.Any] = {}
@@ -84,7 +84,7 @@ def _get_cryptography_private_key_info(
def _check_dsa_consistency(
key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
*, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
# Get parameters
p: int | None = key_public_data.get("p")
@@ -112,10 +112,10 @@ def _check_dsa_consistency(
if (p - 1) % q != 0:
return False
# Check that g**q mod p == 1
if binary_exp_mod(g, q, p) != 1:
if binary_exp_mod(g, q, m=p) != 1:
return False
# Check whether g**x mod p == y
if binary_exp_mod(g, x, p) != y:
if binary_exp_mod(g, x, m=p) != y:
return False
# Check (quickly) whether p or q are not primes
if quick_is_not_prime(q) or quick_is_not_prime(p):
@@ -125,6 +125,7 @@ def _check_dsa_consistency(
def _is_cryptography_key_consistent(
key: PrivateKeyTypes,
*,
key_public_data: dict[str, t.Any],
key_private_data: dict[str, t.Any],
warn_func: t.Callable[[str], None] | None = None,
@@ -135,7 +136,9 @@ def _is_cryptography_key_consistent(
if backend is not None:
return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
result = _check_dsa_consistency(key_public_data, key_private_data)
result = _check_dsa_consistency(
key_public_data=key_public_data, key_private_data=key_private_data
)
if result is not None:
return result
signature = key.sign(
@@ -191,14 +194,14 @@ def _is_cryptography_key_consistent(
class PrivateKeyConsistencyError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PrivateKeyConsistencyError, self).__init__(msg)
self.error_message = msg
self.result = result
class PrivateKeyParseError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PrivateKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
@@ -207,6 +210,7 @@ class PrivateKeyParseError(OpenSSLObjectError):
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -220,22 +224,22 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
self.check_consistency = check_consistency
@abc.abstractmethod
def _get_public_key(self, binary: bool) -> bytes:
def _get_public_key(self, *, binary: bool) -> bytes:
pass
@abc.abstractmethod
def _get_key_info(
self, need_private_key_data: bool = False
self, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
pass
@abc.abstractmethod
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {
"can_parse_key": False,
"key_is_consistent": None,
@@ -253,7 +257,7 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
)
result["can_parse_key"] = True
except OpenSSLObjectError as exc:
raise PrivateKeyParseError(str(exc), result)
raise PrivateKeyParseError(str(exc), result=result)
result["public_key"] = to_native(self._get_public_key(binary=False))
pk = self._get_public_key(binary=True)
@@ -273,7 +277,7 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
if self.check_consistency:
result["key_is_consistent"] = self._is_key_consistent(
key_public_data, key_private_data
key_public_data=key_public_data, key_private_data=key_private_data
)
if result["key_is_consistent"] is False:
# Only fail when it is False, to avoid to fail on None (which means "we do not know")
@@ -281,40 +285,46 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
"Private key is not consistent! (See "
"https://blog.hboeck.de/archives/888-How-I-tricked-Symantec-with-a-Fake-Private-Key.html)"
)
raise PrivateKeyConsistencyError(msg, result)
raise PrivateKeyConsistencyError(msg, result=result)
return result
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
"""Validate the supplied private key, using the cryptography backend"""
def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None:
def __init__(
self, *, module: GeneralAnsibleModule, content: bytes, **kwargs
) -> None:
super(PrivateKeyInfoRetrievalCryptography, self).__init__(
module, content, **kwargs
module=module, content=content, **kwargs
)
def _get_public_key(self, binary: bool) -> bytes:
def _get_public_key(self, *, binary: bool) -> bytes:
return self.key.public_key().public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(
self, need_private_key_data: bool = False
self, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return _get_cryptography_private_key_info(
self.key, need_private_key_data=need_private_key_data
)
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
return _is_cryptography_key_consistent(
self.key, key_public_data, key_private_data, warn_func=self.module.warn
self.key,
key_public_data=key_public_data,
key_private_data=key_private_data,
warn_func=self.module.warn,
)
def get_privatekey_info(
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -322,8 +332,8 @@ def get_privatekey_info(
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PrivateKeyInfoRetrievalCryptography(
module,
content,
module=module,
content=content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
)
@@ -331,6 +341,7 @@ def get_privatekey_info(
def select_backend(
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -341,9 +352,18 @@ def select_backend(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyInfoRetrievalCryptography(
module,
content,
module=module,
content=content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
check_consistency=check_consistency,
)
__all__ = (
"PrivateKeyConsistencyError",
"PrivateKeyParseError",
"PrivateKeyInfoRetrieval",
"get_privatekey_info",
"select_backend",
)

View File

@@ -99,7 +99,7 @@ def _get_cryptography_public_key_info(
class PublicKeyParseError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PublicKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
@@ -108,6 +108,7 @@ class PublicKeyParseError(OpenSSLObjectError):
class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
@@ -125,13 +126,13 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
if self.key is None:
try:
self.key = load_publickey(content=self.content)
except OpenSSLObjectError as e:
raise PublicKeyParseError(str(e), {})
raise PublicKeyParseError(str(e), result={})
pk = self._get_public_key(binary=True)
result["fingerprints"] = (
@@ -151,12 +152,13 @@ class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> None:
super(PublicKeyInfoRetrievalCryptography, self).__init__(
module, content=content, key=key
module=module, content=content, key=key
)
def _get_public_key(self, binary: bool) -> bytes:
@@ -174,16 +176,18 @@ class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
def get_publickey_info(
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
info = PublicKeyInfoRetrievalCryptography(module=module, content=content, key=key)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
@@ -191,4 +195,12 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
return PublicKeyInfoRetrievalCryptography(module=module, content=content, key=key)
__all__ = (
"PublicKeyParseError",
"PublicKeyInfoRetrieval",
"get_publickey_info",
"select_backend",
)