mirror of
https://github.com/ansible-collections/community.crypto.git
synced 2026-03-27 05:43:22 +00:00
Code refactoring (#889)
* Add __all__ to all module and plugin utils. * Convert quite a few positional args to keyword args. * Avoid Python 3.8+ syntax.
This commit is contained in:
@@ -29,7 +29,7 @@ class ACMEAccount:
|
||||
retrieve account data.
|
||||
"""
|
||||
|
||||
def __init__(self, client: ACMEClient) -> None:
|
||||
def __init__(self, *, client: ACMEClient) -> None:
|
||||
# Set to true to enable logging of all signed requests
|
||||
self._debug: bool = False
|
||||
|
||||
@@ -37,6 +37,7 @@ class ACMEAccount:
|
||||
|
||||
def _new_reg(
|
||||
self,
|
||||
*,
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: bool = True,
|
||||
@@ -82,14 +83,15 @@ class ACMEAccount:
|
||||
url = self.client.directory["newAccount"]
|
||||
if external_account_binding is not None:
|
||||
new_reg["externalAccountBinding"] = self.client.sign_request(
|
||||
{
|
||||
protected={
|
||||
"alg": external_account_binding["alg"],
|
||||
"kid": external_account_binding["kid"],
|
||||
"url": url,
|
||||
},
|
||||
self.client.account_jwk,
|
||||
self.client.backend.create_mac_key(
|
||||
external_account_binding["alg"], external_account_binding["key"]
|
||||
payload=self.client.account_jwk,
|
||||
key_data=self.client.backend.create_mac_key(
|
||||
alg=external_account_binding["alg"],
|
||||
key=external_account_binding["key"],
|
||||
),
|
||||
)
|
||||
elif (
|
||||
@@ -106,7 +108,7 @@ class ACMEAccount:
|
||||
)
|
||||
if not isinstance(result, Mapping):
|
||||
raise ACMEProtocolException(
|
||||
self.client.module,
|
||||
module=self.client.module,
|
||||
msg="Invalid account creation reply from ACME server",
|
||||
info=info,
|
||||
content_json=result,
|
||||
@@ -156,7 +158,7 @@ class ACMEAccount:
|
||||
raise ModuleFailException("Account is deactivated")
|
||||
else:
|
||||
raise ACMEProtocolException(
|
||||
self.client.module,
|
||||
module=self.client.module,
|
||||
msg="Registering ACME account failed",
|
||||
info=info,
|
||||
content_json=result,
|
||||
@@ -187,7 +189,7 @@ class ACMEAccount:
|
||||
)
|
||||
if not isinstance(result, Mapping):
|
||||
raise ACMEProtocolException(
|
||||
self.client.module,
|
||||
module=self.client.module,
|
||||
msg="Invalid account data retrieved from ACME server",
|
||||
info=info,
|
||||
content_json=result,
|
||||
@@ -206,7 +208,7 @@ class ACMEAccount:
|
||||
return None
|
||||
if info["status"] < 200 or info["status"] >= 300:
|
||||
raise ACMEProtocolException(
|
||||
self.client.module,
|
||||
module=self.client.module,
|
||||
msg="Error retrieving account data",
|
||||
info=info,
|
||||
content_json=result,
|
||||
@@ -216,6 +218,7 @@ class ACMEAccount:
|
||||
@t.overload
|
||||
def setup_account(
|
||||
self,
|
||||
*,
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: t.Literal[True] = True,
|
||||
@@ -226,6 +229,7 @@ class ACMEAccount:
|
||||
@t.overload
|
||||
def setup_account(
|
||||
self,
|
||||
*,
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: bool = True,
|
||||
@@ -235,6 +239,7 @@ class ACMEAccount:
|
||||
|
||||
def setup_account(
|
||||
self,
|
||||
*,
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: bool = True,
|
||||
@@ -281,7 +286,7 @@ class ACMEAccount:
|
||||
)
|
||||
else:
|
||||
created, account_data = self._new_reg(
|
||||
contact,
|
||||
contact=contact,
|
||||
terms_agreed=terms_agreed,
|
||||
allow_creation=allow_creation and not self.client.module.check_mode,
|
||||
external_account_binding=external_account_binding,
|
||||
@@ -296,7 +301,7 @@ class ACMEAccount:
|
||||
return created, account_data
|
||||
|
||||
def update_account(
|
||||
self, account_data: dict[str, t.Any], contact: list[str] | None = None
|
||||
self, *, account_data: dict[str, t.Any], contact: list[str] | None = None
|
||||
) -> tuple[bool, dict[str, t.Any]]:
|
||||
"""
|
||||
Update an account on the ACME server. Check mode is fully respected.
|
||||
@@ -332,10 +337,13 @@ class ACMEAccount:
|
||||
)
|
||||
if not isinstance(account_data, Mapping):
|
||||
raise ACMEProtocolException(
|
||||
self.client.module,
|
||||
module=self.client.module,
|
||||
msg="Invalid account updating reply from ACME server",
|
||||
info=info,
|
||||
content_json=account_data,
|
||||
)
|
||||
|
||||
return True, account_data
|
||||
|
||||
|
||||
__all__ = ("ACMEAccount",)
|
||||
|
||||
@@ -65,14 +65,14 @@ RETRY_COUNT = 10
|
||||
|
||||
|
||||
def _decode_retry(
|
||||
module: AnsibleModule, response: t.Any, info: dict[str, t.Any], retry_count: int
|
||||
*, module: AnsibleModule, response: t.Any, info: dict[str, t.Any], retry_count: int
|
||||
) -> bool:
|
||||
if info["status"] not in RETRY_STATUS_CODES:
|
||||
return False
|
||||
|
||||
if retry_count >= RETRY_COUNT:
|
||||
raise ACMEProtocolException(
|
||||
module,
|
||||
module=module,
|
||||
msg=f"Giving up after {RETRY_COUNT} retries",
|
||||
info=info,
|
||||
response=response,
|
||||
@@ -93,6 +93,7 @@ def _decode_retry(
|
||||
|
||||
|
||||
def _assert_fetch_url_success(
|
||||
*,
|
||||
module: AnsibleModule,
|
||||
response: t.Any,
|
||||
info: dict[str, t.Any],
|
||||
@@ -108,11 +109,11 @@ def _assert_fetch_url_success(
|
||||
or (400 <= info["status"] < 500 and not allow_client_error)
|
||||
or (info["status"] >= 500 and not allow_server_error)
|
||||
):
|
||||
raise ACMEProtocolException(module, info=info, response=response)
|
||||
raise ACMEProtocolException(module=module, info=info, response=response)
|
||||
|
||||
|
||||
def _is_failed(
|
||||
info: dict[str, t.Any], expected_status_codes: t.Iterable[int] | None = None
|
||||
*, info: dict[str, t.Any], expected_status_codes: t.Iterable[int] | None = None
|
||||
) -> bool:
|
||||
if info["status"] < 200 or info["status"] >= 400:
|
||||
return True
|
||||
@@ -133,7 +134,7 @@ class ACMEDirectory:
|
||||
https://tools.ietf.org/html/rfc8555#section-7.1.1
|
||||
"""
|
||||
|
||||
def __init__(self, module: AnsibleModule, client: ACMEClient) -> None:
|
||||
def __init__(self, *, module: AnsibleModule, client: ACMEClient) -> None:
|
||||
self.module = module
|
||||
self.directory_root = module.params["acme_directory"]
|
||||
self.version = module.params["acme_version"]
|
||||
@@ -171,7 +172,12 @@ class ACMEDirectory:
|
||||
response, info = fetch_url(
|
||||
self.module, url, method="HEAD", timeout=self.request_timeout
|
||||
)
|
||||
if _decode_retry(self.module, response, info, retry_count):
|
||||
if _decode_retry(
|
||||
module=self.module,
|
||||
response=response,
|
||||
info=info,
|
||||
retry_count=retry_count,
|
||||
):
|
||||
retry_count += 1
|
||||
continue
|
||||
if info["status"] not in (200, 204):
|
||||
@@ -185,7 +191,7 @@ class ACMEDirectory:
|
||||
)
|
||||
if retry_count >= 5:
|
||||
raise ACMEProtocolException(
|
||||
self.module,
|
||||
module=self.module,
|
||||
msg="Was not able to obtain nonce, giving up after 5 retries",
|
||||
info=info,
|
||||
response=response,
|
||||
@@ -202,7 +208,7 @@ class ACMEClient:
|
||||
ACME server.
|
||||
"""
|
||||
|
||||
def __init__(self, module: AnsibleModule, backend: CryptoBackend) -> None:
|
||||
def __init__(self, *, module: AnsibleModule, backend: CryptoBackend) -> None:
|
||||
# Set to true to enable logging of all signed requests
|
||||
self._debug = False
|
||||
|
||||
@@ -241,7 +247,7 @@ class ACMEClient:
|
||||
# Make sure self.account_jws_header is updated
|
||||
self.set_account_uri(self.account_uri)
|
||||
|
||||
self.directory = ACMEDirectory(module, self)
|
||||
self.directory = ACMEDirectory(module=module, client=self)
|
||||
|
||||
def set_account_uri(self, uri: str) -> None:
|
||||
"""
|
||||
@@ -255,6 +261,7 @@ class ACMEClient:
|
||||
|
||||
def parse_key(
|
||||
self,
|
||||
*,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
@@ -265,10 +272,13 @@ class ACMEClient:
|
||||
"""
|
||||
if key_file is None and key_content is None:
|
||||
raise AssertionError("One of key_file and key_content must be specified!")
|
||||
return self.backend.parse_key(key_file, key_content, passphrase=passphrase)
|
||||
return self.backend.parse_key(
|
||||
key_file=key_file, key_content=key_content, passphrase=passphrase
|
||||
)
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
*,
|
||||
protected: dict[str, t.Any],
|
||||
payload: str | dict[str, t.Any] | None,
|
||||
key_data: dict[str, t.Any],
|
||||
@@ -292,9 +302,11 @@ class ACMEClient:
|
||||
f"Failed to encode payload / headers as JSON: {e}"
|
||||
)
|
||||
|
||||
return self.backend.sign(payload64, protected64, key_data)
|
||||
return self.backend.sign(
|
||||
payload64=payload64, protected64=protected64, key_data=key_data
|
||||
)
|
||||
|
||||
def _log(self, msg: str, data: t.Any = None) -> None:
|
||||
def _log(self, msg: str, *, data: t.Any = None) -> None:
|
||||
"""
|
||||
Write arguments to acme.log when logging is enabled.
|
||||
"""
|
||||
@@ -373,13 +385,16 @@ class ACMEClient:
|
||||
protected["nonce"] = self.directory.get_nonce()
|
||||
protected["url"] = url
|
||||
|
||||
self._log("URL", url)
|
||||
self._log("protected", protected)
|
||||
self._log("payload", payload)
|
||||
self._log("URL", data=url)
|
||||
self._log("protected", data=protected)
|
||||
self._log("payload", data=payload)
|
||||
data = self.sign_request(
|
||||
protected, payload, key_data, encode_payload=encode_payload
|
||||
protected=protected,
|
||||
payload=payload,
|
||||
key_data=key_data,
|
||||
encode_payload=encode_payload,
|
||||
)
|
||||
self._log("signed request", data)
|
||||
self._log("signed request", data=data)
|
||||
data = self.module.jsonify(data)
|
||||
|
||||
headers = {
|
||||
@@ -393,10 +408,12 @@ class ACMEClient:
|
||||
method="POST",
|
||||
timeout=self.request_timeout,
|
||||
)
|
||||
if _decode_retry(self.module, resp, info, failed_tries):
|
||||
if _decode_retry(
|
||||
module=self.module, response=resp, info=info, retry_count=failed_tries
|
||||
):
|
||||
failed_tries += 1
|
||||
continue
|
||||
_assert_fetch_url_success(self.module, resp, info)
|
||||
_assert_fetch_url_success(module=self.module, response=resp, info=info)
|
||||
result = {}
|
||||
|
||||
try:
|
||||
@@ -415,7 +432,7 @@ class ACMEClient:
|
||||
) or 400 <= info["status"] < 600:
|
||||
try:
|
||||
decoded_result = self.module.from_json(content.decode("utf8"))
|
||||
self._log("parsed result", decoded_result)
|
||||
self._log("parsed result", data=decoded_result)
|
||||
# In case of badNonce error, try again (up to 5 times)
|
||||
# (https://tools.ietf.org/html/rfc8555#section-6.7)
|
||||
if all(
|
||||
@@ -440,10 +457,10 @@ class ACMEClient:
|
||||
result = content
|
||||
|
||||
if fail_on_error and _is_failed(
|
||||
info, expected_status_codes=expected_status_codes
|
||||
info=info, expected_status_codes=expected_status_codes
|
||||
):
|
||||
raise ACMEProtocolException(
|
||||
self.module,
|
||||
module=self.module,
|
||||
msg=error_msg,
|
||||
info=info,
|
||||
content=content,
|
||||
@@ -515,11 +532,16 @@ class ACMEClient:
|
||||
headers=headers,
|
||||
timeout=self.request_timeout,
|
||||
)
|
||||
if not _decode_retry(self.module, resp, info, retry_count):
|
||||
if not _decode_retry(
|
||||
module=self.module,
|
||||
response=resp,
|
||||
info=info,
|
||||
retry_count=retry_count,
|
||||
):
|
||||
break
|
||||
retry_count += 1
|
||||
|
||||
_assert_fetch_url_success(self.module, resp, info)
|
||||
_assert_fetch_url_success(module=self.module, response=resp, info=info)
|
||||
|
||||
try:
|
||||
# In Python 2, reading from a closed response yields a TypeError.
|
||||
@@ -550,10 +572,10 @@ class ACMEClient:
|
||||
result = content
|
||||
|
||||
if fail_on_error and _is_failed(
|
||||
info, expected_status_codes=expected_status_codes
|
||||
info=info, expected_status_codes=expected_status_codes
|
||||
):
|
||||
raise ACMEProtocolException(
|
||||
self.module,
|
||||
module=self.module,
|
||||
msg=error_msg,
|
||||
info=info,
|
||||
content=content,
|
||||
@@ -565,6 +587,7 @@ class ACMEClient:
|
||||
|
||||
def get_renewal_info(
|
||||
self,
|
||||
*,
|
||||
cert_id: str | None = None,
|
||||
cert_info: CertificateInformation | None = None,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
@@ -579,7 +602,7 @@ class ACMEClient:
|
||||
|
||||
if cert_id is None:
|
||||
cert_id = compute_cert_id(
|
||||
self.backend,
|
||||
backend=self.backend,
|
||||
cert_info=cert_info,
|
||||
cert_filename=cert_filename,
|
||||
cert_content=cert_content,
|
||||
@@ -603,6 +626,7 @@ class ACMEClient:
|
||||
|
||||
|
||||
def create_default_argspec(
|
||||
*,
|
||||
with_account: bool = True,
|
||||
require_account_key: bool = True,
|
||||
with_certificate: bool = False,
|
||||
@@ -643,7 +667,9 @@ def create_default_argspec(
|
||||
return result
|
||||
|
||||
|
||||
def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoBackend:
|
||||
def create_backend(
|
||||
module: AnsibleModule, *, needs_acme_v2: bool = True
|
||||
) -> CryptoBackend:
|
||||
backend = module.params["select_crypto_backend"]
|
||||
|
||||
# Backend autodetect
|
||||
@@ -671,10 +697,10 @@ def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoB
|
||||
module.debug(
|
||||
f"Using cryptography backend (library version {CRYPTOGRAPHY_VERSION})"
|
||||
)
|
||||
module_backend = CryptographyBackend(module)
|
||||
module_backend = CryptographyBackend(module=module)
|
||||
elif backend == "openssl":
|
||||
module.debug("Using OpenSSL binary backend")
|
||||
module_backend = OpenSSLCLIBackend(module)
|
||||
module_backend = OpenSSLCLIBackend(module=module)
|
||||
else:
|
||||
module.fail_json(msg=f'Unknown crypto backend "{backend}"!')
|
||||
|
||||
@@ -691,3 +717,11 @@ def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoB
|
||||
locale.setlocale(locale.LC_ALL, "C")
|
||||
|
||||
return module_backend
|
||||
|
||||
|
||||
__all__ = (
|
||||
"ACMEDirectory",
|
||||
"ACMEClient",
|
||||
"create_default_argspec",
|
||||
"create_backend",
|
||||
)
|
||||
|
||||
@@ -92,7 +92,11 @@ if t.TYPE_CHECKING:
|
||||
class CryptographyChainMatcher(ChainMatcher):
|
||||
@staticmethod
|
||||
def _parse_key_identifier(
|
||||
key_identifier: str | None, name: str, criterium_idx: int, module: AnsibleModule
|
||||
*,
|
||||
key_identifier: str | None,
|
||||
name: str,
|
||||
criterium_idx: int,
|
||||
module: AnsibleModule,
|
||||
) -> bytes | None:
|
||||
if key_identifier:
|
||||
try:
|
||||
@@ -109,7 +113,7 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
)
|
||||
return None
|
||||
|
||||
def __init__(self, criterium: Criterium, module: AnsibleModule) -> None:
|
||||
def __init__(self, *, criterium: Criterium, module: AnsibleModule) -> None:
|
||||
self.criterium = criterium
|
||||
self.test_certificates = criterium.test_certificates
|
||||
self.subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = []
|
||||
@@ -117,29 +121,32 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
if criterium.subject:
|
||||
self.subject = [
|
||||
(cryptography_name_to_oid(k), to_native(v))
|
||||
for k, v in parse_name_field(criterium.subject, "subject")
|
||||
for k, v in parse_name_field(
|
||||
criterium.subject, name_field_name="subject"
|
||||
)
|
||||
]
|
||||
if criterium.issuer:
|
||||
self.issuer = [
|
||||
(cryptography_name_to_oid(k), to_native(v))
|
||||
for k, v in parse_name_field(criterium.issuer, "issuer")
|
||||
for k, v in parse_name_field(criterium.issuer, name_field_name="issuer")
|
||||
]
|
||||
self.subject_key_identifier = CryptographyChainMatcher._parse_key_identifier(
|
||||
criterium.subject_key_identifier,
|
||||
"subject_key_identifier",
|
||||
criterium.index,
|
||||
module,
|
||||
key_identifier=criterium.subject_key_identifier,
|
||||
name="subject_key_identifier",
|
||||
criterium_idx=criterium.index,
|
||||
module=module,
|
||||
)
|
||||
self.authority_key_identifier = CryptographyChainMatcher._parse_key_identifier(
|
||||
criterium.authority_key_identifier,
|
||||
"authority_key_identifier",
|
||||
criterium.index,
|
||||
module,
|
||||
key_identifier=criterium.authority_key_identifier,
|
||||
name="authority_key_identifier",
|
||||
criterium_idx=criterium.index,
|
||||
module=module,
|
||||
)
|
||||
self.module = module
|
||||
|
||||
def _match_subject(
|
||||
self,
|
||||
*,
|
||||
x509_subject: cryptography.x509.Name,
|
||||
match_subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]],
|
||||
) -> bool:
|
||||
@@ -153,7 +160,7 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
return False
|
||||
return True
|
||||
|
||||
def match(self, certificate: CertificateChain) -> bool:
|
||||
def match(self, *, certificate: CertificateChain) -> bool:
|
||||
"""
|
||||
Check whether an alternate chain matches the specified criterium.
|
||||
"""
|
||||
@@ -166,9 +173,13 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
try:
|
||||
x509 = cryptography.x509.load_pem_x509_certificate(to_bytes(cert))
|
||||
matches = True
|
||||
if not self._match_subject(x509.subject, self.subject):
|
||||
if not self._match_subject(
|
||||
x509_subject=x509.subject, match_subject=self.subject
|
||||
):
|
||||
matches = False
|
||||
if not self._match_subject(x509.issuer, self.issuer):
|
||||
if not self._match_subject(
|
||||
x509_subject=x509.issuer, match_subject=self.issuer
|
||||
):
|
||||
matches = False
|
||||
if self.subject_key_identifier:
|
||||
try:
|
||||
@@ -199,13 +210,14 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
|
||||
|
||||
class CryptographyBackend(CryptoBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(CryptographyBackend, self).__init__(
|
||||
module, with_timezone=CRYPTOGRAPHY_TIMEZONE
|
||||
module=module, with_timezone=CRYPTOGRAPHY_TIMEZONE
|
||||
)
|
||||
|
||||
def parse_key(
|
||||
self,
|
||||
*,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
@@ -288,7 +300,7 @@ class CryptographyBackend(CryptoBackend):
|
||||
raise KeyParsingError(f'unknown key type "{type(key)}"')
|
||||
|
||||
def sign(
|
||||
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
self, *, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
) -> dict[str, t.Any]:
|
||||
sign_payload = f"{protected64}.{payload64}".encode("utf8")
|
||||
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
|
||||
@@ -317,8 +329,8 @@ class CryptographyBackend(CryptoBackend):
|
||||
r, s = cryptography.hazmat.primitives.asymmetric.utils.decode_dss_signature(
|
||||
key_data["key_obj"].sign(sign_payload, ecdsa)
|
||||
)
|
||||
rr = convert_int_to_hex(r, 2 * key_data["point_size"])
|
||||
ss = convert_int_to_hex(s, 2 * key_data["point_size"])
|
||||
rr = convert_int_to_hex(r, digits=2 * key_data["point_size"])
|
||||
ss = convert_int_to_hex(s, digits=2 * key_data["point_size"])
|
||||
signature = binascii.unhexlify(rr) + binascii.unhexlify(ss)
|
||||
|
||||
return {
|
||||
@@ -327,7 +339,7 @@ class CryptographyBackend(CryptoBackend):
|
||||
"signature": nopad_b64(signature),
|
||||
}
|
||||
|
||||
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
|
||||
def create_mac_key(self, *, alg: str, key: str) -> dict[str, t.Any]:
|
||||
"""Create a MAC key."""
|
||||
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
|
||||
if alg == "HS256":
|
||||
@@ -362,6 +374,7 @@ class CryptographyBackend(CryptoBackend):
|
||||
|
||||
def get_ordered_csr_identifiers(
|
||||
self,
|
||||
*,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
@@ -413,6 +426,7 @@ class CryptographyBackend(CryptoBackend):
|
||||
|
||||
def get_csr_identifiers(
|
||||
self,
|
||||
*,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | bytes | None = None,
|
||||
) -> set[tuple[str, str]]:
|
||||
@@ -429,6 +443,7 @@ class CryptographyBackend(CryptoBackend):
|
||||
|
||||
def get_cert_days(
|
||||
self,
|
||||
*,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
now: datetime.datetime | None = None,
|
||||
@@ -466,14 +481,15 @@ class CryptographyBackend(CryptoBackend):
|
||||
now = add_or_remove_timezone(now, with_timezone=CRYPTOGRAPHY_TIMEZONE)
|
||||
return (get_not_valid_after(cert) - now).days
|
||||
|
||||
def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
|
||||
def create_chain_matcher(self, *, criterium: Criterium) -> ChainMatcher:
|
||||
"""
|
||||
Given a Criterium object, creates a ChainMatcher object.
|
||||
"""
|
||||
return CryptographyChainMatcher(criterium, self.module)
|
||||
return CryptographyChainMatcher(criterium=criterium, module=self.module)
|
||||
|
||||
def get_cert_information(
|
||||
self,
|
||||
*,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
) -> CertificateInformation:
|
||||
@@ -520,3 +536,12 @@ class CryptographyBackend(CryptoBackend):
|
||||
subject_key_identifier=ski,
|
||||
authority_key_identifier=aki,
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"CRYPTOGRAPHY_MINIMAL_VERSION",
|
||||
"CRYPTOGRAPHY_ERROR",
|
||||
"CRYPTOGRAPHY_VERSION",
|
||||
"CRYPTOGRAPHY_ERROR",
|
||||
"CryptographyBackend",
|
||||
)
|
||||
|
||||
@@ -49,7 +49,7 @@ _OPENSSL_ENVIRONMENT_UPDATE = dict(LANG="C", LC_ALL="C", LC_MESSAGES="C", LC_CTY
|
||||
|
||||
|
||||
def _extract_date(
|
||||
out_text: str, name: str, cert_filename_suffix: str = ""
|
||||
out_text: str, *, name: str, cert_filename_suffix: str = ""
|
||||
) -> datetime.datetime:
|
||||
matcher = re.search(rf"\s+{name}\s*:\s+(.*)", out_text)
|
||||
if matcher is None:
|
||||
@@ -76,6 +76,7 @@ def _decode_octets(octets_text: str) -> bytes:
|
||||
@t.overload
|
||||
def _extract_octets(
|
||||
out_text: str,
|
||||
*,
|
||||
name: str,
|
||||
required: t.Literal[False],
|
||||
potential_prefixes: t.Iterable[str] | None = None,
|
||||
@@ -85,6 +86,7 @@ def _extract_octets(
|
||||
@t.overload
|
||||
def _extract_octets(
|
||||
out_text: str,
|
||||
*,
|
||||
name: str,
|
||||
required: t.Literal[True],
|
||||
potential_prefixes: t.Iterable[str] | None = None,
|
||||
@@ -93,6 +95,7 @@ def _extract_octets(
|
||||
|
||||
def _extract_octets(
|
||||
out_text: str,
|
||||
*,
|
||||
name: str,
|
||||
required: bool = True,
|
||||
potential_prefixes: t.Iterable[str] | None = None,
|
||||
@@ -113,15 +116,16 @@ def _extract_octets(
|
||||
|
||||
class OpenSSLCLIBackend(CryptoBackend):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, openssl_binary: str | None = None
|
||||
self, *, module: AnsibleModule, openssl_binary: str | None = None
|
||||
) -> None:
|
||||
super(OpenSSLCLIBackend, self).__init__(module, with_timezone=True)
|
||||
super(OpenSSLCLIBackend, self).__init__(module=module, with_timezone=True)
|
||||
if openssl_binary is None:
|
||||
openssl_binary = module.get_bin_path("openssl", True)
|
||||
self.openssl_binary = openssl_binary
|
||||
|
||||
def parse_key(
|
||||
self,
|
||||
*,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
@@ -282,7 +286,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
)
|
||||
|
||||
def sign(
|
||||
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
self, *, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
) -> dict[str, t.Any]:
|
||||
sign_payload = f"{protected64}.{payload64}".encode("utf8")
|
||||
if key_data["type"] == "hmac":
|
||||
@@ -343,7 +347,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
"signature": nopad_b64(to_bytes(out)),
|
||||
}
|
||||
|
||||
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
|
||||
def create_mac_key(self, *, alg: str, key: str) -> dict[str, t.Any]:
|
||||
"""Create a MAC key."""
|
||||
if alg == "HS256":
|
||||
hashalg = "sha256"
|
||||
@@ -383,6 +387,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
|
||||
def get_ordered_csr_identifiers(
|
||||
self,
|
||||
*,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
@@ -454,6 +459,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
|
||||
def get_csr_identifiers(
|
||||
self,
|
||||
*,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> set[tuple[str, str]]:
|
||||
@@ -470,6 +476,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
|
||||
def get_cert_days(
|
||||
self,
|
||||
*,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
now: datetime.datetime | None = None,
|
||||
@@ -516,7 +523,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
|
||||
out_text = to_text(out, errors="surrogate_or_strict")
|
||||
not_after = _extract_date(
|
||||
out_text, "Not After", cert_filename_suffix=cert_filename_suffix
|
||||
out_text, name="Not After", cert_filename_suffix=cert_filename_suffix
|
||||
)
|
||||
if now is None:
|
||||
now = self.get_now()
|
||||
@@ -524,7 +531,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
now = ensure_utc_timezone(now)
|
||||
return (not_after - now).days
|
||||
|
||||
def create_chain_matcher(self, criterium: Criterium) -> t.NoReturn:
|
||||
def create_chain_matcher(self, *, criterium: Criterium) -> t.NoReturn:
|
||||
"""
|
||||
Given a Criterium object, creates a ChainMatcher object.
|
||||
"""
|
||||
@@ -534,6 +541,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
|
||||
def get_cert_information(
|
||||
self,
|
||||
*,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
) -> CertificateInformation:
|
||||
@@ -572,10 +580,10 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
out_text = to_text(out, errors="surrogate_or_strict")
|
||||
|
||||
not_after = _extract_date(
|
||||
out_text, "Not After", cert_filename_suffix=cert_filename_suffix
|
||||
out_text, name="Not After", cert_filename_suffix=cert_filename_suffix
|
||||
)
|
||||
not_before = _extract_date(
|
||||
out_text, "Not Before", cert_filename_suffix=cert_filename_suffix
|
||||
out_text, name="Not Before", cert_filename_suffix=cert_filename_suffix
|
||||
)
|
||||
|
||||
sn = re.search(
|
||||
@@ -587,13 +595,15 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
serial = int(sn.group(1))
|
||||
else:
|
||||
serial = convert_bytes_to_int(
|
||||
_extract_octets(out_text, "Serial Number", required=True)
|
||||
_extract_octets(out_text, name="Serial Number", required=True)
|
||||
)
|
||||
|
||||
ski = _extract_octets(out_text, "X509v3 Subject Key Identifier", required=False)
|
||||
ski = _extract_octets(
|
||||
out_text, name="X509v3 Subject Key Identifier", required=False
|
||||
)
|
||||
aki = _extract_octets(
|
||||
out_text,
|
||||
"X509v3 Authority Key Identifier",
|
||||
name="X509v3 Authority Key Identifier",
|
||||
required=False,
|
||||
potential_prefixes=["keyid:", ""],
|
||||
)
|
||||
@@ -605,3 +615,6 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
subject_key_identifier=ski,
|
||||
authority_key_identifier=aki,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("OpenSSLCLIBackend",)
|
||||
|
||||
@@ -69,7 +69,9 @@ def _reduce_fractional_digits(timestamp_str: str) -> str:
|
||||
return f"{timestamp}{fractional}{timezone}"
|
||||
|
||||
|
||||
def _parse_acme_timestamp(timestamp_str: str, with_timezone: bool) -> datetime.datetime:
|
||||
def _parse_acme_timestamp(
|
||||
timestamp_str: str, *, with_timezone: bool
|
||||
) -> datetime.datetime:
|
||||
"""
|
||||
Parses a RFC 3339 timestamp.
|
||||
"""
|
||||
@@ -95,7 +97,7 @@ def _parse_acme_timestamp(timestamp_str: str, with_timezone: bool) -> datetime.d
|
||||
|
||||
|
||||
class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module: AnsibleModule, with_timezone: bool = False) -> None:
|
||||
def __init__(self, *, module: AnsibleModule, with_timezone: bool = False) -> None:
|
||||
self.module = module
|
||||
self._with_timezone = with_timezone
|
||||
|
||||
@@ -106,10 +108,10 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
# RFC 3339 (https://www.rfc-editor.org/info/rfc3339)
|
||||
return _parse_acme_timestamp(timestamp_str, with_timezone=self._with_timezone)
|
||||
|
||||
def parse_module_parameter(self, value: str, name: str) -> datetime.datetime:
|
||||
def parse_module_parameter(self, *, value: str, name: str) -> datetime.datetime:
|
||||
try:
|
||||
result = get_relative_time_option(
|
||||
value, name, with_timezone=self._with_timezone
|
||||
value, input_name=name, with_timezone=self._with_timezone
|
||||
)
|
||||
if result is None:
|
||||
raise BackendException(f"Invalid value for {name}: {value!r}")
|
||||
@@ -121,6 +123,7 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
self,
|
||||
timestamp_start: datetime.datetime,
|
||||
timestamp_end: datetime.datetime,
|
||||
*,
|
||||
percentage: float,
|
||||
) -> datetime.datetime:
|
||||
start = get_epoch_seconds(timestamp_start)
|
||||
@@ -141,6 +144,7 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def parse_key(
|
||||
self,
|
||||
*,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
@@ -152,17 +156,18 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def sign(
|
||||
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
self, *, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
) -> dict[str, t.Any]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
|
||||
def create_mac_key(self, *, alg: str, key: str) -> dict[str, t.Any]:
|
||||
"""Create a MAC key."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_ordered_csr_identifiers(
|
||||
self,
|
||||
*,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
@@ -178,6 +183,7 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def get_csr_identifiers(
|
||||
self,
|
||||
*,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> set[tuple[str, str]]:
|
||||
@@ -190,6 +196,7 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def get_cert_days(
|
||||
self,
|
||||
*,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
now: datetime.datetime | None = None,
|
||||
@@ -203,7 +210,7 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
|
||||
def create_chain_matcher(self, *, criterium: Criterium) -> ChainMatcher:
|
||||
"""
|
||||
Given a Criterium object, creates a ChainMatcher object.
|
||||
"""
|
||||
@@ -211,9 +218,13 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def get_cert_information(
|
||||
self,
|
||||
*,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
) -> CertificateInformation:
|
||||
"""
|
||||
Return some information on a X.509 certificate as a CertificateInformation object.
|
||||
"""
|
||||
|
||||
|
||||
__all__ = ("CertificateInformation", "CryptoBackend")
|
||||
|
||||
@@ -58,6 +58,7 @@ class ACMECertificateClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
module: AnsibleModule,
|
||||
backend: CryptoBackend,
|
||||
client: ACMEClient | None = None,
|
||||
@@ -68,10 +69,10 @@ class ACMECertificateClient:
|
||||
self.csr = module.params.get("csr")
|
||||
self.csr_content = module.params.get("csr_content")
|
||||
if client is None:
|
||||
client = ACMEClient(module, backend)
|
||||
client = ACMEClient(module=module, backend=backend)
|
||||
self.client = client
|
||||
if account is None:
|
||||
account = ACMEAccount(self.client)
|
||||
account = ACMEAccount(client=self.client)
|
||||
self.account = account
|
||||
self.order_uri = module.params.get("order_uri")
|
||||
self.order_creation_error_strategy = module.params.get(
|
||||
@@ -108,7 +109,9 @@ class ACMECertificateClient:
|
||||
try:
|
||||
select_chain_matcher.append(
|
||||
self.client.backend.create_chain_matcher(
|
||||
Criterium(criterium, index=criterium_idx)
|
||||
criterium=Criterium(
|
||||
criterium=criterium, index=criterium_idx
|
||||
)
|
||||
)
|
||||
)
|
||||
except ValueError as exc:
|
||||
@@ -120,12 +123,12 @@ class ACMECertificateClient:
|
||||
def load_order(self) -> Order:
|
||||
if not self.order_uri:
|
||||
raise ModuleFailException("The order URI has not been provided")
|
||||
order = Order.from_url(self.client, self.order_uri)
|
||||
order.load_authorizations(self.client)
|
||||
order = Order.from_url(client=self.client, url=self.order_uri)
|
||||
order.load_authorizations(client=self.client)
|
||||
return order
|
||||
|
||||
def create_order(
|
||||
self, replaces_cert_id: str | None = None, profile: str | None = None
|
||||
self, *, replaces_cert_id: str | None = None, profile: str | None = None
|
||||
) -> Order:
|
||||
"""
|
||||
Create a new order.
|
||||
@@ -133,8 +136,8 @@ class ACMECertificateClient:
|
||||
if self.identifiers is None:
|
||||
raise ModuleFailException("No identifiers have been provided")
|
||||
order = Order.create_with_error_handling(
|
||||
self.client,
|
||||
self.identifiers,
|
||||
client=self.client,
|
||||
identifiers=self.identifiers,
|
||||
error_strategy=self.order_creation_error_strategy,
|
||||
error_max_retries=self.order_creation_max_retries,
|
||||
replaces_cert_id=replaces_cert_id,
|
||||
@@ -142,7 +145,7 @@ class ACMECertificateClient:
|
||||
message_callback=self.module.warn,
|
||||
)
|
||||
self.order_uri = order.url
|
||||
order.load_authorizations(self.client)
|
||||
order.load_authorizations(client=self.client)
|
||||
return order
|
||||
|
||||
def get_challenges_data(
|
||||
@@ -161,7 +164,7 @@ class ACMECertificateClient:
|
||||
# and do not need to be returned
|
||||
if authz.status == "valid":
|
||||
continue
|
||||
challenge_data = authz.get_challenge_data(self.client)
|
||||
challenge_data = authz.get_challenge_data(client=self.client)
|
||||
data.append(
|
||||
dict(
|
||||
identifier=authz.identifier,
|
||||
@@ -209,20 +212,27 @@ class ACMECertificateClient:
|
||||
def call_validate(
|
||||
self,
|
||||
pending_authzs: list[Authorization],
|
||||
*,
|
||||
get_challenge: t.Callable[[Authorization], str],
|
||||
wait: bool = True,
|
||||
) -> list[tuple[Authorization, str, Challenge | None]]:
|
||||
authzs_with_challenges_to_wait_for = []
|
||||
for authz in pending_authzs:
|
||||
challenge_type = get_challenge(authz)
|
||||
authz.call_validate(self.client, challenge_type, wait=wait)
|
||||
authz.call_validate(
|
||||
client=self.client, challenge_type=challenge_type, wait=wait
|
||||
)
|
||||
authzs_with_challenges_to_wait_for.append(
|
||||
(authz, challenge_type, authz.find_challenge(challenge_type))
|
||||
(
|
||||
authz,
|
||||
challenge_type,
|
||||
authz.find_challenge(challenge_type=challenge_type),
|
||||
)
|
||||
)
|
||||
return authzs_with_challenges_to_wait_for
|
||||
|
||||
def wait_for_validation(self, authzs_to_wait_for: list[Authorization]) -> None:
|
||||
wait_for_validation(authzs_to_wait_for, self.client)
|
||||
wait_for_validation(authzs=authzs_to_wait_for, client=self.client)
|
||||
|
||||
def _download_alternate_chains(
|
||||
self, cert: CertificateChain
|
||||
@@ -230,7 +240,7 @@ class ACMECertificateClient:
|
||||
alternate_chains = []
|
||||
for alternate in cert.alternates:
|
||||
try:
|
||||
alt_cert = CertificateChain.download(self.client, alternate)
|
||||
alt_cert = CertificateChain.download(client=self.client, url=alternate)
|
||||
except ModuleFailException as e:
|
||||
self.module.warn(
|
||||
f"Error while downloading alternative certificate {alternate}: {e}"
|
||||
@@ -275,7 +285,7 @@ class ACMECertificateClient:
|
||||
f"Order's crtificate URL {order.certificate_uri!r} is empty!"
|
||||
)
|
||||
|
||||
cert = CertificateChain.download(self.client, order.certificate_uri)
|
||||
cert = CertificateChain.download(client=self.client, url=order.certificate_uri)
|
||||
if cert.cert is None:
|
||||
raise ModuleFailException(
|
||||
f"Certificate at {order.certificate_uri} is empty!"
|
||||
@@ -314,22 +324,26 @@ class ACMECertificateClient:
|
||||
for identifier, authz in order.authorizations.items():
|
||||
if authz.status != "valid":
|
||||
authz.raise_error(
|
||||
f'Status is {authz.status!r} and not "valid"',
|
||||
error_msg=f'Status is {authz.status!r} and not "valid"',
|
||||
module=self.module,
|
||||
)
|
||||
|
||||
order.finalize(self.client, pem_to_der(self.csr, self.csr_content))
|
||||
order.finalize(
|
||||
client=self.client,
|
||||
csr_der=pem_to_der(pem_filename=self.csr, pem_content=self.csr_content),
|
||||
)
|
||||
|
||||
return self.download_certificate(order, download_all_chains=download_all_chains)
|
||||
|
||||
def find_matching_chain(
|
||||
self,
|
||||
*,
|
||||
chains: list[CertificateChain],
|
||||
select_chain_matcher: t.Iterable[ChainMatcher],
|
||||
) -> CertificateChain | None:
|
||||
for criterium_idx, matcher in enumerate(select_chain_matcher):
|
||||
for chain in chains:
|
||||
if matcher.match(chain):
|
||||
if matcher.match(certificate=chain):
|
||||
self.module.debug(
|
||||
f"Found matching chain for criterium {criterium_idx}"
|
||||
)
|
||||
@@ -338,6 +352,7 @@ class ACMECertificateClient:
|
||||
|
||||
def write_cert_chain(
|
||||
self,
|
||||
*,
|
||||
cert: CertificateChain,
|
||||
cert_dest: str | os.PathLike | None = None,
|
||||
fullchain_dest: str | os.PathLike | None = None,
|
||||
@@ -347,18 +362,22 @@ class ACMECertificateClient:
|
||||
if cert.cert is None:
|
||||
raise ValueError("Certificate is not present")
|
||||
|
||||
if cert_dest and write_file(self.module, cert_dest, cert.cert.encode("utf8")):
|
||||
if cert_dest and write_file(
|
||||
module=self.module, dest=cert_dest, content=cert.cert.encode("utf8")
|
||||
):
|
||||
changed = True
|
||||
|
||||
if fullchain_dest and write_file(
|
||||
self.module,
|
||||
fullchain_dest,
|
||||
(cert.cert + "\n".join(cert.chain)).encode("utf8"),
|
||||
module=self.module,
|
||||
dest=fullchain_dest,
|
||||
content=(cert.cert + "\n".join(cert.chain)).encode("utf8"),
|
||||
):
|
||||
changed = True
|
||||
|
||||
if chain_dest and write_file(
|
||||
self.module, chain_dest, ("\n".join(cert.chain)).encode("utf8")
|
||||
module=self.module,
|
||||
dest=chain_dest,
|
||||
content=("\n".join(cert.chain)).encode("utf8"),
|
||||
):
|
||||
changed = True
|
||||
|
||||
@@ -374,7 +393,9 @@ class ACMECertificateClient:
|
||||
for authz_uri in order.authorization_uris:
|
||||
authz = None
|
||||
try:
|
||||
authz = Authorization.deactivate_url(self.client, authz_uri)
|
||||
authz = Authorization.deactivate_url(
|
||||
client=self.client, url=authz_uri
|
||||
)
|
||||
except Exception:
|
||||
# ignore errors
|
||||
pass
|
||||
@@ -385,7 +406,7 @@ class ACMECertificateClient:
|
||||
else:
|
||||
for authz in order.authorizations.values():
|
||||
try:
|
||||
authz.deactivate(self.client)
|
||||
authz.deactivate(client=self.client)
|
||||
except Exception:
|
||||
# ignore errors
|
||||
pass
|
||||
@@ -393,3 +414,6 @@ class ACMECertificateClient:
|
||||
self.module.warn(
|
||||
warning=f"Could not deactivate authz object {authz.url}."
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("ACMECertificateClient",)
|
||||
|
||||
@@ -46,7 +46,7 @@ class CertificateChain:
|
||||
|
||||
@classmethod
|
||||
def download(
|
||||
cls: t.Type[_CertificateChain], client: ACMEClient, url: str
|
||||
cls: t.Type[_CertificateChain], *, client: ACMEClient, url: str
|
||||
) -> _CertificateChain:
|
||||
content, info = client.get_request(
|
||||
url,
|
||||
@@ -70,7 +70,10 @@ class CertificateChain:
|
||||
result.chain = certs[1:]
|
||||
|
||||
process_links(
|
||||
info, lambda link, relation: result._process_links(client, link, relation)
|
||||
info=info,
|
||||
callback=lambda link, relation: result._process_links(
|
||||
client=client, link=link, relation=relation
|
||||
),
|
||||
)
|
||||
|
||||
if result.cert is None:
|
||||
@@ -80,7 +83,7 @@ class CertificateChain:
|
||||
|
||||
return result
|
||||
|
||||
def _process_links(self, client: ACMEClient, link: str, relation: str) -> None:
|
||||
def _process_links(self, *, client: ACMEClient, link: str, relation: str) -> None:
|
||||
if relation == "up":
|
||||
# Process link-up headers if there was no chain in reply
|
||||
if not self.chain:
|
||||
@@ -105,7 +108,7 @@ class CertificateChain:
|
||||
|
||||
|
||||
class Criterium:
|
||||
def __init__(self, criterium: dict[str, t.Any], index: int):
|
||||
def __init__(self, *, criterium: dict[str, t.Any], index: int):
|
||||
self.index = index
|
||||
self.test_certificates: t.Literal["first", "last", "all"] = criterium[
|
||||
"test_certificates"
|
||||
@@ -120,7 +123,10 @@ class Criterium:
|
||||
|
||||
class ChainMatcher(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def match(self, certificate: CertificateChain) -> bool:
|
||||
def match(self, *, certificate: CertificateChain) -> bool:
|
||||
"""
|
||||
Check whether a certificate chain (CertificateChain instance) matches.
|
||||
"""
|
||||
|
||||
|
||||
__all__ = ("CertificateChain", "Criterium", "ChainMatcher")
|
||||
|
||||
@@ -34,7 +34,7 @@ if t.TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
def create_key_authorization(client: ACMEClient, token: str) -> str:
|
||||
def create_key_authorization(*, client: ACMEClient, token: str) -> str:
|
||||
"""
|
||||
Returns the key authorization for the given token
|
||||
https://tools.ietf.org/html/rfc8555#section-8.1
|
||||
@@ -46,7 +46,7 @@ def create_key_authorization(client: ACMEClient, token: str) -> str:
|
||||
return f"{token}.{thumbprint}"
|
||||
|
||||
|
||||
def combine_identifier(identifier_type: str, identifier: str) -> str:
|
||||
def combine_identifier(*, identifier_type: str, identifier: str) -> str:
|
||||
return f"{identifier_type}:{identifier}"
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ def normalize_combined_identifier(identifier: str) -> str:
|
||||
identifier_type, identifier = split_identifier(identifier)
|
||||
# Normalize DNS names and IPs
|
||||
identifier = identifier.lower()
|
||||
return combine_identifier(identifier_type, identifier)
|
||||
return combine_identifier(identifier_type=identifier_type, identifier=identifier)
|
||||
|
||||
|
||||
def split_identifier(identifier: str) -> tuple[str, str]:
|
||||
@@ -70,7 +70,7 @@ _Challenge = t.TypeVar("_Challenge", bound="Challenge")
|
||||
|
||||
|
||||
class Challenge:
|
||||
def __init__(self, data: dict[str, t.Any], url: str) -> None:
|
||||
def __init__(self, *, data: dict[str, t.Any], url: str) -> None:
|
||||
self.data = data
|
||||
|
||||
self.type: str = data["type"]
|
||||
@@ -81,11 +81,12 @@ class Challenge:
|
||||
@classmethod
|
||||
def from_json(
|
||||
cls: t.Type[_Challenge],
|
||||
*,
|
||||
client: ACMEClient,
|
||||
data: dict[str, t.Any],
|
||||
url: str | None = None,
|
||||
) -> _Challenge:
|
||||
return cls(data, url or data["url"])
|
||||
return cls(data=data, url=url or data["url"])
|
||||
|
||||
def call_validate(self, client: ACMEClient) -> None:
|
||||
challenge_response: dict[str, t.Any] = {}
|
||||
@@ -100,13 +101,13 @@ class Challenge:
|
||||
return self.data.copy()
|
||||
|
||||
def get_validation_data(
|
||||
self, client: ACMEClient, identifier_type: str, identifier: str
|
||||
self, *, client: ACMEClient, identifier_type: str, identifier: str
|
||||
) -> dict[str, t.Any] | None:
|
||||
if self.token is None:
|
||||
return None
|
||||
|
||||
token = re.sub(r"[^A-Za-z0-9_\-]", "_", self.token)
|
||||
key_authorization = create_key_authorization(client, token)
|
||||
key_authorization = create_key_authorization(client=client, token=token)
|
||||
|
||||
if self.type == "http-01":
|
||||
# https://tools.ietf.org/html/rfc8555#section-8.3
|
||||
@@ -142,7 +143,9 @@ class Challenge:
|
||||
)
|
||||
return {
|
||||
"resource": resource,
|
||||
"resource_original": combine_identifier(identifier_type, identifier),
|
||||
"resource_original": combine_identifier(
|
||||
identifier_type=identifier_type, identifier=identifier
|
||||
),
|
||||
"resource_value": b_value,
|
||||
}
|
||||
|
||||
@@ -154,7 +157,7 @@ _Authorization = t.TypeVar("_Authorization", bound="Authorization")
|
||||
|
||||
|
||||
class Authorization:
|
||||
def __init__(self, url: str) -> None:
|
||||
def __init__(self, *, url: str) -> None:
|
||||
self.url = url
|
||||
|
||||
self.data: dict[str, t.Any] | None = None
|
||||
@@ -163,14 +166,14 @@ class Authorization:
|
||||
self.identifier_type: str | None = None
|
||||
self.identifier: str | None = None
|
||||
|
||||
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
|
||||
def _setup(self, *, client: ACMEClient, data: dict[str, t.Any]) -> None:
|
||||
data["uri"] = self.url
|
||||
self.data = data
|
||||
# While 'challenges' is a required field, apparently not every CA cares
|
||||
# (https://github.com/ansible-collections/community.crypto/issues/824)
|
||||
if data.get("challenges"):
|
||||
self.challenges = [
|
||||
Challenge.from_json(client, challenge)
|
||||
Challenge.from_json(client=client, data=challenge)
|
||||
for challenge in data["challenges"]
|
||||
]
|
||||
else:
|
||||
@@ -184,25 +187,27 @@ class Authorization:
|
||||
@classmethod
|
||||
def from_json(
|
||||
cls: t.Type[_Authorization],
|
||||
*,
|
||||
client: ACMEClient,
|
||||
data: dict[str, t.Any],
|
||||
url: str,
|
||||
) -> _Authorization:
|
||||
result = cls(url)
|
||||
result._setup(client, data)
|
||||
result = cls(url=url)
|
||||
result._setup(client=client, data=data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_url(
|
||||
cls: t.Type[_Authorization], client: ACMEClient, url: str
|
||||
cls: t.Type[_Authorization], *, client: ACMEClient, url: str
|
||||
) -> _Authorization:
|
||||
result = cls(url)
|
||||
result.refresh(client)
|
||||
result = cls(url=url)
|
||||
result.refresh(client=client)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls: t.Type[_Authorization],
|
||||
*,
|
||||
client: ACMEClient,
|
||||
identifier_type: str,
|
||||
identifier: str,
|
||||
@@ -220,7 +225,8 @@ class Authorization:
|
||||
}
|
||||
if "newAuthz" not in client.directory.directory:
|
||||
raise ACMEProtocolException(
|
||||
client.module, "ACME endpoint does not support pre-authorization"
|
||||
module=client.module,
|
||||
msg="ACME endpoint does not support pre-authorization",
|
||||
)
|
||||
url = client.directory["newAuthz"]
|
||||
|
||||
@@ -230,26 +236,28 @@ class Authorization:
|
||||
error_msg="Failed to request challenges",
|
||||
expected_status_codes=[200, 201],
|
||||
)
|
||||
return cls.from_json(client, result, info["location"])
|
||||
return cls.from_json(client=client, data=result, url=info["location"])
|
||||
|
||||
@property
|
||||
def combined_identifier(self) -> str:
|
||||
if self.identifier_type is None or self.identifier is None:
|
||||
raise ValueError("Data not present")
|
||||
return combine_identifier(self.identifier_type, self.identifier)
|
||||
return combine_identifier(
|
||||
identifier_type=self.identifier_type, identifier=self.identifier
|
||||
)
|
||||
|
||||
def to_json(self) -> dict[str, t.Any]:
|
||||
if self.data is None:
|
||||
raise ValueError("Data not present")
|
||||
return self.data.copy()
|
||||
|
||||
def refresh(self, client: ACMEClient) -> bool:
|
||||
def refresh(self, *, client: ACMEClient) -> bool:
|
||||
result, dummy = client.get_request(self.url)
|
||||
changed = self.data != result
|
||||
self._setup(client, result)
|
||||
self._setup(client=client, data=result)
|
||||
return changed
|
||||
|
||||
def get_challenge_data(self, client: ACMEClient) -> dict[str, t.Any]:
|
||||
def get_challenge_data(self, *, client: ACMEClient) -> dict[str, t.Any]:
|
||||
"""
|
||||
Returns a dict with the data for all proposed (and supported) challenges
|
||||
of the given authorization.
|
||||
@@ -259,13 +267,15 @@ class Authorization:
|
||||
data = {}
|
||||
for challenge in self.challenges:
|
||||
validation_data = challenge.get_validation_data(
|
||||
client, self.identifier_type, self.identifier
|
||||
client=client,
|
||||
identifier_type=self.identifier_type,
|
||||
identifier=self.identifier,
|
||||
)
|
||||
if validation_data is not None:
|
||||
data[challenge.type] = validation_data
|
||||
return data
|
||||
|
||||
def raise_error(self, error_msg: str, module: AnsibleModule) -> t.NoReturn:
|
||||
def raise_error(self, *, error_msg: str, module: AnsibleModule) -> t.NoReturn:
|
||||
"""
|
||||
Aborts with a specific error for a challenge.
|
||||
"""
|
||||
@@ -283,40 +293,40 @@ class Authorization:
|
||||
msg = f"{msg}: {problem}"
|
||||
error_details.append(msg)
|
||||
raise ACMEProtocolException(
|
||||
module,
|
||||
f"Failed to validate challenge for {self.combined_identifier}: {error_msg}. {'; '.join(error_details)}",
|
||||
module=module,
|
||||
msg=f"Failed to validate challenge for {self.combined_identifier}: {error_msg}. {'; '.join(error_details)}",
|
||||
extras=dict(
|
||||
identifier=self.combined_identifier,
|
||||
authorization=self.data,
|
||||
),
|
||||
)
|
||||
|
||||
def find_challenge(self, challenge_type: str) -> Challenge | None:
|
||||
def find_challenge(self, *, challenge_type: str) -> Challenge | None:
|
||||
for challenge in self.challenges:
|
||||
if challenge_type == challenge.type:
|
||||
return challenge
|
||||
return None
|
||||
|
||||
def wait_for_validation(self, client: ACMEClient, callenge_type: str) -> bool:
|
||||
def wait_for_validation(self, *, client: ACMEClient) -> bool:
|
||||
while True:
|
||||
self.refresh(client)
|
||||
self.refresh(client=client)
|
||||
if self.status in ["valid", "invalid", "revoked"]:
|
||||
break
|
||||
time.sleep(2)
|
||||
|
||||
if self.status == "invalid":
|
||||
self.raise_error('Status is "invalid"', module=client.module)
|
||||
self.raise_error(error_msg='Status is "invalid"', module=client.module)
|
||||
|
||||
return self.status == "valid"
|
||||
|
||||
def call_validate(
|
||||
self, client: ACMEClient, challenge_type: str, wait: bool = True
|
||||
self, *, client: ACMEClient, challenge_type: str, wait: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the authorization provided in the auth dict. Returns True
|
||||
when the validation was successful and False when it was not.
|
||||
"""
|
||||
challenge = self.find_challenge(challenge_type)
|
||||
challenge = self.find_challenge(challenge_type=challenge_type)
|
||||
if challenge is None:
|
||||
raise ModuleFailException(
|
||||
f'Found no challenge of type "{challenge_type}" for identifier {self.combined_identifier}!'
|
||||
@@ -326,7 +336,7 @@ class Authorization:
|
||||
|
||||
if not wait:
|
||||
return self.status == "valid"
|
||||
return self.wait_for_validation(client, challenge_type)
|
||||
return self.wait_for_validation(client=client)
|
||||
|
||||
def can_deactivate(self) -> bool:
|
||||
"""
|
||||
@@ -336,7 +346,7 @@ class Authorization:
|
||||
"""
|
||||
return self.status in ("valid", "pending")
|
||||
|
||||
def deactivate(self, client: ACMEClient) -> bool | None:
|
||||
def deactivate(self, *, client: ACMEClient) -> bool | None:
|
||||
"""
|
||||
Deactivates this authorization.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
@@ -355,35 +365,50 @@ class Authorization:
|
||||
|
||||
@classmethod
|
||||
def deactivate_url(
|
||||
cls: t.Type[_Authorization], client: ACMEClient, url: str
|
||||
cls: t.Type[_Authorization], *, client: ACMEClient, url: str
|
||||
) -> _Authorization:
|
||||
"""
|
||||
Deactivates this authorization.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
https://tools.ietf.org/html/rfc8555#section-7.5.2
|
||||
"""
|
||||
authz = cls(url)
|
||||
authz = cls(url=url)
|
||||
authz_deactivate = {"status": "deactivated"}
|
||||
result, info = client.send_signed_request(
|
||||
url, authz_deactivate, fail_on_error=True
|
||||
)
|
||||
authz._setup(client, result)
|
||||
authz._setup(client=client, data=result)
|
||||
return authz
|
||||
|
||||
|
||||
def wait_for_validation(authzs: t.Iterable[Authorization], client: ACMEClient) -> None:
|
||||
def wait_for_validation(
|
||||
*, authzs: t.Iterable[Authorization], client: ACMEClient
|
||||
) -> None:
|
||||
"""
|
||||
Wait until a list of authz is valid. Fail if at least one of them is invalid or revoked.
|
||||
"""
|
||||
while authzs:
|
||||
authzs_next = []
|
||||
for authz in authzs:
|
||||
authz.refresh(client)
|
||||
authz.refresh(client=client)
|
||||
if authz.status in ["valid", "invalid", "revoked"]:
|
||||
if authz.status != "valid":
|
||||
authz.raise_error('Status is not "valid"', module=client.module)
|
||||
authz.raise_error(
|
||||
error_msg='Status is not "valid"', module=client.module
|
||||
)
|
||||
else:
|
||||
authzs_next.append(authz)
|
||||
if authzs_next:
|
||||
time.sleep(2)
|
||||
authzs = authzs_next
|
||||
|
||||
|
||||
__all__ = (
|
||||
"create_key_authorization",
|
||||
"combine_identifier",
|
||||
"normalize_combined_identifier",
|
||||
"split_identifier",
|
||||
"Challenge",
|
||||
"Authorization",
|
||||
"wait_for_validation",
|
||||
)
|
||||
|
||||
@@ -25,7 +25,9 @@ def format_http_status(status_code: int) -> str:
|
||||
return f"{status_code} {expl}"
|
||||
|
||||
|
||||
def format_error_problem(problem: dict[str, t.Any], subproblem_prefix: str = "") -> str:
|
||||
def format_error_problem(
|
||||
problem: dict[str, t.Any], *, subproblem_prefix: str = ""
|
||||
) -> str:
|
||||
error_type = problem.get(
|
||||
"type", "about:blank"
|
||||
) # https://www.rfc-editor.org/rfc/rfc7807#section-3.1
|
||||
@@ -57,13 +59,14 @@ class ModuleFailException(Exception):
|
||||
self.msg = msg
|
||||
self.module_fail_args = args
|
||||
|
||||
def do_fail(self, module: AnsibleModule, **arguments) -> t.NoReturn:
|
||||
def do_fail(self, *, module: AnsibleModule, **arguments) -> t.NoReturn:
|
||||
module.fail_json(msg=self.msg, other=self.module_fail_args, **arguments)
|
||||
|
||||
|
||||
class ACMEProtocolException(ModuleFailException):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
module: AnsibleModule,
|
||||
msg: str | None = None,
|
||||
info: dict[str, t.Any] | None = None,
|
||||
@@ -168,3 +171,14 @@ class NetworkException(ModuleFailException):
|
||||
|
||||
class KeyParsingError(ModuleFailException):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = (
|
||||
"format_http_status",
|
||||
"format_error_problem",
|
||||
"ModuleFailException",
|
||||
"ACMEProtocolException",
|
||||
"BackendException",
|
||||
"NetworkException",
|
||||
"KeyParsingError",
|
||||
)
|
||||
|
||||
@@ -33,7 +33,9 @@ def read_file(fn: str | os.PathLike) -> bytes:
|
||||
|
||||
|
||||
# This function was adapted from an earlier version of https://github.com/ansible/ansible/blob/devel/lib/ansible/modules/uri.py
|
||||
def write_file(module: AnsibleModule, dest: str | os.PathLike, content: bytes) -> bool:
|
||||
def write_file(
|
||||
*, module: AnsibleModule, dest: str | os.PathLike, content: bytes
|
||||
) -> bool:
|
||||
"""
|
||||
Write content to destination file dest, only if the content
|
||||
has changed.
|
||||
@@ -95,3 +97,6 @@ def write_file(module: AnsibleModule, dest: str | os.PathLike, content: bytes) -
|
||||
)
|
||||
os.remove(tmpsrc)
|
||||
return changed
|
||||
|
||||
|
||||
__all__ = ("read_file", "write_file")
|
||||
|
||||
@@ -34,7 +34,7 @@ _Order = t.TypeVar("_Order", bound="Order")
|
||||
|
||||
|
||||
class Order:
|
||||
def __init__(self, url: str) -> None:
|
||||
def __init__(self, *, url: str) -> None:
|
||||
self.url = url
|
||||
|
||||
self.data: dict[str, t.Any] | None = None
|
||||
@@ -47,7 +47,7 @@ class Order:
|
||||
self.authorization_uris: list[str] = []
|
||||
self.authorizations: dict[str, Authorization] = {}
|
||||
|
||||
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
|
||||
def _setup(self, *, client: ACMEClient, data: dict[str, t.Any]) -> None:
|
||||
self.data = data
|
||||
|
||||
self.status = data["status"]
|
||||
@@ -62,21 +62,22 @@ class Order:
|
||||
|
||||
@classmethod
|
||||
def from_json(
|
||||
cls: t.Type[_Order], client: ACMEClient, data: dict[str, t.Any], url: str
|
||||
cls: t.Type[_Order], *, client: ACMEClient, data: dict[str, t.Any], url: str
|
||||
) -> _Order:
|
||||
result = cls(url)
|
||||
result._setup(client, data)
|
||||
result = cls(url=url)
|
||||
result._setup(client=client, data=data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_url(cls: t.Type[_Order], client: ACMEClient, url: str) -> _Order:
|
||||
result = cls(url)
|
||||
result.refresh(client)
|
||||
def from_url(cls: t.Type[_Order], *, client: ACMEClient, url: str) -> _Order:
|
||||
result = cls(url=url)
|
||||
result.refresh(client=client)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls: t.Type[_Order],
|
||||
*,
|
||||
client: ACMEClient,
|
||||
identifiers: list[tuple[str, str]],
|
||||
replaces_cert_id: str | None = None,
|
||||
@@ -105,11 +106,12 @@ class Order:
|
||||
error_msg="Failed to start new order",
|
||||
expected_status_codes=[201],
|
||||
)
|
||||
return cls.from_json(client, result, info["location"])
|
||||
return cls.from_json(client=client, data=result, url=info["location"])
|
||||
|
||||
@classmethod
|
||||
def create_with_error_handling(
|
||||
cls: t.Type[_Order],
|
||||
*,
|
||||
client: ACMEClient,
|
||||
identifiers: list[tuple[str, str]],
|
||||
error_strategy: t.Literal[
|
||||
@@ -136,8 +138,8 @@ class Order:
|
||||
tries += 1
|
||||
try:
|
||||
return cls.create(
|
||||
client,
|
||||
identifiers,
|
||||
client=client,
|
||||
identifiers=identifiers,
|
||||
replaces_cert_id=replaces_cert_id,
|
||||
profile=profile,
|
||||
)
|
||||
@@ -164,34 +166,36 @@ class Order:
|
||||
|
||||
raise
|
||||
|
||||
def refresh(self, client: ACMEClient) -> bool:
|
||||
def refresh(self, *, client: ACMEClient) -> bool:
|
||||
result, dummy = client.get_request(self.url)
|
||||
changed = self.data != result
|
||||
self._setup(client, result)
|
||||
self._setup(client=client, data=result)
|
||||
return changed
|
||||
|
||||
def load_authorizations(self, client: ACMEClient) -> None:
|
||||
def load_authorizations(self, *, client: ACMEClient) -> None:
|
||||
for auth_uri in self.authorization_uris:
|
||||
authz = Authorization.from_url(client, auth_uri)
|
||||
authz = Authorization.from_url(client=client, url=auth_uri)
|
||||
self.authorizations[
|
||||
normalize_combined_identifier(authz.combined_identifier)
|
||||
] = authz
|
||||
|
||||
def wait_for_finalization(self, client: ACMEClient) -> None:
|
||||
def wait_for_finalization(self, *, client: ACMEClient) -> None:
|
||||
while True:
|
||||
self.refresh(client)
|
||||
self.refresh(client=client)
|
||||
if self.status in ["valid", "invalid", "pending", "ready"]:
|
||||
break
|
||||
time.sleep(2)
|
||||
|
||||
if self.status != "valid":
|
||||
raise ACMEProtocolException(
|
||||
client.module,
|
||||
f'Failed to wait for order to complete; got status "{self.status}"',
|
||||
module=client.module,
|
||||
msg=f'Failed to wait for order to complete; got status "{self.status}"',
|
||||
content_json=self.data,
|
||||
)
|
||||
|
||||
def finalize(self, client: ACMEClient, csr_der: bytes, wait: bool = True) -> None:
|
||||
def finalize(
|
||||
self, *, client: ACMEClient, csr_der: bytes, wait: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Create a new certificate based on the csr.
|
||||
Return the certificate object as dict
|
||||
@@ -212,13 +216,16 @@ class Order:
|
||||
# Instead of using the result, we call self.refresh(client) below.
|
||||
|
||||
if wait:
|
||||
self.wait_for_finalization(client)
|
||||
self.wait_for_finalization(client=client)
|
||||
else:
|
||||
self.refresh(client)
|
||||
self.refresh(client=client)
|
||||
if self.status not in ["procesing", "valid", "invalid"]:
|
||||
raise ACMEProtocolException(
|
||||
client.module,
|
||||
f'Failed to finalize order; got status "{self.status}"',
|
||||
module=client.module,
|
||||
msg=f'Failed to finalize order; got status "{self.status}"',
|
||||
info=info,
|
||||
content_json=result,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("Order",)
|
||||
|
||||
@@ -48,7 +48,7 @@ def der_to_pem(der_cert: bytes) -> str:
|
||||
|
||||
|
||||
def pem_to_der(
|
||||
pem_filename: str | os.PathLike | None = None, pem_content: str | None = None
|
||||
*, pem_filename: str | os.PathLike | None = None, pem_content: str | None = None
|
||||
) -> bytes:
|
||||
"""
|
||||
Load PEM file, or use PEM file's content, and convert to DER.
|
||||
@@ -85,7 +85,7 @@ def pem_to_der(
|
||||
|
||||
|
||||
def process_links(
|
||||
info: dict[str, t.Any], callback: t.Callable[[str, str], None]
|
||||
*, info: dict[str, t.Any], callback: t.Callable[[str, str], None]
|
||||
) -> None:
|
||||
"""
|
||||
Process link header, calls callback for every link header with the URL and relation as options.
|
||||
@@ -100,6 +100,7 @@ def process_links(
|
||||
|
||||
def parse_retry_after(
|
||||
value: str,
|
||||
*,
|
||||
relative_with_timezone: bool = True,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> datetime.datetime:
|
||||
@@ -112,7 +113,7 @@ def parse_retry_after(
|
||||
try:
|
||||
delta = datetime.timedelta(seconds=int(value))
|
||||
if now is None:
|
||||
now = get_now_datetime(relative_with_timezone)
|
||||
now = get_now_datetime(with_timezone=relative_with_timezone)
|
||||
return now + delta
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -126,6 +127,7 @@ def parse_retry_after(
|
||||
|
||||
|
||||
def compute_cert_id(
|
||||
*,
|
||||
backend: CryptoBackend,
|
||||
cert_info: CertificateInformation | None = None,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
@@ -159,3 +161,13 @@ def compute_cert_id(
|
||||
|
||||
# Compose cert ID
|
||||
return f"{aki}.{serial}"
|
||||
|
||||
|
||||
__all__ = (
|
||||
"nopad_b64",
|
||||
"der_to_pem",
|
||||
"pem_to_der",
|
||||
"process_links",
|
||||
"parse_retry_after",
|
||||
"compute_cert_id",
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ class ArgumentSpec:
|
||||
def __init__(
|
||||
self,
|
||||
argument_spec: dict[str, t.Any] | None = None,
|
||||
*,
|
||||
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_together: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
|
||||
@@ -50,6 +51,7 @@ class ArgumentSpec:
|
||||
|
||||
def update(
|
||||
self,
|
||||
*,
|
||||
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_together: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
|
||||
|
||||
@@ -93,7 +93,12 @@ def serialize_asn1_string_as_der(value: str) -> bytes:
|
||||
|
||||
# We should only do a universal type tag if not IMPLICITLY tagged or the tag class is not universal.
|
||||
if not tag_type or (tag_type == "EXPLICIT" and tag_class != "U"):
|
||||
b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value)
|
||||
b_value = pack_asn1(
|
||||
tag_class=TagClass.universal,
|
||||
constructed=False,
|
||||
tag_number=TagNumber.utf8_string,
|
||||
b_data=b_value,
|
||||
)
|
||||
|
||||
if tag_type:
|
||||
tag_class_enum = {
|
||||
@@ -105,13 +110,22 @@ def serialize_asn1_string_as_der(value: str) -> bytes:
|
||||
|
||||
# When adding support for more types this should be looked into further. For now it works with UTF8Strings.
|
||||
constructed = tag_type == "EXPLICIT" and tag_class_enum != TagClass.universal
|
||||
b_value = pack_asn1(tag_class_enum, constructed, int(tag_number), b_value)
|
||||
b_value = pack_asn1(
|
||||
tag_class=tag_class_enum,
|
||||
constructed=constructed,
|
||||
tag_number=int(tag_number),
|
||||
b_data=b_value,
|
||||
)
|
||||
|
||||
return b_value
|
||||
|
||||
|
||||
def pack_asn1(
|
||||
tag_class: TagClass, constructed: bool, tag_number: TagNumber | int, b_data: bytes
|
||||
*,
|
||||
tag_class: TagClass,
|
||||
constructed: bool,
|
||||
tag_number: TagNumber | int,
|
||||
b_data: bytes,
|
||||
) -> bytes:
|
||||
"""Pack the value into an ASN.1 data structure.
|
||||
|
||||
@@ -159,3 +173,6 @@ def pack_asn1(
|
||||
b_asn1_data.extend(length_octets)
|
||||
|
||||
return bytes(b_asn1_data) + b_data
|
||||
|
||||
|
||||
__all__ = ("TagClass", "TagNumber", "serialize_asn1_string_as_der", "pack_asn1")
|
||||
|
||||
@@ -58,3 +58,6 @@ def obj2txt(openssl_lib, openssl_ffi, obj) -> str:
|
||||
buf = openssl_ffi.new("char[]", buf_len)
|
||||
res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1)
|
||||
return openssl_ffi.buffer(buf, res)[:].decode()
|
||||
|
||||
|
||||
__all__ = ("obj2txt",)
|
||||
|
||||
@@ -33,3 +33,6 @@ for alias, original in [("userID", "userId")]:
|
||||
NORMALIZE_NAMES[alias] = original
|
||||
NORMALIZE_NAMES_SHORT[alias] = NORMALIZE_NAMES_SHORT[original]
|
||||
OID_LOOKUP[alias] = OID_LOOKUP[original]
|
||||
|
||||
|
||||
__all__ = ("OID_LOOKUP", "NORMALIZE_NAMES", "NORMALIZE_NAMES_SHORT")
|
||||
|
||||
@@ -1175,3 +1175,6 @@ OID_MAP = {
|
||||
"2.23.43.1.4.11": ("wap-wsg-idm-ecid-wtls11",),
|
||||
"2.23.43.1.4.12": ("wap-wsg-idm-ecid-wtls12",),
|
||||
}
|
||||
|
||||
|
||||
__all__ = ("OID_MAP",)
|
||||
|
||||
@@ -24,3 +24,6 @@ class OpenSSLObjectError(Exception):
|
||||
|
||||
class OpenSSLBadPassphraseError(OpenSSLObjectError):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ("HAS_CRYPTOGRAPHY", "OpenSSLObjectError", "OpenSSLBadPassphraseError")
|
||||
|
||||
@@ -107,6 +107,7 @@ def cryptography_decode_revoked_certificate(
|
||||
|
||||
def cryptography_dump_revoked(
|
||||
entry: dict[str, t.Any],
|
||||
*,
|
||||
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
|
||||
) -> dict[str, t.Any]:
|
||||
return {
|
||||
@@ -174,18 +175,34 @@ def get_invalidity_date(obj: x509.InvalidityDate) -> datetime.datetime:
|
||||
|
||||
|
||||
def set_next_update(
|
||||
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
|
||||
builder: x509.CertificateRevocationListBuilder, *, value: datetime.datetime
|
||||
) -> x509.CertificateRevocationListBuilder:
|
||||
return builder.next_update(value)
|
||||
|
||||
|
||||
def set_last_update(
|
||||
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
|
||||
builder: x509.CertificateRevocationListBuilder, *, value: datetime.datetime
|
||||
) -> x509.CertificateRevocationListBuilder:
|
||||
return builder.last_update(value)
|
||||
|
||||
|
||||
def set_revocation_date(
|
||||
builder: x509.RevokedCertificateBuilder, value: datetime.datetime
|
||||
builder: x509.RevokedCertificateBuilder, *, value: datetime.datetime
|
||||
) -> x509.RevokedCertificateBuilder:
|
||||
return builder.revocation_date(value)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"REVOCATION_REASON_MAP",
|
||||
"REVOCATION_REASON_MAP_INVERSE",
|
||||
"cryptography_decode_revoked_certificate",
|
||||
"cryptography_dump_revoked",
|
||||
"cryptography_get_signature_algorithm_oid_from_crl",
|
||||
"get_next_update",
|
||||
"get_last_update",
|
||||
"get_revocation_date",
|
||||
"get_invalidity_date",
|
||||
"set_next_update",
|
||||
"set_last_update",
|
||||
"set_revocation_date",
|
||||
)
|
||||
|
||||
@@ -263,7 +263,7 @@ def cryptography_name_to_oid(name: str) -> x509.oid.ObjectIdentifier:
|
||||
|
||||
|
||||
def cryptography_oid_to_name(
|
||||
oid: x509.oid.ObjectIdentifier, short: bool = False
|
||||
oid: x509.oid.ObjectIdentifier, *, short: bool = False
|
||||
) -> str:
|
||||
dotted_string = oid.dotted_string
|
||||
names = OID_MAP.get(dotted_string)
|
||||
@@ -315,7 +315,7 @@ def _int_to_byte(value: int) -> bytes:
|
||||
|
||||
|
||||
def _parse_dn_component(
|
||||
name: bytes, sep: bytes = b",", decode_remainder: bool = True
|
||||
name: bytes, *, sep: bytes = b",", decode_remainder: bool = True
|
||||
) -> tuple[x509.NameAttribute, bytes]:
|
||||
m = DN_COMPONENT_START_RE.match(name)
|
||||
if not m:
|
||||
@@ -428,7 +428,9 @@ def _is_ascii(value: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]) -> str:
|
||||
def _adjust_idn(
|
||||
value: str, *, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
|
||||
) -> str:
|
||||
if idn_rewrite == "ignore" or not value:
|
||||
return value
|
||||
if idn_rewrite == "idna" and _is_ascii(value):
|
||||
@@ -472,19 +474,19 @@ def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"])
|
||||
|
||||
|
||||
def _adjust_idn_email(
|
||||
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
|
||||
value: str, *, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
|
||||
) -> str:
|
||||
idx = value.find("@")
|
||||
if idx < 0:
|
||||
return value
|
||||
return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite)}"
|
||||
return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite=idn_rewrite)}"
|
||||
|
||||
|
||||
def _adjust_idn_url(
|
||||
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
|
||||
value: str, *, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
|
||||
) -> str:
|
||||
url = urlparse(value)
|
||||
host = _adjust_idn(url.hostname, idn_rewrite) if url.hostname else None
|
||||
host = _adjust_idn(url.hostname, idn_rewrite=idn_rewrite) if url.hostname else None
|
||||
if url.username is not None and url.password is not None:
|
||||
host = f"{url.username}:{url.password}@{host}"
|
||||
elif url.username is not None:
|
||||
@@ -504,7 +506,7 @@ def _adjust_idn_url(
|
||||
|
||||
|
||||
def cryptography_get_name(
|
||||
name: str, what: str = "Subject Alternative Name"
|
||||
name: str, *, what: str = "Subject Alternative Name"
|
||||
) -> x509.GeneralName:
|
||||
"""
|
||||
Given a name string, returns a cryptography x509.GeneralName object.
|
||||
@@ -512,17 +514,19 @@ def cryptography_get_name(
|
||||
"""
|
||||
try:
|
||||
if name.startswith("DNS:"):
|
||||
return x509.DNSName(_adjust_idn(to_text(name[4:]), "idna"))
|
||||
return x509.DNSName(_adjust_idn(to_text(name[4:]), idn_rewrite="idna"))
|
||||
if name.startswith("IP:"):
|
||||
address = to_text(name[3:])
|
||||
if "/" in address:
|
||||
return x509.IPAddress(ipaddress.ip_network(address))
|
||||
return x509.IPAddress(ipaddress.ip_address(address))
|
||||
if name.startswith("email:"):
|
||||
return x509.RFC822Name(_adjust_idn_email(to_text(name[6:]), "idna"))
|
||||
return x509.RFC822Name(
|
||||
_adjust_idn_email(to_text(name[6:]), idn_rewrite="idna")
|
||||
)
|
||||
if name.startswith("URI:"):
|
||||
return x509.UniformResourceIdentifier(
|
||||
_adjust_idn_url(to_text(name[4:]), "idna")
|
||||
_adjust_idn_url(to_text(name[4:]), idn_rewrite="idna")
|
||||
)
|
||||
if name.startswith("RID:"):
|
||||
m = re.match(r"^([0-9]+(?:\.[0-9]+)*)$", to_text(name[4:]))
|
||||
@@ -585,6 +589,7 @@ def _dn_escape_value(value: str) -> str:
|
||||
|
||||
def cryptography_decode_name(
|
||||
name: x509.GeneralName,
|
||||
*,
|
||||
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
|
||||
) -> str:
|
||||
"""
|
||||
@@ -596,15 +601,15 @@ def cryptography_decode_name(
|
||||
'idn_rewrite must be one of "ignore", "idna", or "unicode"'
|
||||
)
|
||||
if isinstance(name, x509.DNSName):
|
||||
return f"DNS:{_adjust_idn(name.value, idn_rewrite)}"
|
||||
return f"DNS:{_adjust_idn(name.value, idn_rewrite=idn_rewrite)}"
|
||||
if isinstance(name, x509.IPAddress):
|
||||
if isinstance(name.value, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||
return f"IP:{name.value.network_address.compressed}/{name.value.prefixlen}"
|
||||
return f"IP:{name.value.compressed}"
|
||||
if isinstance(name, x509.RFC822Name):
|
||||
return f"email:{_adjust_idn_email(name.value, idn_rewrite)}"
|
||||
return f"email:{_adjust_idn_email(name.value, idn_rewrite=idn_rewrite)}"
|
||||
if isinstance(name, x509.UniformResourceIdentifier):
|
||||
return f"URI:{_adjust_idn_url(name.value, idn_rewrite)}"
|
||||
return f"URI:{_adjust_idn_url(name.value, idn_rewrite=idn_rewrite)}"
|
||||
if isinstance(name, x509.DirectoryName):
|
||||
# According to https://datatracker.ietf.org/doc/html/rfc4514.html#section-2.1 the
|
||||
# list needs to be reversed, and joined by commas
|
||||
@@ -718,7 +723,7 @@ def cryptography_key_needs_digest_for_signing(
|
||||
|
||||
|
||||
def _compare_public_keys(
|
||||
key1: PublicKeyTypes, key2: PublicKeyTypes, clazz: type[PublicKeyTypes]
|
||||
key1: PublicKeyTypes, key2: PublicKeyTypes, *, clazz: type[PublicKeyTypes]
|
||||
) -> bool | None:
|
||||
a = isinstance(key1, clazz)
|
||||
b = isinstance(key2, clazz)
|
||||
@@ -745,24 +750,24 @@ def cryptography_compare_public_keys(
|
||||
res = _compare_public_keys(
|
||||
key1,
|
||||
key2,
|
||||
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
|
||||
clazz=cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
res = _compare_public_keys(
|
||||
key1,
|
||||
key2,
|
||||
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
|
||||
clazz=cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
res = _compare_public_keys(
|
||||
key1, key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey
|
||||
key1, key2, clazz=cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
res = _compare_public_keys(
|
||||
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
|
||||
key1, key2, clazz=cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
@@ -773,7 +778,7 @@ def cryptography_compare_public_keys(
|
||||
|
||||
|
||||
def _compare_private_keys(
|
||||
key1: PrivateKeyTypes, key2: PrivateKeyTypes, clazz: type[PrivateKeyTypes]
|
||||
key1: PrivateKeyTypes, key2: PrivateKeyTypes, *, clazz: type[PrivateKeyTypes]
|
||||
) -> bool | None:
|
||||
a = isinstance(key1, clazz)
|
||||
b = isinstance(key2, clazz)
|
||||
@@ -805,24 +810,26 @@ def cryptography_compare_private_keys(
|
||||
res = _compare_private_keys(
|
||||
key1,
|
||||
key2,
|
||||
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
|
||||
clazz=cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
res = _compare_private_keys(
|
||||
key1,
|
||||
key2,
|
||||
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
|
||||
clazz=cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
res = _compare_private_keys(
|
||||
key1, key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey
|
||||
key1,
|
||||
key2,
|
||||
clazz=cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey,
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
res = _compare_private_keys(
|
||||
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
|
||||
key1, key2, clazz=cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
@@ -832,7 +839,9 @@ def cryptography_compare_private_keys(
|
||||
)
|
||||
|
||||
|
||||
def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) -> tuple[
|
||||
def parse_pkcs12(
|
||||
pkcs12_bytes: bytes, *, passphrase: bytes | str | None = None
|
||||
) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
@@ -845,15 +854,17 @@ def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) ->
|
||||
|
||||
# Main code for cryptography 36.0.0 and forward
|
||||
if _load_pkcs12 is not None:
|
||||
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase_bytes)
|
||||
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=passphrase_bytes)
|
||||
|
||||
if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"):
|
||||
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase_bytes)
|
||||
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=passphrase_bytes)
|
||||
|
||||
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase_bytes)
|
||||
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase=passphrase_bytes)
|
||||
|
||||
|
||||
def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
|
||||
def _parse_pkcs12_36_0_0(
|
||||
pkcs12_bytes: bytes, *, passphrase: bytes | None = None
|
||||
) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
@@ -871,7 +882,9 @@ def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -
|
||||
return private_key, certificate, additional_certificates, friendly_name
|
||||
|
||||
|
||||
def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
|
||||
def _parse_pkcs12_35_0_0(
|
||||
pkcs12_bytes: bytes, *, passphrase: bytes | None = None
|
||||
) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
@@ -918,7 +931,9 @@ def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -
|
||||
return private_key, certificate, additional_certificates, friendly_name
|
||||
|
||||
|
||||
def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
|
||||
def _parse_pkcs12_legacy(
|
||||
pkcs12_bytes: bytes, *, passphrase: bytes | None = None
|
||||
) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
@@ -940,6 +955,7 @@ def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -
|
||||
|
||||
|
||||
def cryptography_verify_signature(
|
||||
*,
|
||||
signature: bytes,
|
||||
data: bytes,
|
||||
hash_algorithm: hashes.HashAlgorithm | None,
|
||||
@@ -999,16 +1015,16 @@ def cryptography_verify_signature(
|
||||
|
||||
|
||||
def cryptography_verify_certificate_signature(
|
||||
certificate: x509.Certificate, signer_public_key: PublicKeyTypes
|
||||
*, certificate: x509.Certificate, signer_public_key: PublicKeyTypes
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether the given X509 certificate object was signed by the given public key object.
|
||||
"""
|
||||
return cryptography_verify_signature(
|
||||
certificate.signature,
|
||||
certificate.tbs_certificate_bytes,
|
||||
certificate.signature_hash_algorithm,
|
||||
signer_public_key,
|
||||
signature=certificate.signature,
|
||||
data=certificate.tbs_certificate_bytes,
|
||||
hash_algorithm=certificate.signature_hash_algorithm,
|
||||
signer_public_key=signer_public_key,
|
||||
)
|
||||
|
||||
|
||||
@@ -1074,3 +1090,31 @@ def is_potential_certificate_issuer_public_key(
|
||||
cryptography.hazmat.primitives.asymmetric.dh.DHPublicKey,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"CRYPTOGRAPHY_TIMEZONE",
|
||||
"cryptography_get_extensions_from_cert",
|
||||
"cryptography_get_extensions_from_csr",
|
||||
"cryptography_name_to_oid",
|
||||
"cryptography_oid_to_name",
|
||||
"cryptography_parse_relative_distinguished_name",
|
||||
"cryptography_get_name",
|
||||
"cryptography_decode_name",
|
||||
"cryptography_parse_key_usage_params",
|
||||
"cryptography_get_basic_constraints",
|
||||
"cryptography_key_needs_digest_for_signing",
|
||||
"cryptography_compare_public_keys",
|
||||
"cryptography_compare_private_keys",
|
||||
"parse_pkcs12",
|
||||
"cryptography_verify_signature",
|
||||
"cryptography_verify_certificate_signature",
|
||||
"get_not_valid_after",
|
||||
"get_not_valid_before",
|
||||
"set_not_valid_after",
|
||||
"set_not_valid_before",
|
||||
"is_potential_certificate_private_key",
|
||||
"is_potential_certificate_issuer_private_key",
|
||||
"is_potential_certificate_public_key",
|
||||
"is_potential_certificate_issuer_public_key",
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def binary_exp_mod(f: int, e: int, m: int) -> int:
|
||||
def binary_exp_mod(f: int, e: int, *, m: int) -> int:
|
||||
"""Computes f^e mod m in O(log e) multiplications modulo m."""
|
||||
# Compute len_e = floor(log_2(e))
|
||||
len_e = -1
|
||||
@@ -120,7 +120,7 @@ def count_bits(no: int) -> int:
|
||||
return no.bit_length()
|
||||
|
||||
|
||||
def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
|
||||
def convert_int_to_bytes(no: int, *, count: int | None = None) -> bytes:
|
||||
"""
|
||||
Convert the absolute value of an integer to a byte string in network byte order.
|
||||
|
||||
@@ -136,7 +136,7 @@ def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
|
||||
return no.to_bytes(count, byteorder="big")
|
||||
|
||||
|
||||
def convert_int_to_hex(no: int, digits: int | None = None) -> str:
|
||||
def convert_int_to_hex(no: int, *, digits: int | None = None) -> str:
|
||||
"""
|
||||
Convert the absolute value of an integer to a string of hexadecimal digits.
|
||||
|
||||
@@ -156,3 +156,15 @@ def convert_bytes_to_int(data: bytes) -> int:
|
||||
Convert a byte string to an unsigned integer in network byte order.
|
||||
"""
|
||||
return int.from_bytes(data, byteorder="big", signed=False)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"binary_exp_mod",
|
||||
"simple_gcd",
|
||||
"quick_is_not_prime",
|
||||
"count_bytes",
|
||||
"count_bits",
|
||||
"convert_int_to_bytes",
|
||||
"convert_int_to_hex",
|
||||
"convert_bytes_to_int",
|
||||
)
|
||||
|
||||
@@ -63,7 +63,7 @@ class CertificateError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
|
||||
self.force: bool = module.params["force"]
|
||||
@@ -104,7 +104,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
return {}
|
||||
try:
|
||||
result = get_certificate_info(
|
||||
self.module, data, prefer_one_fingerprint=True
|
||||
module=self.module, content=data, prefer_one_fingerprint=True
|
||||
)
|
||||
result["can_parse_certificate"] = True
|
||||
return result
|
||||
@@ -289,6 +289,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
|
||||
def needs_regeneration(
|
||||
self,
|
||||
*,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
@@ -330,7 +331,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
return True
|
||||
return False
|
||||
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result: dict[str, t.Any] = {
|
||||
"privatekey": self.privatekey_path,
|
||||
@@ -372,7 +373,7 @@ class CertificateProvider(metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
def select_backend(
|
||||
module: AnsibleModule, provider: CertificateProvider
|
||||
*, module: AnsibleModule, provider: CertificateProvider
|
||||
) -> CertificateBackend:
|
||||
provider.validate_module_args(module)
|
||||
|
||||
@@ -415,3 +416,11 @@ def get_certificate_argument_spec() -> ArgumentSpec:
|
||||
["privatekey_path", "privatekey_content"],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"CertificateError",
|
||||
"CertificateBackend",
|
||||
"CertificateProvider",
|
||||
"get_certificate_argument_spec",
|
||||
)
|
||||
|
||||
@@ -29,8 +29,8 @@ if t.TYPE_CHECKING:
|
||||
|
||||
|
||||
class AcmeCertificateBackend(CertificateBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(AcmeCertificateBackend, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(AcmeCertificateBackend, self).__init__(module=module)
|
||||
self.accountkey_path: str = module.params["acme_accountkey_path"]
|
||||
self.challenge_path: str = module.params["acme_challenge_path"]
|
||||
self.use_chain: bool = module.params["acme_chain"]
|
||||
@@ -102,8 +102,10 @@ class AcmeCertificateBackend(CertificateBackend):
|
||||
raise AssertionError("Contract violation: cert_bytes is None")
|
||||
return self.cert_bytes
|
||||
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
result = super(AcmeCertificateBackend, self).dump(include_certificate)
|
||||
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
|
||||
result = super(AcmeCertificateBackend, self).dump(
|
||||
include_certificate=include_certificate
|
||||
)
|
||||
result["accountkey"] = self.accountkey_path
|
||||
return result
|
||||
|
||||
@@ -123,7 +125,7 @@ class AcmeCertificateProvider(CertificateProvider):
|
||||
return False
|
||||
|
||||
def create_backend(self, module: AnsibleModule) -> AcmeCertificateBackend:
|
||||
return AcmeCertificateBackend(module)
|
||||
return AcmeCertificateBackend(module=module)
|
||||
|
||||
|
||||
def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
@@ -138,3 +140,10 @@ def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"AcmeCertificateBackend",
|
||||
"AcmeCertificateProvider",
|
||||
"add_acme_provider_to_argument_spec",
|
||||
)
|
||||
|
||||
@@ -50,12 +50,12 @@ except ImportError:
|
||||
|
||||
|
||||
class EntrustCertificateBackend(CertificateBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(EntrustCertificateBackend, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(EntrustCertificateBackend, self).__init__(module=module)
|
||||
self.trackingId = None
|
||||
self.notAfter = get_relative_time_option(
|
||||
module.params["entrust_not_after"],
|
||||
"entrust_not_after",
|
||||
input_name="entrust_not_after",
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
|
||||
@@ -159,6 +159,7 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
|
||||
def needs_regeneration(
|
||||
self,
|
||||
*,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
@@ -229,7 +230,7 @@ class EntrustCertificateProvider(CertificateProvider):
|
||||
return False
|
||||
|
||||
def create_backend(self, module: AnsibleModule) -> EntrustCertificateBackend:
|
||||
return EntrustCertificateBackend(module)
|
||||
return EntrustCertificateBackend(module=module)
|
||||
|
||||
|
||||
def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
@@ -281,3 +282,10 @@ def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"EntrustCertificateBackend",
|
||||
"EntrustCertificateProvider",
|
||||
"add_entrust_provider_to_argument_spec",
|
||||
)
|
||||
|
||||
@@ -70,7 +70,7 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
|
||||
|
||||
|
||||
class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
|
||||
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
|
||||
# content must be a bytes string
|
||||
self.module = module
|
||||
self.content = content
|
||||
@@ -158,11 +158,10 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
def get_info(
|
||||
self, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
|
||||
self, *, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
|
||||
) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
self.cert = load_certificate(
|
||||
None,
|
||||
content=self.content,
|
||||
der_support_enabled=der_support_enabled,
|
||||
)
|
||||
@@ -204,7 +203,7 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
result["public_key"] = to_native(self._get_public_key_pem())
|
||||
|
||||
public_key_info = get_publickey_info(
|
||||
self.module,
|
||||
module=self.module,
|
||||
key=self._get_public_key_object(),
|
||||
prefer_one_fingerprint=prefer_one_fingerprint,
|
||||
)
|
||||
@@ -249,8 +248,10 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
"""Validate the supplied cert, using the cryptography backend"""
|
||||
|
||||
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
|
||||
super(CertificateInfoRetrievalCryptography, self).__init__(module, content)
|
||||
def __init__(self, *, module: GeneralAnsibleModule, content: bytes) -> None:
|
||||
super(CertificateInfoRetrievalCryptography, self).__init__(
|
||||
module=module, content=content
|
||||
)
|
||||
self.name_encoding = module.params.get("name_encoding", "ignore")
|
||||
|
||||
def _get_der_bytes(self) -> bytes:
|
||||
@@ -465,16 +466,22 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
|
||||
|
||||
def get_certificate_info(
|
||||
module: GeneralAnsibleModule, content: bytes, prefer_one_fingerprint: bool = False
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
prefer_one_fingerprint: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
info = CertificateInfoRetrievalCryptography(module, content)
|
||||
info = CertificateInfoRetrievalCryptography(module=module, content=content)
|
||||
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
|
||||
|
||||
|
||||
def select_backend(
|
||||
module: GeneralAnsibleModule, content: bytes
|
||||
*, module: GeneralAnsibleModule, content: bytes
|
||||
) -> CertificateInfoRetrieval:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return CertificateInfoRetrievalCryptography(module, content)
|
||||
return CertificateInfoRetrievalCryptography(module=module, content=content)
|
||||
|
||||
|
||||
__all__ = ("CertificateInfoRetrieval", "get_certificate_info", "select_backend")
|
||||
|
||||
@@ -62,8 +62,8 @@ except ImportError:
|
||||
|
||||
|
||||
class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(OwnCACertificateBackendCryptography, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(OwnCACertificateBackendCryptography, self).__init__(module=module)
|
||||
|
||||
self.create_subject_key_identifier: t.Literal[
|
||||
"create_if_not_provided", "always_create", "never_create"
|
||||
@@ -73,12 +73,12 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
]
|
||||
self.notBefore = get_relative_time_option(
|
||||
module.params["ownca_not_before"],
|
||||
"ownca_not_before",
|
||||
input_name="ownca_not_before",
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
self.notAfter = get_relative_time_option(
|
||||
module.params["ownca_not_after"],
|
||||
"ownca_not_after",
|
||||
input_name="ownca_not_after",
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
self.digest = select_message_digest(module.params["ownca_digest"])
|
||||
@@ -220,6 +220,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
def needs_regeneration(
|
||||
self,
|
||||
*,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
@@ -233,7 +234,8 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
# Check whether certificate is signed by CA certificate
|
||||
if not cryptography_verify_certificate_signature(
|
||||
self.existing_certificate, self.ca_cert.public_key()
|
||||
certificate=self.existing_certificate,
|
||||
signer_public_key=self.ca_cert.public_key(),
|
||||
):
|
||||
return True
|
||||
|
||||
@@ -270,9 +272,9 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
return False
|
||||
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
|
||||
result = super(OwnCACertificateBackendCryptography, self).dump(
|
||||
include_certificate
|
||||
include_certificate=include_certificate
|
||||
)
|
||||
result.update(
|
||||
{
|
||||
@@ -339,7 +341,7 @@ class OwnCACertificateProvider(CertificateProvider):
|
||||
def create_backend(
|
||||
self, module: AnsibleModule
|
||||
) -> OwnCACertificateBackendCryptography:
|
||||
return OwnCACertificateBackendCryptography(module)
|
||||
return OwnCACertificateBackendCryptography(module=module)
|
||||
|
||||
|
||||
def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
@@ -369,3 +371,10 @@ def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
["ownca_privatekey_path", "ownca_privatekey_content"],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"OwnCACertificateBackendCryptography",
|
||||
"OwnCACertificateProvider",
|
||||
"add_ownca_provider_to_argument_spec",
|
||||
)
|
||||
|
||||
@@ -58,20 +58,20 @@ except ImportError:
|
||||
class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
privatekey: CertificateIssuerPrivateKeyTypes
|
||||
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SelfSignedCertificateBackendCryptography, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(SelfSignedCertificateBackendCryptography, self).__init__(module=module)
|
||||
|
||||
self.create_subject_key_identifier: t.Literal[
|
||||
"create_if_not_provided", "always_create", "never_create"
|
||||
] = module.params["selfsigned_create_subject_key_identifier"]
|
||||
self.notBefore = get_relative_time_option(
|
||||
module.params["selfsigned_not_before"],
|
||||
"selfsigned_not_before",
|
||||
input_name="selfsigned_not_before",
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
self.notAfter = get_relative_time_option(
|
||||
module.params["selfsigned_not_after"],
|
||||
"selfsigned_not_after",
|
||||
input_name="selfsigned_not_after",
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
self.digest = select_message_digest(module.params["selfsigned_digest"])
|
||||
@@ -162,6 +162,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
def needs_regeneration(
|
||||
self,
|
||||
*,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
@@ -177,15 +178,16 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
# Check whether certificate is signed by private key
|
||||
if not cryptography_verify_certificate_signature(
|
||||
self.existing_certificate, self.privatekey.public_key()
|
||||
certificate=self.existing_certificate,
|
||||
signer_public_key=self.privatekey.public_key(),
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
def dump(self, *, include_certificate: bool) -> dict[str, t.Any]:
|
||||
result = super(SelfSignedCertificateBackendCryptography, self).dump(
|
||||
include_certificate
|
||||
include_certificate=include_certificate
|
||||
)
|
||||
|
||||
if self.module.check_mode:
|
||||
@@ -239,7 +241,7 @@ class SelfSignedCertificateProvider(CertificateProvider):
|
||||
def create_backend(
|
||||
self, module: AnsibleModule
|
||||
) -> SelfSignedCertificateBackendCryptography:
|
||||
return SelfSignedCertificateBackendCryptography(module)
|
||||
return SelfSignedCertificateBackendCryptography(module=module)
|
||||
|
||||
|
||||
def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
@@ -261,3 +263,10 @@ def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> Non
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"SelfSignedCertificateBackendCryptography",
|
||||
"SelfSignedCertificateProvider",
|
||||
"add_selfsigned_provider_to_argument_spec",
|
||||
)
|
||||
|
||||
@@ -55,6 +55,7 @@ except ImportError:
|
||||
class CRLInfoRetrieval:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
list_revoked_certificates: bool = True,
|
||||
@@ -113,12 +114,20 @@ class CRLInfoRetrieval:
|
||||
|
||||
|
||||
def get_crl_info(
|
||||
module: GeneralAnsibleModule, content: bytes, list_revoked_certificates: bool = True
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
list_revoked_certificates: bool = True,
|
||||
) -> dict[str, t.Any]:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
info = CRLInfoRetrieval(
|
||||
module, content, list_revoked_certificates=list_revoked_certificates
|
||||
module=module,
|
||||
content=content,
|
||||
list_revoked_certificates=list_revoked_certificates,
|
||||
)
|
||||
return info.get_info()
|
||||
|
||||
|
||||
__all__ = ("CRLInfoRetrieval", "get_crl_info")
|
||||
|
||||
@@ -87,7 +87,7 @@ class CertificateSigningRequestError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.digest: str = module.params["digest"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
@@ -158,7 +158,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
try:
|
||||
if module.params["subject"]:
|
||||
self.subject = self.subject + parse_name_field(
|
||||
module.params["subject"], "subject"
|
||||
module.params["subject"], name_field_name="subject"
|
||||
)
|
||||
if module.params["subject_ordered"]:
|
||||
if self.subject:
|
||||
@@ -166,7 +166,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
"subject_ordered cannot be combined with any other subject field"
|
||||
)
|
||||
self.subject = parse_ordered_name_field(
|
||||
module.params["subject_ordered"], "subject_ordered"
|
||||
module.params["subject_ordered"], name_field_name="subject_ordered"
|
||||
)
|
||||
self.ordered_subject = True
|
||||
except ValueError as exc:
|
||||
@@ -205,16 +205,16 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
self.existing_csr: cryptography.x509.CertificateSigningRequest | None = None
|
||||
self.existing_csr_bytes: bytes | None = None
|
||||
|
||||
self.diff_before = self._get_info(None)
|
||||
self.diff_after = self._get_info(None)
|
||||
self.diff_before = self._get_info(data=None)
|
||||
self.diff_after = self._get_info(data=None)
|
||||
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
def _get_info(self, *, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return {}
|
||||
try:
|
||||
result = get_csr_info(
|
||||
self.module,
|
||||
data,
|
||||
module=self.module,
|
||||
content=data,
|
||||
validate_signature=False,
|
||||
prefer_one_fingerprint=True,
|
||||
)
|
||||
@@ -231,10 +231,12 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
def get_csr_data(self) -> bytes:
|
||||
"""Return bytes for self.csr."""
|
||||
|
||||
def set_existing(self, csr_bytes: bytes | None) -> None:
|
||||
def set_existing(self, *, csr_bytes: bytes | None) -> None:
|
||||
"""Set existing CSR bytes. None indicates that the CSR does not exist."""
|
||||
self.existing_csr_bytes = csr_bytes
|
||||
self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes)
|
||||
self.diff_after = self.diff_before = self._get_info(
|
||||
data=self.existing_csr_bytes
|
||||
)
|
||||
|
||||
def has_existing(self) -> bool:
|
||||
"""Query whether an existing CSR is/has been there."""
|
||||
@@ -263,7 +265,6 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
return True
|
||||
try:
|
||||
self.existing_csr = load_certificate_request(
|
||||
None,
|
||||
content=self.existing_csr_bytes,
|
||||
)
|
||||
except Exception:
|
||||
@@ -271,7 +272,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
self._ensure_private_key_loaded()
|
||||
return not self._check_csr()
|
||||
|
||||
def dump(self, include_csr: bool) -> dict[str, t.Any]:
|
||||
def dump(self, *, include_csr: bool) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result: dict[str, t.Any] = {
|
||||
"privatekey": self.privatekey_path,
|
||||
@@ -288,7 +289,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
csr_bytes = self.existing_csr_bytes
|
||||
if self.csr is not None:
|
||||
csr_bytes = self.get_csr_data()
|
||||
self.diff_after = self._get_info(csr_bytes)
|
||||
self.diff_after = self._get_info(data=csr_bytes)
|
||||
if include_csr:
|
||||
# Store result
|
||||
result["csr"] = csr_bytes.decode("utf-8") if csr_bytes else None
|
||||
@@ -301,7 +302,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
def parse_crl_distribution_points(
|
||||
module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
|
||||
*, module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
|
||||
) -> list[cryptography.x509.DistributionPoint]:
|
||||
result = []
|
||||
for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
|
||||
@@ -314,7 +315,7 @@ def parse_crl_distribution_points(
|
||||
if not parse_crl_distribution_point["full_name"]:
|
||||
raise OpenSSLObjectError("full_name must not be empty")
|
||||
full_name = [
|
||||
cryptography_get_name(name, "full name")
|
||||
cryptography_get_name(name, what="full name")
|
||||
for name in parse_crl_distribution_point["full_name"]
|
||||
]
|
||||
if parse_crl_distribution_point["relative_name"] is not None:
|
||||
@@ -327,7 +328,7 @@ def parse_crl_distribution_points(
|
||||
if not parse_crl_distribution_point["crl_issuer"]:
|
||||
raise OpenSSLObjectError("crl_issuer must not be empty")
|
||||
crl_issuer = [
|
||||
cryptography_get_name(name, "CRL issuer")
|
||||
cryptography_get_name(name, what="CRL issuer")
|
||||
for name in parse_crl_distribution_point["crl_issuer"]
|
||||
]
|
||||
if parse_crl_distribution_point["reasons"] is not None:
|
||||
@@ -352,8 +353,10 @@ def parse_crl_distribution_points(
|
||||
|
||||
# Implementation with using cryptography
|
||||
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CertificateSigningRequestCryptographyBackend, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(CertificateSigningRequestCryptographyBackend, self).__init__(
|
||||
module=module
|
||||
)
|
||||
if self.version != 1:
|
||||
module.warn(
|
||||
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
|
||||
@@ -364,7 +367,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
]
|
||||
if crl_distribution_points:
|
||||
self.crl_distribution_points = parse_crl_distribution_points(
|
||||
module, crl_distribution_points
|
||||
module=module, crl_distribution_points=crl_distribution_points
|
||||
)
|
||||
|
||||
def generate_csr(self) -> None:
|
||||
@@ -431,12 +434,16 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
csr = csr.add_extension(
|
||||
cryptography.x509.NameConstraints(
|
||||
[
|
||||
cryptography_get_name(name, "name constraints permitted")
|
||||
cryptography_get_name(
|
||||
name, what="name constraints permitted"
|
||||
)
|
||||
for name in self.name_constraints_permitted
|
||||
]
|
||||
or None,
|
||||
[
|
||||
cryptography_get_name(name, "name constraints excluded")
|
||||
cryptography_get_name(
|
||||
name, what="name constraints excluded"
|
||||
)
|
||||
for name in self.name_constraints_excluded
|
||||
]
|
||||
or None,
|
||||
@@ -473,7 +480,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
issuers = None
|
||||
if self.authority_cert_issuer is not None:
|
||||
issuers = [
|
||||
cryptography_get_name(n, "authority cert issuer")
|
||||
cryptography_get_name(n, what="authority cert issuer")
|
||||
for n in self.authority_cert_issuer
|
||||
]
|
||||
csr = csr.add_extension(
|
||||
@@ -679,11 +686,15 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else []
|
||||
)
|
||||
nc_perm = [
|
||||
to_text(cryptography_get_name(altname, "name constraints permitted"))
|
||||
to_text(
|
||||
cryptography_get_name(altname, what="name constraints permitted")
|
||||
)
|
||||
for altname in self.name_constraints_permitted
|
||||
]
|
||||
nc_excl = [
|
||||
to_text(cryptography_get_name(altname, "name constraints excluded"))
|
||||
to_text(
|
||||
cryptography_get_name(altname, what="name constraints excluded")
|
||||
)
|
||||
for altname in self.name_constraints_excluded
|
||||
]
|
||||
if set(nc_perm) != set(current_nc_perm) or set(nc_excl) != set(
|
||||
@@ -731,7 +742,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
csr_aci = None
|
||||
if self.authority_cert_issuer is not None:
|
||||
aci = [
|
||||
to_text(cryptography_get_name(n, "authority cert issuer"))
|
||||
to_text(cryptography_get_name(n, what="authority cert issuer"))
|
||||
for n in self.authority_cert_issuer
|
||||
]
|
||||
if ext.value.authority_cert_issuer is not None:
|
||||
@@ -798,7 +809,7 @@ def select_backend(
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return CertificateSigningRequestCryptographyBackend(module)
|
||||
return CertificateSigningRequestCryptographyBackend(module=module)
|
||||
|
||||
|
||||
def get_csr_argument_spec() -> ArgumentSpec:
|
||||
@@ -903,3 +914,11 @@ def get_csr_argument_spec() -> ArgumentSpec:
|
||||
["privatekey_path", "privatekey_content"],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"CertificateSigningRequestError",
|
||||
"CertificateSigningRequestBackend",
|
||||
"select_backend",
|
||||
"get_csr_argument_spec",
|
||||
)
|
||||
|
||||
@@ -62,7 +62,7 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
|
||||
|
||||
class CSRInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(
|
||||
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
|
||||
self, *, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
|
||||
) -> None:
|
||||
self.module = module
|
||||
self.content = content
|
||||
@@ -122,10 +122,9 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def _is_signature_valid(self) -> bool:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
self.csr = load_certificate_request(
|
||||
None,
|
||||
content=self.content,
|
||||
)
|
||||
|
||||
@@ -156,7 +155,7 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
result["public_key"] = to_native(self._get_public_key_pem())
|
||||
|
||||
public_key_info = get_publickey_info(
|
||||
self.module,
|
||||
module=self.module,
|
||||
key=self._get_public_key_object(),
|
||||
prefer_one_fingerprint=prefer_one_fingerprint,
|
||||
)
|
||||
@@ -196,10 +195,10 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
"""Validate the supplied CSR, using the cryptography backend"""
|
||||
|
||||
def __init__(
|
||||
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
|
||||
self, *, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
|
||||
) -> None:
|
||||
super(CSRInfoRetrievalCryptography, self).__init__(
|
||||
module, content, validate_signature
|
||||
module=module, content=content, validate_signature=validate_signature
|
||||
)
|
||||
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get(
|
||||
"name_encoding", "ignore"
|
||||
@@ -372,23 +371,27 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
|
||||
|
||||
def get_csr_info(
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
validate_signature: bool = True,
|
||||
prefer_one_fingerprint: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
info = CSRInfoRetrievalCryptography(
|
||||
module, content, validate_signature=validate_signature
|
||||
module=module, content=content, validate_signature=validate_signature
|
||||
)
|
||||
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
|
||||
|
||||
|
||||
def select_backend(
|
||||
module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
|
||||
*, module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
|
||||
) -> CSRInfoRetrieval:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return CSRInfoRetrievalCryptography(
|
||||
module, content, validate_signature=validate_signature
|
||||
module=module, content=content, validate_signature=validate_signature
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("CSRInfoRetrieval", "get_csr_info", "select_backend")
|
||||
|
||||
@@ -80,7 +80,7 @@ class PrivateKeyError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module: GeneralAnsibleModule) -> None:
|
||||
def __init__(self, *, module: GeneralAnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.type: t.Literal[
|
||||
"DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"
|
||||
@@ -104,18 +104,18 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
self.existing_private_key: PrivateKeyTypes | None = None
|
||||
self.existing_private_key_bytes: bytes | None = None
|
||||
|
||||
self.diff_before = self._get_info(None)
|
||||
self.diff_after = self._get_info(None)
|
||||
self.diff_before = self._get_info(data=None)
|
||||
self.diff_after = self._get_info(data=None)
|
||||
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
def _get_info(self, *, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return {}
|
||||
result: dict[str, t.Any] = {"can_parse_key": False}
|
||||
try:
|
||||
result.update(
|
||||
get_privatekey_info(
|
||||
self.module,
|
||||
data,
|
||||
module=self.module,
|
||||
content=data,
|
||||
passphrase=self.passphrase,
|
||||
return_private_key_data=False,
|
||||
prefer_one_fingerprint=True,
|
||||
@@ -148,11 +148,11 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
def get_private_key_data(self) -> bytes:
|
||||
"""Return bytes for self.private_key."""
|
||||
|
||||
def set_existing(self, privatekey_bytes: bytes | None) -> None:
|
||||
def set_existing(self, *, privatekey_bytes: bytes | None) -> None:
|
||||
"""Set existing private key bytes. None indicates that the key does not exist."""
|
||||
self.existing_private_key_bytes = privatekey_bytes
|
||||
self.diff_after = self.diff_before = self._get_info(
|
||||
self.existing_private_key_bytes
|
||||
data=self.existing_private_key_bytes
|
||||
)
|
||||
|
||||
def has_existing(self) -> bool:
|
||||
@@ -235,7 +235,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
return get_fingerprint_of_privatekey(self.existing_private_key)
|
||||
return None
|
||||
|
||||
def dump(self, include_key: bool) -> dict[str, t.Any]:
|
||||
def dump(self, *, include_key: bool) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
if not self.private_key:
|
||||
@@ -255,7 +255,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
pk_bytes = self.existing_private_key_bytes
|
||||
if self.private_key is not None:
|
||||
pk_bytes = self.get_private_key_data()
|
||||
self.diff_after = self._get_info(pk_bytes)
|
||||
self.diff_after = self._get_info(data=pk_bytes)
|
||||
if include_key:
|
||||
# Store result
|
||||
if pk_bytes:
|
||||
@@ -276,6 +276,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
class _Curve:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
ectype: str,
|
||||
deprecated: bool,
|
||||
@@ -285,7 +286,7 @@ class _Curve:
|
||||
self.deprecated = deprecated
|
||||
|
||||
def _get_ec_class(
|
||||
self, module: GeneralAnsibleModule
|
||||
self, *, module: GeneralAnsibleModule
|
||||
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
|
||||
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore
|
||||
if ecclass is None:
|
||||
@@ -295,17 +296,18 @@ class _Curve:
|
||||
return ecclass
|
||||
|
||||
def create(
|
||||
self, size: int, module: GeneralAnsibleModule
|
||||
self, *, size: int, module: GeneralAnsibleModule
|
||||
) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
|
||||
ecclass = self._get_ec_class(module)
|
||||
ecclass = self._get_ec_class(module=module)
|
||||
return ecclass()
|
||||
|
||||
def verify(
|
||||
self,
|
||||
*,
|
||||
privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
|
||||
module: GeneralAnsibleModule,
|
||||
) -> bool:
|
||||
ecclass = self._get_ec_class(module)
|
||||
ecclass = self._get_ec_class(module=module)
|
||||
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
|
||||
|
||||
|
||||
@@ -316,6 +318,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
self,
|
||||
name: str,
|
||||
ectype: str,
|
||||
*,
|
||||
deprecated: bool = False,
|
||||
) -> None:
|
||||
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
|
||||
@@ -575,7 +578,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
if self.curve not in self.curves:
|
||||
return False
|
||||
return self.curves[self.curve].verify(
|
||||
self.existing_private_key, module=self.module
|
||||
privatekey=self.existing_private_key, module=self.module
|
||||
)
|
||||
|
||||
return False
|
||||
@@ -596,7 +599,7 @@ def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return PrivateKeyCryptographyBackend(module)
|
||||
return PrivateKeyCryptographyBackend(module=module)
|
||||
|
||||
|
||||
def get_privatekey_argument_spec() -> ArgumentSpec:
|
||||
@@ -661,3 +664,11 @@ def get_privatekey_argument_spec() -> ArgumentSpec:
|
||||
("type", "ECC", ["curve"]),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"PrivateKeyError",
|
||||
"PrivateKeyBackend",
|
||||
"select_backend",
|
||||
"get_privatekey_argument_spec",
|
||||
)
|
||||
|
||||
@@ -69,7 +69,7 @@ class PrivateKeyError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.src_path: str | None = module.params["src_path"]
|
||||
self.src_content: str | None = module.params["src_content"]
|
||||
@@ -79,7 +79,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
|
||||
self.src_private_key: PrivateKeyTypes | None = None
|
||||
if self.src_path is not None:
|
||||
self.src_private_key_bytes = load_file(self.src_path, module)
|
||||
self.src_private_key_bytes = load_file(path=self.src_path, module=module)
|
||||
else:
|
||||
if self.src_content is None:
|
||||
raise AssertionError("src_content is None")
|
||||
@@ -93,7 +93,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
"""Return bytes for self.src_private_key in output format."""
|
||||
pass
|
||||
|
||||
def set_existing_destination(self, privatekey_bytes: bytes | None) -> None:
|
||||
def set_existing_destination(self, *, privatekey_bytes: bytes | None) -> None:
|
||||
"""Set existing private key bytes. None indicates that the key does not exist."""
|
||||
self.dest_private_key_bytes = privatekey_bytes
|
||||
|
||||
@@ -104,6 +104,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def _load_private_key(
|
||||
self,
|
||||
*,
|
||||
data: bytes,
|
||||
passphrase: str | None,
|
||||
current_hint: PrivateKeyTypes | None = None,
|
||||
@@ -113,7 +114,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
def needs_conversion(self) -> bool:
|
||||
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
|
||||
dummy, self.src_private_key = self._load_private_key(
|
||||
self.src_private_key_bytes, self.src_passphrase
|
||||
data=self.src_private_key_bytes, passphrase=self.src_passphrase
|
||||
)
|
||||
|
||||
if not self.has_existing_destination():
|
||||
@@ -122,8 +123,8 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
|
||||
try:
|
||||
format, self.dest_private_key = self._load_private_key(
|
||||
self.dest_private_key_bytes,
|
||||
self.dest_passphrase,
|
||||
data=self.dest_private_key_bytes,
|
||||
passphrase=self.dest_passphrase,
|
||||
current_hint=self.src_private_key,
|
||||
)
|
||||
except Exception:
|
||||
@@ -140,7 +141,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
|
||||
# Implementation with using cryptography
|
||||
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module)
|
||||
|
||||
def get_private_key_data(self) -> bytes:
|
||||
@@ -201,6 +202,7 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
|
||||
|
||||
def _load_private_key(
|
||||
self,
|
||||
*,
|
||||
data: bytes,
|
||||
passphrase: str | None,
|
||||
current_hint: PrivateKeyTypes | None = None,
|
||||
@@ -276,7 +278,7 @@ def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return PrivateKeyConvertCryptographyBackend(module)
|
||||
return PrivateKeyConvertCryptographyBackend(module=module)
|
||||
|
||||
|
||||
def get_privatekey_argument_spec() -> ArgumentSpec:
|
||||
@@ -295,3 +297,11 @@ def get_privatekey_argument_spec() -> ArgumentSpec:
|
||||
["src_path", "src_content"],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"PrivateKeyError",
|
||||
"PrivateKeyConvertBackend",
|
||||
"select_backend",
|
||||
"get_privatekey_argument_spec",
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ SIGNATURE_TEST_DATA = b"1234"
|
||||
|
||||
|
||||
def _get_cryptography_private_key_info(
|
||||
key: PrivateKeyTypes, need_private_key_data: bool = False
|
||||
key: PrivateKeyTypes, *, need_private_key_data: bool = False
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
key_type, key_public_data = _get_cryptography_public_key_info(key.public_key())
|
||||
key_private_data: dict[str, t.Any] = {}
|
||||
@@ -84,7 +84,7 @@ def _get_cryptography_private_key_info(
|
||||
|
||||
|
||||
def _check_dsa_consistency(
|
||||
key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
*, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
) -> bool | None:
|
||||
# Get parameters
|
||||
p: int | None = key_public_data.get("p")
|
||||
@@ -112,10 +112,10 @@ def _check_dsa_consistency(
|
||||
if (p - 1) % q != 0:
|
||||
return False
|
||||
# Check that g**q mod p == 1
|
||||
if binary_exp_mod(g, q, p) != 1:
|
||||
if binary_exp_mod(g, q, m=p) != 1:
|
||||
return False
|
||||
# Check whether g**x mod p == y
|
||||
if binary_exp_mod(g, x, p) != y:
|
||||
if binary_exp_mod(g, x, m=p) != y:
|
||||
return False
|
||||
# Check (quickly) whether p or q are not primes
|
||||
if quick_is_not_prime(q) or quick_is_not_prime(p):
|
||||
@@ -125,6 +125,7 @@ def _check_dsa_consistency(
|
||||
|
||||
def _is_cryptography_key_consistent(
|
||||
key: PrivateKeyTypes,
|
||||
*,
|
||||
key_public_data: dict[str, t.Any],
|
||||
key_private_data: dict[str, t.Any],
|
||||
warn_func: t.Callable[[str], None] | None = None,
|
||||
@@ -135,7 +136,9 @@ def _is_cryptography_key_consistent(
|
||||
if backend is not None:
|
||||
return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
|
||||
result = _check_dsa_consistency(key_public_data, key_private_data)
|
||||
result = _check_dsa_consistency(
|
||||
key_public_data=key_public_data, key_private_data=key_private_data
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
signature = key.sign(
|
||||
@@ -191,14 +194,14 @@ def _is_cryptography_key_consistent(
|
||||
|
||||
|
||||
class PrivateKeyConsistencyError(OpenSSLObjectError):
|
||||
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
|
||||
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
|
||||
super(PrivateKeyConsistencyError, self).__init__(msg)
|
||||
self.error_message = msg
|
||||
self.result = result
|
||||
|
||||
|
||||
class PrivateKeyParseError(OpenSSLObjectError):
|
||||
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
|
||||
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
|
||||
super(PrivateKeyParseError, self).__init__(msg)
|
||||
self.error_message = msg
|
||||
self.result = result
|
||||
@@ -207,6 +210,7 @@ class PrivateKeyParseError(OpenSSLObjectError):
|
||||
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
passphrase: str | None = None,
|
||||
@@ -220,22 +224,22 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
self.check_consistency = check_consistency
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
def _get_public_key(self, *, binary: bool) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_key_info(
|
||||
self, need_private_key_data: bool = False
|
||||
self, *, need_private_key_data: bool = False
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _is_key_consistent(
|
||||
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
) -> bool | None:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {
|
||||
"can_parse_key": False,
|
||||
"key_is_consistent": None,
|
||||
@@ -253,7 +257,7 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
)
|
||||
result["can_parse_key"] = True
|
||||
except OpenSSLObjectError as exc:
|
||||
raise PrivateKeyParseError(str(exc), result)
|
||||
raise PrivateKeyParseError(str(exc), result=result)
|
||||
|
||||
result["public_key"] = to_native(self._get_public_key(binary=False))
|
||||
pk = self._get_public_key(binary=True)
|
||||
@@ -273,7 +277,7 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
|
||||
if self.check_consistency:
|
||||
result["key_is_consistent"] = self._is_key_consistent(
|
||||
key_public_data, key_private_data
|
||||
key_public_data=key_public_data, key_private_data=key_private_data
|
||||
)
|
||||
if result["key_is_consistent"] is False:
|
||||
# Only fail when it is False, to avoid to fail on None (which means "we do not know")
|
||||
@@ -281,40 +285,46 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
"Private key is not consistent! (See "
|
||||
"https://blog.hboeck.de/archives/888-How-I-tricked-Symantec-with-a-Fake-Private-Key.html)"
|
||||
)
|
||||
raise PrivateKeyConsistencyError(msg, result)
|
||||
raise PrivateKeyConsistencyError(msg, result=result)
|
||||
return result
|
||||
|
||||
|
||||
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
|
||||
"""Validate the supplied private key, using the cryptography backend"""
|
||||
|
||||
def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None:
|
||||
def __init__(
|
||||
self, *, module: GeneralAnsibleModule, content: bytes, **kwargs
|
||||
) -> None:
|
||||
super(PrivateKeyInfoRetrievalCryptography, self).__init__(
|
||||
module, content, **kwargs
|
||||
module=module, content=content, **kwargs
|
||||
)
|
||||
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
def _get_public_key(self, *, binary: bool) -> bytes:
|
||||
return self.key.public_key().public_bytes(
|
||||
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def _get_key_info(
|
||||
self, need_private_key_data: bool = False
|
||||
self, *, need_private_key_data: bool = False
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return _get_cryptography_private_key_info(
|
||||
self.key, need_private_key_data=need_private_key_data
|
||||
)
|
||||
|
||||
def _is_key_consistent(
|
||||
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
) -> bool | None:
|
||||
return _is_cryptography_key_consistent(
|
||||
self.key, key_public_data, key_private_data, warn_func=self.module.warn
|
||||
self.key,
|
||||
key_public_data=key_public_data,
|
||||
key_private_data=key_private_data,
|
||||
warn_func=self.module.warn,
|
||||
)
|
||||
|
||||
|
||||
def get_privatekey_info(
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
passphrase: str | None = None,
|
||||
@@ -322,8 +332,8 @@ def get_privatekey_info(
|
||||
prefer_one_fingerprint: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
info = PrivateKeyInfoRetrievalCryptography(
|
||||
module,
|
||||
content,
|
||||
module=module,
|
||||
content=content,
|
||||
passphrase=passphrase,
|
||||
return_private_key_data=return_private_key_data,
|
||||
)
|
||||
@@ -331,6 +341,7 @@ def get_privatekey_info(
|
||||
|
||||
|
||||
def select_backend(
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
passphrase: str | None = None,
|
||||
@@ -341,9 +352,18 @@ def select_backend(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return PrivateKeyInfoRetrievalCryptography(
|
||||
module,
|
||||
content,
|
||||
module=module,
|
||||
content=content,
|
||||
passphrase=passphrase,
|
||||
return_private_key_data=return_private_key_data,
|
||||
check_consistency=check_consistency,
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"PrivateKeyConsistencyError",
|
||||
"PrivateKeyParseError",
|
||||
"PrivateKeyInfoRetrieval",
|
||||
"get_privatekey_info",
|
||||
"select_backend",
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ def _get_cryptography_public_key_info(
|
||||
|
||||
|
||||
class PublicKeyParseError(OpenSSLObjectError):
|
||||
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
|
||||
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
|
||||
super(PublicKeyParseError, self).__init__(msg)
|
||||
self.error_message = msg
|
||||
self.result = result
|
||||
@@ -108,6 +108,7 @@ class PublicKeyParseError(OpenSSLObjectError):
|
||||
class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
@@ -125,13 +126,13 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
if self.key is None:
|
||||
try:
|
||||
self.key = load_publickey(content=self.content)
|
||||
except OpenSSLObjectError as e:
|
||||
raise PublicKeyParseError(str(e), {})
|
||||
raise PublicKeyParseError(str(e), result={})
|
||||
|
||||
pk = self._get_public_key(binary=True)
|
||||
result["fingerprints"] = (
|
||||
@@ -151,12 +152,13 @@ class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
) -> None:
|
||||
super(PublicKeyInfoRetrievalCryptography, self).__init__(
|
||||
module, content=content, key=key
|
||||
module=module, content=content, key=key
|
||||
)
|
||||
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
@@ -174,16 +176,18 @@ class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
|
||||
|
||||
|
||||
def get_publickey_info(
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
prefer_one_fingerprint: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
|
||||
info = PublicKeyInfoRetrievalCryptography(module=module, content=content, key=key)
|
||||
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
|
||||
|
||||
|
||||
def select_backend(
|
||||
*,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
@@ -191,4 +195,12 @@ def select_backend(
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
|
||||
return PublicKeyInfoRetrievalCryptography(module=module, content=content, key=key)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"PublicKeyParseError",
|
||||
"PublicKeyInfoRetrieval",
|
||||
"get_publickey_info",
|
||||
"select_backend",
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ PKCS8_PRIVATEKEY_NAMES = ("PRIVATE KEY", "ENCRYPTED PRIVATE KEY")
|
||||
PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY"
|
||||
|
||||
|
||||
def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
|
||||
def identify_pem_format(content: bytes, *, encoding: str = "utf-8") -> bool:
|
||||
"""Given the contents of a binary file, tests whether this could be a PEM file."""
|
||||
try:
|
||||
first_pem = extract_first_pem(content.decode(encoding))
|
||||
@@ -36,7 +36,7 @@ def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
|
||||
|
||||
|
||||
def identify_private_key_format(
|
||||
content: bytes, encoding: str = "utf-8"
|
||||
content: bytes, *, encoding: str = "utf-8"
|
||||
) -> t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]:
|
||||
"""Given the contents of a private key file, identifies its format."""
|
||||
# See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85
|
||||
@@ -66,7 +66,7 @@ def identify_private_key_format(
|
||||
return "raw"
|
||||
|
||||
|
||||
def split_pem_list(text: str, keep_inbetween: bool = False) -> list[str]:
|
||||
def split_pem_list(text: str, *, keep_inbetween: bool = False) -> list[str]:
|
||||
"""
|
||||
Split concatenated PEM objects into a list of strings, where each is one PEM object.
|
||||
"""
|
||||
@@ -94,7 +94,7 @@ def extract_first_pem(text: str) -> str | None:
|
||||
return all_pems[0]
|
||||
|
||||
|
||||
def _extract_type(line: str, start: str = PEM_START) -> str | None:
|
||||
def _extract_type(line: str, *, start: str = PEM_START) -> str | None:
|
||||
if not line.startswith(start):
|
||||
return None
|
||||
if not line.endswith(PEM_END):
|
||||
@@ -102,7 +102,7 @@ def _extract_type(line: str, start: str = PEM_START) -> str | None:
|
||||
return line[len(start) : -len(PEM_END)]
|
||||
|
||||
|
||||
def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
|
||||
def extract_pem(content: str, *, strict: bool = False) -> tuple[str, str]:
|
||||
lines = content.splitlines()
|
||||
if len(lines) < 3:
|
||||
raise ValueError(f"PEM must have at least 3 lines, have only {len(lines)}")
|
||||
@@ -125,3 +125,17 @@ def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
|
||||
f"Last line has length {len(lines[-2])}, should be in (0, 64]"
|
||||
)
|
||||
return header_type, "".join(lines[1:-1])
|
||||
|
||||
|
||||
__all__ = (
|
||||
"PEM_START",
|
||||
"PEM_END_START",
|
||||
"PEM_END",
|
||||
"PKCS8_PRIVATEKEY_NAMES",
|
||||
"PKCS1_PRIVATEKEY_SUFFIX",
|
||||
"identify_pem_format",
|
||||
"identify_private_key_format",
|
||||
"split_pem_list",
|
||||
"extract_first_pem",
|
||||
"extract_pem",
|
||||
)
|
||||
|
||||
@@ -63,7 +63,9 @@ PREFERRED_FINGERPRINTS = (
|
||||
)
|
||||
|
||||
|
||||
def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[str, str]:
|
||||
def get_fingerprint_of_bytes(
|
||||
source: bytes, *, prefer_one: bool = False
|
||||
) -> dict[str, str]:
|
||||
"""Generate the fingerprint of the given bytes."""
|
||||
|
||||
fingerprint = {}
|
||||
@@ -107,7 +109,7 @@ def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[st
|
||||
|
||||
|
||||
def get_fingerprint_of_privatekey(
|
||||
privatekey: PrivateKeyTypes, prefer_one: bool = False
|
||||
privatekey: PrivateKeyTypes, *, prefer_one: bool = False
|
||||
) -> dict[str, str]:
|
||||
"""Generate the fingerprint of the public key."""
|
||||
|
||||
@@ -119,6 +121,7 @@ def get_fingerprint_of_privatekey(
|
||||
|
||||
|
||||
def get_fingerprint(
|
||||
*,
|
||||
path: os.PathLike | str | None = None,
|
||||
passphrase: str | bytes | None = None,
|
||||
content: bytes | None = None,
|
||||
@@ -137,6 +140,7 @@ def get_fingerprint(
|
||||
|
||||
|
||||
def load_privatekey(
|
||||
*,
|
||||
path: os.PathLike | str | None = None,
|
||||
passphrase: str | bytes | None = None,
|
||||
check_passphrase: bool = True,
|
||||
@@ -219,7 +223,7 @@ def load_certificate_issuer_privatekey(
|
||||
|
||||
|
||||
def load_publickey(
|
||||
path: os.PathLike | str | None = None, content: bytes | None = None
|
||||
*, path: os.PathLike | str | None = None, content: bytes | None = None
|
||||
) -> PublicKeyTypes:
|
||||
if content is None:
|
||||
if path is None:
|
||||
@@ -237,6 +241,7 @@ def load_publickey(
|
||||
|
||||
|
||||
def load_certificate(
|
||||
*,
|
||||
path: os.PathLike | str | None = None,
|
||||
content: bytes | None = None,
|
||||
der_support_enabled: bool = False,
|
||||
@@ -266,7 +271,7 @@ def load_certificate(
|
||||
|
||||
|
||||
def load_certificate_request(
|
||||
path: os.PathLike | str | None = None, content: bytes | None = None
|
||||
*, path: os.PathLike | str | None = None, content: bytes | None = None
|
||||
) -> x509.CertificateSigningRequest:
|
||||
"""Load the specified certificate signing request."""
|
||||
try:
|
||||
@@ -287,6 +292,7 @@ def load_certificate_request(
|
||||
|
||||
def parse_name_field(
|
||||
input_dict: dict[str, list[str | bytes] | str | bytes],
|
||||
*,
|
||||
name_field_name: str | None = None,
|
||||
) -> list[tuple[str, str | bytes]]:
|
||||
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
|
||||
@@ -321,7 +327,9 @@ def parse_name_field(
|
||||
|
||||
|
||||
def parse_ordered_name_field(
|
||||
input_list: list[dict[str, list[str | bytes] | str | bytes]], name_field_name: str
|
||||
input_list: list[dict[str, list[str | bytes] | str | bytes]],
|
||||
*,
|
||||
name_field_name: str,
|
||||
) -> list[tuple[str, str | bytes]]:
|
||||
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
|
||||
|
||||
@@ -372,7 +380,7 @@ def select_message_digest(
|
||||
|
||||
class OpenSSLObject(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, path: str, state: str, force: bool, check_mode: bool) -> None:
|
||||
def __init__(self, *, path: str, state: str, force: bool, check_mode: bool) -> None:
|
||||
self.path = path
|
||||
self.state = state
|
||||
self.force = force
|
||||
@@ -380,7 +388,7 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
|
||||
self.changed = False
|
||||
self.check_mode = check_mode
|
||||
|
||||
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
|
||||
def check(self, module: AnsibleModule, *, perms_required: bool = True) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
|
||||
def _check_state() -> bool:
|
||||
@@ -420,3 +428,20 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
|
||||
raise OpenSSLObjectError(exc)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = (
|
||||
"get_fingerprint_of_bytes",
|
||||
"get_fingerprint_of_privatekey",
|
||||
"get_fingerprint",
|
||||
"load_privatekey",
|
||||
"load_certificate_privatekey",
|
||||
"load_certificate_issuer_privatekey",
|
||||
"load_publickey",
|
||||
"load_certificate",
|
||||
"load_certificate_request",
|
||||
"parse_name_field",
|
||||
"parse_ordered_name_field",
|
||||
"select_message_digest",
|
||||
"OpenSSLObject",
|
||||
)
|
||||
|
||||
@@ -385,3 +385,11 @@ def ECSClient(
|
||||
entrust_api_cert_key=entrust_api_cert_key,
|
||||
entrust_api_specification_path=entrust_api_specification_path,
|
||||
).client()
|
||||
|
||||
|
||||
__all__ = (
|
||||
"ecs_client_argument_spec",
|
||||
"SessionConfigurationException",
|
||||
"RestOperationException",
|
||||
"ECSClient",
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ class GPGError(Exception):
|
||||
class GPGRunner(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def run_command(
|
||||
self, command: list[str], check_rc: bool = True, data: bytes | None = None
|
||||
self, command: list[str], *, check_rc: bool = True, data: bytes | None = None
|
||||
) -> tuple[int, str, str]:
|
||||
"""
|
||||
Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``.
|
||||
@@ -34,7 +34,7 @@ class GPGRunner(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
|
||||
def get_fingerprint_from_stdout(stdout: str) -> str:
|
||||
def get_fingerprint_from_stdout(*, stdout: str) -> str:
|
||||
lines = stdout.splitlines(False)
|
||||
for line in lines:
|
||||
if line.startswith("fpr:"):
|
||||
@@ -47,7 +47,7 @@ def get_fingerprint_from_stdout(stdout: str) -> str:
|
||||
raise GPGError(f'Cannot extract fingerprint from stdout "{stdout}"')
|
||||
|
||||
|
||||
def get_fingerprint_from_file(gpg_runner: GPGRunner, path: str) -> str:
|
||||
def get_fingerprint_from_file(*, gpg_runner: GPGRunner, path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
raise GPGError(f"{path} does not exist")
|
||||
stdout = gpg_runner.run_command(
|
||||
@@ -61,10 +61,10 @@ def get_fingerprint_from_file(gpg_runner: GPGRunner, path: str) -> str:
|
||||
],
|
||||
check_rc=True,
|
||||
)[1]
|
||||
return get_fingerprint_from_stdout(stdout)
|
||||
return get_fingerprint_from_stdout(stdout=stdout)
|
||||
|
||||
|
||||
def get_fingerprint_from_bytes(gpg_runner: GPGRunner, content: bytes) -> str:
|
||||
def get_fingerprint_from_bytes(*, gpg_runner: GPGRunner, content: bytes) -> str:
|
||||
stdout = gpg_runner.run_command(
|
||||
[
|
||||
"--no-keyring",
|
||||
@@ -77,4 +77,13 @@ def get_fingerprint_from_bytes(gpg_runner: GPGRunner, content: bytes) -> str:
|
||||
data=content,
|
||||
check_rc=True,
|
||||
)[1]
|
||||
return get_fingerprint_from_stdout(stdout)
|
||||
return get_fingerprint_from_stdout(stdout=stdout)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"GPGError",
|
||||
"GPGRunner",
|
||||
"get_fingerprint_from_stdout",
|
||||
"get_fingerprint_from_file",
|
||||
"get_fingerprint_from_bytes",
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def load_file(path: str | os.PathLike, module: AnsibleModule | None = None) -> bytes:
|
||||
def load_file(*, path: str | os.PathLike, module: AnsibleModule | None = None) -> bytes:
|
||||
"""
|
||||
Load the file as a bytes string.
|
||||
"""
|
||||
@@ -31,6 +31,7 @@ def load_file(path: str | os.PathLike, module: AnsibleModule | None = None) -> b
|
||||
|
||||
|
||||
def load_file_if_exists(
|
||||
*,
|
||||
path: str | os.PathLike,
|
||||
module: AnsibleModule | None = None,
|
||||
ignore_errors: bool = False,
|
||||
@@ -62,6 +63,7 @@ def load_file_if_exists(
|
||||
|
||||
|
||||
def write_file(
|
||||
*,
|
||||
module: AnsibleModule,
|
||||
content: bytes,
|
||||
default_mode: str | int | None = None,
|
||||
@@ -117,3 +119,6 @@ def write_file(
|
||||
except Exception:
|
||||
pass
|
||||
module.fail_json(msg=f"Error while writing result: {e}")
|
||||
|
||||
|
||||
__all__ = ("load_file", "load_file_if_exists", "write_file")
|
||||
|
||||
@@ -99,7 +99,7 @@ def _restore_all_on_failure(
|
||||
|
||||
|
||||
class OpensshModule(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
|
||||
self.changed: bool = False
|
||||
@@ -210,6 +210,7 @@ class KeygenCommand:
|
||||
|
||||
def generate_certificate(
|
||||
self,
|
||||
*,
|
||||
certificate_path: str,
|
||||
identifier: str,
|
||||
options: list[str] | None,
|
||||
@@ -247,7 +248,13 @@ class KeygenCommand:
|
||||
return self._run_command(args, **kwargs)
|
||||
|
||||
def generate_keypair(
|
||||
self, private_key_path: str, size: int, type: str, comment: str | None, **kwargs
|
||||
self,
|
||||
*,
|
||||
private_key_path: str,
|
||||
size: int,
|
||||
type: str,
|
||||
comment: str | None,
|
||||
**kwargs,
|
||||
) -> tuple[int, str, str]:
|
||||
args = [
|
||||
self._bin_path,
|
||||
@@ -270,26 +277,29 @@ class KeygenCommand:
|
||||
return self._run_command(args, data=data, **kwargs)
|
||||
|
||||
def get_certificate_info(
|
||||
self, certificate_path: str, **kwargs
|
||||
self, *, certificate_path: str, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-L", "-f", certificate_path], **kwargs
|
||||
)
|
||||
|
||||
def get_matching_public_key(
|
||||
self, private_key_path: str, **kwargs
|
||||
self, *, private_key_path: str, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
|
||||
)
|
||||
|
||||
def get_private_key(self, private_key_path: str, **kwargs) -> tuple[int, str, str]:
|
||||
def get_private_key(
|
||||
self, *, private_key_path: str, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-l", "-f", private_key_path], **kwargs
|
||||
)
|
||||
|
||||
def update_comment(
|
||||
self,
|
||||
*,
|
||||
private_key_path: str,
|
||||
comment: str,
|
||||
force_new_format: bool = True,
|
||||
@@ -317,7 +327,7 @@ _PrivateKey = t.TypeVar("_PrivateKey", bound="PrivateKey")
|
||||
|
||||
class PrivateKey:
|
||||
def __init__(
|
||||
self, size: int, key_type: str, fingerprint: str, format: str = ""
|
||||
self, *, size: int, key_type: str, fingerprint: str, format: str = ""
|
||||
) -> None:
|
||||
self._size = size
|
||||
self._type = key_type
|
||||
@@ -363,7 +373,7 @@ _PublicKey = t.TypeVar("_PublicKey", bound="PublicKey")
|
||||
|
||||
|
||||
class PublicKey:
|
||||
def __init__(self, type_string: str, data: str, comment: str | None) -> None:
|
||||
def __init__(self, *, type_string: str, data: str, comment: str | None) -> None:
|
||||
self._type_string = type_string
|
||||
self._data = data
|
||||
self._comment = comment
|
||||
@@ -441,6 +451,7 @@ class PublicKey:
|
||||
|
||||
|
||||
def parse_private_key_format(
|
||||
*,
|
||||
path: str | os.PathLike,
|
||||
) -> t.Literal["SSH", "PKCS8", "PKCS1", ""]:
|
||||
with open(path, "r") as file:
|
||||
@@ -454,3 +465,14 @@ def parse_private_key_format(
|
||||
return "PKCS1"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
__all__ = (
|
||||
"restore_on_failure",
|
||||
"safe_atomic_move",
|
||||
"OpensshModule",
|
||||
"KeygenCommand",
|
||||
"PrivateKey",
|
||||
"PublicKey",
|
||||
"parse_private_key_format",
|
||||
)
|
||||
|
||||
@@ -53,8 +53,8 @@ if t.TYPE_CHECKING:
|
||||
|
||||
class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackend, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(KeypairBackend, self).__init__(module=module)
|
||||
|
||||
self.comment: str | None = self.module.params["comment"]
|
||||
self.private_key_path: str = self.module.params["path"]
|
||||
@@ -296,9 +296,9 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
try:
|
||||
secure_write(
|
||||
temp_public_key,
|
||||
existing_permissions or default_permissions,
|
||||
to_bytes(content),
|
||||
path=temp_public_key,
|
||||
mode=existing_permissions or default_permissions,
|
||||
content=to_bytes(content),
|
||||
)
|
||||
except (IOError, OSError) as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
@@ -357,8 +357,8 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class KeypairBackendOpensshBin(KeypairBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendOpensshBin, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendOpensshBin, self).__init__(module=module)
|
||||
|
||||
if self.module.params["private_key_format"] != "auto":
|
||||
self.module.fail_json(
|
||||
@@ -369,12 +369,16 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
|
||||
def _generate_keypair(self, private_key_path: str) -> None:
|
||||
self.ssh_keygen.generate_keypair(
|
||||
private_key_path, self.size, self.type, self.comment, check_rc=True
|
||||
private_key_path=private_key_path,
|
||||
size=self.size,
|
||||
type=self.type,
|
||||
comment=self.comment,
|
||||
check_rc=True,
|
||||
)
|
||||
|
||||
def _get_private_key(self) -> PrivateKey:
|
||||
rc, private_key_content, err = self.ssh_keygen.get_private_key(
|
||||
self.private_key_path, check_rc=False
|
||||
private_key_path=self.private_key_path, check_rc=False
|
||||
)
|
||||
if rc != 0:
|
||||
raise ValueError(err)
|
||||
@@ -382,13 +386,13 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
|
||||
def _get_public_key(self) -> PublicKey | t.Literal[""]:
|
||||
public_key_content = self.ssh_keygen.get_matching_public_key(
|
||||
self.private_key_path, check_rc=True
|
||||
private_key_path=self.private_key_path, check_rc=True
|
||||
)[1]
|
||||
return PublicKey.from_string(public_key_content)
|
||||
|
||||
def _private_key_readable(self) -> bool:
|
||||
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
|
||||
self.private_key_path, check_rc=False
|
||||
private_key_path=self.private_key_path, check_rc=False
|
||||
)
|
||||
return not (
|
||||
rc == 255
|
||||
@@ -407,8 +411,8 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
LooseVersion("6.5") <= LooseVersion(ssh_version) < LooseVersion("7.8")
|
||||
)
|
||||
self.ssh_keygen.update_comment(
|
||||
self.private_key_path,
|
||||
self.comment or "",
|
||||
private_key_path=self.private_key_path,
|
||||
comment=self.comment or "",
|
||||
force_new_format=force_new_format,
|
||||
check_rc=True,
|
||||
)
|
||||
@@ -420,8 +424,8 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
|
||||
|
||||
class KeypairBackendCryptography(KeypairBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendCryptography, self).__init__(module)
|
||||
def __init__(self, *, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendCryptography, self).__init__(module=module)
|
||||
|
||||
if self.type == "rsa1":
|
||||
self.module.fail_json(
|
||||
@@ -469,12 +473,12 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
)
|
||||
|
||||
encoded_private_key = OpensshKeypair.encode_openssh_privatekey(
|
||||
keypair.asymmetric_keypair, self.private_key_format
|
||||
asym_keypair=keypair.asymmetric_keypair, key_format=self.private_key_format
|
||||
)
|
||||
secure_write(private_key_path, 0o600, encoded_private_key)
|
||||
secure_write(path=private_key_path, mode=0o600, content=encoded_private_key)
|
||||
|
||||
public_key_path = private_key_path + ".pub"
|
||||
secure_write(public_key_path, 0o644, keypair.public_key)
|
||||
secure_write(path=public_key_path, mode=0o644, content=keypair.public_key)
|
||||
|
||||
def _get_private_key(self) -> PrivateKey:
|
||||
keypair = OpensshKeypair.load(
|
||||
@@ -485,7 +489,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
size=keypair.size,
|
||||
key_type=keypair.key_type,
|
||||
fingerprint=keypair.fingerprint,
|
||||
format=parse_private_key_format(self.private_key_path),
|
||||
format=parse_private_key_format(path=self.private_key_path),
|
||||
)
|
||||
|
||||
def _get_public_key(self) -> PublicKey | t.Literal[""]:
|
||||
@@ -550,7 +554,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
|
||||
|
||||
def select_backend(
|
||||
module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
|
||||
*, module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
|
||||
) -> KeypairBackend:
|
||||
can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion(
|
||||
CRYPTOGRAPHY_VERSION
|
||||
@@ -573,7 +577,7 @@ def select_backend(
|
||||
if backend == "opensshbin":
|
||||
if not can_use_opensshbin:
|
||||
module.fail_json(msg="Cannot find the OpenSSH binary in the PATH")
|
||||
return KeypairBackendOpensshBin(module)
|
||||
return KeypairBackendOpensshBin(module=module)
|
||||
if backend == "cryptography":
|
||||
if not can_use_cryptography:
|
||||
module.fail_json(
|
||||
@@ -581,5 +585,8 @@ def select_backend(
|
||||
f"cryptography >= {COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION}"
|
||||
)
|
||||
)
|
||||
return KeypairBackendCryptography(module)
|
||||
return KeypairBackendCryptography(module=module)
|
||||
raise ValueError(f"Unsupported value for backend: {backend}")
|
||||
|
||||
|
||||
__all__ = ("KeypairBackend", "select_backend")
|
||||
|
||||
@@ -111,7 +111,7 @@ _EXTENSIONS = (
|
||||
|
||||
class OpensshCertificateTimeParameters:
|
||||
def __init__(
|
||||
self, valid_from: str | bytes | int, valid_to: str | bytes | int
|
||||
self, *, valid_from: str | bytes | int, valid_to: str | bytes | int
|
||||
) -> None:
|
||||
self._valid_from = self.to_datetime(valid_from)
|
||||
self._valid_to = self.to_datetime(valid_to)
|
||||
@@ -149,7 +149,7 @@ class OpensshCertificateTimeParameters:
|
||||
def valid_from(self, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
def valid_from(self, date_format: DateFormat) -> str | int:
|
||||
return self.format_datetime(self._valid_from, date_format)
|
||||
return self.format_datetime(self._valid_from, date_format=date_format)
|
||||
|
||||
@t.overload
|
||||
def valid_to(self, date_format: DateFormatStr) -> str: ...
|
||||
@@ -161,7 +161,7 @@ class OpensshCertificateTimeParameters:
|
||||
def valid_to(self, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
def valid_to(self, date_format: DateFormat) -> str | int:
|
||||
return self.format_datetime(self._valid_to, date_format)
|
||||
return self.format_datetime(self._valid_to, date_format=date_format)
|
||||
|
||||
def within_range(self, valid_at: str | bytes | int | None) -> bool:
|
||||
if valid_at is not None:
|
||||
@@ -171,18 +171,18 @@ class OpensshCertificateTimeParameters:
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ...
|
||||
def format_datetime(dt: datetime, *, date_format: DateFormatStr) -> str: ...
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ...
|
||||
def format_datetime(dt: datetime, *, date_format: DateFormatInt) -> int: ...
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ...
|
||||
def format_datetime(dt: datetime, *, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int:
|
||||
def format_datetime(dt: datetime, *, date_format: DateFormat) -> str | int:
|
||||
if date_format in ("human_readable", "openssh"):
|
||||
if dt == _ALWAYS:
|
||||
return "always"
|
||||
@@ -264,6 +264,7 @@ _OpensshCertificateOption = t.TypeVar(
|
||||
class OpensshCertificateOption:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
option_type: t.Literal["critical", "extension"],
|
||||
name: str | bytes,
|
||||
data: str | bytes,
|
||||
@@ -350,6 +351,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
nonce: bytes | None = None,
|
||||
serial: int | None = None,
|
||||
cert_type: int | None = None,
|
||||
@@ -409,7 +411,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class OpensshRSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None:
|
||||
def __init__(self, *, e: int | None = None, n: int | None = None, **kwargs) -> None:
|
||||
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
|
||||
self.e = e
|
||||
@@ -435,6 +437,7 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
|
||||
class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
p: int | None = None,
|
||||
q: int | None = None,
|
||||
g: int | None = None,
|
||||
@@ -471,7 +474,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(
|
||||
self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
|
||||
self, *, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
|
||||
):
|
||||
super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
|
||||
self._curve = None
|
||||
@@ -515,7 +518,7 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
|
||||
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, pk: bytes | None = None, **kwargs) -> None:
|
||||
def __init__(self, *, pk: bytes | None = None, **kwargs) -> None:
|
||||
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
|
||||
self.pk = pk
|
||||
@@ -541,8 +544,7 @@ _OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate
|
||||
class OpensshCertificate:
|
||||
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
|
||||
|
||||
def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes):
|
||||
|
||||
def __init__(self, *, cert_info: OpensshCertificateInfo, signature: bytes):
|
||||
self._cert_info = cert_info
|
||||
self.signature = signature
|
||||
|
||||
@@ -574,7 +576,7 @@ class OpensshCertificate:
|
||||
f"Invalid certificate format identifier: {format_identifier!r}"
|
||||
)
|
||||
|
||||
parser = OpensshParser(cert)
|
||||
parser = OpensshParser(data=cert)
|
||||
|
||||
if format_identifier != parser.string():
|
||||
raise ValueError("Certificate formats do not match")
|
||||
@@ -649,7 +651,9 @@ class OpensshCertificate:
|
||||
if self._cert_info.critical_options is None:
|
||||
raise ValueError
|
||||
return [
|
||||
OpensshCertificateOption("critical", to_text(n), to_text(d))
|
||||
OpensshCertificateOption(
|
||||
option_type="critical", name=to_text(n), data=to_text(d)
|
||||
)
|
||||
for n, d in self._cert_info.critical_options
|
||||
]
|
||||
|
||||
@@ -658,7 +662,9 @@ class OpensshCertificate:
|
||||
if self._cert_info.extensions is None:
|
||||
raise ValueError
|
||||
return [
|
||||
OpensshCertificateOption("extension", to_text(n), to_text(d))
|
||||
OpensshCertificateOption(
|
||||
option_type="extension", name=to_text(n), data=to_text(d)
|
||||
)
|
||||
for n, d in self._cert_info.extensions
|
||||
]
|
||||
|
||||
@@ -674,7 +680,7 @@ class OpensshCertificate:
|
||||
|
||||
@property
|
||||
def signature_type(self) -> str:
|
||||
signature_data = OpensshParser.signature_data(self.signature)
|
||||
signature_data = OpensshParser.signature_data(signature_string=self.signature)
|
||||
return to_text(signature_data["signature_type"])
|
||||
|
||||
@staticmethod
|
||||
@@ -727,16 +733,20 @@ def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOpti
|
||||
|
||||
directive_to_option = {
|
||||
"no-x11-forwarding": OpensshCertificateOption(
|
||||
"extension", "permit-x11-forwarding", ""
|
||||
option_type="extension", name="permit-x11-forwarding", data=""
|
||||
),
|
||||
"no-agent-forwarding": OpensshCertificateOption(
|
||||
"extension", "permit-agent-forwarding", ""
|
||||
option_type="extension", name="permit-agent-forwarding", data=""
|
||||
),
|
||||
"no-port-forwarding": OpensshCertificateOption(
|
||||
"extension", "permit-port-forwarding", ""
|
||||
option_type="extension", name="permit-port-forwarding", data=""
|
||||
),
|
||||
"no-pty": OpensshCertificateOption(
|
||||
option_type="extension", name="permit-pty", data=""
|
||||
),
|
||||
"no-user-rc": OpensshCertificateOption(
|
||||
option_type="extension", name="permit-user-rc", data=""
|
||||
),
|
||||
"no-pty": OpensshCertificateOption("extension", "permit-pty", ""),
|
||||
"no-user-rc": OpensshCertificateOption("extension", "permit-user-rc", ""),
|
||||
}
|
||||
|
||||
if "clear" in directives:
|
||||
@@ -748,7 +758,10 @@ def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOpti
|
||||
|
||||
|
||||
def default_options() -> list[OpensshCertificateOption]:
|
||||
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
|
||||
return [
|
||||
OpensshCertificateOption(option_type="extension", name=name, data="")
|
||||
for name in _EXTENSIONS
|
||||
]
|
||||
|
||||
|
||||
def fingerprint(public_key: bytes) -> bytes:
|
||||
@@ -803,3 +816,22 @@ def parse_option_list(
|
||||
extensions.append(option_object)
|
||||
|
||||
return critical_options, list(set(extensions + apply_directives(directives)))
|
||||
|
||||
|
||||
__all__ = (
|
||||
"OpensshCertificateTimeParameters",
|
||||
"OpensshCertificateOption",
|
||||
"OpensshCertificateInfo",
|
||||
"OpensshRSACertificateInfo",
|
||||
"OpensshDSACertificateInfo",
|
||||
"OpensshECDSACertificateInfo",
|
||||
"OpensshED25519CertificateInfo",
|
||||
"OpensshCertificate",
|
||||
"apply_directives",
|
||||
"default_options",
|
||||
"fingerprint",
|
||||
"get_cert_info_object",
|
||||
"get_option_type",
|
||||
"is_relative_time_string",
|
||||
"parse_option_list",
|
||||
)
|
||||
|
||||
@@ -145,6 +145,7 @@ class AsymmetricKeypair:
|
||||
@classmethod
|
||||
def generate(
|
||||
cls: t.Type[_AsymmetricKeypair],
|
||||
*,
|
||||
keytype: KeyType = "rsa",
|
||||
size: int | None = None,
|
||||
passphrase: bytes | None = None,
|
||||
@@ -208,6 +209,7 @@ class AsymmetricKeypair:
|
||||
@classmethod
|
||||
def load(
|
||||
cls: t.Type[_AsymmetricKeypair],
|
||||
*,
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None = None,
|
||||
private_key_format: KeySerializationFormat = "PEM",
|
||||
@@ -228,13 +230,15 @@ class AsymmetricKeypair:
|
||||
else:
|
||||
encryption_algorithm = serialization.NoEncryption()
|
||||
|
||||
privatekey = load_privatekey(path, passphrase, private_key_format)
|
||||
privatekey = load_privatekey(
|
||||
path=path, passphrase=passphrase, key_format=private_key_format
|
||||
)
|
||||
if no_public_key:
|
||||
publickey = privatekey.public_key()
|
||||
else:
|
||||
# TODO: BUG: load_publickey() can return unsupported key types
|
||||
# (Also we should check whether the public key fits the private key...)
|
||||
publickey = load_publickey(path + ".pub", public_key_format) # type: ignore
|
||||
publickey = load_publickey(path=path + ".pub", key_format=public_key_format) # type: ignore
|
||||
|
||||
# Ed25519 keys are always of size 256 and do not have a key_size attribute
|
||||
if isinstance(privatekey, Ed25519PrivateKey):
|
||||
@@ -264,6 +268,7 @@ class AsymmetricKeypair:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
keytype: KeyType,
|
||||
size: int,
|
||||
privatekey: PrivateKeyTypes,
|
||||
@@ -285,7 +290,7 @@ class AsymmetricKeypair:
|
||||
self.__encryption_algorithm = encryption_algorithm
|
||||
|
||||
try:
|
||||
self.verify(self.sign(b"message"), b"message")
|
||||
self.verify(signature=self.sign(b"message"), data=b"message")
|
||||
except InvalidSignatureError:
|
||||
raise InvalidPublicKeyFileError(
|
||||
"The private key and public key of this keypair do not match"
|
||||
@@ -347,7 +352,7 @@ class AsymmetricKeypair:
|
||||
except TypeError as e:
|
||||
raise InvalidDataError(e)
|
||||
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
def verify(self, *, signature: bytes, data: bytes) -> None:
|
||||
"""Verifies that the signature associated with the provided data was signed
|
||||
by the private key of this key pair.
|
||||
|
||||
@@ -384,6 +389,7 @@ class OpensshKeypair:
|
||||
@classmethod
|
||||
def generate(
|
||||
cls: t.Type[_OpensshKeypair],
|
||||
*,
|
||||
keytype: KeyType = "rsa",
|
||||
size: int | None = None,
|
||||
passphrase: bytes | None = None,
|
||||
@@ -400,9 +406,15 @@ class OpensshKeypair:
|
||||
if comment is None:
|
||||
comment = f"{getuser()}@{gethostname()}"
|
||||
|
||||
asym_keypair = AsymmetricKeypair.generate(keytype, size, passphrase)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
|
||||
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
|
||||
asym_keypair = AsymmetricKeypair.generate(
|
||||
keytype=keytype, size=size, passphrase=passphrase
|
||||
)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(
|
||||
asym_keypair=asym_keypair, key_format="SSH"
|
||||
)
|
||||
openssh_publickey = cls.encode_openssh_publickey(
|
||||
asym_keypair=asym_keypair, comment=comment
|
||||
)
|
||||
fingerprint = calculate_fingerprint(openssh_publickey)
|
||||
|
||||
return cls(
|
||||
@@ -416,6 +428,7 @@ class OpensshKeypair:
|
||||
@classmethod
|
||||
def load(
|
||||
cls: t.Type[_OpensshKeypair],
|
||||
*,
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None = None,
|
||||
no_public_key: bool = False,
|
||||
@@ -433,10 +446,18 @@ class OpensshKeypair:
|
||||
comment = extract_comment(str(path) + ".pub")
|
||||
|
||||
asym_keypair = AsymmetricKeypair.load(
|
||||
path, passphrase, "SSH", "SSH", no_public_key
|
||||
path=path,
|
||||
passphrase=passphrase,
|
||||
private_key_format="SSH",
|
||||
public_key_format="SSH",
|
||||
no_public_key=no_public_key,
|
||||
)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(
|
||||
asym_keypair=asym_keypair, key_format="SSH"
|
||||
)
|
||||
openssh_publickey = cls.encode_openssh_publickey(
|
||||
asym_keypair=asym_keypair, comment=comment
|
||||
)
|
||||
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
|
||||
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
|
||||
fingerprint = calculate_fingerprint(openssh_publickey)
|
||||
|
||||
return cls(
|
||||
@@ -449,7 +470,7 @@ class OpensshKeypair:
|
||||
|
||||
@staticmethod
|
||||
def encode_openssh_privatekey(
|
||||
asym_keypair: AsymmetricKeypair, key_format: KeyFormat
|
||||
*, asym_keypair: AsymmetricKeypair, key_format: KeyFormat
|
||||
) -> bytes:
|
||||
"""Returns an OpenSSH encoded private key for a given keypair
|
||||
|
||||
@@ -482,7 +503,7 @@ class OpensshKeypair:
|
||||
|
||||
@staticmethod
|
||||
def encode_openssh_publickey(
|
||||
asym_keypair: AsymmetricKeypair, comment: str
|
||||
*, asym_keypair: AsymmetricKeypair, comment: str
|
||||
) -> bytes:
|
||||
"""Returns an OpenSSH encoded public key for a given keypair
|
||||
|
||||
@@ -504,6 +525,7 @@ class OpensshKeypair:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
asym_keypair: AsymmetricKeypair,
|
||||
openssh_privatekey: bytes,
|
||||
openssh_publickey: bytes,
|
||||
@@ -603,11 +625,12 @@ class OpensshKeypair:
|
||||
|
||||
self.__asym_keypair.update_passphrase(passphrase)
|
||||
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(
|
||||
self.__asym_keypair, "SSH"
|
||||
asym_keypair=self.__asym_keypair, key_format="SSH"
|
||||
)
|
||||
|
||||
|
||||
def load_privatekey(
|
||||
*,
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None,
|
||||
key_format: KeySerializationFormat,
|
||||
@@ -662,7 +685,7 @@ def load_privatekey(
|
||||
|
||||
|
||||
def load_publickey(
|
||||
path: str | os.PathLike, key_format: KeySerializationFormat
|
||||
*, path: str | os.PathLike, key_format: KeySerializationFormat
|
||||
) -> AllPublicKeyTypes:
|
||||
publickey_loaders = {
|
||||
"PEM": serialization.load_pem_public_key,
|
||||
@@ -767,3 +790,30 @@ def calculate_fingerprint(openssh_publickey: bytes) -> str:
|
||||
|
||||
value = b64encode(digest.finalize()).decode(encoding=_TEXT_ENCODING).rstrip("=")
|
||||
return f"SHA256:{value}"
|
||||
|
||||
|
||||
__all__ = (
|
||||
"HAS_OPENSSH_SUPPORT",
|
||||
"CRYPTOGRAPHY_VERSION",
|
||||
"OpenSSHError",
|
||||
"InvalidAlgorithmError",
|
||||
"InvalidCommentError",
|
||||
"InvalidDataError",
|
||||
"InvalidPrivateKeyFileError",
|
||||
"InvalidPublicKeyFileError",
|
||||
"InvalidKeyFormatError",
|
||||
"InvalidKeySizeError",
|
||||
"InvalidKeyTypeError",
|
||||
"InvalidPassphraseError",
|
||||
"InvalidSignatureError",
|
||||
"AsymmetricKeypair",
|
||||
"OpensshKeypair",
|
||||
"load_privatekey",
|
||||
"load_publickey",
|
||||
"compare_publickeys",
|
||||
"compare_encryption_algorithms",
|
||||
"get_encryption_algorithm",
|
||||
"validate_comment",
|
||||
"extract_comment",
|
||||
"calculate_fingerprint",
|
||||
)
|
||||
|
||||
@@ -70,7 +70,7 @@ def parse_openssh_version(version_string: str) -> str | None:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
|
||||
def secure_open(*, path: str | os.PathLike, mode: int) -> t.Iterator[int]:
|
||||
fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode)
|
||||
try:
|
||||
yield fd
|
||||
@@ -78,8 +78,8 @@ def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
|
||||
os.close(fd)
|
||||
|
||||
|
||||
def secure_write(path: str | os.PathLike, mode: int, content: bytes) -> None:
|
||||
with secure_open(path, mode) as fd:
|
||||
def secure_write(*, path: str | os.PathLike, mode: int, content: bytes) -> None:
|
||||
with secure_open(path=path, mode=mode) as fd:
|
||||
os.write(fd, content)
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class OpensshParser:
|
||||
UINT32_OFFSET = 4
|
||||
UINT64_OFFSET = 8
|
||||
|
||||
def __init__(self, data: bytes | bytearray) -> None:
|
||||
def __init__(self, *, data: bytes | bytearray) -> None:
|
||||
if not isinstance(data, (bytes, bytearray)):
|
||||
raise TypeError(f"Data must be bytes-like not {type(data)}")
|
||||
|
||||
@@ -142,7 +142,7 @@ class OpensshParser:
|
||||
raw_string = self.string()
|
||||
|
||||
if raw_string:
|
||||
parser = OpensshParser(raw_string)
|
||||
parser = OpensshParser(data=raw_string)
|
||||
while parser.remaining_bytes():
|
||||
result.append(parser.string())
|
||||
|
||||
@@ -154,14 +154,14 @@ class OpensshParser:
|
||||
raw_string = self.string()
|
||||
|
||||
if raw_string:
|
||||
parser = OpensshParser(raw_string)
|
||||
parser = OpensshParser(data=raw_string)
|
||||
|
||||
while parser.remaining_bytes():
|
||||
name = parser.string()
|
||||
data = parser.string()
|
||||
if data:
|
||||
# data is doubly-encoded
|
||||
data = OpensshParser(data).string()
|
||||
data = OpensshParser(data=data).string()
|
||||
result.append((name, data))
|
||||
|
||||
return result
|
||||
@@ -183,14 +183,14 @@ class OpensshParser:
|
||||
return self._pos + offset
|
||||
|
||||
@classmethod
|
||||
def signature_data(cls, signature_string: bytes) -> dict[str, bytes | int]:
|
||||
def signature_data(cls, *, signature_string: bytes) -> dict[str, bytes | int]:
|
||||
signature_data: dict[str, bytes | int] = {}
|
||||
|
||||
parser = cls(signature_string)
|
||||
parser = cls(data=signature_string)
|
||||
signature_type = parser.string()
|
||||
signature_blob = parser.string()
|
||||
|
||||
blob_parser = cls(signature_blob)
|
||||
blob_parser = cls(data=signature_blob)
|
||||
if signature_type in (b"ssh-rsa", b"rsa-sha2-256", b"rsa-sha2-512"):
|
||||
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
# https://datatracker.ietf.org/doc/html/rfc8332#section-3
|
||||
@@ -242,7 +242,7 @@ class _OpensshWriter:
|
||||
in validating parsed material.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer: bytearray | None = None):
|
||||
def __init__(self, *, buffer: bytearray | None = None):
|
||||
if buffer is not None:
|
||||
if not isinstance(buffer, bytearray):
|
||||
raise TypeError(f"Buffer must be a bytearray, not {type(buffer)}")
|
||||
@@ -347,3 +347,13 @@ class _OpensshWriter:
|
||||
|
||||
def bytes(self) -> bytes:
|
||||
return bytes(self._buff)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"any_in",
|
||||
"file_mode",
|
||||
"parse_openssh_version",
|
||||
"secure_open",
|
||||
"secure_write",
|
||||
"OpensshParser",
|
||||
)
|
||||
|
||||
@@ -54,3 +54,6 @@ def to_serial(value: int) -> str:
|
||||
if len(value_str) % 2 != 0:
|
||||
value_str = f"0{value_str}"
|
||||
return ":".join(value_str[i : i + 2] for i in range(0, len(value_str), 2))
|
||||
|
||||
|
||||
__all__ = ("parse_serial", "to_serial")
|
||||
|
||||
@@ -19,7 +19,7 @@ from ansible_collections.community.crypto.plugins.module_utils._crypto.basic imp
|
||||
UTC = datetime.timezone.utc
|
||||
|
||||
|
||||
def get_now_datetime(with_timezone: bool) -> datetime.datetime:
|
||||
def get_now_datetime(*, with_timezone: bool) -> datetime.datetime:
|
||||
if with_timezone:
|
||||
return datetime.datetime.now(tz=UTC)
|
||||
return datetime.datetime.utcnow()
|
||||
@@ -44,7 +44,7 @@ def remove_timezone(timestamp: datetime.datetime) -> datetime.datetime:
|
||||
|
||||
|
||||
def add_or_remove_timezone(
|
||||
timestamp: datetime.datetime, with_timezone: bool
|
||||
timestamp: datetime.datetime, *, with_timezone: bool
|
||||
) -> datetime.datetime:
|
||||
return (
|
||||
ensure_utc_timezone(timestamp) if with_timezone else remove_timezone(timestamp)
|
||||
@@ -59,7 +59,7 @@ def get_epoch_seconds(timestamp: datetime.datetime) -> float:
|
||||
|
||||
|
||||
def from_epoch_seconds(
|
||||
timestamp: int | float, with_timezone: bool
|
||||
timestamp: int | float, *, with_timezone: bool
|
||||
) -> datetime.datetime:
|
||||
if with_timezone:
|
||||
return datetime.datetime.fromtimestamp(timestamp, UTC)
|
||||
@@ -68,6 +68,7 @@ def from_epoch_seconds(
|
||||
|
||||
def convert_relative_to_datetime(
|
||||
relative_time_string: str,
|
||||
*,
|
||||
with_timezone: bool = False,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> datetime.datetime | None:
|
||||
@@ -107,6 +108,7 @@ def convert_relative_to_datetime(
|
||||
|
||||
def get_relative_time_option(
|
||||
input_string: str,
|
||||
*,
|
||||
input_name: str,
|
||||
with_timezone: bool = False,
|
||||
now: datetime.datetime | None = None,
|
||||
@@ -155,3 +157,15 @@ def get_relative_time_option(
|
||||
raise OpenSSLObjectError(
|
||||
f'The time spec "{input_string}" for {input_name} is invalid'
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"get_now_datetime",
|
||||
"ensure_utc_timezone",
|
||||
"remove_timezone",
|
||||
"add_or_remove_timezone",
|
||||
"get_epoch_seconds",
|
||||
"from_epoch_seconds",
|
||||
"convert_relative_to_datetime",
|
||||
"get_relative_time_option",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user