Code refactoring (#889)

* Add __all__ to all module and plugin utils.

* Convert quite a few positional args to keyword args.

* Avoid Python 3.8+ syntax.
This commit is contained in:
Felix Fontein
2025-05-16 06:55:57 +02:00
committed by GitHub
parent a5a4e022ba
commit 44bcc8cebc
101 changed files with 1510 additions and 748 deletions

View File

@@ -93,7 +93,12 @@ def serialize_asn1_string_as_der(value: str) -> bytes:
# We should only do a universal type tag if not IMPLICITLY tagged or the tag class is not universal.
if not tag_type or (tag_type == "EXPLICIT" and tag_class != "U"):
b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value)
b_value = pack_asn1(
tag_class=TagClass.universal,
constructed=False,
tag_number=TagNumber.utf8_string,
b_data=b_value,
)
if tag_type:
tag_class_enum = {
@@ -105,13 +110,22 @@ def serialize_asn1_string_as_der(value: str) -> bytes:
# When adding support for more types this should be looked into further. For now it works with UTF8Strings.
constructed = tag_type == "EXPLICIT" and tag_class_enum != TagClass.universal
b_value = pack_asn1(tag_class_enum, constructed, int(tag_number), b_value)
b_value = pack_asn1(
tag_class=tag_class_enum,
constructed=constructed,
tag_number=int(tag_number),
b_data=b_value,
)
return b_value
def pack_asn1(
tag_class: TagClass, constructed: bool, tag_number: TagNumber | int, b_data: bytes
*,
tag_class: TagClass,
constructed: bool,
tag_number: TagNumber | int,
b_data: bytes,
) -> bytes:
"""Pack the value into an ASN.1 data structure.
@@ -159,3 +173,6 @@ def pack_asn1(
b_asn1_data.extend(length_octets)
return bytes(b_asn1_data) + b_data
__all__ = ("TagClass", "TagNumber", "serialize_asn1_string_as_der", "pack_asn1")

View File

@@ -58,3 +58,6 @@ def obj2txt(openssl_lib, openssl_ffi, obj) -> str:
buf = openssl_ffi.new("char[]", buf_len)
res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1)
return openssl_ffi.buffer(buf, res)[:].decode()
__all__ = ("obj2txt",)

View File

@@ -33,3 +33,6 @@ for alias, original in [("userID", "userId")]:
NORMALIZE_NAMES[alias] = original
NORMALIZE_NAMES_SHORT[alias] = NORMALIZE_NAMES_SHORT[original]
OID_LOOKUP[alias] = OID_LOOKUP[original]
__all__ = ("OID_LOOKUP", "NORMALIZE_NAMES", "NORMALIZE_NAMES_SHORT")

View File

@@ -1175,3 +1175,6 @@ OID_MAP = {
"2.23.43.1.4.11": ("wap-wsg-idm-ecid-wtls11",),
"2.23.43.1.4.12": ("wap-wsg-idm-ecid-wtls12",),
}
__all__ = ("OID_MAP",)

View File

@@ -24,3 +24,6 @@ class OpenSSLObjectError(Exception):
class OpenSSLBadPassphraseError(OpenSSLObjectError):
pass
__all__ = ("HAS_CRYPTOGRAPHY", "OpenSSLObjectError", "OpenSSLBadPassphraseError")

View File

@@ -107,6 +107,7 @@ def cryptography_decode_revoked_certificate(
def cryptography_dump_revoked(
entry: dict[str, t.Any],
*,
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
) -> dict[str, t.Any]:
return {
@@ -174,18 +175,34 @@ def get_invalidity_date(obj: x509.InvalidityDate) -> datetime.datetime:
def set_next_update(
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
builder: x509.CertificateRevocationListBuilder, *, value: datetime.datetime
) -> x509.CertificateRevocationListBuilder:
return builder.next_update(value)
def set_last_update(
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
builder: x509.CertificateRevocationListBuilder, *, value: datetime.datetime
) -> x509.CertificateRevocationListBuilder:
return builder.last_update(value)
def set_revocation_date(
builder: x509.RevokedCertificateBuilder, value: datetime.datetime
builder: x509.RevokedCertificateBuilder, *, value: datetime.datetime
) -> x509.RevokedCertificateBuilder:
return builder.revocation_date(value)
__all__ = (
"REVOCATION_REASON_MAP",
"REVOCATION_REASON_MAP_INVERSE",
"cryptography_decode_revoked_certificate",
"cryptography_dump_revoked",
"cryptography_get_signature_algorithm_oid_from_crl",
"get_next_update",
"get_last_update",
"get_revocation_date",
"get_invalidity_date",
"set_next_update",
"set_last_update",
"set_revocation_date",
)

View File

@@ -263,7 +263,7 @@ def cryptography_name_to_oid(name: str) -> x509.oid.ObjectIdentifier:
def cryptography_oid_to_name(
oid: x509.oid.ObjectIdentifier, short: bool = False
oid: x509.oid.ObjectIdentifier, *, short: bool = False
) -> str:
dotted_string = oid.dotted_string
names = OID_MAP.get(dotted_string)
@@ -315,7 +315,7 @@ def _int_to_byte(value: int) -> bytes:
def _parse_dn_component(
name: bytes, sep: bytes = b",", decode_remainder: bool = True
name: bytes, *, sep: bytes = b",", decode_remainder: bool = True
) -> tuple[x509.NameAttribute, bytes]:
m = DN_COMPONENT_START_RE.match(name)
if not m:
@@ -428,7 +428,9 @@ def _is_ascii(value: str) -> bool:
return False
def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]) -> str:
def _adjust_idn(
value: str, *, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
) -> str:
if idn_rewrite == "ignore" or not value:
return value
if idn_rewrite == "idna" and _is_ascii(value):
@@ -472,19 +474,19 @@ def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"])
def _adjust_idn_email(
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
value: str, *, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
) -> str:
idx = value.find("@")
if idx < 0:
return value
return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite)}"
return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite=idn_rewrite)}"
def _adjust_idn_url(
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
value: str, *, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
) -> str:
url = urlparse(value)
host = _adjust_idn(url.hostname, idn_rewrite) if url.hostname else None
host = _adjust_idn(url.hostname, idn_rewrite=idn_rewrite) if url.hostname else None
if url.username is not None and url.password is not None:
host = f"{url.username}:{url.password}@{host}"
elif url.username is not None:
@@ -504,7 +506,7 @@ def _adjust_idn_url(
def cryptography_get_name(
name: str, what: str = "Subject Alternative Name"
name: str, *, what: str = "Subject Alternative Name"
) -> x509.GeneralName:
"""
Given a name string, returns a cryptography x509.GeneralName object.
@@ -512,17 +514,19 @@ def cryptography_get_name(
"""
try:
if name.startswith("DNS:"):
return x509.DNSName(_adjust_idn(to_text(name[4:]), "idna"))
return x509.DNSName(_adjust_idn(to_text(name[4:]), idn_rewrite="idna"))
if name.startswith("IP:"):
address = to_text(name[3:])
if "/" in address:
return x509.IPAddress(ipaddress.ip_network(address))
return x509.IPAddress(ipaddress.ip_address(address))
if name.startswith("email:"):
return x509.RFC822Name(_adjust_idn_email(to_text(name[6:]), "idna"))
return x509.RFC822Name(
_adjust_idn_email(to_text(name[6:]), idn_rewrite="idna")
)
if name.startswith("URI:"):
return x509.UniformResourceIdentifier(
_adjust_idn_url(to_text(name[4:]), "idna")
_adjust_idn_url(to_text(name[4:]), idn_rewrite="idna")
)
if name.startswith("RID:"):
m = re.match(r"^([0-9]+(?:\.[0-9]+)*)$", to_text(name[4:]))
@@ -585,6 +589,7 @@ def _dn_escape_value(value: str) -> str:
def cryptography_decode_name(
name: x509.GeneralName,
*,
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
) -> str:
"""
@@ -596,15 +601,15 @@ def cryptography_decode_name(
'idn_rewrite must be one of "ignore", "idna", or "unicode"'
)
if isinstance(name, x509.DNSName):
return f"DNS:{_adjust_idn(name.value, idn_rewrite)}"
return f"DNS:{_adjust_idn(name.value, idn_rewrite=idn_rewrite)}"
if isinstance(name, x509.IPAddress):
if isinstance(name.value, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
return f"IP:{name.value.network_address.compressed}/{name.value.prefixlen}"
return f"IP:{name.value.compressed}"
if isinstance(name, x509.RFC822Name):
return f"email:{_adjust_idn_email(name.value, idn_rewrite)}"
return f"email:{_adjust_idn_email(name.value, idn_rewrite=idn_rewrite)}"
if isinstance(name, x509.UniformResourceIdentifier):
return f"URI:{_adjust_idn_url(name.value, idn_rewrite)}"
return f"URI:{_adjust_idn_url(name.value, idn_rewrite=idn_rewrite)}"
if isinstance(name, x509.DirectoryName):
# According to https://datatracker.ietf.org/doc/html/rfc4514.html#section-2.1 the
# list needs to be reversed, and joined by commas
@@ -718,7 +723,7 @@ def cryptography_key_needs_digest_for_signing(
def _compare_public_keys(
key1: PublicKeyTypes, key2: PublicKeyTypes, clazz: type[PublicKeyTypes]
key1: PublicKeyTypes, key2: PublicKeyTypes, *, clazz: type[PublicKeyTypes]
) -> bool | None:
a = isinstance(key1, clazz)
b = isinstance(key2, clazz)
@@ -745,24 +750,24 @@ def cryptography_compare_public_keys(
res = _compare_public_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
clazz=cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
)
if res is not None:
return res
res = _compare_public_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
clazz=cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
)
if res is not None:
return res
res = _compare_public_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey
key1, key2, clazz=cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey
)
if res is not None:
return res
res = _compare_public_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
key1, key2, clazz=cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
)
if res is not None:
return res
@@ -773,7 +778,7 @@ def cryptography_compare_public_keys(
def _compare_private_keys(
key1: PrivateKeyTypes, key2: PrivateKeyTypes, clazz: type[PrivateKeyTypes]
key1: PrivateKeyTypes, key2: PrivateKeyTypes, *, clazz: type[PrivateKeyTypes]
) -> bool | None:
a = isinstance(key1, clazz)
b = isinstance(key2, clazz)
@@ -805,24 +810,26 @@ def cryptography_compare_private_keys(
res = _compare_private_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
clazz=cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
)
if res is not None:
return res
res = _compare_private_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
clazz=cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
)
if res is not None:
return res
res = _compare_private_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey
key1,
key2,
clazz=cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey,
)
if res is not None:
return res
res = _compare_private_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
key1, key2, clazz=cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
)
if res is not None:
return res
@@ -832,7 +839,9 @@ def cryptography_compare_private_keys(
)
def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) -> tuple[
def parse_pkcs12(
pkcs12_bytes: bytes, *, passphrase: bytes | str | None = None
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
@@ -845,15 +854,17 @@ def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) ->
# Main code for cryptography 36.0.0 and forward
if _load_pkcs12 is not None:
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase_bytes)
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=passphrase_bytes)
if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"):
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase_bytes)
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=passphrase_bytes)
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase_bytes)
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase=passphrase_bytes)
def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
def _parse_pkcs12_36_0_0(
pkcs12_bytes: bytes, *, passphrase: bytes | None = None
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
@@ -871,7 +882,9 @@ def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -
return private_key, certificate, additional_certificates, friendly_name
def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
def _parse_pkcs12_35_0_0(
pkcs12_bytes: bytes, *, passphrase: bytes | None = None
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
@@ -918,7 +931,9 @@ def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -
return private_key, certificate, additional_certificates, friendly_name
def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
def _parse_pkcs12_legacy(
pkcs12_bytes: bytes, *, passphrase: bytes | None = None
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
@@ -940,6 +955,7 @@ def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -
def cryptography_verify_signature(
*,
signature: bytes,
data: bytes,
hash_algorithm: hashes.HashAlgorithm | None,
@@ -999,16 +1015,16 @@ def cryptography_verify_signature(
def cryptography_verify_certificate_signature(
certificate: x509.Certificate, signer_public_key: PublicKeyTypes
*, certificate: x509.Certificate, signer_public_key: PublicKeyTypes
) -> bool:
"""
Check whether the given X509 certificate object was signed by the given public key object.
"""
return cryptography_verify_signature(
certificate.signature,
certificate.tbs_certificate_bytes,
certificate.signature_hash_algorithm,
signer_public_key,
signature=certificate.signature,
data=certificate.tbs_certificate_bytes,
hash_algorithm=certificate.signature_hash_algorithm,
signer_public_key=signer_public_key,
)
@@ -1074,3 +1090,31 @@ def is_potential_certificate_issuer_public_key(
cryptography.hazmat.primitives.asymmetric.dh.DHPublicKey,
),
)
__all__ = (
"CRYPTOGRAPHY_TIMEZONE",
"cryptography_get_extensions_from_cert",
"cryptography_get_extensions_from_csr",
"cryptography_name_to_oid",
"cryptography_oid_to_name",
"cryptography_parse_relative_distinguished_name",
"cryptography_get_name",
"cryptography_decode_name",
"cryptography_parse_key_usage_params",
"cryptography_get_basic_constraints",
"cryptography_key_needs_digest_for_signing",
"cryptography_compare_public_keys",
"cryptography_compare_private_keys",
"parse_pkcs12",
"cryptography_verify_signature",
"cryptography_verify_certificate_signature",
"get_not_valid_after",
"get_not_valid_before",
"set_not_valid_after",
"set_not_valid_before",
"is_potential_certificate_private_key",
"is_potential_certificate_issuer_private_key",
"is_potential_certificate_public_key",
"is_potential_certificate_issuer_public_key",
)

View File

@@ -8,7 +8,7 @@
from __future__ import annotations
def binary_exp_mod(f: int, e: int, m: int) -> int:
def binary_exp_mod(f: int, e: int, *, m: int) -> int:
"""Computes f^e mod m in O(log e) multiplications modulo m."""
# Compute len_e = floor(log_2(e))
len_e = -1
@@ -120,7 +120,7 @@ def count_bits(no: int) -> int:
return no.bit_length()
def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
def convert_int_to_bytes(no: int, *, count: int | None = None) -> bytes:
"""
Convert the absolute value of an integer to a byte string in network byte order.
@@ -136,7 +136,7 @@ def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
return no.to_bytes(count, byteorder="big")
def convert_int_to_hex(no: int, digits: int | None = None) -> str:
def convert_int_to_hex(no: int, *, digits: int | None = None) -> str:
"""
Convert the absolute value of an integer to a string of hexadecimal digits.
@@ -156,3 +156,15 @@ def convert_bytes_to_int(data: bytes) -> int:
Convert a byte string to an unsigned integer in network byte order.
"""
return int.from_bytes(data, byteorder="big", signed=False)
__all__ = (
"binary_exp_mod",
"simple_gcd",
"quick_is_not_prime",
"count_bytes",
"count_bits",
"convert_int_to_bytes",
"convert_int_to_hex",
"convert_bytes_to_int",
)

View File

@@ -63,7 +63,7 @@ class CertificateError(OpenSSLObjectError):
class CertificateBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.force: bool = module.params["force"]
@@ -104,7 +104,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return {}
try:
result = get_certificate_info(
self.module, data, prefer_one_fingerprint=True
module=self.module, content=data, prefer_one_fingerprint=True
)
result["can_parse_certificate"] = True
return result
@@ -289,6 +289,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -330,7 +331,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return True
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
@@ -372,7 +373,7 @@ class CertificateProvider(metaclass=abc.ABCMeta):
def select_backend(
module: AnsibleModule, provider: CertificateProvider
*, module: AnsibleModule, provider: CertificateProvider
) -> CertificateBackend:
provider.validate_module_args(module)
@@ -415,3 +416,11 @@ def get_certificate_argument_spec() -> ArgumentSpec:
["privatekey_path", "privatekey_content"],
],
)
__all__ = (
"CertificateError",
"CertificateBackend",
"CertificateProvider",
"get_certificate_argument_spec",
)

View File

@@ -29,8 +29,8 @@ if t.TYPE_CHECKING:
class AcmeCertificateBackend(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(AcmeCertificateBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(AcmeCertificateBackend, self).__init__(module=module)
self.accountkey_path: str = module.params["acme_accountkey_path"]
self.challenge_path: str = module.params["acme_challenge_path"]
self.use_chain: bool = module.params["acme_chain"]
@@ -102,8 +102,10 @@ class AcmeCertificateBackend(CertificateBackend):
raise AssertionError("Contract violation: cert_bytes is None")
return self.cert_bytes
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(AcmeCertificateBackend, self).dump(include_certificate)
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
result = super(AcmeCertificateBackend, self).dump(
include_certificate=include_certificate
)
result["accountkey"] = self.accountkey_path
return result
@@ -123,7 +125,7 @@ class AcmeCertificateProvider(CertificateProvider):
return False
def create_backend(self, module: AnsibleModule) -> AcmeCertificateBackend:
return AcmeCertificateBackend(module)
return AcmeCertificateBackend(module=module)
def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -138,3 +140,10 @@ def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
),
)
)
__all__ = (
"AcmeCertificateBackend",
"AcmeCertificateProvider",
"add_acme_provider_to_argument_spec",
)

View File

@@ -50,12 +50,12 @@ except ImportError:
class EntrustCertificateBackend(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(EntrustCertificateBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(EntrustCertificateBackend, self).__init__(module=module)
self.trackingId = None
self.notAfter = get_relative_time_option(
module.params["entrust_not_after"],
"entrust_not_after",
input_name="entrust_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
@@ -159,6 +159,7 @@ class EntrustCertificateBackend(CertificateBackend):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -229,7 +230,7 @@ class EntrustCertificateProvider(CertificateProvider):
return False
def create_backend(self, module: AnsibleModule) -> EntrustCertificateBackend:
return EntrustCertificateBackend(module)
return EntrustCertificateBackend(module=module)
def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -281,3 +282,10 @@ def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
],
)
)
__all__ = (
"EntrustCertificateBackend",
"EntrustCertificateProvider",
"add_entrust_provider_to_argument_spec",
)

View File

@@ -70,7 +70,7 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
# content must be a bytes string
self.module = module
self.content = content
@@ -158,11 +158,10 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
pass
def get_info(
self, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
self, *, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.cert = load_certificate(
None,
content=self.content,
der_support_enabled=der_support_enabled,
)
@@ -204,7 +203,7 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
result["public_key"] = to_native(self._get_public_key_pem())
public_key_info = get_publickey_info(
self.module,
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
@@ -249,8 +248,10 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
"""Validate the supplied cert, using the cryptography backend"""
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
super(CertificateInfoRetrievalCryptography, self).__init__(module, content)
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
super(CertificateInfoRetrievalCryptography, self).__init__(
module=module, content=content
)
self.name_encoding = module.params.get("name_encoding", "ignore")
def _get_der_bytes(self) -> bytes:
@@ -465,16 +466,22 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
def get_certificate_info(
module: GeneralAnsibleModule, content: bytes, prefer_one_fingerprint: bool = False
*,
module: GeneralAnsibleModule,
content: bytes,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CertificateInfoRetrievalCryptography(module, content)
info = CertificateInfoRetrievalCryptography(module=module, content=content)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
module: GeneralAnsibleModule, content: bytes
*, module: GeneralAnsibleModule, content: bytes
) -> CertificateInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateInfoRetrievalCryptography(module, content)
return CertificateInfoRetrievalCryptography(module=module, content=content)
__all__ = ("CertificateInfoRetrieval", "get_certificate_info", "select_backend")

View File

@@ -62,8 +62,8 @@ except ImportError:
class OwnCACertificateBackendCryptography(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(OwnCACertificateBackendCryptography, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(OwnCACertificateBackendCryptography, self).__init__(module=module)
self.create_subject_key_identifier: t.Literal[
"create_if_not_provided", "always_create", "never_create"
@@ -73,12 +73,12 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
]
self.notBefore = get_relative_time_option(
module.params["ownca_not_before"],
"ownca_not_before",
input_name="ownca_not_before",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.notAfter = get_relative_time_option(
module.params["ownca_not_after"],
"ownca_not_after",
input_name="ownca_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["ownca_digest"])
@@ -220,6 +220,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -233,7 +234,8 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
# Check whether certificate is signed by CA certificate
if not cryptography_verify_certificate_signature(
self.existing_certificate, self.ca_cert.public_key()
certificate=self.existing_certificate,
signer_public_key=self.ca_cert.public_key(),
):
return True
@@ -270,9 +272,9 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
result = super(OwnCACertificateBackendCryptography, self).dump(
include_certificate
include_certificate=include_certificate
)
result.update(
{
@@ -339,7 +341,7 @@ class OwnCACertificateProvider(CertificateProvider):
def create_backend(
self, module: AnsibleModule
) -> OwnCACertificateBackendCryptography:
return OwnCACertificateBackendCryptography(module)
return OwnCACertificateBackendCryptography(module=module)
def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -369,3 +371,10 @@ def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
["ownca_privatekey_path", "ownca_privatekey_content"],
]
)
__all__ = (
"OwnCACertificateBackendCryptography",
"OwnCACertificateProvider",
"add_ownca_provider_to_argument_spec",
)

View File

@@ -58,20 +58,20 @@ except ImportError:
class SelfSignedCertificateBackendCryptography(CertificateBackend):
privatekey: CertificateIssuerPrivateKeyTypes
def __init__(self, module: AnsibleModule) -> None:
super(SelfSignedCertificateBackendCryptography, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(SelfSignedCertificateBackendCryptography, self).__init__(module=module)
self.create_subject_key_identifier: t.Literal[
"create_if_not_provided", "always_create", "never_create"
] = module.params["selfsigned_create_subject_key_identifier"]
self.notBefore = get_relative_time_option(
module.params["selfsigned_not_before"],
"selfsigned_not_before",
input_name="selfsigned_not_before",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.notAfter = get_relative_time_option(
module.params["selfsigned_not_after"],
"selfsigned_not_after",
input_name="selfsigned_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["selfsigned_digest"])
@@ -162,6 +162,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -177,15 +178,16 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
# Check whether certificate is signed by private key
if not cryptography_verify_certificate_signature(
self.existing_certificate, self.privatekey.public_key()
certificate=self.existing_certificate,
signer_public_key=self.privatekey.public_key(),
):
return True
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
result = super(SelfSignedCertificateBackendCryptography, self).dump(
include_certificate
include_certificate=include_certificate
)
if self.module.check_mode:
@@ -239,7 +241,7 @@ class SelfSignedCertificateProvider(CertificateProvider):
def create_backend(
self, module: AnsibleModule
) -> SelfSignedCertificateBackendCryptography:
return SelfSignedCertificateBackendCryptography(module)
return SelfSignedCertificateBackendCryptography(module=module)
def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -261,3 +263,10 @@ def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> Non
),
)
)
__all__ = (
"SelfSignedCertificateBackendCryptography",
"SelfSignedCertificateProvider",
"add_selfsigned_provider_to_argument_spec",
)

View File

@@ -55,6 +55,7 @@ except ImportError:
class CRLInfoRetrieval:
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes,
list_revoked_certificates: bool = True,
@@ -113,12 +114,20 @@ class CRLInfoRetrieval:
def get_crl_info(
module: GeneralAnsibleModule, content: bytes, list_revoked_certificates: bool = True
*,
module: GeneralAnsibleModule,
content: bytes,
list_revoked_certificates: bool = True,
) -> dict[str, t.Any]:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
info = CRLInfoRetrieval(
module, content, list_revoked_certificates=list_revoked_certificates
module=module,
content=content,
list_revoked_certificates=list_revoked_certificates,
)
return info.get_info()
__all__ = ("CRLInfoRetrieval", "get_crl_info")

View File

@@ -87,7 +87,7 @@ class CertificateSigningRequestError(OpenSSLObjectError):
class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.digest: str = module.params["digest"]
self.privatekey_path: str | None = module.params["privatekey_path"]
@@ -158,7 +158,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
try:
if module.params["subject"]:
self.subject = self.subject + parse_name_field(
module.params["subject"], "subject"
module.params["subject"], name_field_name="subject"
)
if module.params["subject_ordered"]:
if self.subject:
@@ -166,7 +166,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
"subject_ordered cannot be combined with any other subject field"
)
self.subject = parse_ordered_name_field(
module.params["subject_ordered"], "subject_ordered"
module.params["subject_ordered"], name_field_name="subject_ordered"
)
self.ordered_subject = True
except ValueError as exc:
@@ -205,16 +205,16 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self.existing_csr: cryptography.x509.CertificateSigningRequest | None = None
self.existing_csr_bytes: bytes | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
self.diff_before = self._get_info(data=None)
self.diff_after = self._get_info(data=None)
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
def _get_info(self, *, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
try:
result = get_csr_info(
self.module,
data,
module=self.module,
content=data,
validate_signature=False,
prefer_one_fingerprint=True,
)
@@ -231,10 +231,12 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def get_csr_data(self) -> bytes:
"""Return bytes for self.csr."""
def set_existing(self, csr_bytes: bytes | None) -> None:
def set_existing(self, *, csr_bytes: bytes | None) -> None:
"""Set existing CSR bytes. None indicates that the CSR does not exist."""
self.existing_csr_bytes = csr_bytes
self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes)
self.diff_after = self.diff_before = self._get_info(
data=self.existing_csr_bytes
)
def has_existing(self) -> bool:
"""Query whether an existing CSR is/has been there."""
@@ -263,7 +265,6 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
return True
try:
self.existing_csr = load_certificate_request(
None,
content=self.existing_csr_bytes,
)
except Exception:
@@ -271,7 +272,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self._ensure_private_key_loaded()
return not self._check_csr()
def dump(self, include_csr: bool) -> dict[str, t.Any]:
def dump(self, *, include_csr: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
@@ -288,7 +289,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
csr_bytes = self.existing_csr_bytes
if self.csr is not None:
csr_bytes = self.get_csr_data()
self.diff_after = self._get_info(csr_bytes)
self.diff_after = self._get_info(data=csr_bytes)
if include_csr:
# Store result
result["csr"] = csr_bytes.decode("utf-8") if csr_bytes else None
@@ -301,7 +302,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def parse_crl_distribution_points(
module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
*, module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
) -> list[cryptography.x509.DistributionPoint]:
result = []
for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
@@ -314,7 +315,7 @@ def parse_crl_distribution_points(
if not parse_crl_distribution_point["full_name"]:
raise OpenSSLObjectError("full_name must not be empty")
full_name = [
cryptography_get_name(name, "full name")
cryptography_get_name(name, what="full name")
for name in parse_crl_distribution_point["full_name"]
]
if parse_crl_distribution_point["relative_name"] is not None:
@@ -327,7 +328,7 @@ def parse_crl_distribution_points(
if not parse_crl_distribution_point["crl_issuer"]:
raise OpenSSLObjectError("crl_issuer must not be empty")
crl_issuer = [
cryptography_get_name(name, "CRL issuer")
cryptography_get_name(name, what="CRL issuer")
for name in parse_crl_distribution_point["crl_issuer"]
]
if parse_crl_distribution_point["reasons"] is not None:
@@ -352,8 +353,10 @@ def parse_crl_distribution_points(
# Implementation with using cryptography
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
def __init__(self, module: AnsibleModule) -> None:
super(CertificateSigningRequestCryptographyBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(CertificateSigningRequestCryptographyBackend, self).__init__(
module=module
)
if self.version != 1:
module.warn(
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
@@ -364,7 +367,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
]
if crl_distribution_points:
self.crl_distribution_points = parse_crl_distribution_points(
module, crl_distribution_points
module=module, crl_distribution_points=crl_distribution_points
)
def generate_csr(self) -> None:
@@ -431,12 +434,16 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
csr = csr.add_extension(
cryptography.x509.NameConstraints(
[
cryptography_get_name(name, "name constraints permitted")
cryptography_get_name(
name, what="name constraints permitted"
)
for name in self.name_constraints_permitted
]
or None,
[
cryptography_get_name(name, "name constraints excluded")
cryptography_get_name(
name, what="name constraints excluded"
)
for name in self.name_constraints_excluded
]
or None,
@@ -473,7 +480,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
issuers = None
if self.authority_cert_issuer is not None:
issuers = [
cryptography_get_name(n, "authority cert issuer")
cryptography_get_name(n, what="authority cert issuer")
for n in self.authority_cert_issuer
]
csr = csr.add_extension(
@@ -679,11 +686,15 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else []
)
nc_perm = [
to_text(cryptography_get_name(altname, "name constraints permitted"))
to_text(
cryptography_get_name(altname, what="name constraints permitted")
)
for altname in self.name_constraints_permitted
]
nc_excl = [
to_text(cryptography_get_name(altname, "name constraints excluded"))
to_text(
cryptography_get_name(altname, what="name constraints excluded")
)
for altname in self.name_constraints_excluded
]
if set(nc_perm) != set(current_nc_perm) or set(nc_excl) != set(
@@ -731,7 +742,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
csr_aci = None
if self.authority_cert_issuer is not None:
aci = [
to_text(cryptography_get_name(n, "authority cert issuer"))
to_text(cryptography_get_name(n, what="authority cert issuer"))
for n in self.authority_cert_issuer
]
if ext.value.authority_cert_issuer is not None:
@@ -798,7 +809,7 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateSigningRequestCryptographyBackend(module)
return CertificateSigningRequestCryptographyBackend(module=module)
def get_csr_argument_spec() -> ArgumentSpec:
@@ -903,3 +914,11 @@ def get_csr_argument_spec() -> ArgumentSpec:
["privatekey_path", "privatekey_content"],
],
)
__all__ = (
"CertificateSigningRequestError",
"CertificateSigningRequestBackend",
"select_backend",
"get_csr_argument_spec",
)

View File

@@ -62,7 +62,7 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CSRInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
self, *, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
self.module = module
self.content = content
@@ -122,10 +122,9 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
def _is_signature_valid(self) -> bool:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.csr = load_certificate_request(
None,
content=self.content,
)
@@ -156,7 +155,7 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
result["public_key"] = to_native(self._get_public_key_pem())
public_key_info = get_publickey_info(
self.module,
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
@@ -196,10 +195,10 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
"""Validate the supplied CSR, using the cryptography backend"""
def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
self, *, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
super(CSRInfoRetrievalCryptography, self).__init__(
module, content, validate_signature
module=module, content=content, validate_signature=validate_signature
)
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get(
"name_encoding", "ignore"
@@ -372,23 +371,27 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
def get_csr_info(
*,
module: GeneralAnsibleModule,
content: bytes,
validate_signature: bool = True,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
module=module, content=content, validate_signature=validate_signature
)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
*, module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
) -> CSRInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
module=module, content=content, validate_signature=validate_signature
)
__all__ = ("CSRInfoRetrieval", "get_csr_info", "select_backend")

View File

@@ -80,7 +80,7 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyBackend(metaclass=abc.ABCMeta):
def __init__(self, module: GeneralAnsibleModule) -> None:
def __init__(self, *, module: GeneralAnsibleModule) -> None:
self.module = module
self.type: t.Literal[
"DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"
@@ -104,18 +104,18 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
self.existing_private_key: PrivateKeyTypes | None = None
self.existing_private_key_bytes: bytes | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
self.diff_before = self._get_info(data=None)
self.diff_after = self._get_info(data=None)
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
def _get_info(self, *, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
result: dict[str, t.Any] = {"can_parse_key": False}
try:
result.update(
get_privatekey_info(
self.module,
data,
module=self.module,
content=data,
passphrase=self.passphrase,
return_private_key_data=False,
prefer_one_fingerprint=True,
@@ -148,11 +148,11 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key."""
def set_existing(self, privatekey_bytes: bytes | None) -> None:
def set_existing(self, *, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.existing_private_key_bytes = privatekey_bytes
self.diff_after = self.diff_before = self._get_info(
self.existing_private_key_bytes
data=self.existing_private_key_bytes
)
def has_existing(self) -> bool:
@@ -235,7 +235,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
return get_fingerprint_of_privatekey(self.existing_private_key)
return None
def dump(self, include_key: bool) -> dict[str, t.Any]:
def dump(self, *, include_key: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
if not self.private_key:
@@ -255,7 +255,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
pk_bytes = self.existing_private_key_bytes
if self.private_key is not None:
pk_bytes = self.get_private_key_data()
self.diff_after = self._get_info(pk_bytes)
self.diff_after = self._get_info(data=pk_bytes)
if include_key:
# Store result
if pk_bytes:
@@ -276,6 +276,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
class _Curve:
def __init__(
self,
*,
name: str,
ectype: str,
deprecated: bool,
@@ -285,7 +286,7 @@ class _Curve:
self.deprecated = deprecated
def _get_ec_class(
self, module: GeneralAnsibleModule
self, *, module: GeneralAnsibleModule
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore
if ecclass is None:
@@ -295,17 +296,18 @@ class _Curve:
return ecclass
def create(
self, size: int, module: GeneralAnsibleModule
self, *, size: int, module: GeneralAnsibleModule
) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
ecclass = self._get_ec_class(module)
ecclass = self._get_ec_class(module=module)
return ecclass()
def verify(
self,
*,
privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
module: GeneralAnsibleModule,
) -> bool:
ecclass = self._get_ec_class(module)
ecclass = self._get_ec_class(module=module)
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
@@ -316,6 +318,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
self,
name: str,
ectype: str,
*,
deprecated: bool = False,
) -> None:
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
@@ -575,7 +578,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
if self.curve not in self.curves:
return False
return self.curves[self.curve].verify(
self.existing_private_key, module=self.module
privatekey=self.existing_private_key, module=self.module
)
return False
@@ -596,7 +599,7 @@ def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyCryptographyBackend(module)
return PrivateKeyCryptographyBackend(module=module)
def get_privatekey_argument_spec() -> ArgumentSpec:
@@ -661,3 +664,11 @@ def get_privatekey_argument_spec() -> ArgumentSpec:
("type", "ECC", ["curve"]),
],
)
__all__ = (
"PrivateKeyError",
"PrivateKeyBackend",
"select_backend",
"get_privatekey_argument_spec",
)

View File

@@ -69,7 +69,7 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.src_path: str | None = module.params["src_path"]
self.src_content: str | None = module.params["src_content"]
@@ -79,7 +79,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
self.src_private_key: PrivateKeyTypes | None = None
if self.src_path is not None:
self.src_private_key_bytes = load_file(self.src_path, module)
self.src_private_key_bytes = load_file(path=self.src_path, module=module)
else:
if self.src_content is None:
raise AssertionError("src_content is None")
@@ -93,7 +93,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
"""Return bytes for self.src_private_key in output format."""
pass
def set_existing_destination(self, privatekey_bytes: bytes | None) -> None:
def set_existing_destination(self, *, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.dest_private_key_bytes = privatekey_bytes
@@ -104,6 +104,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
@abc.abstractmethod
def _load_private_key(
self,
*,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
@@ -113,7 +114,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
dummy, self.src_private_key = self._load_private_key(
self.src_private_key_bytes, self.src_passphrase
data=self.src_private_key_bytes, passphrase=self.src_passphrase
)
if not self.has_existing_destination():
@@ -122,8 +123,8 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
try:
format, self.dest_private_key = self._load_private_key(
self.dest_private_key_bytes,
self.dest_passphrase,
data=self.dest_private_key_bytes,
passphrase=self.dest_passphrase,
current_hint=self.src_private_key,
)
except Exception:
@@ -140,7 +141,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
# Implementation with using cryptography
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module)
def get_private_key_data(self) -> bytes:
@@ -201,6 +202,7 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def _load_private_key(
self,
*,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
@@ -276,7 +278,7 @@ def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyConvertCryptographyBackend(module)
return PrivateKeyConvertCryptographyBackend(module=module)
def get_privatekey_argument_spec() -> ArgumentSpec:
@@ -295,3 +297,11 @@ def get_privatekey_argument_spec() -> ArgumentSpec:
["src_path", "src_content"],
],
)
__all__ = (
"PrivateKeyError",
"PrivateKeyConvertBackend",
"select_backend",
"get_privatekey_argument_spec",
)

View File

@@ -60,7 +60,7 @@ SIGNATURE_TEST_DATA = b"1234"
def _get_cryptography_private_key_info(
key: PrivateKeyTypes, need_private_key_data: bool = False
key: PrivateKeyTypes, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
key_type, key_public_data = _get_cryptography_public_key_info(key.public_key())
key_private_data: dict[str, t.Any] = {}
@@ -84,7 +84,7 @@ def _get_cryptography_private_key_info(
def _check_dsa_consistency(
key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
*, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
# Get parameters
p: int | None = key_public_data.get("p")
@@ -112,10 +112,10 @@ def _check_dsa_consistency(
if (p - 1) % q != 0:
return False
# Check that g**q mod p == 1
if binary_exp_mod(g, q, p) != 1:
if binary_exp_mod(g, q, m=p) != 1:
return False
# Check whether g**x mod p == y
if binary_exp_mod(g, x, p) != y:
if binary_exp_mod(g, x, m=p) != y:
return False
# Check (quickly) whether p or q are not primes
if quick_is_not_prime(q) or quick_is_not_prime(p):
@@ -125,6 +125,7 @@ def _check_dsa_consistency(
def _is_cryptography_key_consistent(
key: PrivateKeyTypes,
*,
key_public_data: dict[str, t.Any],
key_private_data: dict[str, t.Any],
warn_func: t.Callable[[str], None] | None = None,
@@ -135,7 +136,9 @@ def _is_cryptography_key_consistent(
if backend is not None:
return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
result = _check_dsa_consistency(key_public_data, key_private_data)
result = _check_dsa_consistency(
key_public_data=key_public_data, key_private_data=key_private_data
)
if result is not None:
return result
signature = key.sign(
@@ -191,14 +194,14 @@ def _is_cryptography_key_consistent(
class PrivateKeyConsistencyError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PrivateKeyConsistencyError, self).__init__(msg)
self.error_message = msg
self.result = result
class PrivateKeyParseError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PrivateKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
@@ -207,6 +210,7 @@ class PrivateKeyParseError(OpenSSLObjectError):
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -220,22 +224,22 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
self.check_consistency = check_consistency
@abc.abstractmethod
def _get_public_key(self, binary: bool) -> bytes:
def _get_public_key(self, *, binary: bool) -> bytes:
pass
@abc.abstractmethod
def _get_key_info(
self, need_private_key_data: bool = False
self, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
pass
@abc.abstractmethod
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {
"can_parse_key": False,
"key_is_consistent": None,
@@ -253,7 +257,7 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
)
result["can_parse_key"] = True
except OpenSSLObjectError as exc:
raise PrivateKeyParseError(str(exc), result)
raise PrivateKeyParseError(str(exc), result=result)
result["public_key"] = to_native(self._get_public_key(binary=False))
pk = self._get_public_key(binary=True)
@@ -273,7 +277,7 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
if self.check_consistency:
result["key_is_consistent"] = self._is_key_consistent(
key_public_data, key_private_data
key_public_data=key_public_data, key_private_data=key_private_data
)
if result["key_is_consistent"] is False:
# Only fail when it is False, to avoid to fail on None (which means "we do not know")
@@ -281,40 +285,46 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
"Private key is not consistent! (See "
"https://blog.hboeck.de/archives/888-How-I-tricked-Symantec-with-a-Fake-Private-Key.html)"
)
raise PrivateKeyConsistencyError(msg, result)
raise PrivateKeyConsistencyError(msg, result=result)
return result
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
"""Validate the supplied private key, using the cryptography backend"""
def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None:
def __init__(
self, *, module: GeneralAnsibleModule, content: bytes, **kwargs
) -> None:
super(PrivateKeyInfoRetrievalCryptography, self).__init__(
module, content, **kwargs
module=module, content=content, **kwargs
)
def _get_public_key(self, binary: bool) -> bytes:
def _get_public_key(self, *, binary: bool) -> bytes:
return self.key.public_key().public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(
self, need_private_key_data: bool = False
self, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return _get_cryptography_private_key_info(
self.key, need_private_key_data=need_private_key_data
)
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
return _is_cryptography_key_consistent(
self.key, key_public_data, key_private_data, warn_func=self.module.warn
self.key,
key_public_data=key_public_data,
key_private_data=key_private_data,
warn_func=self.module.warn,
)
def get_privatekey_info(
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -322,8 +332,8 @@ def get_privatekey_info(
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PrivateKeyInfoRetrievalCryptography(
module,
content,
module=module,
content=content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
)
@@ -331,6 +341,7 @@ def get_privatekey_info(
def select_backend(
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -341,9 +352,18 @@ def select_backend(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyInfoRetrievalCryptography(
module,
content,
module=module,
content=content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
check_consistency=check_consistency,
)
__all__ = (
"PrivateKeyConsistencyError",
"PrivateKeyParseError",
"PrivateKeyInfoRetrieval",
"get_privatekey_info",
"select_backend",
)

View File

@@ -99,7 +99,7 @@ def _get_cryptography_public_key_info(
class PublicKeyParseError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PublicKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
@@ -108,6 +108,7 @@ class PublicKeyParseError(OpenSSLObjectError):
class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
@@ -125,13 +126,13 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
if self.key is None:
try:
self.key = load_publickey(content=self.content)
except OpenSSLObjectError as e:
raise PublicKeyParseError(str(e), {})
raise PublicKeyParseError(str(e), result={})
pk = self._get_public_key(binary=True)
result["fingerprints"] = (
@@ -151,12 +152,13 @@ class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> None:
super(PublicKeyInfoRetrievalCryptography, self).__init__(
module, content=content, key=key
module=module, content=content, key=key
)
def _get_public_key(self, binary: bool) -> bytes:
@@ -174,16 +176,18 @@ class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
def get_publickey_info(
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
info = PublicKeyInfoRetrievalCryptography(module=module, content=content, key=key)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
@@ -191,4 +195,12 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
return PublicKeyInfoRetrievalCryptography(module=module, content=content, key=key)
__all__ = (
"PublicKeyParseError",
"PublicKeyInfoRetrieval",
"get_publickey_info",
"select_backend",
)

View File

@@ -17,7 +17,7 @@ PKCS8_PRIVATEKEY_NAMES = ("PRIVATE KEY", "ENCRYPTED PRIVATE KEY")
PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY"
def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
def identify_pem_format(content: bytes, *, encoding: str = "utf-8") -> bool:
"""Given the contents of a binary file, tests whether this could be a PEM file."""
try:
first_pem = extract_first_pem(content.decode(encoding))
@@ -36,7 +36,7 @@ def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
def identify_private_key_format(
content: bytes, encoding: str = "utf-8"
content: bytes, *, encoding: str = "utf-8"
) -> t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]:
"""Given the contents of a private key file, identifies its format."""
# See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85
@@ -66,7 +66,7 @@ def identify_private_key_format(
return "raw"
def split_pem_list(text: str, keep_inbetween: bool = False) -> list[str]:
def split_pem_list(text: str, *, keep_inbetween: bool = False) -> list[str]:
"""
Split concatenated PEM objects into a list of strings, where each is one PEM object.
"""
@@ -94,7 +94,7 @@ def extract_first_pem(text: str) -> str | None:
return all_pems[0]
def _extract_type(line: str, start: str = PEM_START) -> str | None:
def _extract_type(line: str, *, start: str = PEM_START) -> str | None:
if not line.startswith(start):
return None
if not line.endswith(PEM_END):
@@ -102,7 +102,7 @@ def _extract_type(line: str, start: str = PEM_START) -> str | None:
return line[len(start) : -len(PEM_END)]
def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
def extract_pem(content: str, *, strict: bool = False) -> tuple[str, str]:
lines = content.splitlines()
if len(lines) < 3:
raise ValueError(f"PEM must have at least 3 lines, have only {len(lines)}")
@@ -125,3 +125,17 @@ def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
f"Last line has length {len(lines[-2])}, should be in (0, 64]"
)
return header_type, "".join(lines[1:-1])
__all__ = (
"PEM_START",
"PEM_END_START",
"PEM_END",
"PKCS8_PRIVATEKEY_NAMES",
"PKCS1_PRIVATEKEY_SUFFIX",
"identify_pem_format",
"identify_private_key_format",
"split_pem_list",
"extract_first_pem",
"extract_pem",
)

View File

@@ -63,7 +63,9 @@ PREFERRED_FINGERPRINTS = (
)
def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[str, str]:
def get_fingerprint_of_bytes(
source: bytes, *, prefer_one: bool = False
) -> dict[str, str]:
"""Generate the fingerprint of the given bytes."""
fingerprint = {}
@@ -107,7 +109,7 @@ def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[st
def get_fingerprint_of_privatekey(
privatekey: PrivateKeyTypes, prefer_one: bool = False
privatekey: PrivateKeyTypes, *, prefer_one: bool = False
) -> dict[str, str]:
"""Generate the fingerprint of the public key."""
@@ -119,6 +121,7 @@ def get_fingerprint_of_privatekey(
def get_fingerprint(
*,
path: os.PathLike | str | None = None,
passphrase: str | bytes | None = None,
content: bytes | None = None,
@@ -137,6 +140,7 @@ def get_fingerprint(
def load_privatekey(
*,
path: os.PathLike | str | None = None,
passphrase: str | bytes | None = None,
check_passphrase: bool = True,
@@ -219,7 +223,7 @@ def load_certificate_issuer_privatekey(
def load_publickey(
path: os.PathLike | str | None = None, content: bytes | None = None
*, path: os.PathLike | str | None = None, content: bytes | None = None
) -> PublicKeyTypes:
if content is None:
if path is None:
@@ -237,6 +241,7 @@ def load_publickey(
def load_certificate(
*,
path: os.PathLike | str | None = None,
content: bytes | None = None,
der_support_enabled: bool = False,
@@ -266,7 +271,7 @@ def load_certificate(
def load_certificate_request(
path: os.PathLike | str | None = None, content: bytes | None = None
*, path: os.PathLike | str | None = None, content: bytes | None = None
) -> x509.CertificateSigningRequest:
"""Load the specified certificate signing request."""
try:
@@ -287,6 +292,7 @@ def load_certificate_request(
def parse_name_field(
input_dict: dict[str, list[str | bytes] | str | bytes],
*,
name_field_name: str | None = None,
) -> list[tuple[str, str | bytes]]:
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
@@ -321,7 +327,9 @@ def parse_name_field(
def parse_ordered_name_field(
input_list: list[dict[str, list[str | bytes] | str | bytes]], name_field_name: str
input_list: list[dict[str, list[str | bytes] | str | bytes]],
*,
name_field_name: str,
) -> list[tuple[str, str | bytes]]:
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
@@ -372,7 +380,7 @@ def select_message_digest(
class OpenSSLObject(metaclass=abc.ABCMeta):
def __init__(self, path: str, state: str, force: bool, check_mode: bool) -> None:
def __init__(self, *, path: str, state: str, force: bool, check_mode: bool) -> None:
self.path = path
self.state = state
self.force = force
@@ -380,7 +388,7 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
self.changed = False
self.check_mode = check_mode
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
def check(self, module: AnsibleModule, *, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
def _check_state() -> bool:
@@ -420,3 +428,20 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
raise OpenSSLObjectError(exc)
else:
pass
__all__ = (
"get_fingerprint_of_bytes",
"get_fingerprint_of_privatekey",
"get_fingerprint",
"load_privatekey",
"load_certificate_privatekey",
"load_certificate_issuer_privatekey",
"load_publickey",
"load_certificate",
"load_certificate_request",
"parse_name_field",
"parse_ordered_name_field",
"select_message_digest",
"OpenSSLObject",
)