Code refactoring (#889)

* Add __all__ to all module and plugin utils.

* Convert quite a few positional args to keyword args.

* Avoid Python 3.8+ syntax.
This commit is contained in:
Felix Fontein
2025-05-16 06:55:57 +02:00
committed by GitHub
parent a5a4e022ba
commit 44bcc8cebc
101 changed files with 1510 additions and 748 deletions

View File

@@ -29,7 +29,7 @@ class ACMEAccount:
retrieve account data.
"""
def __init__(self, client: ACMEClient) -> None:
def __init__(self, *, client: ACMEClient) -> None:
# Set to true to enable logging of all signed requests
self._debug: bool = False
@@ -37,6 +37,7 @@ class ACMEAccount:
def _new_reg(
self,
*,
contact: list[str] | None = None,
terms_agreed: bool = False,
allow_creation: bool = True,
@@ -82,14 +83,15 @@ class ACMEAccount:
url = self.client.directory["newAccount"]
if external_account_binding is not None:
new_reg["externalAccountBinding"] = self.client.sign_request(
{
protected={
"alg": external_account_binding["alg"],
"kid": external_account_binding["kid"],
"url": url,
},
self.client.account_jwk,
self.client.backend.create_mac_key(
external_account_binding["alg"], external_account_binding["key"]
payload=self.client.account_jwk,
key_data=self.client.backend.create_mac_key(
alg=external_account_binding["alg"],
key=external_account_binding["key"],
),
)
elif (
@@ -106,7 +108,7 @@ class ACMEAccount:
)
if not isinstance(result, Mapping):
raise ACMEProtocolException(
self.client.module,
module=self.client.module,
msg="Invalid account creation reply from ACME server",
info=info,
content_json=result,
@@ -156,7 +158,7 @@ class ACMEAccount:
raise ModuleFailException("Account is deactivated")
else:
raise ACMEProtocolException(
self.client.module,
module=self.client.module,
msg="Registering ACME account failed",
info=info,
content_json=result,
@@ -187,7 +189,7 @@ class ACMEAccount:
)
if not isinstance(result, Mapping):
raise ACMEProtocolException(
self.client.module,
module=self.client.module,
msg="Invalid account data retrieved from ACME server",
info=info,
content_json=result,
@@ -206,7 +208,7 @@ class ACMEAccount:
return None
if info["status"] < 200 or info["status"] >= 300:
raise ACMEProtocolException(
self.client.module,
module=self.client.module,
msg="Error retrieving account data",
info=info,
content_json=result,
@@ -216,6 +218,7 @@ class ACMEAccount:
@t.overload
def setup_account(
self,
*,
contact: list[str] | None = None,
terms_agreed: bool = False,
allow_creation: t.Literal[True] = True,
@@ -226,6 +229,7 @@ class ACMEAccount:
@t.overload
def setup_account(
self,
*,
contact: list[str] | None = None,
terms_agreed: bool = False,
allow_creation: bool = True,
@@ -235,6 +239,7 @@ class ACMEAccount:
def setup_account(
self,
*,
contact: list[str] | None = None,
terms_agreed: bool = False,
allow_creation: bool = True,
@@ -281,7 +286,7 @@ class ACMEAccount:
)
else:
created, account_data = self._new_reg(
contact,
contact=contact,
terms_agreed=terms_agreed,
allow_creation=allow_creation and not self.client.module.check_mode,
external_account_binding=external_account_binding,
@@ -296,7 +301,7 @@ class ACMEAccount:
return created, account_data
def update_account(
self, account_data: dict[str, t.Any], contact: list[str] | None = None
self, *, account_data: dict[str, t.Any], contact: list[str] | None = None
) -> tuple[bool, dict[str, t.Any]]:
"""
Update an account on the ACME server. Check mode is fully respected.
@@ -332,10 +337,13 @@ class ACMEAccount:
)
if not isinstance(account_data, Mapping):
raise ACMEProtocolException(
self.client.module,
module=self.client.module,
msg="Invalid account updating reply from ACME server",
info=info,
content_json=account_data,
)
return True, account_data
__all__ = ("ACMEAccount",)

View File

@@ -65,14 +65,14 @@ RETRY_COUNT = 10
def _decode_retry(
module: AnsibleModule, response: t.Any, info: dict[str, t.Any], retry_count: int
*, module: AnsibleModule, response: t.Any, info: dict[str, t.Any], retry_count: int
) -> bool:
if info["status"] not in RETRY_STATUS_CODES:
return False
if retry_count >= RETRY_COUNT:
raise ACMEProtocolException(
module,
module=module,
msg=f"Giving up after {RETRY_COUNT} retries",
info=info,
response=response,
@@ -93,6 +93,7 @@ def _decode_retry(
def _assert_fetch_url_success(
*,
module: AnsibleModule,
response: t.Any,
info: dict[str, t.Any],
@@ -108,11 +109,11 @@ def _assert_fetch_url_success(
or (400 <= info["status"] < 500 and not allow_client_error)
or (info["status"] >= 500 and not allow_server_error)
):
raise ACMEProtocolException(module, info=info, response=response)
raise ACMEProtocolException(module=module, info=info, response=response)
def _is_failed(
info: dict[str, t.Any], expected_status_codes: t.Iterable[int] | None = None
*, info: dict[str, t.Any], expected_status_codes: t.Iterable[int] | None = None
) -> bool:
if info["status"] < 200 or info["status"] >= 400:
return True
@@ -133,7 +134,7 @@ class ACMEDirectory:
https://tools.ietf.org/html/rfc8555#section-7.1.1
"""
def __init__(self, module: AnsibleModule, client: ACMEClient) -> None:
def __init__(self, *, module: AnsibleModule, client: ACMEClient) -> None:
self.module = module
self.directory_root = module.params["acme_directory"]
self.version = module.params["acme_version"]
@@ -171,7 +172,12 @@ class ACMEDirectory:
response, info = fetch_url(
self.module, url, method="HEAD", timeout=self.request_timeout
)
if _decode_retry(self.module, response, info, retry_count):
if _decode_retry(
module=self.module,
response=response,
info=info,
retry_count=retry_count,
):
retry_count += 1
continue
if info["status"] not in (200, 204):
@@ -185,7 +191,7 @@ class ACMEDirectory:
)
if retry_count >= 5:
raise ACMEProtocolException(
self.module,
module=self.module,
msg="Was not able to obtain nonce, giving up after 5 retries",
info=info,
response=response,
@@ -202,7 +208,7 @@ class ACMEClient:
ACME server.
"""
def __init__(self, module: AnsibleModule, backend: CryptoBackend) -> None:
def __init__(self, *, module: AnsibleModule, backend: CryptoBackend) -> None:
# Set to true to enable logging of all signed requests
self._debug = False
@@ -241,7 +247,7 @@ class ACMEClient:
# Make sure self.account_jws_header is updated
self.set_account_uri(self.account_uri)
self.directory = ACMEDirectory(module, self)
self.directory = ACMEDirectory(module=module, client=self)
def set_account_uri(self, uri: str) -> None:
"""
@@ -255,6 +261,7 @@ class ACMEClient:
def parse_key(
self,
*,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
@@ -265,10 +272,13 @@ class ACMEClient:
"""
if key_file is None and key_content is None:
raise AssertionError("One of key_file and key_content must be specified!")
return self.backend.parse_key(key_file, key_content, passphrase=passphrase)
return self.backend.parse_key(
key_file=key_file, key_content=key_content, passphrase=passphrase
)
def sign_request(
self,
*,
protected: dict[str, t.Any],
payload: str | dict[str, t.Any] | None,
key_data: dict[str, t.Any],
@@ -292,9 +302,11 @@ class ACMEClient:
f"Failed to encode payload / headers as JSON: {e}"
)
return self.backend.sign(payload64, protected64, key_data)
return self.backend.sign(
payload64=payload64, protected64=protected64, key_data=key_data
)
def _log(self, msg: str, data: t.Any = None) -> None:
def _log(self, msg: str, *, data: t.Any = None) -> None:
"""
Write arguments to acme.log when logging is enabled.
"""
@@ -373,13 +385,16 @@ class ACMEClient:
protected["nonce"] = self.directory.get_nonce()
protected["url"] = url
self._log("URL", url)
self._log("protected", protected)
self._log("payload", payload)
self._log("URL", data=url)
self._log("protected", data=protected)
self._log("payload", data=payload)
data = self.sign_request(
protected, payload, key_data, encode_payload=encode_payload
protected=protected,
payload=payload,
key_data=key_data,
encode_payload=encode_payload,
)
self._log("signed request", data)
self._log("signed request", data=data)
data = self.module.jsonify(data)
headers = {
@@ -393,10 +408,12 @@ class ACMEClient:
method="POST",
timeout=self.request_timeout,
)
if _decode_retry(self.module, resp, info, failed_tries):
if _decode_retry(
module=self.module, response=resp, info=info, retry_count=failed_tries
):
failed_tries += 1
continue
_assert_fetch_url_success(self.module, resp, info)
_assert_fetch_url_success(module=self.module, response=resp, info=info)
result = {}
try:
@@ -415,7 +432,7 @@ class ACMEClient:
) or 400 <= info["status"] < 600:
try:
decoded_result = self.module.from_json(content.decode("utf8"))
self._log("parsed result", decoded_result)
self._log("parsed result", data=decoded_result)
# In case of badNonce error, try again (up to 5 times)
# (https://tools.ietf.org/html/rfc8555#section-6.7)
if all(
@@ -440,10 +457,10 @@ class ACMEClient:
result = content
if fail_on_error and _is_failed(
info, expected_status_codes=expected_status_codes
info=info, expected_status_codes=expected_status_codes
):
raise ACMEProtocolException(
self.module,
module=self.module,
msg=error_msg,
info=info,
content=content,
@@ -515,11 +532,16 @@ class ACMEClient:
headers=headers,
timeout=self.request_timeout,
)
if not _decode_retry(self.module, resp, info, retry_count):
if not _decode_retry(
module=self.module,
response=resp,
info=info,
retry_count=retry_count,
):
break
retry_count += 1
_assert_fetch_url_success(self.module, resp, info)
_assert_fetch_url_success(module=self.module, response=resp, info=info)
try:
# In Python 2, reading from a closed response yields a TypeError.
@@ -550,10 +572,10 @@ class ACMEClient:
result = content
if fail_on_error and _is_failed(
info, expected_status_codes=expected_status_codes
info=info, expected_status_codes=expected_status_codes
):
raise ACMEProtocolException(
self.module,
module=self.module,
msg=error_msg,
info=info,
content=content,
@@ -565,6 +587,7 @@ class ACMEClient:
def get_renewal_info(
self,
*,
cert_id: str | None = None,
cert_info: CertificateInformation | None = None,
cert_filename: str | os.PathLike | None = None,
@@ -579,7 +602,7 @@ class ACMEClient:
if cert_id is None:
cert_id = compute_cert_id(
self.backend,
backend=self.backend,
cert_info=cert_info,
cert_filename=cert_filename,
cert_content=cert_content,
@@ -603,6 +626,7 @@ class ACMEClient:
def create_default_argspec(
*,
with_account: bool = True,
require_account_key: bool = True,
with_certificate: bool = False,
@@ -643,7 +667,9 @@ def create_default_argspec(
return result
def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoBackend:
def create_backend(
module: AnsibleModule, *, needs_acme_v2: bool = True
) -> CryptoBackend:
backend = module.params["select_crypto_backend"]
# Backend autodetect
@@ -671,10 +697,10 @@ def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoB
module.debug(
f"Using cryptography backend (library version {CRYPTOGRAPHY_VERSION})"
)
module_backend = CryptographyBackend(module)
module_backend = CryptographyBackend(module=module)
elif backend == "openssl":
module.debug("Using OpenSSL binary backend")
module_backend = OpenSSLCLIBackend(module)
module_backend = OpenSSLCLIBackend(module=module)
else:
module.fail_json(msg=f'Unknown crypto backend "{backend}"!')
@@ -691,3 +717,11 @@ def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoB
locale.setlocale(locale.LC_ALL, "C")
return module_backend
__all__ = (
"ACMEDirectory",
"ACMEClient",
"create_default_argspec",
"create_backend",
)

View File

@@ -92,7 +92,11 @@ if t.TYPE_CHECKING:
class CryptographyChainMatcher(ChainMatcher):
@staticmethod
def _parse_key_identifier(
key_identifier: str | None, name: str, criterium_idx: int, module: AnsibleModule
*,
key_identifier: str | None,
name: str,
criterium_idx: int,
module: AnsibleModule,
) -> bytes | None:
if key_identifier:
try:
@@ -109,7 +113,7 @@ class CryptographyChainMatcher(ChainMatcher):
)
return None
def __init__(self, criterium: Criterium, module: AnsibleModule) -> None:
def __init__(self, *, criterium: Criterium, module: AnsibleModule) -> None:
self.criterium = criterium
self.test_certificates = criterium.test_certificates
self.subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = []
@@ -117,29 +121,32 @@ class CryptographyChainMatcher(ChainMatcher):
if criterium.subject:
self.subject = [
(cryptography_name_to_oid(k), to_native(v))
for k, v in parse_name_field(criterium.subject, "subject")
for k, v in parse_name_field(
criterium.subject, name_field_name="subject"
)
]
if criterium.issuer:
self.issuer = [
(cryptography_name_to_oid(k), to_native(v))
for k, v in parse_name_field(criterium.issuer, "issuer")
for k, v in parse_name_field(criterium.issuer, name_field_name="issuer")
]
self.subject_key_identifier = CryptographyChainMatcher._parse_key_identifier(
criterium.subject_key_identifier,
"subject_key_identifier",
criterium.index,
module,
key_identifier=criterium.subject_key_identifier,
name="subject_key_identifier",
criterium_idx=criterium.index,
module=module,
)
self.authority_key_identifier = CryptographyChainMatcher._parse_key_identifier(
criterium.authority_key_identifier,
"authority_key_identifier",
criterium.index,
module,
key_identifier=criterium.authority_key_identifier,
name="authority_key_identifier",
criterium_idx=criterium.index,
module=module,
)
self.module = module
def _match_subject(
self,
*,
x509_subject: cryptography.x509.Name,
match_subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]],
) -> bool:
@@ -153,7 +160,7 @@ class CryptographyChainMatcher(ChainMatcher):
return False
return True
def match(self, certificate: CertificateChain) -> bool:
def match(self, *, certificate: CertificateChain) -> bool:
"""
Check whether an alternate chain matches the specified criterium.
"""
@@ -166,9 +173,13 @@ class CryptographyChainMatcher(ChainMatcher):
try:
x509 = cryptography.x509.load_pem_x509_certificate(to_bytes(cert))
matches = True
if not self._match_subject(x509.subject, self.subject):
if not self._match_subject(
x509_subject=x509.subject, match_subject=self.subject
):
matches = False
if not self._match_subject(x509.issuer, self.issuer):
if not self._match_subject(
x509_subject=x509.issuer, match_subject=self.issuer
):
matches = False
if self.subject_key_identifier:
try:
@@ -199,13 +210,14 @@ class CryptographyChainMatcher(ChainMatcher):
class CryptographyBackend(CryptoBackend):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
super(CryptographyBackend, self).__init__(
module, with_timezone=CRYPTOGRAPHY_TIMEZONE
module=module, with_timezone=CRYPTOGRAPHY_TIMEZONE
)
def parse_key(
self,
*,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
@@ -288,7 +300,7 @@ class CryptographyBackend(CryptoBackend):
raise KeyParsingError(f'unknown key type "{type(key)}"')
def sign(
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
self, *, payload64: str, protected64: str, key_data: dict[str, t.Any]
) -> dict[str, t.Any]:
sign_payload = f"{protected64}.{payload64}".encode("utf8")
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
@@ -317,8 +329,8 @@ class CryptographyBackend(CryptoBackend):
r, s = cryptography.hazmat.primitives.asymmetric.utils.decode_dss_signature(
key_data["key_obj"].sign(sign_payload, ecdsa)
)
rr = convert_int_to_hex(r, 2 * key_data["point_size"])
ss = convert_int_to_hex(s, 2 * key_data["point_size"])
rr = convert_int_to_hex(r, digits=2 * key_data["point_size"])
ss = convert_int_to_hex(s, digits=2 * key_data["point_size"])
signature = binascii.unhexlify(rr) + binascii.unhexlify(ss)
return {
@@ -327,7 +339,7 @@ class CryptographyBackend(CryptoBackend):
"signature": nopad_b64(signature),
}
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
def create_mac_key(self, *, alg: str, key: str) -> dict[str, t.Any]:
"""Create a MAC key."""
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
if alg == "HS256":
@@ -362,6 +374,7 @@ class CryptographyBackend(CryptoBackend):
def get_ordered_csr_identifiers(
self,
*,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> list[tuple[str, str]]:
@@ -413,6 +426,7 @@ class CryptographyBackend(CryptoBackend):
def get_csr_identifiers(
self,
*,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | bytes | None = None,
) -> set[tuple[str, str]]:
@@ -429,6 +443,7 @@ class CryptographyBackend(CryptoBackend):
def get_cert_days(
self,
*,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
now: datetime.datetime | None = None,
@@ -466,14 +481,15 @@ class CryptographyBackend(CryptoBackend):
now = add_or_remove_timezone(now, with_timezone=CRYPTOGRAPHY_TIMEZONE)
return (get_not_valid_after(cert) - now).days
def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
def create_chain_matcher(self, *, criterium: Criterium) -> ChainMatcher:
"""
Given a Criterium object, creates a ChainMatcher object.
"""
return CryptographyChainMatcher(criterium, self.module)
return CryptographyChainMatcher(criterium=criterium, module=self.module)
def get_cert_information(
self,
*,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
) -> CertificateInformation:
@@ -520,3 +536,12 @@ class CryptographyBackend(CryptoBackend):
subject_key_identifier=ski,
authority_key_identifier=aki,
)
__all__ = (
"CRYPTOGRAPHY_MINIMAL_VERSION",
"CRYPTOGRAPHY_ERROR",
"CRYPTOGRAPHY_VERSION",
"CRYPTOGRAPHY_ERROR",
"CryptographyBackend",
)

View File

@@ -49,7 +49,7 @@ _OPENSSL_ENVIRONMENT_UPDATE = dict(LANG="C", LC_ALL="C", LC_MESSAGES="C", LC_CTY
def _extract_date(
out_text: str, name: str, cert_filename_suffix: str = ""
out_text: str, *, name: str, cert_filename_suffix: str = ""
) -> datetime.datetime:
matcher = re.search(rf"\s+{name}\s*:\s+(.*)", out_text)
if matcher is None:
@@ -76,6 +76,7 @@ def _decode_octets(octets_text: str) -> bytes:
@t.overload
def _extract_octets(
out_text: str,
*,
name: str,
required: t.Literal[False],
potential_prefixes: t.Iterable[str] | None = None,
@@ -85,6 +86,7 @@ def _extract_octets(
@t.overload
def _extract_octets(
out_text: str,
*,
name: str,
required: t.Literal[True],
potential_prefixes: t.Iterable[str] | None = None,
@@ -93,6 +95,7 @@ def _extract_octets(
def _extract_octets(
out_text: str,
*,
name: str,
required: bool = True,
potential_prefixes: t.Iterable[str] | None = None,
@@ -113,15 +116,16 @@ def _extract_octets(
class OpenSSLCLIBackend(CryptoBackend):
def __init__(
self, module: AnsibleModule, openssl_binary: str | None = None
self, *, module: AnsibleModule, openssl_binary: str | None = None
) -> None:
super(OpenSSLCLIBackend, self).__init__(module, with_timezone=True)
super(OpenSSLCLIBackend, self).__init__(module=module, with_timezone=True)
if openssl_binary is None:
openssl_binary = module.get_bin_path("openssl", True)
self.openssl_binary = openssl_binary
def parse_key(
self,
*,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
@@ -282,7 +286,7 @@ class OpenSSLCLIBackend(CryptoBackend):
)
def sign(
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
self, *, payload64: str, protected64: str, key_data: dict[str, t.Any]
) -> dict[str, t.Any]:
sign_payload = f"{protected64}.{payload64}".encode("utf8")
if key_data["type"] == "hmac":
@@ -343,7 +347,7 @@ class OpenSSLCLIBackend(CryptoBackend):
"signature": nopad_b64(to_bytes(out)),
}
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
def create_mac_key(self, *, alg: str, key: str) -> dict[str, t.Any]:
"""Create a MAC key."""
if alg == "HS256":
hashalg = "sha256"
@@ -383,6 +387,7 @@ class OpenSSLCLIBackend(CryptoBackend):
def get_ordered_csr_identifiers(
self,
*,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> list[tuple[str, str]]:
@@ -454,6 +459,7 @@ class OpenSSLCLIBackend(CryptoBackend):
def get_csr_identifiers(
self,
*,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> set[tuple[str, str]]:
@@ -470,6 +476,7 @@ class OpenSSLCLIBackend(CryptoBackend):
def get_cert_days(
self,
*,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
now: datetime.datetime | None = None,
@@ -516,7 +523,7 @@ class OpenSSLCLIBackend(CryptoBackend):
out_text = to_text(out, errors="surrogate_or_strict")
not_after = _extract_date(
out_text, "Not After", cert_filename_suffix=cert_filename_suffix
out_text, name="Not After", cert_filename_suffix=cert_filename_suffix
)
if now is None:
now = self.get_now()
@@ -524,7 +531,7 @@ class OpenSSLCLIBackend(CryptoBackend):
now = ensure_utc_timezone(now)
return (not_after - now).days
def create_chain_matcher(self, criterium: Criterium) -> t.NoReturn:
def create_chain_matcher(self, *, criterium: Criterium) -> t.NoReturn:
"""
Given a Criterium object, creates a ChainMatcher object.
"""
@@ -534,6 +541,7 @@ class OpenSSLCLIBackend(CryptoBackend):
def get_cert_information(
self,
*,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
) -> CertificateInformation:
@@ -572,10 +580,10 @@ class OpenSSLCLIBackend(CryptoBackend):
out_text = to_text(out, errors="surrogate_or_strict")
not_after = _extract_date(
out_text, "Not After", cert_filename_suffix=cert_filename_suffix
out_text, name="Not After", cert_filename_suffix=cert_filename_suffix
)
not_before = _extract_date(
out_text, "Not Before", cert_filename_suffix=cert_filename_suffix
out_text, name="Not Before", cert_filename_suffix=cert_filename_suffix
)
sn = re.search(
@@ -587,13 +595,15 @@ class OpenSSLCLIBackend(CryptoBackend):
serial = int(sn.group(1))
else:
serial = convert_bytes_to_int(
_extract_octets(out_text, "Serial Number", required=True)
_extract_octets(out_text, name="Serial Number", required=True)
)
ski = _extract_octets(out_text, "X509v3 Subject Key Identifier", required=False)
ski = _extract_octets(
out_text, name="X509v3 Subject Key Identifier", required=False
)
aki = _extract_octets(
out_text,
"X509v3 Authority Key Identifier",
name="X509v3 Authority Key Identifier",
required=False,
potential_prefixes=["keyid:", ""],
)
@@ -605,3 +615,6 @@ class OpenSSLCLIBackend(CryptoBackend):
subject_key_identifier=ski,
authority_key_identifier=aki,
)
__all__ = ("OpenSSLCLIBackend",)

View File

@@ -69,7 +69,9 @@ def _reduce_fractional_digits(timestamp_str: str) -> str:
return f"{timestamp}{fractional}{timezone}"
def _parse_acme_timestamp(timestamp_str: str, with_timezone: bool) -> datetime.datetime:
def _parse_acme_timestamp(
timestamp_str: str, *, with_timezone: bool
) -> datetime.datetime:
"""
Parses a RFC 3339 timestamp.
"""
@@ -95,7 +97,7 @@ def _parse_acme_timestamp(timestamp_str: str, with_timezone: bool) -> datetime.d
class CryptoBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule, with_timezone: bool = False) -> None:
def __init__(self, *, module: AnsibleModule, with_timezone: bool = False) -> None:
self.module = module
self._with_timezone = with_timezone
@@ -106,10 +108,10 @@ class CryptoBackend(metaclass=abc.ABCMeta):
# RFC 3339 (https://www.rfc-editor.org/info/rfc3339)
return _parse_acme_timestamp(timestamp_str, with_timezone=self._with_timezone)
def parse_module_parameter(self, value: str, name: str) -> datetime.datetime:
def parse_module_parameter(self, *, value: str, name: str) -> datetime.datetime:
try:
result = get_relative_time_option(
value, name, with_timezone=self._with_timezone
value, input_name=name, with_timezone=self._with_timezone
)
if result is None:
raise BackendException(f"Invalid value for {name}: {value!r}")
@@ -121,6 +123,7 @@ class CryptoBackend(metaclass=abc.ABCMeta):
self,
timestamp_start: datetime.datetime,
timestamp_end: datetime.datetime,
*,
percentage: float,
) -> datetime.datetime:
start = get_epoch_seconds(timestamp_start)
@@ -141,6 +144,7 @@ class CryptoBackend(metaclass=abc.ABCMeta):
@abc.abstractmethod
def parse_key(
self,
*,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
@@ -152,17 +156,18 @@ class CryptoBackend(metaclass=abc.ABCMeta):
@abc.abstractmethod
def sign(
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
self, *, payload64: str, protected64: str, key_data: dict[str, t.Any]
) -> dict[str, t.Any]:
pass
@abc.abstractmethod
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
def create_mac_key(self, *, alg: str, key: str) -> dict[str, t.Any]:
"""Create a MAC key."""
@abc.abstractmethod
def get_ordered_csr_identifiers(
self,
*,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> list[tuple[str, str]]:
@@ -178,6 +183,7 @@ class CryptoBackend(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_csr_identifiers(
self,
*,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> set[tuple[str, str]]:
@@ -190,6 +196,7 @@ class CryptoBackend(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_cert_days(
self,
*,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
now: datetime.datetime | None = None,
@@ -203,7 +210,7 @@ class CryptoBackend(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
def create_chain_matcher(self, *, criterium: Criterium) -> ChainMatcher:
"""
Given a Criterium object, creates a ChainMatcher object.
"""
@@ -211,9 +218,13 @@ class CryptoBackend(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_cert_information(
self,
*,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
) -> CertificateInformation:
"""
Return some information on a X.509 certificate as a CertificateInformation object.
"""
__all__ = ("CertificateInformation", "CryptoBackend")

View File

@@ -58,6 +58,7 @@ class ACMECertificateClient:
def __init__(
self,
*,
module: AnsibleModule,
backend: CryptoBackend,
client: ACMEClient | None = None,
@@ -68,10 +69,10 @@ class ACMECertificateClient:
self.csr = module.params.get("csr")
self.csr_content = module.params.get("csr_content")
if client is None:
client = ACMEClient(module, backend)
client = ACMEClient(module=module, backend=backend)
self.client = client
if account is None:
account = ACMEAccount(self.client)
account = ACMEAccount(client=self.client)
self.account = account
self.order_uri = module.params.get("order_uri")
self.order_creation_error_strategy = module.params.get(
@@ -108,7 +109,9 @@ class ACMECertificateClient:
try:
select_chain_matcher.append(
self.client.backend.create_chain_matcher(
Criterium(criterium, index=criterium_idx)
criterium=Criterium(
criterium=criterium, index=criterium_idx
)
)
)
except ValueError as exc:
@@ -120,12 +123,12 @@ class ACMECertificateClient:
def load_order(self) -> Order:
if not self.order_uri:
raise ModuleFailException("The order URI has not been provided")
order = Order.from_url(self.client, self.order_uri)
order.load_authorizations(self.client)
order = Order.from_url(client=self.client, url=self.order_uri)
order.load_authorizations(client=self.client)
return order
def create_order(
self, replaces_cert_id: str | None = None, profile: str | None = None
self, *, replaces_cert_id: str | None = None, profile: str | None = None
) -> Order:
"""
Create a new order.
@@ -133,8 +136,8 @@ class ACMECertificateClient:
if self.identifiers is None:
raise ModuleFailException("No identifiers have been provided")
order = Order.create_with_error_handling(
self.client,
self.identifiers,
client=self.client,
identifiers=self.identifiers,
error_strategy=self.order_creation_error_strategy,
error_max_retries=self.order_creation_max_retries,
replaces_cert_id=replaces_cert_id,
@@ -142,7 +145,7 @@ class ACMECertificateClient:
message_callback=self.module.warn,
)
self.order_uri = order.url
order.load_authorizations(self.client)
order.load_authorizations(client=self.client)
return order
def get_challenges_data(
@@ -161,7 +164,7 @@ class ACMECertificateClient:
# and do not need to be returned
if authz.status == "valid":
continue
challenge_data = authz.get_challenge_data(self.client)
challenge_data = authz.get_challenge_data(client=self.client)
data.append(
dict(
identifier=authz.identifier,
@@ -209,20 +212,27 @@ class ACMECertificateClient:
def call_validate(
self,
pending_authzs: list[Authorization],
*,
get_challenge: t.Callable[[Authorization], str],
wait: bool = True,
) -> list[tuple[Authorization, str, Challenge | None]]:
authzs_with_challenges_to_wait_for = []
for authz in pending_authzs:
challenge_type = get_challenge(authz)
authz.call_validate(self.client, challenge_type, wait=wait)
authz.call_validate(
client=self.client, challenge_type=challenge_type, wait=wait
)
authzs_with_challenges_to_wait_for.append(
(authz, challenge_type, authz.find_challenge(challenge_type))
(
authz,
challenge_type,
authz.find_challenge(challenge_type=challenge_type),
)
)
return authzs_with_challenges_to_wait_for
def wait_for_validation(self, authzs_to_wait_for: list[Authorization]) -> None:
wait_for_validation(authzs_to_wait_for, self.client)
wait_for_validation(authzs=authzs_to_wait_for, client=self.client)
def _download_alternate_chains(
self, cert: CertificateChain
@@ -230,7 +240,7 @@ class ACMECertificateClient:
alternate_chains = []
for alternate in cert.alternates:
try:
alt_cert = CertificateChain.download(self.client, alternate)
alt_cert = CertificateChain.download(client=self.client, url=alternate)
except ModuleFailException as e:
self.module.warn(
f"Error while downloading alternative certificate {alternate}: {e}"
@@ -275,7 +285,7 @@ class ACMECertificateClient:
f"Order's crtificate URL {order.certificate_uri!r} is empty!"
)
cert = CertificateChain.download(self.client, order.certificate_uri)
cert = CertificateChain.download(client=self.client, url=order.certificate_uri)
if cert.cert is None:
raise ModuleFailException(
f"Certificate at {order.certificate_uri} is empty!"
@@ -314,22 +324,26 @@ class ACMECertificateClient:
for identifier, authz in order.authorizations.items():
if authz.status != "valid":
authz.raise_error(
f'Status is {authz.status!r} and not "valid"',
error_msg=f'Status is {authz.status!r} and not "valid"',
module=self.module,
)
order.finalize(self.client, pem_to_der(self.csr, self.csr_content))
order.finalize(
client=self.client,
csr_der=pem_to_der(pem_filename=self.csr, pem_content=self.csr_content),
)
return self.download_certificate(order, download_all_chains=download_all_chains)
def find_matching_chain(
self,
*,
chains: list[CertificateChain],
select_chain_matcher: t.Iterable[ChainMatcher],
) -> CertificateChain | None:
for criterium_idx, matcher in enumerate(select_chain_matcher):
for chain in chains:
if matcher.match(chain):
if matcher.match(certificate=chain):
self.module.debug(
f"Found matching chain for criterium {criterium_idx}"
)
@@ -338,6 +352,7 @@ class ACMECertificateClient:
def write_cert_chain(
self,
*,
cert: CertificateChain,
cert_dest: str | os.PathLike | None = None,
fullchain_dest: str | os.PathLike | None = None,
@@ -347,18 +362,22 @@ class ACMECertificateClient:
if cert.cert is None:
raise ValueError("Certificate is not present")
if cert_dest and write_file(self.module, cert_dest, cert.cert.encode("utf8")):
if cert_dest and write_file(
module=self.module, dest=cert_dest, content=cert.cert.encode("utf8")
):
changed = True
if fullchain_dest and write_file(
self.module,
fullchain_dest,
(cert.cert + "\n".join(cert.chain)).encode("utf8"),
module=self.module,
dest=fullchain_dest,
content=(cert.cert + "\n".join(cert.chain)).encode("utf8"),
):
changed = True
if chain_dest and write_file(
self.module, chain_dest, ("\n".join(cert.chain)).encode("utf8")
module=self.module,
dest=chain_dest,
content=("\n".join(cert.chain)).encode("utf8"),
):
changed = True
@@ -374,7 +393,9 @@ class ACMECertificateClient:
for authz_uri in order.authorization_uris:
authz = None
try:
authz = Authorization.deactivate_url(self.client, authz_uri)
authz = Authorization.deactivate_url(
client=self.client, url=authz_uri
)
except Exception:
# ignore errors
pass
@@ -385,7 +406,7 @@ class ACMECertificateClient:
else:
for authz in order.authorizations.values():
try:
authz.deactivate(self.client)
authz.deactivate(client=self.client)
except Exception:
# ignore errors
pass
@@ -393,3 +414,6 @@ class ACMECertificateClient:
self.module.warn(
warning=f"Could not deactivate authz object {authz.url}."
)
__all__ = ("ACMECertificateClient",)

View File

@@ -46,7 +46,7 @@ class CertificateChain:
@classmethod
def download(
cls: t.Type[_CertificateChain], client: ACMEClient, url: str
cls: t.Type[_CertificateChain], *, client: ACMEClient, url: str
) -> _CertificateChain:
content, info = client.get_request(
url,
@@ -70,7 +70,10 @@ class CertificateChain:
result.chain = certs[1:]
process_links(
info, lambda link, relation: result._process_links(client, link, relation)
info=info,
callback=lambda link, relation: result._process_links(
client=client, link=link, relation=relation
),
)
if result.cert is None:
@@ -80,7 +83,7 @@ class CertificateChain:
return result
def _process_links(self, client: ACMEClient, link: str, relation: str) -> None:
def _process_links(self, *, client: ACMEClient, link: str, relation: str) -> None:
if relation == "up":
# Process link-up headers if there was no chain in reply
if not self.chain:
@@ -105,7 +108,7 @@ class CertificateChain:
class Criterium:
def __init__(self, criterium: dict[str, t.Any], index: int):
def __init__(self, *, criterium: dict[str, t.Any], index: int):
self.index = index
self.test_certificates: t.Literal["first", "last", "all"] = criterium[
"test_certificates"
@@ -120,7 +123,10 @@ class Criterium:
class ChainMatcher(metaclass=abc.ABCMeta):
@abc.abstractmethod
def match(self, certificate: CertificateChain) -> bool:
def match(self, *, certificate: CertificateChain) -> bool:
"""
Check whether a certificate chain (CertificateChain instance) matches.
"""
__all__ = ("CertificateChain", "Criterium", "ChainMatcher")

View File

@@ -34,7 +34,7 @@ if t.TYPE_CHECKING:
)
def create_key_authorization(client: ACMEClient, token: str) -> str:
def create_key_authorization(*, client: ACMEClient, token: str) -> str:
"""
Returns the key authorization for the given token
https://tools.ietf.org/html/rfc8555#section-8.1
@@ -46,7 +46,7 @@ def create_key_authorization(client: ACMEClient, token: str) -> str:
return f"{token}.{thumbprint}"
def combine_identifier(identifier_type: str, identifier: str) -> str:
def combine_identifier(*, identifier_type: str, identifier: str) -> str:
return f"{identifier_type}:{identifier}"
@@ -54,7 +54,7 @@ def normalize_combined_identifier(identifier: str) -> str:
identifier_type, identifier = split_identifier(identifier)
# Normalize DNS names and IPs
identifier = identifier.lower()
return combine_identifier(identifier_type, identifier)
return combine_identifier(identifier_type=identifier_type, identifier=identifier)
def split_identifier(identifier: str) -> tuple[str, str]:
@@ -70,7 +70,7 @@ _Challenge = t.TypeVar("_Challenge", bound="Challenge")
class Challenge:
def __init__(self, data: dict[str, t.Any], url: str) -> None:
def __init__(self, *, data: dict[str, t.Any], url: str) -> None:
self.data = data
self.type: str = data["type"]
@@ -81,11 +81,12 @@ class Challenge:
@classmethod
def from_json(
cls: t.Type[_Challenge],
*,
client: ACMEClient,
data: dict[str, t.Any],
url: str | None = None,
) -> _Challenge:
return cls(data, url or data["url"])
return cls(data=data, url=url or data["url"])
def call_validate(self, client: ACMEClient) -> None:
challenge_response: dict[str, t.Any] = {}
@@ -100,13 +101,13 @@ class Challenge:
return self.data.copy()
def get_validation_data(
self, client: ACMEClient, identifier_type: str, identifier: str
self, *, client: ACMEClient, identifier_type: str, identifier: str
) -> dict[str, t.Any] | None:
if self.token is None:
return None
token = re.sub(r"[^A-Za-z0-9_\-]", "_", self.token)
key_authorization = create_key_authorization(client, token)
key_authorization = create_key_authorization(client=client, token=token)
if self.type == "http-01":
# https://tools.ietf.org/html/rfc8555#section-8.3
@@ -142,7 +143,9 @@ class Challenge:
)
return {
"resource": resource,
"resource_original": combine_identifier(identifier_type, identifier),
"resource_original": combine_identifier(
identifier_type=identifier_type, identifier=identifier
),
"resource_value": b_value,
}
@@ -154,7 +157,7 @@ _Authorization = t.TypeVar("_Authorization", bound="Authorization")
class Authorization:
def __init__(self, url: str) -> None:
def __init__(self, *, url: str) -> None:
self.url = url
self.data: dict[str, t.Any] | None = None
@@ -163,14 +166,14 @@ class Authorization:
self.identifier_type: str | None = None
self.identifier: str | None = None
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
def _setup(self, *, client: ACMEClient, data: dict[str, t.Any]) -> None:
data["uri"] = self.url
self.data = data
# While 'challenges' is a required field, apparently not every CA cares
# (https://github.com/ansible-collections/community.crypto/issues/824)
if data.get("challenges"):
self.challenges = [
Challenge.from_json(client, challenge)
Challenge.from_json(client=client, data=challenge)
for challenge in data["challenges"]
]
else:
@@ -184,25 +187,27 @@ class Authorization:
@classmethod
def from_json(
cls: t.Type[_Authorization],
*,
client: ACMEClient,
data: dict[str, t.Any],
url: str,
) -> _Authorization:
result = cls(url)
result._setup(client, data)
result = cls(url=url)
result._setup(client=client, data=data)
return result
@classmethod
def from_url(
cls: t.Type[_Authorization], client: ACMEClient, url: str
cls: t.Type[_Authorization], *, client: ACMEClient, url: str
) -> _Authorization:
result = cls(url)
result.refresh(client)
result = cls(url=url)
result.refresh(client=client)
return result
@classmethod
def create(
cls: t.Type[_Authorization],
*,
client: ACMEClient,
identifier_type: str,
identifier: str,
@@ -220,7 +225,8 @@ class Authorization:
}
if "newAuthz" not in client.directory.directory:
raise ACMEProtocolException(
client.module, "ACME endpoint does not support pre-authorization"
module=client.module,
msg="ACME endpoint does not support pre-authorization",
)
url = client.directory["newAuthz"]
@@ -230,26 +236,28 @@ class Authorization:
error_msg="Failed to request challenges",
expected_status_codes=[200, 201],
)
return cls.from_json(client, result, info["location"])
return cls.from_json(client=client, data=result, url=info["location"])
@property
def combined_identifier(self) -> str:
if self.identifier_type is None or self.identifier is None:
raise ValueError("Data not present")
return combine_identifier(self.identifier_type, self.identifier)
return combine_identifier(
identifier_type=self.identifier_type, identifier=self.identifier
)
def to_json(self) -> dict[str, t.Any]:
if self.data is None:
raise ValueError("Data not present")
return self.data.copy()
def refresh(self, client: ACMEClient) -> bool:
def refresh(self, *, client: ACMEClient) -> bool:
result, dummy = client.get_request(self.url)
changed = self.data != result
self._setup(client, result)
self._setup(client=client, data=result)
return changed
def get_challenge_data(self, client: ACMEClient) -> dict[str, t.Any]:
def get_challenge_data(self, *, client: ACMEClient) -> dict[str, t.Any]:
"""
Returns a dict with the data for all proposed (and supported) challenges
of the given authorization.
@@ -259,13 +267,15 @@ class Authorization:
data = {}
for challenge in self.challenges:
validation_data = challenge.get_validation_data(
client, self.identifier_type, self.identifier
client=client,
identifier_type=self.identifier_type,
identifier=self.identifier,
)
if validation_data is not None:
data[challenge.type] = validation_data
return data
def raise_error(self, error_msg: str, module: AnsibleModule) -> t.NoReturn:
def raise_error(self, *, error_msg: str, module: AnsibleModule) -> t.NoReturn:
"""
Aborts with a specific error for a challenge.
"""
@@ -283,40 +293,40 @@ class Authorization:
msg = f"{msg}: {problem}"
error_details.append(msg)
raise ACMEProtocolException(
module,
f"Failed to validate challenge for {self.combined_identifier}: {error_msg}. {'; '.join(error_details)}",
module=module,
msg=f"Failed to validate challenge for {self.combined_identifier}: {error_msg}. {'; '.join(error_details)}",
extras=dict(
identifier=self.combined_identifier,
authorization=self.data,
),
)
def find_challenge(self, challenge_type: str) -> Challenge | None:
def find_challenge(self, *, challenge_type: str) -> Challenge | None:
for challenge in self.challenges:
if challenge_type == challenge.type:
return challenge
return None
def wait_for_validation(self, client: ACMEClient, callenge_type: str) -> bool:
def wait_for_validation(self, *, client: ACMEClient) -> bool:
while True:
self.refresh(client)
self.refresh(client=client)
if self.status in ["valid", "invalid", "revoked"]:
break
time.sleep(2)
if self.status == "invalid":
self.raise_error('Status is "invalid"', module=client.module)
self.raise_error(error_msg='Status is "invalid"', module=client.module)
return self.status == "valid"
def call_validate(
self, client: ACMEClient, challenge_type: str, wait: bool = True
self, *, client: ACMEClient, challenge_type: str, wait: bool = True
) -> bool:
"""
Validate the authorization provided in the auth dict. Returns True
when the validation was successful and False when it was not.
"""
challenge = self.find_challenge(challenge_type)
challenge = self.find_challenge(challenge_type=challenge_type)
if challenge is None:
raise ModuleFailException(
f'Found no challenge of type "{challenge_type}" for identifier {self.combined_identifier}!'
@@ -326,7 +336,7 @@ class Authorization:
if not wait:
return self.status == "valid"
return self.wait_for_validation(client, challenge_type)
return self.wait_for_validation(client=client)
def can_deactivate(self) -> bool:
"""
@@ -336,7 +346,7 @@ class Authorization:
"""
return self.status in ("valid", "pending")
def deactivate(self, client: ACMEClient) -> bool | None:
def deactivate(self, *, client: ACMEClient) -> bool | None:
"""
Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
@@ -355,35 +365,50 @@ class Authorization:
@classmethod
def deactivate_url(
cls: t.Type[_Authorization], client: ACMEClient, url: str
cls: t.Type[_Authorization], *, client: ACMEClient, url: str
) -> _Authorization:
"""
Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
https://tools.ietf.org/html/rfc8555#section-7.5.2
"""
authz = cls(url)
authz = cls(url=url)
authz_deactivate = {"status": "deactivated"}
result, info = client.send_signed_request(
url, authz_deactivate, fail_on_error=True
)
authz._setup(client, result)
authz._setup(client=client, data=result)
return authz
def wait_for_validation(authzs: t.Iterable[Authorization], client: ACMEClient) -> None:
def wait_for_validation(
*, authzs: t.Iterable[Authorization], client: ACMEClient
) -> None:
"""
Wait until a list of authz is valid. Fail if at least one of them is invalid or revoked.
"""
while authzs:
authzs_next = []
for authz in authzs:
authz.refresh(client)
authz.refresh(client=client)
if authz.status in ["valid", "invalid", "revoked"]:
if authz.status != "valid":
authz.raise_error('Status is not "valid"', module=client.module)
authz.raise_error(
error_msg='Status is not "valid"', module=client.module
)
else:
authzs_next.append(authz)
if authzs_next:
time.sleep(2)
authzs = authzs_next
__all__ = (
"create_key_authorization",
"combine_identifier",
"normalize_combined_identifier",
"split_identifier",
"Challenge",
"Authorization",
"wait_for_validation",
)

View File

@@ -25,7 +25,9 @@ def format_http_status(status_code: int) -> str:
return f"{status_code} {expl}"
def format_error_problem(problem: dict[str, t.Any], subproblem_prefix: str = "") -> str:
def format_error_problem(
problem: dict[str, t.Any], *, subproblem_prefix: str = ""
) -> str:
error_type = problem.get(
"type", "about:blank"
) # https://www.rfc-editor.org/rfc/rfc7807#section-3.1
@@ -57,13 +59,14 @@ class ModuleFailException(Exception):
self.msg = msg
self.module_fail_args = args
def do_fail(self, module: AnsibleModule, **arguments) -> t.NoReturn:
def do_fail(self, *, module: AnsibleModule, **arguments) -> t.NoReturn:
module.fail_json(msg=self.msg, other=self.module_fail_args, **arguments)
class ACMEProtocolException(ModuleFailException):
def __init__(
self,
*,
module: AnsibleModule,
msg: str | None = None,
info: dict[str, t.Any] | None = None,
@@ -168,3 +171,14 @@ class NetworkException(ModuleFailException):
class KeyParsingError(ModuleFailException):
pass
__all__ = (
"format_http_status",
"format_error_problem",
"ModuleFailException",
"ACMEProtocolException",
"BackendException",
"NetworkException",
"KeyParsingError",
)

View File

@@ -33,7 +33,9 @@ def read_file(fn: str | os.PathLike) -> bytes:
# This function was adapted from an earlier version of https://github.com/ansible/ansible/blob/devel/lib/ansible/modules/uri.py
def write_file(module: AnsibleModule, dest: str | os.PathLike, content: bytes) -> bool:
def write_file(
*, module: AnsibleModule, dest: str | os.PathLike, content: bytes
) -> bool:
"""
Write content to destination file dest, only if the content
has changed.
@@ -95,3 +97,6 @@ def write_file(module: AnsibleModule, dest: str | os.PathLike, content: bytes) -
)
os.remove(tmpsrc)
return changed
__all__ = ("read_file", "write_file")

View File

@@ -34,7 +34,7 @@ _Order = t.TypeVar("_Order", bound="Order")
class Order:
def __init__(self, url: str) -> None:
def __init__(self, *, url: str) -> None:
self.url = url
self.data: dict[str, t.Any] | None = None
@@ -47,7 +47,7 @@ class Order:
self.authorization_uris: list[str] = []
self.authorizations: dict[str, Authorization] = {}
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
def _setup(self, *, client: ACMEClient, data: dict[str, t.Any]) -> None:
self.data = data
self.status = data["status"]
@@ -62,21 +62,22 @@ class Order:
@classmethod
def from_json(
cls: t.Type[_Order], client: ACMEClient, data: dict[str, t.Any], url: str
cls: t.Type[_Order], *, client: ACMEClient, data: dict[str, t.Any], url: str
) -> _Order:
result = cls(url)
result._setup(client, data)
result = cls(url=url)
result._setup(client=client, data=data)
return result
@classmethod
def from_url(cls: t.Type[_Order], client: ACMEClient, url: str) -> _Order:
result = cls(url)
result.refresh(client)
def from_url(cls: t.Type[_Order], *, client: ACMEClient, url: str) -> _Order:
result = cls(url=url)
result.refresh(client=client)
return result
@classmethod
def create(
cls: t.Type[_Order],
*,
client: ACMEClient,
identifiers: list[tuple[str, str]],
replaces_cert_id: str | None = None,
@@ -105,11 +106,12 @@ class Order:
error_msg="Failed to start new order",
expected_status_codes=[201],
)
return cls.from_json(client, result, info["location"])
return cls.from_json(client=client, data=result, url=info["location"])
@classmethod
def create_with_error_handling(
cls: t.Type[_Order],
*,
client: ACMEClient,
identifiers: list[tuple[str, str]],
error_strategy: t.Literal[
@@ -136,8 +138,8 @@ class Order:
tries += 1
try:
return cls.create(
client,
identifiers,
client=client,
identifiers=identifiers,
replaces_cert_id=replaces_cert_id,
profile=profile,
)
@@ -164,34 +166,36 @@ class Order:
raise
def refresh(self, client: ACMEClient) -> bool:
def refresh(self, *, client: ACMEClient) -> bool:
result, dummy = client.get_request(self.url)
changed = self.data != result
self._setup(client, result)
self._setup(client=client, data=result)
return changed
def load_authorizations(self, client: ACMEClient) -> None:
def load_authorizations(self, *, client: ACMEClient) -> None:
for auth_uri in self.authorization_uris:
authz = Authorization.from_url(client, auth_uri)
authz = Authorization.from_url(client=client, url=auth_uri)
self.authorizations[
normalize_combined_identifier(authz.combined_identifier)
] = authz
def wait_for_finalization(self, client: ACMEClient) -> None:
def wait_for_finalization(self, *, client: ACMEClient) -> None:
while True:
self.refresh(client)
self.refresh(client=client)
if self.status in ["valid", "invalid", "pending", "ready"]:
break
time.sleep(2)
if self.status != "valid":
raise ACMEProtocolException(
client.module,
f'Failed to wait for order to complete; got status "{self.status}"',
module=client.module,
msg=f'Failed to wait for order to complete; got status "{self.status}"',
content_json=self.data,
)
def finalize(self, client: ACMEClient, csr_der: bytes, wait: bool = True) -> None:
def finalize(
self, *, client: ACMEClient, csr_der: bytes, wait: bool = True
) -> None:
"""
Create a new certificate based on the csr.
Return the certificate object as dict
@@ -212,13 +216,16 @@ class Order:
# Instead of using the result, we call self.refresh(client) below.
if wait:
self.wait_for_finalization(client)
self.wait_for_finalization(client=client)
else:
self.refresh(client)
self.refresh(client=client)
if self.status not in ["procesing", "valid", "invalid"]:
raise ACMEProtocolException(
client.module,
f'Failed to finalize order; got status "{self.status}"',
module=client.module,
msg=f'Failed to finalize order; got status "{self.status}"',
info=info,
content_json=result,
)
__all__ = ("Order",)

View File

@@ -48,7 +48,7 @@ def der_to_pem(der_cert: bytes) -> str:
def pem_to_der(
pem_filename: str | os.PathLike | None = None, pem_content: str | None = None
*, pem_filename: str | os.PathLike | None = None, pem_content: str | None = None
) -> bytes:
"""
Load PEM file, or use PEM file's content, and convert to DER.
@@ -85,7 +85,7 @@ def pem_to_der(
def process_links(
info: dict[str, t.Any], callback: t.Callable[[str, str], None]
*, info: dict[str, t.Any], callback: t.Callable[[str, str], None]
) -> None:
"""
Process link header, calls callback for every link header with the URL and relation as options.
@@ -100,6 +100,7 @@ def process_links(
def parse_retry_after(
value: str,
*,
relative_with_timezone: bool = True,
now: datetime.datetime | None = None,
) -> datetime.datetime:
@@ -112,7 +113,7 @@ def parse_retry_after(
try:
delta = datetime.timedelta(seconds=int(value))
if now is None:
now = get_now_datetime(relative_with_timezone)
now = get_now_datetime(with_timezone=relative_with_timezone)
return now + delta
except ValueError:
pass
@@ -126,6 +127,7 @@ def parse_retry_after(
def compute_cert_id(
*,
backend: CryptoBackend,
cert_info: CertificateInformation | None = None,
cert_filename: str | os.PathLike | None = None,
@@ -159,3 +161,13 @@ def compute_cert_id(
# Compose cert ID
return f"{aki}.{serial}"
__all__ = (
"nopad_b64",
"der_to_pem",
"pem_to_der",
"process_links",
"parse_retry_after",
"compute_cert_id",
)

View File

@@ -25,6 +25,7 @@ class ArgumentSpec:
def __init__(
self,
argument_spec: dict[str, t.Any] | None = None,
*,
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
required_together: list[list[str] | tuple[str, ...]] | None = None,
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
@@ -50,6 +51,7 @@ class ArgumentSpec:
def update(
self,
*,
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
required_together: list[list[str] | tuple[str, ...]] | None = None,
required_one_of: list[list[str] | tuple[str, ...]] | None = None,

View File

@@ -93,7 +93,12 @@ def serialize_asn1_string_as_der(value: str) -> bytes:
# We should only do a universal type tag if not IMPLICITLY tagged or the tag class is not universal.
if not tag_type or (tag_type == "EXPLICIT" and tag_class != "U"):
b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value)
b_value = pack_asn1(
tag_class=TagClass.universal,
constructed=False,
tag_number=TagNumber.utf8_string,
b_data=b_value,
)
if tag_type:
tag_class_enum = {
@@ -105,13 +110,22 @@ def serialize_asn1_string_as_der(value: str) -> bytes:
# When adding support for more types this should be looked into further. For now it works with UTF8Strings.
constructed = tag_type == "EXPLICIT" and tag_class_enum != TagClass.universal
b_value = pack_asn1(tag_class_enum, constructed, int(tag_number), b_value)
b_value = pack_asn1(
tag_class=tag_class_enum,
constructed=constructed,
tag_number=int(tag_number),
b_data=b_value,
)
return b_value
def pack_asn1(
tag_class: TagClass, constructed: bool, tag_number: TagNumber | int, b_data: bytes
*,
tag_class: TagClass,
constructed: bool,
tag_number: TagNumber | int,
b_data: bytes,
) -> bytes:
"""Pack the value into an ASN.1 data structure.
@@ -159,3 +173,6 @@ def pack_asn1(
b_asn1_data.extend(length_octets)
return bytes(b_asn1_data) + b_data
__all__ = ("TagClass", "TagNumber", "serialize_asn1_string_as_der", "pack_asn1")

View File

@@ -58,3 +58,6 @@ def obj2txt(openssl_lib, openssl_ffi, obj) -> str:
buf = openssl_ffi.new("char[]", buf_len)
res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1)
return openssl_ffi.buffer(buf, res)[:].decode()
__all__ = ("obj2txt",)

View File

@@ -33,3 +33,6 @@ for alias, original in [("userID", "userId")]:
NORMALIZE_NAMES[alias] = original
NORMALIZE_NAMES_SHORT[alias] = NORMALIZE_NAMES_SHORT[original]
OID_LOOKUP[alias] = OID_LOOKUP[original]
__all__ = ("OID_LOOKUP", "NORMALIZE_NAMES", "NORMALIZE_NAMES_SHORT")

View File

@@ -1175,3 +1175,6 @@ OID_MAP = {
"2.23.43.1.4.11": ("wap-wsg-idm-ecid-wtls11",),
"2.23.43.1.4.12": ("wap-wsg-idm-ecid-wtls12",),
}
__all__ = ("OID_MAP",)

View File

@@ -24,3 +24,6 @@ class OpenSSLObjectError(Exception):
class OpenSSLBadPassphraseError(OpenSSLObjectError):
pass
__all__ = ("HAS_CRYPTOGRAPHY", "OpenSSLObjectError", "OpenSSLBadPassphraseError")

View File

@@ -107,6 +107,7 @@ def cryptography_decode_revoked_certificate(
def cryptography_dump_revoked(
entry: dict[str, t.Any],
*,
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
) -> dict[str, t.Any]:
return {
@@ -174,18 +175,34 @@ def get_invalidity_date(obj: x509.InvalidityDate) -> datetime.datetime:
def set_next_update(
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
builder: x509.CertificateRevocationListBuilder, *, value: datetime.datetime
) -> x509.CertificateRevocationListBuilder:
return builder.next_update(value)
def set_last_update(
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
builder: x509.CertificateRevocationListBuilder, *, value: datetime.datetime
) -> x509.CertificateRevocationListBuilder:
return builder.last_update(value)
def set_revocation_date(
builder: x509.RevokedCertificateBuilder, value: datetime.datetime
builder: x509.RevokedCertificateBuilder, *, value: datetime.datetime
) -> x509.RevokedCertificateBuilder:
return builder.revocation_date(value)
__all__ = (
"REVOCATION_REASON_MAP",
"REVOCATION_REASON_MAP_INVERSE",
"cryptography_decode_revoked_certificate",
"cryptography_dump_revoked",
"cryptography_get_signature_algorithm_oid_from_crl",
"get_next_update",
"get_last_update",
"get_revocation_date",
"get_invalidity_date",
"set_next_update",
"set_last_update",
"set_revocation_date",
)

View File

@@ -263,7 +263,7 @@ def cryptography_name_to_oid(name: str) -> x509.oid.ObjectIdentifier:
def cryptography_oid_to_name(
oid: x509.oid.ObjectIdentifier, short: bool = False
oid: x509.oid.ObjectIdentifier, *, short: bool = False
) -> str:
dotted_string = oid.dotted_string
names = OID_MAP.get(dotted_string)
@@ -315,7 +315,7 @@ def _int_to_byte(value: int) -> bytes:
def _parse_dn_component(
name: bytes, sep: bytes = b",", decode_remainder: bool = True
name: bytes, *, sep: bytes = b",", decode_remainder: bool = True
) -> tuple[x509.NameAttribute, bytes]:
m = DN_COMPONENT_START_RE.match(name)
if not m:
@@ -428,7 +428,9 @@ def _is_ascii(value: str) -> bool:
return False
def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]) -> str:
def _adjust_idn(
value: str, *, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
) -> str:
if idn_rewrite == "ignore" or not value:
return value
if idn_rewrite == "idna" and _is_ascii(value):
@@ -472,19 +474,19 @@ def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"])
def _adjust_idn_email(
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
value: str, *, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
) -> str:
idx = value.find("@")
if idx < 0:
return value
return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite)}"
return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite=idn_rewrite)}"
def _adjust_idn_url(
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
value: str, *, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
) -> str:
url = urlparse(value)
host = _adjust_idn(url.hostname, idn_rewrite) if url.hostname else None
host = _adjust_idn(url.hostname, idn_rewrite=idn_rewrite) if url.hostname else None
if url.username is not None and url.password is not None:
host = f"{url.username}:{url.password}@{host}"
elif url.username is not None:
@@ -504,7 +506,7 @@ def _adjust_idn_url(
def cryptography_get_name(
name: str, what: str = "Subject Alternative Name"
name: str, *, what: str = "Subject Alternative Name"
) -> x509.GeneralName:
"""
Given a name string, returns a cryptography x509.GeneralName object.
@@ -512,17 +514,19 @@ def cryptography_get_name(
"""
try:
if name.startswith("DNS:"):
return x509.DNSName(_adjust_idn(to_text(name[4:]), "idna"))
return x509.DNSName(_adjust_idn(to_text(name[4:]), idn_rewrite="idna"))
if name.startswith("IP:"):
address = to_text(name[3:])
if "/" in address:
return x509.IPAddress(ipaddress.ip_network(address))
return x509.IPAddress(ipaddress.ip_address(address))
if name.startswith("email:"):
return x509.RFC822Name(_adjust_idn_email(to_text(name[6:]), "idna"))
return x509.RFC822Name(
_adjust_idn_email(to_text(name[6:]), idn_rewrite="idna")
)
if name.startswith("URI:"):
return x509.UniformResourceIdentifier(
_adjust_idn_url(to_text(name[4:]), "idna")
_adjust_idn_url(to_text(name[4:]), idn_rewrite="idna")
)
if name.startswith("RID:"):
m = re.match(r"^([0-9]+(?:\.[0-9]+)*)$", to_text(name[4:]))
@@ -585,6 +589,7 @@ def _dn_escape_value(value: str) -> str:
def cryptography_decode_name(
name: x509.GeneralName,
*,
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
) -> str:
"""
@@ -596,15 +601,15 @@ def cryptography_decode_name(
'idn_rewrite must be one of "ignore", "idna", or "unicode"'
)
if isinstance(name, x509.DNSName):
return f"DNS:{_adjust_idn(name.value, idn_rewrite)}"
return f"DNS:{_adjust_idn(name.value, idn_rewrite=idn_rewrite)}"
if isinstance(name, x509.IPAddress):
if isinstance(name.value, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
return f"IP:{name.value.network_address.compressed}/{name.value.prefixlen}"
return f"IP:{name.value.compressed}"
if isinstance(name, x509.RFC822Name):
return f"email:{_adjust_idn_email(name.value, idn_rewrite)}"
return f"email:{_adjust_idn_email(name.value, idn_rewrite=idn_rewrite)}"
if isinstance(name, x509.UniformResourceIdentifier):
return f"URI:{_adjust_idn_url(name.value, idn_rewrite)}"
return f"URI:{_adjust_idn_url(name.value, idn_rewrite=idn_rewrite)}"
if isinstance(name, x509.DirectoryName):
# According to https://datatracker.ietf.org/doc/html/rfc4514.html#section-2.1 the
# list needs to be reversed, and joined by commas
@@ -718,7 +723,7 @@ def cryptography_key_needs_digest_for_signing(
def _compare_public_keys(
key1: PublicKeyTypes, key2: PublicKeyTypes, clazz: type[PublicKeyTypes]
key1: PublicKeyTypes, key2: PublicKeyTypes, *, clazz: type[PublicKeyTypes]
) -> bool | None:
a = isinstance(key1, clazz)
b = isinstance(key2, clazz)
@@ -745,24 +750,24 @@ def cryptography_compare_public_keys(
res = _compare_public_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
clazz=cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
)
if res is not None:
return res
res = _compare_public_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
clazz=cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
)
if res is not None:
return res
res = _compare_public_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey
key1, key2, clazz=cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey
)
if res is not None:
return res
res = _compare_public_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
key1, key2, clazz=cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
)
if res is not None:
return res
@@ -773,7 +778,7 @@ def cryptography_compare_public_keys(
def _compare_private_keys(
key1: PrivateKeyTypes, key2: PrivateKeyTypes, clazz: type[PrivateKeyTypes]
key1: PrivateKeyTypes, key2: PrivateKeyTypes, *, clazz: type[PrivateKeyTypes]
) -> bool | None:
a = isinstance(key1, clazz)
b = isinstance(key2, clazz)
@@ -805,24 +810,26 @@ def cryptography_compare_private_keys(
res = _compare_private_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
clazz=cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
)
if res is not None:
return res
res = _compare_private_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
clazz=cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
)
if res is not None:
return res
res = _compare_private_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey
key1,
key2,
clazz=cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey,
)
if res is not None:
return res
res = _compare_private_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
key1, key2, clazz=cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
)
if res is not None:
return res
@@ -832,7 +839,9 @@ def cryptography_compare_private_keys(
)
def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) -> tuple[
def parse_pkcs12(
pkcs12_bytes: bytes, *, passphrase: bytes | str | None = None
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
@@ -845,15 +854,17 @@ def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) ->
# Main code for cryptography 36.0.0 and forward
if _load_pkcs12 is not None:
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase_bytes)
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=passphrase_bytes)
if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"):
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase_bytes)
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=passphrase_bytes)
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase_bytes)
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase=passphrase_bytes)
def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
def _parse_pkcs12_36_0_0(
pkcs12_bytes: bytes, *, passphrase: bytes | None = None
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
@@ -871,7 +882,9 @@ def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -
return private_key, certificate, additional_certificates, friendly_name
def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
def _parse_pkcs12_35_0_0(
pkcs12_bytes: bytes, *, passphrase: bytes | None = None
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
@@ -918,7 +931,9 @@ def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -
return private_key, certificate, additional_certificates, friendly_name
def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
def _parse_pkcs12_legacy(
pkcs12_bytes: bytes, *, passphrase: bytes | None = None
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
@@ -940,6 +955,7 @@ def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -
def cryptography_verify_signature(
*,
signature: bytes,
data: bytes,
hash_algorithm: hashes.HashAlgorithm | None,
@@ -999,16 +1015,16 @@ def cryptography_verify_signature(
def cryptography_verify_certificate_signature(
certificate: x509.Certificate, signer_public_key: PublicKeyTypes
*, certificate: x509.Certificate, signer_public_key: PublicKeyTypes
) -> bool:
"""
Check whether the given X509 certificate object was signed by the given public key object.
"""
return cryptography_verify_signature(
certificate.signature,
certificate.tbs_certificate_bytes,
certificate.signature_hash_algorithm,
signer_public_key,
signature=certificate.signature,
data=certificate.tbs_certificate_bytes,
hash_algorithm=certificate.signature_hash_algorithm,
signer_public_key=signer_public_key,
)
@@ -1074,3 +1090,31 @@ def is_potential_certificate_issuer_public_key(
cryptography.hazmat.primitives.asymmetric.dh.DHPublicKey,
),
)
__all__ = (
"CRYPTOGRAPHY_TIMEZONE",
"cryptography_get_extensions_from_cert",
"cryptography_get_extensions_from_csr",
"cryptography_name_to_oid",
"cryptography_oid_to_name",
"cryptography_parse_relative_distinguished_name",
"cryptography_get_name",
"cryptography_decode_name",
"cryptography_parse_key_usage_params",
"cryptography_get_basic_constraints",
"cryptography_key_needs_digest_for_signing",
"cryptography_compare_public_keys",
"cryptography_compare_private_keys",
"parse_pkcs12",
"cryptography_verify_signature",
"cryptography_verify_certificate_signature",
"get_not_valid_after",
"get_not_valid_before",
"set_not_valid_after",
"set_not_valid_before",
"is_potential_certificate_private_key",
"is_potential_certificate_issuer_private_key",
"is_potential_certificate_public_key",
"is_potential_certificate_issuer_public_key",
)

View File

@@ -8,7 +8,7 @@
from __future__ import annotations
def binary_exp_mod(f: int, e: int, m: int) -> int:
def binary_exp_mod(f: int, e: int, *, m: int) -> int:
"""Computes f^e mod m in O(log e) multiplications modulo m."""
# Compute len_e = floor(log_2(e))
len_e = -1
@@ -120,7 +120,7 @@ def count_bits(no: int) -> int:
return no.bit_length()
def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
def convert_int_to_bytes(no: int, *, count: int | None = None) -> bytes:
"""
Convert the absolute value of an integer to a byte string in network byte order.
@@ -136,7 +136,7 @@ def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
return no.to_bytes(count, byteorder="big")
def convert_int_to_hex(no: int, digits: int | None = None) -> str:
def convert_int_to_hex(no: int, *, digits: int | None = None) -> str:
"""
Convert the absolute value of an integer to a string of hexadecimal digits.
@@ -156,3 +156,15 @@ def convert_bytes_to_int(data: bytes) -> int:
Convert a byte string to an unsigned integer in network byte order.
"""
return int.from_bytes(data, byteorder="big", signed=False)
__all__ = (
"binary_exp_mod",
"simple_gcd",
"quick_is_not_prime",
"count_bytes",
"count_bits",
"convert_int_to_bytes",
"convert_int_to_hex",
"convert_bytes_to_int",
)

View File

@@ -63,7 +63,7 @@ class CertificateError(OpenSSLObjectError):
class CertificateBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.force: bool = module.params["force"]
@@ -104,7 +104,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return {}
try:
result = get_certificate_info(
self.module, data, prefer_one_fingerprint=True
module=self.module, content=data, prefer_one_fingerprint=True
)
result["can_parse_certificate"] = True
return result
@@ -289,6 +289,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -330,7 +331,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return True
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
@@ -372,7 +373,7 @@ class CertificateProvider(metaclass=abc.ABCMeta):
def select_backend(
module: AnsibleModule, provider: CertificateProvider
*, module: AnsibleModule, provider: CertificateProvider
) -> CertificateBackend:
provider.validate_module_args(module)
@@ -415,3 +416,11 @@ def get_certificate_argument_spec() -> ArgumentSpec:
["privatekey_path", "privatekey_content"],
],
)
__all__ = (
"CertificateError",
"CertificateBackend",
"CertificateProvider",
"get_certificate_argument_spec",
)

View File

@@ -29,8 +29,8 @@ if t.TYPE_CHECKING:
class AcmeCertificateBackend(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(AcmeCertificateBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(AcmeCertificateBackend, self).__init__(module=module)
self.accountkey_path: str = module.params["acme_accountkey_path"]
self.challenge_path: str = module.params["acme_challenge_path"]
self.use_chain: bool = module.params["acme_chain"]
@@ -102,8 +102,10 @@ class AcmeCertificateBackend(CertificateBackend):
raise AssertionError("Contract violation: cert_bytes is None")
return self.cert_bytes
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(AcmeCertificateBackend, self).dump(include_certificate)
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
result = super(AcmeCertificateBackend, self).dump(
include_certificate=include_certificate
)
result["accountkey"] = self.accountkey_path
return result
@@ -123,7 +125,7 @@ class AcmeCertificateProvider(CertificateProvider):
return False
def create_backend(self, module: AnsibleModule) -> AcmeCertificateBackend:
return AcmeCertificateBackend(module)
return AcmeCertificateBackend(module=module)
def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -138,3 +140,10 @@ def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
),
)
)
__all__ = (
"AcmeCertificateBackend",
"AcmeCertificateProvider",
"add_acme_provider_to_argument_spec",
)

View File

@@ -50,12 +50,12 @@ except ImportError:
class EntrustCertificateBackend(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(EntrustCertificateBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(EntrustCertificateBackend, self).__init__(module=module)
self.trackingId = None
self.notAfter = get_relative_time_option(
module.params["entrust_not_after"],
"entrust_not_after",
input_name="entrust_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
@@ -159,6 +159,7 @@ class EntrustCertificateBackend(CertificateBackend):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -229,7 +230,7 @@ class EntrustCertificateProvider(CertificateProvider):
return False
def create_backend(self, module: AnsibleModule) -> EntrustCertificateBackend:
return EntrustCertificateBackend(module)
return EntrustCertificateBackend(module=module)
def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -281,3 +282,10 @@ def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
],
)
)
__all__ = (
"EntrustCertificateBackend",
"EntrustCertificateProvider",
"add_entrust_provider_to_argument_spec",
)

View File

@@ -70,7 +70,7 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
# content must be a bytes string
self.module = module
self.content = content
@@ -158,11 +158,10 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
pass
def get_info(
self, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
self, *, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.cert = load_certificate(
None,
content=self.content,
der_support_enabled=der_support_enabled,
)
@@ -204,7 +203,7 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
result["public_key"] = to_native(self._get_public_key_pem())
public_key_info = get_publickey_info(
self.module,
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
@@ -249,8 +248,10 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
"""Validate the supplied cert, using the cryptography backend"""
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
super(CertificateInfoRetrievalCryptography, self).__init__(module, content)
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
super(CertificateInfoRetrievalCryptography, self).__init__(
module=module, content=content
)
self.name_encoding = module.params.get("name_encoding", "ignore")
def _get_der_bytes(self) -> bytes:
@@ -465,16 +466,22 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
def get_certificate_info(
module: GeneralAnsibleModule, content: bytes, prefer_one_fingerprint: bool = False
*,
module: GeneralAnsibleModule,
content: bytes,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CertificateInfoRetrievalCryptography(module, content)
info = CertificateInfoRetrievalCryptography(module=module, content=content)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
module: GeneralAnsibleModule, content: bytes
*, module: GeneralAnsibleModule, content: bytes
) -> CertificateInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateInfoRetrievalCryptography(module, content)
return CertificateInfoRetrievalCryptography(module=module, content=content)
__all__ = ("CertificateInfoRetrieval", "get_certificate_info", "select_backend")

View File

@@ -62,8 +62,8 @@ except ImportError:
class OwnCACertificateBackendCryptography(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(OwnCACertificateBackendCryptography, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(OwnCACertificateBackendCryptography, self).__init__(module=module)
self.create_subject_key_identifier: t.Literal[
"create_if_not_provided", "always_create", "never_create"
@@ -73,12 +73,12 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
]
self.notBefore = get_relative_time_option(
module.params["ownca_not_before"],
"ownca_not_before",
input_name="ownca_not_before",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.notAfter = get_relative_time_option(
module.params["ownca_not_after"],
"ownca_not_after",
input_name="ownca_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["ownca_digest"])
@@ -220,6 +220,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -233,7 +234,8 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
# Check whether certificate is signed by CA certificate
if not cryptography_verify_certificate_signature(
self.existing_certificate, self.ca_cert.public_key()
certificate=self.existing_certificate,
signer_public_key=self.ca_cert.public_key(),
):
return True
@@ -270,9 +272,9 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
result = super(OwnCACertificateBackendCryptography, self).dump(
include_certificate
include_certificate=include_certificate
)
result.update(
{
@@ -339,7 +341,7 @@ class OwnCACertificateProvider(CertificateProvider):
def create_backend(
self, module: AnsibleModule
) -> OwnCACertificateBackendCryptography:
return OwnCACertificateBackendCryptography(module)
return OwnCACertificateBackendCryptography(module=module)
def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -369,3 +371,10 @@ def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
["ownca_privatekey_path", "ownca_privatekey_content"],
]
)
__all__ = (
"OwnCACertificateBackendCryptography",
"OwnCACertificateProvider",
"add_ownca_provider_to_argument_spec",
)

View File

@@ -58,20 +58,20 @@ except ImportError:
class SelfSignedCertificateBackendCryptography(CertificateBackend):
privatekey: CertificateIssuerPrivateKeyTypes
def __init__(self, module: AnsibleModule) -> None:
super(SelfSignedCertificateBackendCryptography, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(SelfSignedCertificateBackendCryptography, self).__init__(module=module)
self.create_subject_key_identifier: t.Literal[
"create_if_not_provided", "always_create", "never_create"
] = module.params["selfsigned_create_subject_key_identifier"]
self.notBefore = get_relative_time_option(
module.params["selfsigned_not_before"],
"selfsigned_not_before",
input_name="selfsigned_not_before",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.notAfter = get_relative_time_option(
module.params["selfsigned_not_after"],
"selfsigned_not_after",
input_name="selfsigned_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["selfsigned_digest"])
@@ -162,6 +162,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
def needs_regeneration(
self,
*,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
@@ -177,15 +178,16 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
# Check whether certificate is signed by private key
if not cryptography_verify_certificate_signature(
self.existing_certificate, self.privatekey.public_key()
certificate=self.existing_certificate,
signer_public_key=self.privatekey.public_key(),
):
return True
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
result = super(SelfSignedCertificateBackendCryptography, self).dump(
include_certificate
include_certificate=include_certificate
)
if self.module.check_mode:
@@ -239,7 +241,7 @@ class SelfSignedCertificateProvider(CertificateProvider):
def create_backend(
self, module: AnsibleModule
) -> SelfSignedCertificateBackendCryptography:
return SelfSignedCertificateBackendCryptography(module)
return SelfSignedCertificateBackendCryptography(module=module)
def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
@@ -261,3 +263,10 @@ def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> Non
),
)
)
__all__ = (
"SelfSignedCertificateBackendCryptography",
"SelfSignedCertificateProvider",
"add_selfsigned_provider_to_argument_spec",
)

View File

@@ -55,6 +55,7 @@ except ImportError:
class CRLInfoRetrieval:
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes,
list_revoked_certificates: bool = True,
@@ -113,12 +114,20 @@ class CRLInfoRetrieval:
def get_crl_info(
module: GeneralAnsibleModule, content: bytes, list_revoked_certificates: bool = True
*,
module: GeneralAnsibleModule,
content: bytes,
list_revoked_certificates: bool = True,
) -> dict[str, t.Any]:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
info = CRLInfoRetrieval(
module, content, list_revoked_certificates=list_revoked_certificates
module=module,
content=content,
list_revoked_certificates=list_revoked_certificates,
)
return info.get_info()
__all__ = ("CRLInfoRetrieval", "get_crl_info")

View File

@@ -87,7 +87,7 @@ class CertificateSigningRequestError(OpenSSLObjectError):
class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.digest: str = module.params["digest"]
self.privatekey_path: str | None = module.params["privatekey_path"]
@@ -158,7 +158,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
try:
if module.params["subject"]:
self.subject = self.subject + parse_name_field(
module.params["subject"], "subject"
module.params["subject"], name_field_name="subject"
)
if module.params["subject_ordered"]:
if self.subject:
@@ -166,7 +166,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
"subject_ordered cannot be combined with any other subject field"
)
self.subject = parse_ordered_name_field(
module.params["subject_ordered"], "subject_ordered"
module.params["subject_ordered"], name_field_name="subject_ordered"
)
self.ordered_subject = True
except ValueError as exc:
@@ -205,16 +205,16 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self.existing_csr: cryptography.x509.CertificateSigningRequest | None = None
self.existing_csr_bytes: bytes | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
self.diff_before = self._get_info(data=None)
self.diff_after = self._get_info(data=None)
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
def _get_info(self, *, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
try:
result = get_csr_info(
self.module,
data,
module=self.module,
content=data,
validate_signature=False,
prefer_one_fingerprint=True,
)
@@ -231,10 +231,12 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def get_csr_data(self) -> bytes:
"""Return bytes for self.csr."""
def set_existing(self, csr_bytes: bytes | None) -> None:
def set_existing(self, *, csr_bytes: bytes | None) -> None:
"""Set existing CSR bytes. None indicates that the CSR does not exist."""
self.existing_csr_bytes = csr_bytes
self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes)
self.diff_after = self.diff_before = self._get_info(
data=self.existing_csr_bytes
)
def has_existing(self) -> bool:
"""Query whether an existing CSR is/has been there."""
@@ -263,7 +265,6 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
return True
try:
self.existing_csr = load_certificate_request(
None,
content=self.existing_csr_bytes,
)
except Exception:
@@ -271,7 +272,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self._ensure_private_key_loaded()
return not self._check_csr()
def dump(self, include_csr: bool) -> dict[str, t.Any]:
def dump(self, *, include_csr: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
@@ -288,7 +289,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
csr_bytes = self.existing_csr_bytes
if self.csr is not None:
csr_bytes = self.get_csr_data()
self.diff_after = self._get_info(csr_bytes)
self.diff_after = self._get_info(data=csr_bytes)
if include_csr:
# Store result
result["csr"] = csr_bytes.decode("utf-8") if csr_bytes else None
@@ -301,7 +302,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def parse_crl_distribution_points(
module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
*, module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
) -> list[cryptography.x509.DistributionPoint]:
result = []
for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
@@ -314,7 +315,7 @@ def parse_crl_distribution_points(
if not parse_crl_distribution_point["full_name"]:
raise OpenSSLObjectError("full_name must not be empty")
full_name = [
cryptography_get_name(name, "full name")
cryptography_get_name(name, what="full name")
for name in parse_crl_distribution_point["full_name"]
]
if parse_crl_distribution_point["relative_name"] is not None:
@@ -327,7 +328,7 @@ def parse_crl_distribution_points(
if not parse_crl_distribution_point["crl_issuer"]:
raise OpenSSLObjectError("crl_issuer must not be empty")
crl_issuer = [
cryptography_get_name(name, "CRL issuer")
cryptography_get_name(name, what="CRL issuer")
for name in parse_crl_distribution_point["crl_issuer"]
]
if parse_crl_distribution_point["reasons"] is not None:
@@ -352,8 +353,10 @@ def parse_crl_distribution_points(
# Implementation with using cryptography
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
def __init__(self, module: AnsibleModule) -> None:
super(CertificateSigningRequestCryptographyBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(CertificateSigningRequestCryptographyBackend, self).__init__(
module=module
)
if self.version != 1:
module.warn(
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
@@ -364,7 +367,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
]
if crl_distribution_points:
self.crl_distribution_points = parse_crl_distribution_points(
module, crl_distribution_points
module=module, crl_distribution_points=crl_distribution_points
)
def generate_csr(self) -> None:
@@ -431,12 +434,16 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
csr = csr.add_extension(
cryptography.x509.NameConstraints(
[
cryptography_get_name(name, "name constraints permitted")
cryptography_get_name(
name, what="name constraints permitted"
)
for name in self.name_constraints_permitted
]
or None,
[
cryptography_get_name(name, "name constraints excluded")
cryptography_get_name(
name, what="name constraints excluded"
)
for name in self.name_constraints_excluded
]
or None,
@@ -473,7 +480,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
issuers = None
if self.authority_cert_issuer is not None:
issuers = [
cryptography_get_name(n, "authority cert issuer")
cryptography_get_name(n, what="authority cert issuer")
for n in self.authority_cert_issuer
]
csr = csr.add_extension(
@@ -679,11 +686,15 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else []
)
nc_perm = [
to_text(cryptography_get_name(altname, "name constraints permitted"))
to_text(
cryptography_get_name(altname, what="name constraints permitted")
)
for altname in self.name_constraints_permitted
]
nc_excl = [
to_text(cryptography_get_name(altname, "name constraints excluded"))
to_text(
cryptography_get_name(altname, what="name constraints excluded")
)
for altname in self.name_constraints_excluded
]
if set(nc_perm) != set(current_nc_perm) or set(nc_excl) != set(
@@ -731,7 +742,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
csr_aci = None
if self.authority_cert_issuer is not None:
aci = [
to_text(cryptography_get_name(n, "authority cert issuer"))
to_text(cryptography_get_name(n, what="authority cert issuer"))
for n in self.authority_cert_issuer
]
if ext.value.authority_cert_issuer is not None:
@@ -798,7 +809,7 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateSigningRequestCryptographyBackend(module)
return CertificateSigningRequestCryptographyBackend(module=module)
def get_csr_argument_spec() -> ArgumentSpec:
@@ -903,3 +914,11 @@ def get_csr_argument_spec() -> ArgumentSpec:
["privatekey_path", "privatekey_content"],
],
)
__all__ = (
"CertificateSigningRequestError",
"CertificateSigningRequestBackend",
"select_backend",
"get_csr_argument_spec",
)

View File

@@ -62,7 +62,7 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CSRInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
self, *, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
self.module = module
self.content = content
@@ -122,10 +122,9 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
def _is_signature_valid(self) -> bool:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.csr = load_certificate_request(
None,
content=self.content,
)
@@ -156,7 +155,7 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
result["public_key"] = to_native(self._get_public_key_pem())
public_key_info = get_publickey_info(
self.module,
module=self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
@@ -196,10 +195,10 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
"""Validate the supplied CSR, using the cryptography backend"""
def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
self, *, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
super(CSRInfoRetrievalCryptography, self).__init__(
module, content, validate_signature
module=module, content=content, validate_signature=validate_signature
)
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get(
"name_encoding", "ignore"
@@ -372,23 +371,27 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
def get_csr_info(
*,
module: GeneralAnsibleModule,
content: bytes,
validate_signature: bool = True,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
module=module, content=content, validate_signature=validate_signature
)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
*, module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
) -> CSRInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
module=module, content=content, validate_signature=validate_signature
)
__all__ = ("CSRInfoRetrieval", "get_csr_info", "select_backend")

View File

@@ -80,7 +80,7 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyBackend(metaclass=abc.ABCMeta):
def __init__(self, module: GeneralAnsibleModule) -> None:
def __init__(self, *, module: GeneralAnsibleModule) -> None:
self.module = module
self.type: t.Literal[
"DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"
@@ -104,18 +104,18 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
self.existing_private_key: PrivateKeyTypes | None = None
self.existing_private_key_bytes: bytes | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
self.diff_before = self._get_info(data=None)
self.diff_after = self._get_info(data=None)
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
def _get_info(self, *, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
result: dict[str, t.Any] = {"can_parse_key": False}
try:
result.update(
get_privatekey_info(
self.module,
data,
module=self.module,
content=data,
passphrase=self.passphrase,
return_private_key_data=False,
prefer_one_fingerprint=True,
@@ -148,11 +148,11 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key."""
def set_existing(self, privatekey_bytes: bytes | None) -> None:
def set_existing(self, *, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.existing_private_key_bytes = privatekey_bytes
self.diff_after = self.diff_before = self._get_info(
self.existing_private_key_bytes
data=self.existing_private_key_bytes
)
def has_existing(self) -> bool:
@@ -235,7 +235,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
return get_fingerprint_of_privatekey(self.existing_private_key)
return None
def dump(self, include_key: bool) -> dict[str, t.Any]:
def dump(self, *, include_key: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
if not self.private_key:
@@ -255,7 +255,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
pk_bytes = self.existing_private_key_bytes
if self.private_key is not None:
pk_bytes = self.get_private_key_data()
self.diff_after = self._get_info(pk_bytes)
self.diff_after = self._get_info(data=pk_bytes)
if include_key:
# Store result
if pk_bytes:
@@ -276,6 +276,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
class _Curve:
def __init__(
self,
*,
name: str,
ectype: str,
deprecated: bool,
@@ -285,7 +286,7 @@ class _Curve:
self.deprecated = deprecated
def _get_ec_class(
self, module: GeneralAnsibleModule
self, *, module: GeneralAnsibleModule
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore
if ecclass is None:
@@ -295,17 +296,18 @@ class _Curve:
return ecclass
def create(
self, size: int, module: GeneralAnsibleModule
self, *, size: int, module: GeneralAnsibleModule
) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
ecclass = self._get_ec_class(module)
ecclass = self._get_ec_class(module=module)
return ecclass()
def verify(
self,
*,
privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
module: GeneralAnsibleModule,
) -> bool:
ecclass = self._get_ec_class(module)
ecclass = self._get_ec_class(module=module)
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
@@ -316,6 +318,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
self,
name: str,
ectype: str,
*,
deprecated: bool = False,
) -> None:
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
@@ -575,7 +578,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
if self.curve not in self.curves:
return False
return self.curves[self.curve].verify(
self.existing_private_key, module=self.module
privatekey=self.existing_private_key, module=self.module
)
return False
@@ -596,7 +599,7 @@ def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyCryptographyBackend(module)
return PrivateKeyCryptographyBackend(module=module)
def get_privatekey_argument_spec() -> ArgumentSpec:
@@ -661,3 +664,11 @@ def get_privatekey_argument_spec() -> ArgumentSpec:
("type", "ECC", ["curve"]),
],
)
__all__ = (
"PrivateKeyError",
"PrivateKeyBackend",
"select_backend",
"get_privatekey_argument_spec",
)

View File

@@ -69,7 +69,7 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.src_path: str | None = module.params["src_path"]
self.src_content: str | None = module.params["src_content"]
@@ -79,7 +79,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
self.src_private_key: PrivateKeyTypes | None = None
if self.src_path is not None:
self.src_private_key_bytes = load_file(self.src_path, module)
self.src_private_key_bytes = load_file(path=self.src_path, module=module)
else:
if self.src_content is None:
raise AssertionError("src_content is None")
@@ -93,7 +93,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
"""Return bytes for self.src_private_key in output format."""
pass
def set_existing_destination(self, privatekey_bytes: bytes | None) -> None:
def set_existing_destination(self, *, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.dest_private_key_bytes = privatekey_bytes
@@ -104,6 +104,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
@abc.abstractmethod
def _load_private_key(
self,
*,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
@@ -113,7 +114,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
dummy, self.src_private_key = self._load_private_key(
self.src_private_key_bytes, self.src_passphrase
data=self.src_private_key_bytes, passphrase=self.src_passphrase
)
if not self.has_existing_destination():
@@ -122,8 +123,8 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
try:
format, self.dest_private_key = self._load_private_key(
self.dest_private_key_bytes,
self.dest_passphrase,
data=self.dest_private_key_bytes,
passphrase=self.dest_passphrase,
current_hint=self.src_private_key,
)
except Exception:
@@ -140,7 +141,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
# Implementation with using cryptography
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module)
def get_private_key_data(self) -> bytes:
@@ -201,6 +202,7 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def _load_private_key(
self,
*,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
@@ -276,7 +278,7 @@ def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyConvertCryptographyBackend(module)
return PrivateKeyConvertCryptographyBackend(module=module)
def get_privatekey_argument_spec() -> ArgumentSpec:
@@ -295,3 +297,11 @@ def get_privatekey_argument_spec() -> ArgumentSpec:
["src_path", "src_content"],
],
)
__all__ = (
"PrivateKeyError",
"PrivateKeyConvertBackend",
"select_backend",
"get_privatekey_argument_spec",
)

View File

@@ -60,7 +60,7 @@ SIGNATURE_TEST_DATA = b"1234"
def _get_cryptography_private_key_info(
key: PrivateKeyTypes, need_private_key_data: bool = False
key: PrivateKeyTypes, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
key_type, key_public_data = _get_cryptography_public_key_info(key.public_key())
key_private_data: dict[str, t.Any] = {}
@@ -84,7 +84,7 @@ def _get_cryptography_private_key_info(
def _check_dsa_consistency(
key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
*, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
# Get parameters
p: int | None = key_public_data.get("p")
@@ -112,10 +112,10 @@ def _check_dsa_consistency(
if (p - 1) % q != 0:
return False
# Check that g**q mod p == 1
if binary_exp_mod(g, q, p) != 1:
if binary_exp_mod(g, q, m=p) != 1:
return False
# Check whether g**x mod p == y
if binary_exp_mod(g, x, p) != y:
if binary_exp_mod(g, x, m=p) != y:
return False
# Check (quickly) whether p or q are not primes
if quick_is_not_prime(q) or quick_is_not_prime(p):
@@ -125,6 +125,7 @@ def _check_dsa_consistency(
def _is_cryptography_key_consistent(
key: PrivateKeyTypes,
*,
key_public_data: dict[str, t.Any],
key_private_data: dict[str, t.Any],
warn_func: t.Callable[[str], None] | None = None,
@@ -135,7 +136,9 @@ def _is_cryptography_key_consistent(
if backend is not None:
return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
result = _check_dsa_consistency(key_public_data, key_private_data)
result = _check_dsa_consistency(
key_public_data=key_public_data, key_private_data=key_private_data
)
if result is not None:
return result
signature = key.sign(
@@ -191,14 +194,14 @@ def _is_cryptography_key_consistent(
class PrivateKeyConsistencyError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PrivateKeyConsistencyError, self).__init__(msg)
self.error_message = msg
self.result = result
class PrivateKeyParseError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PrivateKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
@@ -207,6 +210,7 @@ class PrivateKeyParseError(OpenSSLObjectError):
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -220,22 +224,22 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
self.check_consistency = check_consistency
@abc.abstractmethod
def _get_public_key(self, binary: bool) -> bytes:
def _get_public_key(self, *, binary: bool) -> bytes:
pass
@abc.abstractmethod
def _get_key_info(
self, need_private_key_data: bool = False
self, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
pass
@abc.abstractmethod
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {
"can_parse_key": False,
"key_is_consistent": None,
@@ -253,7 +257,7 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
)
result["can_parse_key"] = True
except OpenSSLObjectError as exc:
raise PrivateKeyParseError(str(exc), result)
raise PrivateKeyParseError(str(exc), result=result)
result["public_key"] = to_native(self._get_public_key(binary=False))
pk = self._get_public_key(binary=True)
@@ -273,7 +277,7 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
if self.check_consistency:
result["key_is_consistent"] = self._is_key_consistent(
key_public_data, key_private_data
key_public_data=key_public_data, key_private_data=key_private_data
)
if result["key_is_consistent"] is False:
# Only fail when it is False, to avoid to fail on None (which means "we do not know")
@@ -281,40 +285,46 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
"Private key is not consistent! (See "
"https://blog.hboeck.de/archives/888-How-I-tricked-Symantec-with-a-Fake-Private-Key.html)"
)
raise PrivateKeyConsistencyError(msg, result)
raise PrivateKeyConsistencyError(msg, result=result)
return result
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
"""Validate the supplied private key, using the cryptography backend"""
def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None:
def __init__(
self, *, module: GeneralAnsibleModule, content: bytes, **kwargs
) -> None:
super(PrivateKeyInfoRetrievalCryptography, self).__init__(
module, content, **kwargs
module=module, content=content, **kwargs
)
def _get_public_key(self, binary: bool) -> bytes:
def _get_public_key(self, *, binary: bool) -> bytes:
return self.key.public_key().public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(
self, need_private_key_data: bool = False
self, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return _get_cryptography_private_key_info(
self.key, need_private_key_data=need_private_key_data
)
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
return _is_cryptography_key_consistent(
self.key, key_public_data, key_private_data, warn_func=self.module.warn
self.key,
key_public_data=key_public_data,
key_private_data=key_private_data,
warn_func=self.module.warn,
)
def get_privatekey_info(
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -322,8 +332,8 @@ def get_privatekey_info(
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PrivateKeyInfoRetrievalCryptography(
module,
content,
module=module,
content=content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
)
@@ -331,6 +341,7 @@ def get_privatekey_info(
def select_backend(
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -341,9 +352,18 @@ def select_backend(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyInfoRetrievalCryptography(
module,
content,
module=module,
content=content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
check_consistency=check_consistency,
)
__all__ = (
"PrivateKeyConsistencyError",
"PrivateKeyParseError",
"PrivateKeyInfoRetrieval",
"get_privatekey_info",
"select_backend",
)

View File

@@ -99,7 +99,7 @@ def _get_cryptography_public_key_info(
class PublicKeyParseError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PublicKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
@@ -108,6 +108,7 @@ class PublicKeyParseError(OpenSSLObjectError):
class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
@@ -125,13 +126,13 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
if self.key is None:
try:
self.key = load_publickey(content=self.content)
except OpenSSLObjectError as e:
raise PublicKeyParseError(str(e), {})
raise PublicKeyParseError(str(e), result={})
pk = self._get_public_key(binary=True)
result["fingerprints"] = (
@@ -151,12 +152,13 @@ class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> None:
super(PublicKeyInfoRetrievalCryptography, self).__init__(
module, content=content, key=key
module=module, content=content, key=key
)
def _get_public_key(self, binary: bool) -> bytes:
@@ -174,16 +176,18 @@ class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
def get_publickey_info(
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
info = PublicKeyInfoRetrievalCryptography(module=module, content=content, key=key)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
*,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
@@ -191,4 +195,12 @@ def select_backend(
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
return PublicKeyInfoRetrievalCryptography(module=module, content=content, key=key)
__all__ = (
"PublicKeyParseError",
"PublicKeyInfoRetrieval",
"get_publickey_info",
"select_backend",
)

View File

@@ -17,7 +17,7 @@ PKCS8_PRIVATEKEY_NAMES = ("PRIVATE KEY", "ENCRYPTED PRIVATE KEY")
PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY"
def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
def identify_pem_format(content: bytes, *, encoding: str = "utf-8") -> bool:
"""Given the contents of a binary file, tests whether this could be a PEM file."""
try:
first_pem = extract_first_pem(content.decode(encoding))
@@ -36,7 +36,7 @@ def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
def identify_private_key_format(
content: bytes, encoding: str = "utf-8"
content: bytes, *, encoding: str = "utf-8"
) -> t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]:
"""Given the contents of a private key file, identifies its format."""
# See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85
@@ -66,7 +66,7 @@ def identify_private_key_format(
return "raw"
def split_pem_list(text: str, keep_inbetween: bool = False) -> list[str]:
def split_pem_list(text: str, *, keep_inbetween: bool = False) -> list[str]:
"""
Split concatenated PEM objects into a list of strings, where each is one PEM object.
"""
@@ -94,7 +94,7 @@ def extract_first_pem(text: str) -> str | None:
return all_pems[0]
def _extract_type(line: str, start: str = PEM_START) -> str | None:
def _extract_type(line: str, *, start: str = PEM_START) -> str | None:
if not line.startswith(start):
return None
if not line.endswith(PEM_END):
@@ -102,7 +102,7 @@ def _extract_type(line: str, start: str = PEM_START) -> str | None:
return line[len(start) : -len(PEM_END)]
def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
def extract_pem(content: str, *, strict: bool = False) -> tuple[str, str]:
lines = content.splitlines()
if len(lines) < 3:
raise ValueError(f"PEM must have at least 3 lines, have only {len(lines)}")
@@ -125,3 +125,17 @@ def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
f"Last line has length {len(lines[-2])}, should be in (0, 64]"
)
return header_type, "".join(lines[1:-1])
__all__ = (
"PEM_START",
"PEM_END_START",
"PEM_END",
"PKCS8_PRIVATEKEY_NAMES",
"PKCS1_PRIVATEKEY_SUFFIX",
"identify_pem_format",
"identify_private_key_format",
"split_pem_list",
"extract_first_pem",
"extract_pem",
)

View File

@@ -63,7 +63,9 @@ PREFERRED_FINGERPRINTS = (
)
def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[str, str]:
def get_fingerprint_of_bytes(
source: bytes, *, prefer_one: bool = False
) -> dict[str, str]:
"""Generate the fingerprint of the given bytes."""
fingerprint = {}
@@ -107,7 +109,7 @@ def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[st
def get_fingerprint_of_privatekey(
privatekey: PrivateKeyTypes, prefer_one: bool = False
privatekey: PrivateKeyTypes, *, prefer_one: bool = False
) -> dict[str, str]:
"""Generate the fingerprint of the public key."""
@@ -119,6 +121,7 @@ def get_fingerprint_of_privatekey(
def get_fingerprint(
*,
path: os.PathLike | str | None = None,
passphrase: str | bytes | None = None,
content: bytes | None = None,
@@ -137,6 +140,7 @@ def get_fingerprint(
def load_privatekey(
*,
path: os.PathLike | str | None = None,
passphrase: str | bytes | None = None,
check_passphrase: bool = True,
@@ -219,7 +223,7 @@ def load_certificate_issuer_privatekey(
def load_publickey(
path: os.PathLike | str | None = None, content: bytes | None = None
*, path: os.PathLike | str | None = None, content: bytes | None = None
) -> PublicKeyTypes:
if content is None:
if path is None:
@@ -237,6 +241,7 @@ def load_publickey(
def load_certificate(
*,
path: os.PathLike | str | None = None,
content: bytes | None = None,
der_support_enabled: bool = False,
@@ -266,7 +271,7 @@ def load_certificate(
def load_certificate_request(
path: os.PathLike | str | None = None, content: bytes | None = None
*, path: os.PathLike | str | None = None, content: bytes | None = None
) -> x509.CertificateSigningRequest:
"""Load the specified certificate signing request."""
try:
@@ -287,6 +292,7 @@ def load_certificate_request(
def parse_name_field(
input_dict: dict[str, list[str | bytes] | str | bytes],
*,
name_field_name: str | None = None,
) -> list[tuple[str, str | bytes]]:
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
@@ -321,7 +327,9 @@ def parse_name_field(
def parse_ordered_name_field(
input_list: list[dict[str, list[str | bytes] | str | bytes]], name_field_name: str
input_list: list[dict[str, list[str | bytes] | str | bytes]],
*,
name_field_name: str,
) -> list[tuple[str, str | bytes]]:
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
@@ -372,7 +380,7 @@ def select_message_digest(
class OpenSSLObject(metaclass=abc.ABCMeta):
def __init__(self, path: str, state: str, force: bool, check_mode: bool) -> None:
def __init__(self, *, path: str, state: str, force: bool, check_mode: bool) -> None:
self.path = path
self.state = state
self.force = force
@@ -380,7 +388,7 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
self.changed = False
self.check_mode = check_mode
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
def check(self, module: AnsibleModule, *, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
def _check_state() -> bool:
@@ -420,3 +428,20 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
raise OpenSSLObjectError(exc)
else:
pass
__all__ = (
"get_fingerprint_of_bytes",
"get_fingerprint_of_privatekey",
"get_fingerprint",
"load_privatekey",
"load_certificate_privatekey",
"load_certificate_issuer_privatekey",
"load_publickey",
"load_certificate",
"load_certificate_request",
"parse_name_field",
"parse_ordered_name_field",
"select_message_digest",
"OpenSSLObject",
)

View File

@@ -385,3 +385,11 @@ def ECSClient(
entrust_api_cert_key=entrust_api_cert_key,
entrust_api_specification_path=entrust_api_specification_path,
).client()
__all__ = (
"ecs_client_argument_spec",
"SessionConfigurationException",
"RestOperationException",
"ECSClient",
)

View File

@@ -18,7 +18,7 @@ class GPGError(Exception):
class GPGRunner(metaclass=abc.ABCMeta):
@abc.abstractmethod
def run_command(
self, command: list[str], check_rc: bool = True, data: bytes | None = None
self, command: list[str], *, check_rc: bool = True, data: bytes | None = None
) -> tuple[int, str, str]:
"""
Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``.
@@ -34,7 +34,7 @@ class GPGRunner(metaclass=abc.ABCMeta):
pass
def get_fingerprint_from_stdout(stdout: str) -> str:
def get_fingerprint_from_stdout(*, stdout: str) -> str:
lines = stdout.splitlines(False)
for line in lines:
if line.startswith("fpr:"):
@@ -47,7 +47,7 @@ def get_fingerprint_from_stdout(stdout: str) -> str:
raise GPGError(f'Cannot extract fingerprint from stdout "{stdout}"')
def get_fingerprint_from_file(gpg_runner: GPGRunner, path: str) -> str:
def get_fingerprint_from_file(*, gpg_runner: GPGRunner, path: str) -> str:
if not os.path.exists(path):
raise GPGError(f"{path} does not exist")
stdout = gpg_runner.run_command(
@@ -61,10 +61,10 @@ def get_fingerprint_from_file(gpg_runner: GPGRunner, path: str) -> str:
],
check_rc=True,
)[1]
return get_fingerprint_from_stdout(stdout)
return get_fingerprint_from_stdout(stdout=stdout)
def get_fingerprint_from_bytes(gpg_runner: GPGRunner, content: bytes) -> str:
def get_fingerprint_from_bytes(*, gpg_runner: GPGRunner, content: bytes) -> str:
stdout = gpg_runner.run_command(
[
"--no-keyring",
@@ -77,4 +77,13 @@ def get_fingerprint_from_bytes(gpg_runner: GPGRunner, content: bytes) -> str:
data=content,
check_rc=True,
)[1]
return get_fingerprint_from_stdout(stdout)
return get_fingerprint_from_stdout(stdout=stdout)
__all__ = (
"GPGError",
"GPGRunner",
"get_fingerprint_from_stdout",
"get_fingerprint_from_file",
"get_fingerprint_from_bytes",
)

View File

@@ -17,7 +17,7 @@ if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def load_file(path: str | os.PathLike, module: AnsibleModule | None = None) -> bytes:
def load_file(*, path: str | os.PathLike, module: AnsibleModule | None = None) -> bytes:
"""
Load the file as a bytes string.
"""
@@ -31,6 +31,7 @@ def load_file(path: str | os.PathLike, module: AnsibleModule | None = None) -> b
def load_file_if_exists(
*,
path: str | os.PathLike,
module: AnsibleModule | None = None,
ignore_errors: bool = False,
@@ -62,6 +63,7 @@ def load_file_if_exists(
def write_file(
*,
module: AnsibleModule,
content: bytes,
default_mode: str | int | None = None,
@@ -117,3 +119,6 @@ def write_file(
except Exception:
pass
module.fail_json(msg=f"Error while writing result: {e}")
__all__ = ("load_file", "load_file_if_exists", "write_file")

View File

@@ -99,7 +99,7 @@ def _restore_all_on_failure(
class OpensshModule(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
self.changed: bool = False
@@ -210,6 +210,7 @@ class KeygenCommand:
def generate_certificate(
self,
*,
certificate_path: str,
identifier: str,
options: list[str] | None,
@@ -247,7 +248,13 @@ class KeygenCommand:
return self._run_command(args, **kwargs)
def generate_keypair(
self, private_key_path: str, size: int, type: str, comment: str | None, **kwargs
self,
*,
private_key_path: str,
size: int,
type: str,
comment: str | None,
**kwargs,
) -> tuple[int, str, str]:
args = [
self._bin_path,
@@ -270,26 +277,29 @@ class KeygenCommand:
return self._run_command(args, data=data, **kwargs)
def get_certificate_info(
self, certificate_path: str, **kwargs
self, *, certificate_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-L", "-f", certificate_path], **kwargs
)
def get_matching_public_key(
self, private_key_path: str, **kwargs
self, *, private_key_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
)
def get_private_key(self, private_key_path: str, **kwargs) -> tuple[int, str, str]:
def get_private_key(
self, *, private_key_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-l", "-f", private_key_path], **kwargs
)
def update_comment(
self,
*,
private_key_path: str,
comment: str,
force_new_format: bool = True,
@@ -317,7 +327,7 @@ _PrivateKey = t.TypeVar("_PrivateKey", bound="PrivateKey")
class PrivateKey:
def __init__(
self, size: int, key_type: str, fingerprint: str, format: str = ""
self, *, size: int, key_type: str, fingerprint: str, format: str = ""
) -> None:
self._size = size
self._type = key_type
@@ -363,7 +373,7 @@ _PublicKey = t.TypeVar("_PublicKey", bound="PublicKey")
class PublicKey:
def __init__(self, type_string: str, data: str, comment: str | None) -> None:
def __init__(self, *, type_string: str, data: str, comment: str | None) -> None:
self._type_string = type_string
self._data = data
self._comment = comment
@@ -441,6 +451,7 @@ class PublicKey:
def parse_private_key_format(
*,
path: str | os.PathLike,
) -> t.Literal["SSH", "PKCS8", "PKCS1", ""]:
with open(path, "r") as file:
@@ -454,3 +465,14 @@ def parse_private_key_format(
return "PKCS1"
return ""
__all__ = (
"restore_on_failure",
"safe_atomic_move",
"OpensshModule",
"KeygenCommand",
"PrivateKey",
"PublicKey",
"parse_private_key_format",
)

View File

@@ -53,8 +53,8 @@ if t.TYPE_CHECKING:
class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackend, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(KeypairBackend, self).__init__(module=module)
self.comment: str | None = self.module.params["comment"]
self.private_key_path: str = self.module.params["path"]
@@ -296,9 +296,9 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
try:
secure_write(
temp_public_key,
existing_permissions or default_permissions,
to_bytes(content),
path=temp_public_key,
mode=existing_permissions or default_permissions,
content=to_bytes(content),
)
except (IOError, OSError) as e:
self.module.fail_json(msg=str(e))
@@ -357,8 +357,8 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
class KeypairBackendOpensshBin(KeypairBackend):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackendOpensshBin, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(KeypairBackendOpensshBin, self).__init__(module=module)
if self.module.params["private_key_format"] != "auto":
self.module.fail_json(
@@ -369,12 +369,16 @@ class KeypairBackendOpensshBin(KeypairBackend):
def _generate_keypair(self, private_key_path: str) -> None:
self.ssh_keygen.generate_keypair(
private_key_path, self.size, self.type, self.comment, check_rc=True
private_key_path=private_key_path,
size=self.size,
type=self.type,
comment=self.comment,
check_rc=True,
)
def _get_private_key(self) -> PrivateKey:
rc, private_key_content, err = self.ssh_keygen.get_private_key(
self.private_key_path, check_rc=False
private_key_path=self.private_key_path, check_rc=False
)
if rc != 0:
raise ValueError(err)
@@ -382,13 +386,13 @@ class KeypairBackendOpensshBin(KeypairBackend):
def _get_public_key(self) -> PublicKey | t.Literal[""]:
public_key_content = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=True
private_key_path=self.private_key_path, check_rc=True
)[1]
return PublicKey.from_string(public_key_content)
def _private_key_readable(self) -> bool:
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=False
private_key_path=self.private_key_path, check_rc=False
)
return not (
rc == 255
@@ -407,8 +411,8 @@ class KeypairBackendOpensshBin(KeypairBackend):
LooseVersion("6.5") <= LooseVersion(ssh_version) < LooseVersion("7.8")
)
self.ssh_keygen.update_comment(
self.private_key_path,
self.comment or "",
private_key_path=self.private_key_path,
comment=self.comment or "",
force_new_format=force_new_format,
check_rc=True,
)
@@ -420,8 +424,8 @@ class KeypairBackendOpensshBin(KeypairBackend):
class KeypairBackendCryptography(KeypairBackend):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackendCryptography, self).__init__(module)
def __init__(self, *, module: AnsibleModule) -> None:
super(KeypairBackendCryptography, self).__init__(module=module)
if self.type == "rsa1":
self.module.fail_json(
@@ -469,12 +473,12 @@ class KeypairBackendCryptography(KeypairBackend):
)
encoded_private_key = OpensshKeypair.encode_openssh_privatekey(
keypair.asymmetric_keypair, self.private_key_format
asym_keypair=keypair.asymmetric_keypair, key_format=self.private_key_format
)
secure_write(private_key_path, 0o600, encoded_private_key)
secure_write(path=private_key_path, mode=0o600, content=encoded_private_key)
public_key_path = private_key_path + ".pub"
secure_write(public_key_path, 0o644, keypair.public_key)
secure_write(path=public_key_path, mode=0o644, content=keypair.public_key)
def _get_private_key(self) -> PrivateKey:
keypair = OpensshKeypair.load(
@@ -485,7 +489,7 @@ class KeypairBackendCryptography(KeypairBackend):
size=keypair.size,
key_type=keypair.key_type,
fingerprint=keypair.fingerprint,
format=parse_private_key_format(self.private_key_path),
format=parse_private_key_format(path=self.private_key_path),
)
def _get_public_key(self) -> PublicKey | t.Literal[""]:
@@ -550,7 +554,7 @@ class KeypairBackendCryptography(KeypairBackend):
def select_backend(
module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
*, module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
) -> KeypairBackend:
can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion(
CRYPTOGRAPHY_VERSION
@@ -573,7 +577,7 @@ def select_backend(
if backend == "opensshbin":
if not can_use_opensshbin:
module.fail_json(msg="Cannot find the OpenSSH binary in the PATH")
return KeypairBackendOpensshBin(module)
return KeypairBackendOpensshBin(module=module)
if backend == "cryptography":
if not can_use_cryptography:
module.fail_json(
@@ -581,5 +585,8 @@ def select_backend(
f"cryptography >= {COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION}"
)
)
return KeypairBackendCryptography(module)
return KeypairBackendCryptography(module=module)
raise ValueError(f"Unsupported value for backend: {backend}")
__all__ = ("KeypairBackend", "select_backend")

View File

@@ -111,7 +111,7 @@ _EXTENSIONS = (
class OpensshCertificateTimeParameters:
def __init__(
self, valid_from: str | bytes | int, valid_to: str | bytes | int
self, *, valid_from: str | bytes | int, valid_to: str | bytes | int
) -> None:
self._valid_from = self.to_datetime(valid_from)
self._valid_to = self.to_datetime(valid_to)
@@ -149,7 +149,7 @@ class OpensshCertificateTimeParameters:
def valid_from(self, date_format: DateFormat) -> str | int: ...
def valid_from(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_from, date_format)
return self.format_datetime(self._valid_from, date_format=date_format)
@t.overload
def valid_to(self, date_format: DateFormatStr) -> str: ...
@@ -161,7 +161,7 @@ class OpensshCertificateTimeParameters:
def valid_to(self, date_format: DateFormat) -> str | int: ...
def valid_to(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_to, date_format)
return self.format_datetime(self._valid_to, date_format=date_format)
def within_range(self, valid_at: str | bytes | int | None) -> bool:
if valid_at is not None:
@@ -171,18 +171,18 @@ class OpensshCertificateTimeParameters:
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ...
def format_datetime(dt: datetime, *, date_format: DateFormatStr) -> str: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ...
def format_datetime(dt: datetime, *, date_format: DateFormatInt) -> int: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ...
def format_datetime(dt: datetime, *, date_format: DateFormat) -> str | int: ...
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int:
def format_datetime(dt: datetime, *, date_format: DateFormat) -> str | int:
if date_format in ("human_readable", "openssh"):
if dt == _ALWAYS:
return "always"
@@ -264,6 +264,7 @@ _OpensshCertificateOption = t.TypeVar(
class OpensshCertificateOption:
def __init__(
self,
*,
option_type: t.Literal["critical", "extension"],
name: str | bytes,
data: str | bytes,
@@ -350,6 +351,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
def __init__(
self,
*,
nonce: bytes | None = None,
serial: int | None = None,
cert_type: int | None = None,
@@ -409,7 +411,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
class OpensshRSACertificateInfo(OpensshCertificateInfo):
def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None:
def __init__(self, *, e: int | None = None, n: int | None = None, **kwargs) -> None:
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
self.e = e
@@ -435,6 +437,7 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
class OpensshDSACertificateInfo(OpensshCertificateInfo):
def __init__(
self,
*,
p: int | None = None,
q: int | None = None,
g: int | None = None,
@@ -471,7 +474,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
class OpensshECDSACertificateInfo(OpensshCertificateInfo):
def __init__(
self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
self, *, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
):
super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
self._curve = None
@@ -515,7 +518,7 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
def __init__(self, pk: bytes | None = None, **kwargs) -> None:
def __init__(self, *, pk: bytes | None = None, **kwargs) -> None:
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
self.pk = pk
@@ -541,8 +544,7 @@ _OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate
class OpensshCertificate:
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes):
def __init__(self, *, cert_info: OpensshCertificateInfo, signature: bytes):
self._cert_info = cert_info
self.signature = signature
@@ -574,7 +576,7 @@ class OpensshCertificate:
f"Invalid certificate format identifier: {format_identifier!r}"
)
parser = OpensshParser(cert)
parser = OpensshParser(data=cert)
if format_identifier != parser.string():
raise ValueError("Certificate formats do not match")
@@ -649,7 +651,9 @@ class OpensshCertificate:
if self._cert_info.critical_options is None:
raise ValueError
return [
OpensshCertificateOption("critical", to_text(n), to_text(d))
OpensshCertificateOption(
option_type="critical", name=to_text(n), data=to_text(d)
)
for n, d in self._cert_info.critical_options
]
@@ -658,7 +662,9 @@ class OpensshCertificate:
if self._cert_info.extensions is None:
raise ValueError
return [
OpensshCertificateOption("extension", to_text(n), to_text(d))
OpensshCertificateOption(
option_type="extension", name=to_text(n), data=to_text(d)
)
for n, d in self._cert_info.extensions
]
@@ -674,7 +680,7 @@ class OpensshCertificate:
@property
def signature_type(self) -> str:
signature_data = OpensshParser.signature_data(self.signature)
signature_data = OpensshParser.signature_data(signature_string=self.signature)
return to_text(signature_data["signature_type"])
@staticmethod
@@ -727,16 +733,20 @@ def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOpti
directive_to_option = {
"no-x11-forwarding": OpensshCertificateOption(
"extension", "permit-x11-forwarding", ""
option_type="extension", name="permit-x11-forwarding", data=""
),
"no-agent-forwarding": OpensshCertificateOption(
"extension", "permit-agent-forwarding", ""
option_type="extension", name="permit-agent-forwarding", data=""
),
"no-port-forwarding": OpensshCertificateOption(
"extension", "permit-port-forwarding", ""
option_type="extension", name="permit-port-forwarding", data=""
),
"no-pty": OpensshCertificateOption(
option_type="extension", name="permit-pty", data=""
),
"no-user-rc": OpensshCertificateOption(
option_type="extension", name="permit-user-rc", data=""
),
"no-pty": OpensshCertificateOption("extension", "permit-pty", ""),
"no-user-rc": OpensshCertificateOption("extension", "permit-user-rc", ""),
}
if "clear" in directives:
@@ -748,7 +758,10 @@ def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOpti
def default_options() -> list[OpensshCertificateOption]:
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
return [
OpensshCertificateOption(option_type="extension", name=name, data="")
for name in _EXTENSIONS
]
def fingerprint(public_key: bytes) -> bytes:
@@ -803,3 +816,22 @@ def parse_option_list(
extensions.append(option_object)
return critical_options, list(set(extensions + apply_directives(directives)))
__all__ = (
"OpensshCertificateTimeParameters",
"OpensshCertificateOption",
"OpensshCertificateInfo",
"OpensshRSACertificateInfo",
"OpensshDSACertificateInfo",
"OpensshECDSACertificateInfo",
"OpensshED25519CertificateInfo",
"OpensshCertificate",
"apply_directives",
"default_options",
"fingerprint",
"get_cert_info_object",
"get_option_type",
"is_relative_time_string",
"parse_option_list",
)

View File

@@ -145,6 +145,7 @@ class AsymmetricKeypair:
@classmethod
def generate(
cls: t.Type[_AsymmetricKeypair],
*,
keytype: KeyType = "rsa",
size: int | None = None,
passphrase: bytes | None = None,
@@ -208,6 +209,7 @@ class AsymmetricKeypair:
@classmethod
def load(
cls: t.Type[_AsymmetricKeypair],
*,
path: str | os.PathLike,
passphrase: bytes | None = None,
private_key_format: KeySerializationFormat = "PEM",
@@ -228,13 +230,15 @@ class AsymmetricKeypair:
else:
encryption_algorithm = serialization.NoEncryption()
privatekey = load_privatekey(path, passphrase, private_key_format)
privatekey = load_privatekey(
path=path, passphrase=passphrase, key_format=private_key_format
)
if no_public_key:
publickey = privatekey.public_key()
else:
# TODO: BUG: load_publickey() can return unsupported key types
# (Also we should check whether the public key fits the private key...)
publickey = load_publickey(path + ".pub", public_key_format) # type: ignore
publickey = load_publickey(path=path + ".pub", key_format=public_key_format) # type: ignore
# Ed25519 keys are always of size 256 and do not have a key_size attribute
if isinstance(privatekey, Ed25519PrivateKey):
@@ -264,6 +268,7 @@ class AsymmetricKeypair:
def __init__(
self,
*,
keytype: KeyType,
size: int,
privatekey: PrivateKeyTypes,
@@ -285,7 +290,7 @@ class AsymmetricKeypair:
self.__encryption_algorithm = encryption_algorithm
try:
self.verify(self.sign(b"message"), b"message")
self.verify(signature=self.sign(b"message"), data=b"message")
except InvalidSignatureError:
raise InvalidPublicKeyFileError(
"The private key and public key of this keypair do not match"
@@ -347,7 +352,7 @@ class AsymmetricKeypair:
except TypeError as e:
raise InvalidDataError(e)
def verify(self, signature: bytes, data: bytes) -> None:
def verify(self, *, signature: bytes, data: bytes) -> None:
"""Verifies that the signature associated with the provided data was signed
by the private key of this key pair.
@@ -384,6 +389,7 @@ class OpensshKeypair:
@classmethod
def generate(
cls: t.Type[_OpensshKeypair],
*,
keytype: KeyType = "rsa",
size: int | None = None,
passphrase: bytes | None = None,
@@ -400,9 +406,15 @@ class OpensshKeypair:
if comment is None:
comment = f"{getuser()}@{gethostname()}"
asym_keypair = AsymmetricKeypair.generate(keytype, size, passphrase)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
asym_keypair = AsymmetricKeypair.generate(
keytype=keytype, size=size, passphrase=passphrase
)
openssh_privatekey = cls.encode_openssh_privatekey(
asym_keypair=asym_keypair, key_format="SSH"
)
openssh_publickey = cls.encode_openssh_publickey(
asym_keypair=asym_keypair, comment=comment
)
fingerprint = calculate_fingerprint(openssh_publickey)
return cls(
@@ -416,6 +428,7 @@ class OpensshKeypair:
@classmethod
def load(
cls: t.Type[_OpensshKeypair],
*,
path: str | os.PathLike,
passphrase: bytes | None = None,
no_public_key: bool = False,
@@ -433,10 +446,18 @@ class OpensshKeypair:
comment = extract_comment(str(path) + ".pub")
asym_keypair = AsymmetricKeypair.load(
path, passphrase, "SSH", "SSH", no_public_key
path=path,
passphrase=passphrase,
private_key_format="SSH",
public_key_format="SSH",
no_public_key=no_public_key,
)
openssh_privatekey = cls.encode_openssh_privatekey(
asym_keypair=asym_keypair, key_format="SSH"
)
openssh_publickey = cls.encode_openssh_publickey(
asym_keypair=asym_keypair, comment=comment
)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
fingerprint = calculate_fingerprint(openssh_publickey)
return cls(
@@ -449,7 +470,7 @@ class OpensshKeypair:
@staticmethod
def encode_openssh_privatekey(
asym_keypair: AsymmetricKeypair, key_format: KeyFormat
*, asym_keypair: AsymmetricKeypair, key_format: KeyFormat
) -> bytes:
"""Returns an OpenSSH encoded private key for a given keypair
@@ -482,7 +503,7 @@ class OpensshKeypair:
@staticmethod
def encode_openssh_publickey(
asym_keypair: AsymmetricKeypair, comment: str
*, asym_keypair: AsymmetricKeypair, comment: str
) -> bytes:
"""Returns an OpenSSH encoded public key for a given keypair
@@ -504,6 +525,7 @@ class OpensshKeypair:
def __init__(
self,
*,
asym_keypair: AsymmetricKeypair,
openssh_privatekey: bytes,
openssh_publickey: bytes,
@@ -603,11 +625,12 @@ class OpensshKeypair:
self.__asym_keypair.update_passphrase(passphrase)
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(
self.__asym_keypair, "SSH"
asym_keypair=self.__asym_keypair, key_format="SSH"
)
def load_privatekey(
*,
path: str | os.PathLike,
passphrase: bytes | None,
key_format: KeySerializationFormat,
@@ -662,7 +685,7 @@ def load_privatekey(
def load_publickey(
path: str | os.PathLike, key_format: KeySerializationFormat
*, path: str | os.PathLike, key_format: KeySerializationFormat
) -> AllPublicKeyTypes:
publickey_loaders = {
"PEM": serialization.load_pem_public_key,
@@ -767,3 +790,30 @@ def calculate_fingerprint(openssh_publickey: bytes) -> str:
value = b64encode(digest.finalize()).decode(encoding=_TEXT_ENCODING).rstrip("=")
return f"SHA256:{value}"
__all__ = (
"HAS_OPENSSH_SUPPORT",
"CRYPTOGRAPHY_VERSION",
"OpenSSHError",
"InvalidAlgorithmError",
"InvalidCommentError",
"InvalidDataError",
"InvalidPrivateKeyFileError",
"InvalidPublicKeyFileError",
"InvalidKeyFormatError",
"InvalidKeySizeError",
"InvalidKeyTypeError",
"InvalidPassphraseError",
"InvalidSignatureError",
"AsymmetricKeypair",
"OpensshKeypair",
"load_privatekey",
"load_publickey",
"compare_publickeys",
"compare_encryption_algorithms",
"get_encryption_algorithm",
"validate_comment",
"extract_comment",
"calculate_fingerprint",
)

View File

@@ -70,7 +70,7 @@ def parse_openssh_version(version_string: str) -> str | None:
@contextmanager
def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
def secure_open(*, path: str | os.PathLike, mode: int) -> t.Iterator[int]:
fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode)
try:
yield fd
@@ -78,8 +78,8 @@ def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
os.close(fd)
def secure_write(path: str | os.PathLike, mode: int, content: bytes) -> None:
with secure_open(path, mode) as fd:
def secure_write(*, path: str | os.PathLike, mode: int, content: bytes) -> None:
with secure_open(path=path, mode=mode) as fd:
os.write(fd, content)
@@ -91,7 +91,7 @@ class OpensshParser:
UINT32_OFFSET = 4
UINT64_OFFSET = 8
def __init__(self, data: bytes | bytearray) -> None:
def __init__(self, *, data: bytes | bytearray) -> None:
if not isinstance(data, (bytes, bytearray)):
raise TypeError(f"Data must be bytes-like not {type(data)}")
@@ -142,7 +142,7 @@ class OpensshParser:
raw_string = self.string()
if raw_string:
parser = OpensshParser(raw_string)
parser = OpensshParser(data=raw_string)
while parser.remaining_bytes():
result.append(parser.string())
@@ -154,14 +154,14 @@ class OpensshParser:
raw_string = self.string()
if raw_string:
parser = OpensshParser(raw_string)
parser = OpensshParser(data=raw_string)
while parser.remaining_bytes():
name = parser.string()
data = parser.string()
if data:
# data is doubly-encoded
data = OpensshParser(data).string()
data = OpensshParser(data=data).string()
result.append((name, data))
return result
@@ -183,14 +183,14 @@ class OpensshParser:
return self._pos + offset
@classmethod
def signature_data(cls, signature_string: bytes) -> dict[str, bytes | int]:
def signature_data(cls, *, signature_string: bytes) -> dict[str, bytes | int]:
signature_data: dict[str, bytes | int] = {}
parser = cls(signature_string)
parser = cls(data=signature_string)
signature_type = parser.string()
signature_blob = parser.string()
blob_parser = cls(signature_blob)
blob_parser = cls(data=signature_blob)
if signature_type in (b"ssh-rsa", b"rsa-sha2-256", b"rsa-sha2-512"):
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
# https://datatracker.ietf.org/doc/html/rfc8332#section-3
@@ -242,7 +242,7 @@ class _OpensshWriter:
in validating parsed material.
"""
def __init__(self, buffer: bytearray | None = None):
def __init__(self, *, buffer: bytearray | None = None):
if buffer is not None:
if not isinstance(buffer, bytearray):
raise TypeError(f"Buffer must be a bytearray, not {type(buffer)}")
@@ -347,3 +347,13 @@ class _OpensshWriter:
def bytes(self) -> bytes:
return bytes(self._buff)
__all__ = (
"any_in",
"file_mode",
"parse_openssh_version",
"secure_open",
"secure_write",
"OpensshParser",
)

View File

@@ -54,3 +54,6 @@ def to_serial(value: int) -> str:
if len(value_str) % 2 != 0:
value_str = f"0{value_str}"
return ":".join(value_str[i : i + 2] for i in range(0, len(value_str), 2))
__all__ = ("parse_serial", "to_serial")

View File

@@ -19,7 +19,7 @@ from ansible_collections.community.crypto.plugins.module_utils._crypto.basic imp
UTC = datetime.timezone.utc
def get_now_datetime(with_timezone: bool) -> datetime.datetime:
def get_now_datetime(*, with_timezone: bool) -> datetime.datetime:
if with_timezone:
return datetime.datetime.now(tz=UTC)
return datetime.datetime.utcnow()
@@ -44,7 +44,7 @@ def remove_timezone(timestamp: datetime.datetime) -> datetime.datetime:
def add_or_remove_timezone(
timestamp: datetime.datetime, with_timezone: bool
timestamp: datetime.datetime, *, with_timezone: bool
) -> datetime.datetime:
return (
ensure_utc_timezone(timestamp) if with_timezone else remove_timezone(timestamp)
@@ -59,7 +59,7 @@ def get_epoch_seconds(timestamp: datetime.datetime) -> float:
def from_epoch_seconds(
timestamp: int | float, with_timezone: bool
timestamp: int | float, *, with_timezone: bool
) -> datetime.datetime:
if with_timezone:
return datetime.datetime.fromtimestamp(timestamp, UTC)
@@ -68,6 +68,7 @@ def from_epoch_seconds(
def convert_relative_to_datetime(
relative_time_string: str,
*,
with_timezone: bool = False,
now: datetime.datetime | None = None,
) -> datetime.datetime | None:
@@ -107,6 +108,7 @@ def convert_relative_to_datetime(
def get_relative_time_option(
input_string: str,
*,
input_name: str,
with_timezone: bool = False,
now: datetime.datetime | None = None,
@@ -155,3 +157,15 @@ def get_relative_time_option(
raise OpenSSLObjectError(
f'The time spec "{input_string}" for {input_name} is invalid'
)
__all__ = (
"get_now_datetime",
"ensure_utc_timezone",
"remove_timezone",
"add_or_remove_timezone",
"get_epoch_seconds",
"from_epoch_seconds",
"convert_relative_to_datetime",
"get_relative_time_option",
)