mirror of
https://github.com/ansible-collections/community.crypto.git
synced 2026-05-06 13:22:58 +00:00
Add type hints and type checking (#885)
* Enable basic type checking. * Fix first errors. * Add changelog fragment. * Add types to module_utils and plugin_utils (without module backends). * Add typing hints for acme_* modules. * Add typing to X.509 certificate modules, and add more helpers. * Add typing to remaining module backends. * Add typing for action, filter, and lookup plugins. * Bump ansible-core 2.19 beta requirement for typing. * Add more typing definitions. * Add typing to some unit tests.
This commit is contained in:
@@ -5,6 +5,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common._collections_compat import Mapping
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
ACMEProtocolException,
|
||||
@@ -12,26 +14,29 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .acme import ACMEClient
|
||||
|
||||
|
||||
class ACMEAccount:
|
||||
"""
|
||||
ACME account object. Allows to create new accounts, check for existence of accounts,
|
||||
retrieve account data.
|
||||
"""
|
||||
|
||||
def __init__(self, client):
|
||||
def __init__(self, client: ACMEClient) -> None:
|
||||
# Set to true to enable logging of all signed requests
|
||||
self._debug = False
|
||||
self._debug: bool = False
|
||||
|
||||
self.client = client
|
||||
|
||||
def _new_reg(
|
||||
self,
|
||||
contact=None,
|
||||
agreement=None,
|
||||
terms_agreed=False,
|
||||
allow_creation=True,
|
||||
external_account_binding=None,
|
||||
):
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: bool = True,
|
||||
external_account_binding: dict[str, t.Any] | None = None,
|
||||
) -> tuple[bool, dict[str, t.Any] | None]:
|
||||
"""
|
||||
Registers a new ACME account. Returns a pair ``(created, data)``.
|
||||
Here, ``created`` is ``True`` if the account was created and
|
||||
@@ -63,7 +68,7 @@ class ACMEAccount:
|
||||
return created, data
|
||||
# An account does not yet exist. Try to create one next.
|
||||
|
||||
new_reg = {"contact": contact}
|
||||
new_reg: dict[str, t.Any] = {"contact": contact}
|
||||
if not allow_creation:
|
||||
# https://tools.ietf.org/html/rfc8555#section-7.3.1
|
||||
new_reg["onlyReturnExisting"] = True
|
||||
@@ -99,7 +104,7 @@ class ACMEAccount:
|
||||
self.client.module,
|
||||
msg="Invalid account creation reply from ACME server",
|
||||
info=info,
|
||||
content=result,
|
||||
content_json=result,
|
||||
)
|
||||
|
||||
if info["status"] == 201:
|
||||
@@ -152,7 +157,7 @@ class ACMEAccount:
|
||||
content_json=result,
|
||||
)
|
||||
|
||||
def get_account_data(self):
|
||||
def get_account_data(self) -> dict[str, t.Any] | None:
|
||||
"""
|
||||
Retrieve account information. Can only be called when the account
|
||||
URI is already known (such as after calling setup_account).
|
||||
@@ -161,7 +166,7 @@ class ACMEAccount:
|
||||
if self.client.account_uri is None:
|
||||
raise ModuleFailException("Account URI unknown")
|
||||
# try POST-as-GET first (draft-15 or newer)
|
||||
data = None
|
||||
data: dict[str, t.Any] | None = None
|
||||
result, info = self.client.send_signed_request(
|
||||
self.client.account_uri, data, fail_on_error=False
|
||||
)
|
||||
@@ -180,7 +185,7 @@ class ACMEAccount:
|
||||
self.client.module,
|
||||
msg="Invalid account data retrieved from ACME server",
|
||||
info=info,
|
||||
content=result,
|
||||
content_json=result,
|
||||
)
|
||||
if (
|
||||
info["status"] in (400, 403)
|
||||
@@ -203,15 +208,34 @@ class ACMEAccount:
|
||||
)
|
||||
return result
|
||||
|
||||
@t.overload
|
||||
def setup_account(
|
||||
self,
|
||||
contact=None,
|
||||
agreement=None,
|
||||
terms_agreed=False,
|
||||
allow_creation=True,
|
||||
remove_account_uri_if_not_exists=False,
|
||||
external_account_binding=None,
|
||||
):
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: t.Literal[True] = True,
|
||||
remove_account_uri_if_not_exists: bool = False,
|
||||
external_account_binding: dict[str, t.Any] | None = None,
|
||||
) -> tuple[bool, dict[str, t.Any]]: ...
|
||||
|
||||
@t.overload
|
||||
def setup_account(
|
||||
self,
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: bool = True,
|
||||
remove_account_uri_if_not_exists: bool = False,
|
||||
external_account_binding: dict[str, t.Any] | None = None,
|
||||
) -> tuple[bool, dict[str, t.Any] | None]: ...
|
||||
|
||||
def setup_account(
|
||||
self,
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: bool = True,
|
||||
remove_account_uri_if_not_exists: bool = False,
|
||||
external_account_binding: dict[str, t.Any] | None = None,
|
||||
) -> tuple[bool, dict[str, t.Any] | None]:
|
||||
"""
|
||||
Detect or create an account on the ACME server. For ACME v1,
|
||||
as the only way (without knowing an account URI) to test if an
|
||||
@@ -253,7 +277,6 @@ class ACMEAccount:
|
||||
else:
|
||||
created, account_data = self._new_reg(
|
||||
contact,
|
||||
agreement=agreement,
|
||||
terms_agreed=terms_agreed,
|
||||
allow_creation=allow_creation and not self.client.module.check_mode,
|
||||
external_account_binding=external_account_binding,
|
||||
@@ -267,7 +290,9 @@ class ACMEAccount:
|
||||
account_data = {"contact": contact or []}
|
||||
return created, account_data
|
||||
|
||||
def update_account(self, account_data, contact=None):
|
||||
def update_account(
|
||||
self, account_data: dict[str, t.Any], contact: list[str] | None = None
|
||||
) -> tuple[bool, dict[str, t.Any]]:
|
||||
"""
|
||||
Update an account on the ACME server. Check mode is fully respected.
|
||||
|
||||
@@ -280,8 +305,11 @@ class ACMEAccount:
|
||||
|
||||
https://tools.ietf.org/html/rfc8555#section-7.3.2
|
||||
"""
|
||||
if self.client.account_uri is None:
|
||||
raise ModuleFailException("Cannot update account without account URI")
|
||||
|
||||
# Create request
|
||||
update_request = {}
|
||||
update_request: dict[str, t.Any] = {}
|
||||
if contact is not None and account_data.get("contact", []) != contact:
|
||||
update_request["contact"] = list(contact)
|
||||
|
||||
@@ -302,7 +330,7 @@ class ACMEAccount:
|
||||
self.client.module,
|
||||
msg="Invalid account updating reply from ACME server",
|
||||
info=info,
|
||||
content=account_data,
|
||||
content_json=account_data,
|
||||
)
|
||||
|
||||
return True, account_data
|
||||
|
||||
@@ -10,6 +10,7 @@ import datetime
|
||||
import json
|
||||
import locale
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import missing_required_lib
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
@@ -41,13 +42,24 @@ from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import os
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .account import ACMEAccount
|
||||
from .backends import CertificateInformation, CryptoBackend
|
||||
|
||||
|
||||
# -1 usually means connection problems
|
||||
RETRY_STATUS_CODES = (-1, 408, 429, 503)
|
||||
|
||||
RETRY_COUNT = 10
|
||||
|
||||
|
||||
def _decode_retry(module, response, info, retry_count):
|
||||
def _decode_retry(
|
||||
module: AnsibleModule, response: t.Any, info: dict[str, t.Any], retry_count: int
|
||||
) -> bool:
|
||||
if info["status"] not in RETRY_STATUS_CODES:
|
||||
return False
|
||||
|
||||
@@ -61,7 +73,8 @@ def _decode_retry(module, response, info, retry_count):
|
||||
|
||||
# 429 and 503 should have a Retry-After header (https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
|
||||
try:
|
||||
retry_after = min(max(1, int(info.get("retry-after"))), 60)
|
||||
# TODO: use utils.parse_retry_after()
|
||||
retry_after = min(max(1, int(info.get("retry-after", "10"))), 60)
|
||||
except (TypeError, ValueError):
|
||||
retry_after = 10
|
||||
module.log(
|
||||
@@ -73,13 +86,13 @@ def _decode_retry(module, response, info, retry_count):
|
||||
|
||||
|
||||
def _assert_fetch_url_success(
|
||||
module,
|
||||
response,
|
||||
info,
|
||||
allow_redirect=False,
|
||||
allow_client_error=True,
|
||||
allow_server_error=True,
|
||||
):
|
||||
module: AnsibleModule,
|
||||
response: t.Any,
|
||||
info: dict[str, t.Any],
|
||||
allow_redirect: bool = False,
|
||||
allow_client_error: bool = True,
|
||||
allow_server_error: bool = True,
|
||||
) -> None:
|
||||
if info["status"] < 0:
|
||||
raise NetworkException(msg=f"Failure downloading {info['url']}, {info['msg']}")
|
||||
|
||||
@@ -91,7 +104,9 @@ def _assert_fetch_url_success(
|
||||
raise ACMEProtocolException(module, info=info, response=response)
|
||||
|
||||
|
||||
def _is_failed(info, expected_status_codes=None):
|
||||
def _is_failed(
|
||||
info: dict[str, t.Any], expected_status_codes: t.Iterable[int] | None = None
|
||||
) -> bool:
|
||||
if info["status"] < 200 or info["status"] >= 400:
|
||||
return True
|
||||
if (
|
||||
@@ -111,12 +126,12 @@ class ACMEDirectory:
|
||||
https://tools.ietf.org/html/rfc8555#section-7.1.1
|
||||
"""
|
||||
|
||||
def __init__(self, module, account):
|
||||
def __init__(self, module: AnsibleModule, client: ACMEClient) -> None:
|
||||
self.module = module
|
||||
self.directory_root = module.params["acme_directory"]
|
||||
self.version = module.params["acme_version"]
|
||||
|
||||
self.directory, dummy = account.get_request(self.directory_root, get_only=True)
|
||||
self.directory, dummy = client.get_request(self.directory_root, get_only=True)
|
||||
|
||||
self.request_timeout = module.params["request_timeout"]
|
||||
|
||||
@@ -131,16 +146,16 @@ class ACMEDirectory:
|
||||
if "meta" not in self.directory:
|
||||
self.directory["meta"] = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
return self.directory[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self.directory
|
||||
|
||||
def get(self, key, default_value=None):
|
||||
def get(self, key: str, default_value: t.Any = None) -> t.Any:
|
||||
return self.directory.get(key, default_value)
|
||||
|
||||
def get_nonce(self, resource=None):
|
||||
def get_nonce(self, resource: str | None = None) -> str:
|
||||
url = self.directory["newNonce"]
|
||||
if resource is not None:
|
||||
url = resource
|
||||
@@ -170,7 +185,7 @@ class ACMEDirectory:
|
||||
)
|
||||
retry_count += 1
|
||||
|
||||
def has_renewal_info_endpoint(self):
|
||||
def has_renewal_info_endpoint(self) -> bool:
|
||||
return "renewalInfo" in self.directory
|
||||
|
||||
|
||||
@@ -180,7 +195,7 @@ class ACMEClient:
|
||||
ACME server.
|
||||
"""
|
||||
|
||||
def __init__(self, module, backend):
|
||||
def __init__(self, module: AnsibleModule, backend: CryptoBackend) -> None:
|
||||
# Set to true to enable logging of all signed requests
|
||||
self._debug = False
|
||||
|
||||
@@ -221,16 +236,22 @@ class ACMEClient:
|
||||
|
||||
self.directory = ACMEDirectory(module, self)
|
||||
|
||||
def set_account_uri(self, uri):
|
||||
def set_account_uri(self, uri: str) -> None:
|
||||
"""
|
||||
Set account URI. For ACME v2, it needs to be used to sending signed
|
||||
requests.
|
||||
"""
|
||||
self.account_uri = uri
|
||||
self.account_jws_header.pop("jwk")
|
||||
self.account_jws_header["kid"] = self.account_uri
|
||||
if self.account_jws_header:
|
||||
self.account_jws_header.pop("jwk", None)
|
||||
self.account_jws_header["kid"] = self.account_uri
|
||||
|
||||
def parse_key(self, key_file=None, key_content=None, passphrase=None):
|
||||
def parse_key(
|
||||
self,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
|
||||
In case of an error, raises KeyParsingError.
|
||||
@@ -239,7 +260,13 @@ class ACMEClient:
|
||||
raise AssertionError("One of key_file and key_content must be specified!")
|
||||
return self.backend.parse_key(key_file, key_content, passphrase=passphrase)
|
||||
|
||||
def sign_request(self, protected, payload, key_data, encode_payload=True):
|
||||
def sign_request(
|
||||
self,
|
||||
protected: dict[str, t.Any],
|
||||
payload: str | dict[str, t.Any] | None,
|
||||
key_data: dict[str, t.Any],
|
||||
encode_payload: bool = True,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Signs an ACME request.
|
||||
"""
|
||||
@@ -260,7 +287,7 @@ class ACMEClient:
|
||||
|
||||
return self.backend.sign(payload64, protected64, key_data)
|
||||
|
||||
def _log(self, msg, data=None):
|
||||
def _log(self, msg: str, data: t.Any = None) -> None:
|
||||
"""
|
||||
Write arguments to acme.log when logging is enabled.
|
||||
"""
|
||||
@@ -275,18 +302,49 @@ class ACMEClient:
|
||||
)
|
||||
)
|
||||
|
||||
@t.overload
|
||||
def send_signed_request(
|
||||
self,
|
||||
url,
|
||||
payload,
|
||||
key_data=None,
|
||||
jws_header=None,
|
||||
parse_json_result=True,
|
||||
encode_payload=True,
|
||||
fail_on_error=True,
|
||||
error_msg=None,
|
||||
expected_status_codes=None,
|
||||
):
|
||||
url: str,
|
||||
payload: str | dict[str, t.Any] | None,
|
||||
*,
|
||||
key_data: dict[str, t.Any] | None = None,
|
||||
jws_header: dict[str, t.Any] | None = None,
|
||||
parse_json_result: t.Literal[True] = True,
|
||||
encode_payload: bool = True,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ...
|
||||
|
||||
@t.overload
|
||||
def send_signed_request(
|
||||
self,
|
||||
url: str,
|
||||
payload: str | dict[str, t.Any] | None,
|
||||
*,
|
||||
key_data: dict[str, t.Any] | None = None,
|
||||
jws_header: dict[str, t.Any] | None = None,
|
||||
parse_json_result: t.Literal[False],
|
||||
encode_payload: bool = True,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[bytes, dict[str, t.Any]]: ...
|
||||
|
||||
def send_signed_request(
|
||||
self,
|
||||
url: str,
|
||||
payload: str | dict[str, t.Any] | None,
|
||||
*,
|
||||
key_data: dict[str, t.Any] | None = None,
|
||||
jws_header: dict[str, t.Any] | None = None,
|
||||
parse_json_result: bool = True,
|
||||
encode_payload: bool = True,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]:
|
||||
"""
|
||||
Sends a JWS signed HTTP POST request to the ACME server and returns
|
||||
the response as dictionary (if parse_json_result is True) or in raw form
|
||||
@@ -297,7 +355,11 @@ class ACMEClient:
|
||||
(https://tools.ietf.org/html/rfc8555#section-6.3)
|
||||
"""
|
||||
key_data = key_data or self.account_key_data
|
||||
if key_data is None:
|
||||
raise ModuleFailException("Missing key data")
|
||||
jws_header = jws_header or self.account_jws_header
|
||||
if jws_header is None:
|
||||
raise ModuleFailException("Missing JWS header")
|
||||
failed_tries = 0
|
||||
while True:
|
||||
protected = copy.deepcopy(jws_header)
|
||||
@@ -382,16 +444,43 @@ class ACMEClient:
|
||||
)
|
||||
return result, info
|
||||
|
||||
@t.overload
|
||||
def get_request(
|
||||
self,
|
||||
uri,
|
||||
parse_json_result=True,
|
||||
headers=None,
|
||||
get_only=False,
|
||||
fail_on_error=True,
|
||||
error_msg=None,
|
||||
expected_status_codes=None,
|
||||
):
|
||||
uri: str,
|
||||
*,
|
||||
parse_json_result: t.Literal[True] = True,
|
||||
headers: dict[str, str] | None = None,
|
||||
get_only: bool = False,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ...
|
||||
|
||||
@t.overload
|
||||
def get_request(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
parse_json_result: t.Literal[False],
|
||||
headers: dict[str, str] | None = None,
|
||||
get_only: bool = False,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[bytes, dict[str, t.Any]]: ...
|
||||
|
||||
def get_request(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
parse_json_result: bool = True,
|
||||
headers: dict[str, str] | None = None,
|
||||
get_only: bool = False,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]:
|
||||
"""
|
||||
Perform a GET-like request. Will try POST-as-GET for ACMEv2, with fallback
|
||||
to GET if server replies with a status code of 405.
|
||||
@@ -436,6 +525,7 @@ class ACMEClient:
|
||||
|
||||
# Process result
|
||||
parsed_json_result = False
|
||||
result: dict[str, t.Any] | bytes
|
||||
if parse_json_result:
|
||||
result = {}
|
||||
if content:
|
||||
@@ -445,7 +535,7 @@ class ACMEClient:
|
||||
parsed_json_result = True
|
||||
except ValueError:
|
||||
raise NetworkException(
|
||||
f"Failed to parse the ACME response: {uri} {content}"
|
||||
f"Failed to parse the ACME response: {uri} {content!r}"
|
||||
)
|
||||
else:
|
||||
result = content
|
||||
@@ -460,19 +550,21 @@ class ACMEClient:
|
||||
msg=error_msg,
|
||||
info=info,
|
||||
content=content,
|
||||
content_json=result if parsed_json_result else None,
|
||||
content_json=(
|
||||
t.cast(dict[str, t.Any], result) if parsed_json_result else None
|
||||
),
|
||||
)
|
||||
return result, info
|
||||
|
||||
def get_renewal_info(
|
||||
self,
|
||||
cert_id=None,
|
||||
cert_info=None,
|
||||
cert_filename=None,
|
||||
cert_content=None,
|
||||
include_retry_after=False,
|
||||
retry_after_relative_with_timezone=True,
|
||||
):
|
||||
cert_id: str | None = None,
|
||||
cert_info: CertificateInformation | None = None,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
include_retry_after: bool = False,
|
||||
retry_after_relative_with_timezone: bool = True,
|
||||
) -> dict[str, t.Any]:
|
||||
if not self.directory.has_renewal_info_endpoint():
|
||||
raise ModuleFailException(
|
||||
"The ACME endpoint does not support ACME Renewal Information retrieval"
|
||||
@@ -504,10 +596,10 @@ class ACMEClient:
|
||||
|
||||
|
||||
def create_default_argspec(
|
||||
with_account=True,
|
||||
require_account_key=True,
|
||||
with_certificate=False,
|
||||
):
|
||||
with_account: bool = True,
|
||||
require_account_key: bool = True,
|
||||
with_certificate: bool = False,
|
||||
) -> ArgumentSpec:
|
||||
"""
|
||||
Provides default argument spec for the options documented in the acme doc fragment.
|
||||
"""
|
||||
@@ -544,7 +636,7 @@ def create_default_argspec(
|
||||
return result
|
||||
|
||||
|
||||
def create_backend(module, needs_acme_v2=True):
|
||||
def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoBackend:
|
||||
backend = module.params["select_crypto_backend"]
|
||||
|
||||
# Backend autodetect
|
||||
@@ -552,6 +644,7 @@ def create_backend(module, needs_acme_v2=True):
|
||||
backend = "cryptography" if HAS_CURRENT_CRYPTOGRAPHY else "openssl"
|
||||
|
||||
# Create backend object
|
||||
module_backend: CryptoBackend
|
||||
if backend == "cryptography":
|
||||
if CRYPTOGRAPHY_ERROR is not None:
|
||||
# Either we could not import cryptography at all, or there was an unexpected error
|
||||
|
||||
@@ -9,6 +9,7 @@ import base64
|
||||
import binascii
|
||||
import os
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
@@ -75,10 +76,19 @@ else:
|
||||
CRYPTOGRAPHY_MINIMAL_VERSION
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .certificates import CertificateChain, Criterium
|
||||
|
||||
|
||||
class CryptographyChainMatcher(ChainMatcher):
|
||||
@staticmethod
|
||||
def _parse_key_identifier(key_identifier, name, criterium_idx, module):
|
||||
def _parse_key_identifier(
|
||||
key_identifier: str | None, name: str, criterium_idx: int, module: AnsibleModule
|
||||
) -> bytes | None:
|
||||
if key_identifier:
|
||||
try:
|
||||
return binascii.unhexlify(key_identifier.replace(":", ""))
|
||||
@@ -94,11 +104,11 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
)
|
||||
return None
|
||||
|
||||
def __init__(self, criterium, module):
|
||||
def __init__(self, criterium: Criterium, module: AnsibleModule) -> None:
|
||||
self.criterium = criterium
|
||||
self.test_certificates = criterium.test_certificates
|
||||
self.subject = []
|
||||
self.issuer = []
|
||||
self.subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = []
|
||||
self.issuer: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = []
|
||||
if criterium.subject:
|
||||
self.subject = [
|
||||
(cryptography_name_to_oid(k), to_native(v))
|
||||
@@ -121,8 +131,13 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
criterium.index,
|
||||
module,
|
||||
)
|
||||
self.module = module
|
||||
|
||||
def _match_subject(self, x509_subject, match_subject):
|
||||
def _match_subject(
|
||||
self,
|
||||
x509_subject: cryptography.x509.Name,
|
||||
match_subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]],
|
||||
) -> bool:
|
||||
for oid, value in match_subject:
|
||||
found = False
|
||||
for attribute in x509_subject:
|
||||
@@ -133,7 +148,7 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
return False
|
||||
return True
|
||||
|
||||
def match(self, certificate):
|
||||
def match(self, certificate: CertificateChain) -> bool:
|
||||
"""
|
||||
Check whether an alternate chain matches the specified criterium.
|
||||
"""
|
||||
@@ -152,19 +167,22 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
matches = False
|
||||
if self.subject_key_identifier:
|
||||
try:
|
||||
ext = x509.extensions.get_extension_for_class(
|
||||
ext_ski = x509.extensions.get_extension_for_class(
|
||||
cryptography.x509.SubjectKeyIdentifier
|
||||
)
|
||||
if self.subject_key_identifier != ext.value.digest:
|
||||
if self.subject_key_identifier != ext_ski.value.digest:
|
||||
matches = False
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
matches = False
|
||||
if self.authority_key_identifier:
|
||||
try:
|
||||
ext = x509.extensions.get_extension_for_class(
|
||||
ext_aki = x509.extensions.get_extension_for_class(
|
||||
cryptography.x509.AuthorityKeyIdentifier
|
||||
)
|
||||
if self.authority_key_identifier != ext.value.key_identifier:
|
||||
if (
|
||||
self.authority_key_identifier
|
||||
!= ext_aki.value.key_identifier
|
||||
):
|
||||
matches = False
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
matches = False
|
||||
@@ -176,59 +194,68 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
|
||||
|
||||
class CryptographyBackend(CryptoBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CryptographyBackend, self).__init__(
|
||||
module, with_timezone=CRYPTOGRAPHY_TIMEZONE
|
||||
)
|
||||
|
||||
def parse_key(self, key_file=None, key_content=None, passphrase=None):
|
||||
def parse_key(
|
||||
self,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
|
||||
Raises KeyParsingError in case of errors.
|
||||
"""
|
||||
# If key_content is not given, read key_file
|
||||
if key_content is None:
|
||||
key_content = read_file(key_file)
|
||||
if key_file is None:
|
||||
raise KeyParsingError(
|
||||
"one of key_file and key_content must be specified"
|
||||
)
|
||||
b_key_content = read_file(key_file)
|
||||
else:
|
||||
key_content = to_bytes(key_content)
|
||||
b_key_content = to_bytes(key_content)
|
||||
# Parse key
|
||||
try:
|
||||
key = cryptography.hazmat.primitives.serialization.load_pem_private_key(
|
||||
key_content,
|
||||
b_key_content,
|
||||
password=to_bytes(passphrase) if passphrase is not None else None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise KeyParsingError(f"error while loading key: {e}")
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
|
||||
pk = key.public_key().public_numbers()
|
||||
rsa_pk = key.public_key().public_numbers()
|
||||
return {
|
||||
"key_obj": key,
|
||||
"type": "rsa",
|
||||
"alg": "RS256",
|
||||
"jwk": {
|
||||
"kty": "RSA",
|
||||
"e": nopad_b64(convert_int_to_bytes(pk.e)),
|
||||
"n": nopad_b64(convert_int_to_bytes(pk.n)),
|
||||
"e": nopad_b64(convert_int_to_bytes(rsa_pk.e)),
|
||||
"n": nopad_b64(convert_int_to_bytes(rsa_pk.n)),
|
||||
},
|
||||
"hash": "sha256",
|
||||
}
|
||||
elif isinstance(
|
||||
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
|
||||
):
|
||||
pk = key.public_key().public_numbers()
|
||||
if pk.curve.name == "secp256r1":
|
||||
ec_pk = key.public_key().public_numbers()
|
||||
if ec_pk.curve.name == "secp256r1":
|
||||
bits = 256
|
||||
alg = "ES256"
|
||||
hashalg = "sha256"
|
||||
point_size = 32
|
||||
curve = "P-256"
|
||||
elif pk.curve.name == "secp384r1":
|
||||
elif ec_pk.curve.name == "secp384r1":
|
||||
bits = 384
|
||||
alg = "ES384"
|
||||
hashalg = "sha384"
|
||||
point_size = 48
|
||||
curve = "P-384"
|
||||
elif pk.curve.name == "secp521r1":
|
||||
elif ec_pk.curve.name == "secp521r1":
|
||||
# Not yet supported on Let's Encrypt side, see
|
||||
# https://github.com/letsencrypt/boulder/issues/2217
|
||||
bits = 521
|
||||
@@ -237,7 +264,7 @@ class CryptographyBackend(CryptoBackend):
|
||||
point_size = 66
|
||||
curve = "P-521"
|
||||
else:
|
||||
raise KeyParsingError(f"unknown elliptic curve: {pk.curve.name}")
|
||||
raise KeyParsingError(f"unknown elliptic curve: {ec_pk.curve.name}")
|
||||
num_bytes = (bits + 7) // 8
|
||||
return {
|
||||
"key_obj": key,
|
||||
@@ -246,8 +273,8 @@ class CryptographyBackend(CryptoBackend):
|
||||
"jwk": {
|
||||
"kty": "EC",
|
||||
"crv": curve,
|
||||
"x": nopad_b64(convert_int_to_bytes(pk.x, count=num_bytes)),
|
||||
"y": nopad_b64(convert_int_to_bytes(pk.y, count=num_bytes)),
|
||||
"x": nopad_b64(convert_int_to_bytes(ec_pk.x, count=num_bytes)),
|
||||
"y": nopad_b64(convert_int_to_bytes(ec_pk.y, count=num_bytes)),
|
||||
},
|
||||
"hash": hashalg,
|
||||
"point_size": point_size,
|
||||
@@ -255,8 +282,11 @@ class CryptographyBackend(CryptoBackend):
|
||||
else:
|
||||
raise KeyParsingError(f'unknown key type "{type(key)}"')
|
||||
|
||||
def sign(self, payload64, protected64, key_data):
|
||||
def sign(
|
||||
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
) -> dict[str, t.Any]:
|
||||
sign_payload = f"{protected64}.{payload64}".encode("utf8")
|
||||
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
|
||||
if "mac_obj" in key_data:
|
||||
mac = key_data["mac_obj"]()
|
||||
mac.update(sign_payload)
|
||||
@@ -292,8 +322,9 @@ class CryptographyBackend(CryptoBackend):
|
||||
"signature": nopad_b64(signature),
|
||||
}
|
||||
|
||||
def create_mac_key(self, alg, key):
|
||||
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
|
||||
"""Create a MAC key."""
|
||||
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
|
||||
if alg == "HS256":
|
||||
hashalg = cryptography.hazmat.primitives.hashes.SHA256
|
||||
hashbytes = 32
|
||||
@@ -324,7 +355,11 @@ class CryptographyBackend(CryptoBackend):
|
||||
},
|
||||
}
|
||||
|
||||
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_ordered_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Return a list of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -334,15 +369,19 @@ class CryptographyBackend(CryptoBackend):
|
||||
as the first element in the result.
|
||||
"""
|
||||
if csr_content is None:
|
||||
csr_content = read_file(csr_filename)
|
||||
if csr_filename is None:
|
||||
raise BackendException(
|
||||
"One of csr_content and csr_filename has to be provided"
|
||||
)
|
||||
b_csr_content = read_file(csr_filename)
|
||||
else:
|
||||
csr_content = to_bytes(csr_content)
|
||||
csr = cryptography.x509.load_pem_x509_csr(csr_content)
|
||||
b_csr_content = to_bytes(csr_content)
|
||||
csr = cryptography.x509.load_pem_x509_csr(b_csr_content)
|
||||
|
||||
identifiers = set()
|
||||
result = []
|
||||
|
||||
def add_identifier(identifier):
|
||||
def add_identifier(identifier: tuple[str, str]) -> None:
|
||||
if identifier in identifiers:
|
||||
return
|
||||
identifiers.add(identifier)
|
||||
@@ -350,7 +389,7 @@ class CryptographyBackend(CryptoBackend):
|
||||
|
||||
for sub in csr.subject:
|
||||
if sub.oid == cryptography.x509.oid.NameOID.COMMON_NAME:
|
||||
add_identifier(("dns", sub.value))
|
||||
add_identifier(("dns", t.cast(str, sub.value)))
|
||||
for extension in csr.extensions:
|
||||
if (
|
||||
extension.oid
|
||||
@@ -367,7 +406,11 @@ class CryptographyBackend(CryptoBackend):
|
||||
)
|
||||
return result
|
||||
|
||||
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | bytes | None = None,
|
||||
) -> set[tuple[str, str]]:
|
||||
"""
|
||||
Return a set of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -379,7 +422,12 @@ class CryptographyBackend(CryptoBackend):
|
||||
)
|
||||
)
|
||||
|
||||
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
|
||||
def get_cert_days(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Return the days the certificate in cert_filename remains valid and -1
|
||||
if the file was not found. If cert_filename contains more than one
|
||||
@@ -398,10 +446,10 @@ class CryptographyBackend(CryptoBackend):
|
||||
return -1
|
||||
|
||||
# Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf.
|
||||
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
|
||||
b_cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
|
||||
|
||||
try:
|
||||
cert = cryptography.x509.load_pem_x509_certificate(cert_content)
|
||||
cert = cryptography.x509.load_pem_x509_certificate(b_cert_content)
|
||||
except Exception as e:
|
||||
if cert_filename is None:
|
||||
raise BackendException(f"Cannot parse certificate: {e}")
|
||||
@@ -413,13 +461,17 @@ class CryptographyBackend(CryptoBackend):
|
||||
now = add_or_remove_timezone(now, with_timezone=CRYPTOGRAPHY_TIMEZONE)
|
||||
return (get_not_valid_after(cert) - now).days
|
||||
|
||||
def create_chain_matcher(self, criterium):
|
||||
def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
|
||||
"""
|
||||
Given a Criterium object, creates a ChainMatcher object.
|
||||
"""
|
||||
return CryptographyChainMatcher(criterium, self.module)
|
||||
|
||||
def get_cert_information(self, cert_filename=None, cert_content=None):
|
||||
def get_cert_information(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
) -> CertificateInformation:
|
||||
"""
|
||||
Return some information on a X.509 certificate as a CertificateInformation object.
|
||||
"""
|
||||
@@ -429,10 +481,10 @@ class CryptographyBackend(CryptoBackend):
|
||||
cert_content = to_bytes(cert_content)
|
||||
|
||||
# Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf.
|
||||
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
|
||||
b_cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
|
||||
|
||||
try:
|
||||
cert = cryptography.x509.load_pem_x509_certificate(cert_content)
|
||||
cert = cryptography.x509.load_pem_x509_certificate(b_cert_content)
|
||||
except Exception as e:
|
||||
if cert_filename is None:
|
||||
raise BackendException(f"Cannot parse certificate: {e}")
|
||||
@@ -440,19 +492,19 @@ class CryptographyBackend(CryptoBackend):
|
||||
|
||||
ski = None
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(
|
||||
ext_ski = cert.extensions.get_extension_for_class(
|
||||
cryptography.x509.SubjectKeyIdentifier
|
||||
)
|
||||
ski = ext.value.digest
|
||||
ski = ext_ski.value.digest
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
pass
|
||||
|
||||
aki = None
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(
|
||||
ext_aki = cert.extensions.get_extension_for_class(
|
||||
cryptography.x509.AuthorityKeyIdentifier
|
||||
)
|
||||
aki = ext.value.key_identifier
|
||||
aki = ext_aki.value.key_identifier
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
pass
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
@@ -34,12 +35,23 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .certificates import Criterium
|
||||
|
||||
|
||||
_OPENSSL_ENVIRONMENT_UPDATE = dict(LANG="C", LC_ALL="C", LC_MESSAGES="C", LC_CTYPE="C")
|
||||
|
||||
|
||||
def _extract_date(out_text, name, cert_filename_suffix=""):
|
||||
def _extract_date(
|
||||
out_text: str, name: str, cert_filename_suffix: str = ""
|
||||
) -> datetime.datetime:
|
||||
matcher = re.search(rf"\s+{name}\s*:\s+(.*)", out_text)
|
||||
if matcher is None:
|
||||
raise BackendException(f"No '{name}' date found{cert_filename_suffix}")
|
||||
date_str = matcher.group(1)
|
||||
try:
|
||||
date_str = re.search(rf"\s+{name}\s*:\s+(.*)", out_text).group(1)
|
||||
# For some reason Python's strptime() does not return any timezone information,
|
||||
# even though the information is there and a supported timezone for all supported
|
||||
# Python implementations (GMT). So we have to modify the datetime object by
|
||||
@@ -47,19 +59,40 @@ def _extract_date(out_text, name, cert_filename_suffix=""):
|
||||
return ensure_utc_timezone(
|
||||
datetime.datetime.strptime(date_str, "%b %d %H:%M:%S %Y %Z")
|
||||
)
|
||||
except AttributeError:
|
||||
raise BackendException(f"No '{name}' date found{cert_filename_suffix}")
|
||||
except ValueError as exc:
|
||||
raise BackendException(
|
||||
f"Failed to parse '{name}' date{cert_filename_suffix}: {exc}"
|
||||
)
|
||||
|
||||
|
||||
def _decode_octets(octets_text):
|
||||
def _decode_octets(octets_text: str) -> bytes:
|
||||
return binascii.unhexlify(re.sub(r"(\s|:)", "", octets_text).encode("utf-8"))
|
||||
|
||||
|
||||
def _extract_octets(out_text, name, required=True, potential_prefixes=None):
|
||||
@t.overload
|
||||
def _extract_octets(
|
||||
out_text: str,
|
||||
name: str,
|
||||
required: t.Literal[False],
|
||||
potential_prefixes: t.Iterable[str] | None = None,
|
||||
) -> bytes | None: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def _extract_octets(
|
||||
out_text: str,
|
||||
name: str,
|
||||
required: t.Literal[True],
|
||||
potential_prefixes: t.Iterable[str] | None = None,
|
||||
) -> bytes: ...
|
||||
|
||||
|
||||
def _extract_octets(
|
||||
out_text: str,
|
||||
name: str,
|
||||
required: bool = True,
|
||||
potential_prefixes: t.Iterable[str] | None = None,
|
||||
) -> bytes | None:
|
||||
part = (
|
||||
f"(?:{'|'.join(re.escape(pp) for pp in potential_prefixes)})"
|
||||
if potential_prefixes
|
||||
@@ -75,13 +108,20 @@ def _extract_octets(out_text, name, required=True, potential_prefixes=None):
|
||||
|
||||
|
||||
class OpenSSLCLIBackend(CryptoBackend):
|
||||
def __init__(self, module, openssl_binary=None):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, openssl_binary: str | None = None
|
||||
) -> None:
|
||||
super(OpenSSLCLIBackend, self).__init__(module, with_timezone=True)
|
||||
if openssl_binary is None:
|
||||
openssl_binary = module.get_bin_path("openssl", True)
|
||||
self.openssl_binary = openssl_binary
|
||||
|
||||
def parse_key(self, key_file=None, key_content=None, passphrase=None):
|
||||
def parse_key(
|
||||
self,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
|
||||
Raises KeyParsingError in case of errors.
|
||||
@@ -90,6 +130,10 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
raise KeyParsingError("openssl backend does not support key passphrases")
|
||||
# If key_file is not given, but key_content, write that to a temporary file
|
||||
if key_file is None:
|
||||
if key_content is None:
|
||||
raise KeyParsingError(
|
||||
"one of key_file and key_content must be specified"
|
||||
)
|
||||
fd, tmpsrc = tempfile.mkstemp()
|
||||
self.module.add_cleanup_file(tmpsrc) # Ansible will delete the file on exit
|
||||
f = os.fdopen(fd, "wb")
|
||||
@@ -108,8 +152,8 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
f.close()
|
||||
# Parse key
|
||||
account_key_type = None
|
||||
with open(key_file, "rt") as f:
|
||||
for line in f:
|
||||
with open(key_file, "rt") as fi:
|
||||
for line in fi:
|
||||
m = re.match(
|
||||
r"^\s*-{5,}BEGIN\s+(EC|RSA)\s+PRIVATE\s+KEY-{5,}\s*$", line
|
||||
)
|
||||
@@ -129,38 +173,44 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
self.openssl_binary,
|
||||
account_key_type,
|
||||
"-in",
|
||||
key_file,
|
||||
str(key_file),
|
||||
"-noout",
|
||||
"-text",
|
||||
]
|
||||
rc, out, err = self.module.run_command(
|
||||
rc, out, stderr = self.module.run_command(
|
||||
openssl_keydump_cmd,
|
||||
check_rc=False,
|
||||
environ_update=_OPENSSL_ENVIRONMENT_UPDATE,
|
||||
)
|
||||
if rc != 0:
|
||||
raise BackendException(
|
||||
f"Error while running {' '.join(openssl_keydump_cmd)}: {err}"
|
||||
f"Error while running {' '.join(openssl_keydump_cmd)}: {stderr}"
|
||||
)
|
||||
|
||||
out_text = to_text(out, errors="surrogate_or_strict")
|
||||
|
||||
if account_key_type == "rsa":
|
||||
pub_hex = re.search(
|
||||
matcher = re.search(
|
||||
r"modulus:\n\s+00:([a-f0-9\:\s]+?)\npublicExponent",
|
||||
out_text,
|
||||
re.MULTILINE | re.DOTALL,
|
||||
).group(1)
|
||||
)
|
||||
if matcher is None:
|
||||
raise KeyParsingError("cannot parse RSA key: modulus not found")
|
||||
pub_hex = matcher.group(1)
|
||||
|
||||
pub_exp = re.search(
|
||||
matcher = re.search(
|
||||
r"\npublicExponent: ([0-9]+)", out_text, re.MULTILINE | re.DOTALL
|
||||
).group(1)
|
||||
)
|
||||
if matcher is None:
|
||||
raise KeyParsingError("cannot parse RSA key: public exponent not found")
|
||||
pub_exp = matcher.group(1)
|
||||
pub_exp = f"{int(pub_exp):x}"
|
||||
if len(pub_exp) % 2:
|
||||
pub_exp = f"0{pub_exp}"
|
||||
|
||||
return {
|
||||
"key_file": key_file,
|
||||
"key_file": str(key_file),
|
||||
"type": "rsa",
|
||||
"alg": "RS256",
|
||||
"jwk": {
|
||||
@@ -223,8 +273,13 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
"hash": hashalg,
|
||||
"point_size": point_size,
|
||||
}
|
||||
raise KeyParsingError(
|
||||
f"Internal error: unexpected account_key_type = {account_key_type!r}"
|
||||
)
|
||||
|
||||
def sign(self, payload64, protected64, key_data):
|
||||
def sign(
|
||||
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
) -> dict[str, t.Any]:
|
||||
sign_payload = f"{protected64}.{payload64}".encode("utf8")
|
||||
if key_data["type"] == "hmac":
|
||||
hex_key = (
|
||||
@@ -284,7 +339,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
"signature": nopad_b64(to_bytes(out)),
|
||||
}
|
||||
|
||||
def create_mac_key(self, alg, key):
|
||||
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
|
||||
"""Create a MAC key."""
|
||||
if alg == "HS256":
|
||||
hashalg = "sha256"
|
||||
@@ -315,14 +370,18 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_ip(ip):
|
||||
def _normalize_ip(ip: str) -> str:
|
||||
try:
|
||||
return ipaddress.ip_address(to_text(ip)).compressed
|
||||
return ipaddress.ip_address(ip).compressed
|
||||
except ValueError:
|
||||
# We do not want to error out on something IPAddress() cannot parse
|
||||
return ip
|
||||
|
||||
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_ordered_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Return a list of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -335,13 +394,13 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
data = None
|
||||
if csr_content is not None:
|
||||
filename = "/dev/stdin"
|
||||
data = csr_content.encode("utf-8")
|
||||
data = to_bytes(csr_content)
|
||||
|
||||
openssl_csr_cmd = [
|
||||
self.openssl_binary,
|
||||
"req",
|
||||
"-in",
|
||||
filename,
|
||||
str(filename),
|
||||
"-noout",
|
||||
"-text",
|
||||
]
|
||||
@@ -360,7 +419,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
identifiers = set()
|
||||
result = []
|
||||
|
||||
def add_identifier(identifier):
|
||||
def add_identifier(identifier: tuple[str, str]) -> None:
|
||||
if identifier in identifiers:
|
||||
return
|
||||
identifiers.add(identifier)
|
||||
@@ -389,7 +448,11 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
raise BackendException(f'Found unsupported SAN identifier "{san}"')
|
||||
return result
|
||||
|
||||
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> set[tuple[str, str]]:
|
||||
"""
|
||||
Return a set of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -401,7 +464,12 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
)
|
||||
)
|
||||
|
||||
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
|
||||
def get_cert_days(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Return the days the certificate in cert_filename remains valid and -1
|
||||
if the file was not found. If cert_filename contains more than one
|
||||
@@ -413,7 +481,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
data = None
|
||||
if cert_content is not None:
|
||||
filename = "/dev/stdin"
|
||||
data = cert_content.encode("utf-8")
|
||||
data = to_bytes(cert_content)
|
||||
cert_filename_suffix = ""
|
||||
elif cert_filename is not None:
|
||||
if not os.path.exists(cert_filename):
|
||||
@@ -426,7 +494,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
self.openssl_binary,
|
||||
"x509",
|
||||
"-in",
|
||||
filename,
|
||||
str(filename),
|
||||
"-noout",
|
||||
"-text",
|
||||
]
|
||||
@@ -452,7 +520,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
now = ensure_utc_timezone(now)
|
||||
return (not_after - now).days
|
||||
|
||||
def create_chain_matcher(self, criterium):
|
||||
def create_chain_matcher(self, criterium: Criterium) -> t.NoReturn:
|
||||
"""
|
||||
Given a Criterium object, creates a ChainMatcher object.
|
||||
"""
|
||||
@@ -460,7 +528,11 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
'Alternate chain matching can only be used with the "cryptography" backend.'
|
||||
)
|
||||
|
||||
def get_cert_information(self, cert_filename=None, cert_content=None):
|
||||
def get_cert_information(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
) -> CertificateInformation:
|
||||
"""
|
||||
Return some information on a X.509 certificate as a CertificateInformation object.
|
||||
"""
|
||||
@@ -477,7 +549,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
self.openssl_binary,
|
||||
"x509",
|
||||
"-in",
|
||||
filename,
|
||||
str(filename),
|
||||
"-noout",
|
||||
"-text",
|
||||
]
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import datetime
|
||||
import re
|
||||
from collections import namedtuple
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
BackendException,
|
||||
@@ -27,16 +27,20 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
CertificateInformation = namedtuple(
|
||||
"CertificateInformation",
|
||||
(
|
||||
"not_valid_after",
|
||||
"not_valid_before",
|
||||
"serial_number",
|
||||
"subject_key_identifier",
|
||||
"authority_key_identifier",
|
||||
),
|
||||
)
|
||||
if t.TYPE_CHECKING:
|
||||
import os
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .certificates import ChainMatcher, Criterium
|
||||
|
||||
|
||||
class CertificateInformation(t.NamedTuple):
|
||||
not_valid_after: datetime.datetime
|
||||
not_valid_before: datetime.datetime
|
||||
serial_number: int
|
||||
subject_key_identifier: bytes | None
|
||||
authority_key_identifier: bytes | None
|
||||
|
||||
|
||||
_FRACTIONAL_MATCHER = re.compile(
|
||||
@@ -44,7 +48,7 @@ _FRACTIONAL_MATCHER = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def _reduce_fractional_digits(timestamp_str):
|
||||
def _reduce_fractional_digits(timestamp_str: str) -> str:
|
||||
"""
|
||||
Given a RFC 3339 timestamp that includes too many digits for the fractional seconds part, reduces these to at most 6.
|
||||
"""
|
||||
@@ -60,7 +64,7 @@ def _reduce_fractional_digits(timestamp_str):
|
||||
return f"{timestamp}{fractional}{timezone}"
|
||||
|
||||
|
||||
def _parse_acme_timestamp(timestamp_str, with_timezone):
|
||||
def _parse_acme_timestamp(timestamp_str: str, with_timezone: bool) -> datetime.datetime:
|
||||
"""
|
||||
Parses a RFC 3339 timestamp.
|
||||
"""
|
||||
@@ -86,34 +90,42 @@ def _parse_acme_timestamp(timestamp_str, with_timezone):
|
||||
|
||||
|
||||
class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module, with_timezone=False):
|
||||
def __init__(self, module: AnsibleModule, with_timezone: bool = False) -> None:
|
||||
self.module = module
|
||||
self._with_timezone = with_timezone
|
||||
|
||||
def get_now(self):
|
||||
def get_now(self) -> datetime.datetime:
|
||||
return get_now_datetime(with_timezone=self._with_timezone)
|
||||
|
||||
def parse_acme_timestamp(self, timestamp_str):
|
||||
def parse_acme_timestamp(self, timestamp_str: str) -> datetime.datetime:
|
||||
# RFC 3339 (https://www.rfc-editor.org/info/rfc3339)
|
||||
return _parse_acme_timestamp(timestamp_str, with_timezone=self._with_timezone)
|
||||
|
||||
def parse_module_parameter(self, value, name):
|
||||
def parse_module_parameter(self, value: str, name: str) -> datetime.datetime:
|
||||
try:
|
||||
return get_relative_time_option(
|
||||
result = get_relative_time_option(
|
||||
value, name, with_timezone=self._with_timezone
|
||||
)
|
||||
if result is None:
|
||||
raise BackendException(f"Invalid value for {name}: {value!r}")
|
||||
return result
|
||||
except OpenSSLObjectError as exc:
|
||||
raise BackendException(str(exc))
|
||||
|
||||
def interpolate_timestamp(self, timestamp_start, timestamp_end, percentage):
|
||||
def interpolate_timestamp(
|
||||
self,
|
||||
timestamp_start: datetime.datetime,
|
||||
timestamp_end: datetime.datetime,
|
||||
percentage: float,
|
||||
) -> datetime.datetime:
|
||||
start = get_epoch_seconds(timestamp_start)
|
||||
end = get_epoch_seconds(timestamp_end)
|
||||
return from_epoch_seconds(
|
||||
start + percentage * (end - start), with_timezone=self._with_timezone
|
||||
)
|
||||
|
||||
def get_utc_datetime(self, *args, **kwargs):
|
||||
kwargs_ext = dict(kwargs)
|
||||
def get_utc_datetime(self, *args, **kwargs) -> datetime.datetime:
|
||||
kwargs_ext: dict[str, t.Any] = dict(kwargs)
|
||||
if self._with_timezone and ("tzinfo" not in kwargs_ext and len(args) < 8):
|
||||
kwargs_ext["tzinfo"] = UTC
|
||||
result = datetime.datetime(*args, **kwargs_ext)
|
||||
@@ -122,22 +134,33 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
return result
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_key(self, key_file=None, key_content=None, passphrase=None):
|
||||
def parse_key(
|
||||
self,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
|
||||
Raises KeyParsingError in case of errors.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sign(self, payload64, protected64, key_data):
|
||||
def sign(
|
||||
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
) -> dict[str, t.Any]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_mac_key(self, alg, key):
|
||||
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
|
||||
"""Create a MAC key."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_ordered_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Return a list of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -148,7 +171,11 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> set[tuple[str, str]]:
|
||||
"""
|
||||
Return a set of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -156,7 +183,12 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
|
||||
def get_cert_days(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Return the days the certificate in cert_filename remains valid and -1
|
||||
if the file was not found. If cert_filename contains more than one
|
||||
@@ -166,13 +198,17 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_chain_matcher(self, criterium):
|
||||
def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
|
||||
"""
|
||||
Given a Criterium object, creates a ChainMatcher object.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_cert_information(self, cert_filename=None, cert_content=None):
|
||||
def get_cert_information(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
) -> CertificateInformation:
|
||||
"""
|
||||
Return some information on a X.509 certificate as a CertificateInformation object.
|
||||
"""
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
@@ -30,6 +31,14 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .backends import CryptoBackend
|
||||
from .certificates import ChainMatcher
|
||||
from .challenges import Challenge
|
||||
|
||||
|
||||
class ACMECertificateClient:
|
||||
"""
|
||||
ACME v2 client class. Uses an ACME account object and a CSR to
|
||||
@@ -37,7 +46,13 @@ class ACMECertificateClient:
|
||||
certificates.
|
||||
"""
|
||||
|
||||
def __init__(self, module, backend, client=None, account=None):
|
||||
def __init__(
|
||||
self,
|
||||
module: AnsibleModule,
|
||||
backend: CryptoBackend,
|
||||
client: ACMEClient | None = None,
|
||||
account: ACMEAccount | None = None,
|
||||
) -> None:
|
||||
self.module = module
|
||||
self.version = module.params["acme_version"]
|
||||
self.csr = module.params.get("csr")
|
||||
@@ -66,13 +81,17 @@ class ACMECertificateClient:
|
||||
|
||||
# Extract list of identifiers from CSR
|
||||
if self.csr is not None or self.csr_content is not None:
|
||||
self.identifiers = self.client.backend.get_ordered_csr_identifiers(
|
||||
csr_filename=self.csr, csr_content=self.csr_content
|
||||
self.identifiers: list[tuple[str, str]] | None = (
|
||||
self.client.backend.get_ordered_csr_identifiers(
|
||||
csr_filename=self.csr, csr_content=self.csr_content
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.identifiers = None
|
||||
|
||||
def parse_select_chain(self, select_chain):
|
||||
def parse_select_chain(
|
||||
self, select_chain: list[dict[str, t.Any]] | None
|
||||
) -> list[ChainMatcher]:
|
||||
select_chain_matcher = []
|
||||
if select_chain:
|
||||
for criterium_idx, criterium in enumerate(select_chain):
|
||||
@@ -88,14 +107,16 @@ class ACMECertificateClient:
|
||||
)
|
||||
return select_chain_matcher
|
||||
|
||||
def load_order(self):
|
||||
def load_order(self) -> Order:
|
||||
if not self.order_uri:
|
||||
raise ModuleFailException("The order URI has not been provided")
|
||||
order = Order.from_url(self.client, self.order_uri)
|
||||
order.load_authorizations(self.client)
|
||||
return order
|
||||
|
||||
def create_order(self, replaces_cert_id=None, profile=None):
|
||||
def create_order(
|
||||
self, replaces_cert_id: str | None = None, profile: str | None = None
|
||||
) -> Order:
|
||||
"""
|
||||
Create a new order.
|
||||
"""
|
||||
@@ -114,31 +135,31 @@ class ACMECertificateClient:
|
||||
order.load_authorizations(self.client)
|
||||
return order
|
||||
|
||||
def get_challenges_data(self, order):
|
||||
def get_challenges_data(
|
||||
self, order: Order
|
||||
) -> tuple[list[dict[str, t.Any]], dict[str, list[str]]]:
|
||||
"""
|
||||
Get challenge details.
|
||||
|
||||
Return a tuple of generic challenge details, and specialized DNS challenge details.
|
||||
"""
|
||||
# Get general challenge data
|
||||
data = []
|
||||
data: list[dict[str, t.Any]] = []
|
||||
data_dns: dict[str, list[str]] = {}
|
||||
dns_challenge_type = "dns-01"
|
||||
for authz in order.authorizations.values():
|
||||
# Skip valid authentications: their challenges are already valid
|
||||
# and do not need to be returned
|
||||
if authz.status == "valid":
|
||||
continue
|
||||
challenge_data = authz.get_challenge_data(self.client)
|
||||
data.append(
|
||||
dict(
|
||||
identifier=authz.identifier,
|
||||
identifier_type=authz.identifier_type,
|
||||
challenges=authz.get_challenge_data(self.client),
|
||||
challenges=challenge_data,
|
||||
)
|
||||
)
|
||||
# Get DNS challenge data
|
||||
data_dns = {}
|
||||
dns_challenge_type = "dns-01"
|
||||
for entry in data:
|
||||
dns_challenge = entry["challenges"].get(dns_challenge_type)
|
||||
dns_challenge = challenge_data.get(dns_challenge_type)
|
||||
if dns_challenge:
|
||||
values = data_dns.get(dns_challenge["record"])
|
||||
if values is None:
|
||||
@@ -147,7 +168,7 @@ class ACMECertificateClient:
|
||||
values.append(dns_challenge["resource_value"])
|
||||
return data, data_dns
|
||||
|
||||
def check_that_authorizations_can_be_used(self, order):
|
||||
def check_that_authorizations_can_be_used(self, order: Order) -> None:
|
||||
bad_authzs = []
|
||||
for authz in order.authorizations.values():
|
||||
if authz.status not in ("valid", "pending"):
|
||||
@@ -155,27 +176,32 @@ class ACMECertificateClient:
|
||||
f"{authz.combined_identifier} (status={authz.status!r})"
|
||||
)
|
||||
if bad_authzs:
|
||||
bad_authzs = ", ".join(sorted(bad_authzs))
|
||||
bad_authzs_str = ", ".join(sorted(bad_authzs))
|
||||
raise ModuleFailException(
|
||||
"Some of the authorizations for the order are in a bad state, so the order"
|
||||
f" can no longer be satisfied: {bad_authzs}",
|
||||
f" can no longer be satisfied: {bad_authzs_str}",
|
||||
)
|
||||
|
||||
def collect_invalid_authzs(self, order):
|
||||
def collect_invalid_authzs(self, order: Order) -> list[Authorization]:
|
||||
return [
|
||||
authz
|
||||
for authz in order.authorizations.values()
|
||||
if authz.status == "invalid"
|
||||
]
|
||||
|
||||
def collect_pending_authzs(self, order):
|
||||
def collect_pending_authzs(self, order: Order) -> list[Authorization]:
|
||||
return [
|
||||
authz
|
||||
for authz in order.authorizations.values()
|
||||
if authz.status == "pending"
|
||||
]
|
||||
|
||||
def call_validate(self, pending_authzs, get_challenge, wait=True):
|
||||
def call_validate(
|
||||
self,
|
||||
pending_authzs: list[Authorization],
|
||||
get_challenge: t.Callable[[Authorization], str],
|
||||
wait: bool = True,
|
||||
) -> list[tuple[Authorization, str, Challenge | None]]:
|
||||
authzs_with_challenges_to_wait_for = []
|
||||
for authz in pending_authzs:
|
||||
challenge_type = get_challenge(authz)
|
||||
@@ -185,10 +211,12 @@ class ACMECertificateClient:
|
||||
)
|
||||
return authzs_with_challenges_to_wait_for
|
||||
|
||||
def wait_for_validation(self, authzs_to_wait_for):
|
||||
def wait_for_validation(self, authzs_to_wait_for: list[Authorization]) -> None:
|
||||
wait_for_validation(authzs_to_wait_for, self.client)
|
||||
|
||||
def _download_alternate_chains(self, cert):
|
||||
def _download_alternate_chains(
|
||||
self, cert: CertificateChain
|
||||
) -> list[CertificateChain]:
|
||||
alternate_chains = []
|
||||
for alternate in cert.alternates:
|
||||
try:
|
||||
@@ -206,13 +234,30 @@ class ACMECertificateClient:
|
||||
)
|
||||
return alternate_chains
|
||||
|
||||
def download_certificate(self, order, download_all_chains=True):
|
||||
@t.overload
|
||||
def download_certificate(
|
||||
self, order: Order, *, download_all_chains: t.Literal[True] = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain]]: ...
|
||||
|
||||
@t.overload
|
||||
def download_certificate(
|
||||
self, order: Order, *, download_all_chains: t.Literal[False]
|
||||
) -> tuple[CertificateChain, None]: ...
|
||||
|
||||
@t.overload
|
||||
def download_certificate(
|
||||
self, order: Order, *, download_all_chains: bool = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
|
||||
|
||||
def download_certificate(
|
||||
self, order: Order, *, download_all_chains: bool = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]:
|
||||
"""
|
||||
Download certificate from a valid oder.
|
||||
"""
|
||||
if order.status != "valid":
|
||||
raise ModuleFailException(
|
||||
f"The order must be valid, but has state {order.state!r}!"
|
||||
f"The order must be valid, but has state {order.status!r}!"
|
||||
)
|
||||
|
||||
if not order.certificate_uri:
|
||||
@@ -232,7 +277,24 @@ class ACMECertificateClient:
|
||||
|
||||
return cert, alternate_chains
|
||||
|
||||
def get_certificate(self, order, download_all_chains=True):
|
||||
@t.overload
|
||||
def get_certificate(
|
||||
self, order: Order, *, download_all_chains: t.Literal[True] = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
|
||||
|
||||
@t.overload
|
||||
def get_certificate(
|
||||
self, order: Order, *, download_all_chains: t.Literal[False]
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
|
||||
|
||||
@t.overload
|
||||
def get_certificate(
|
||||
self, order: Order, *, download_all_chains: bool = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
|
||||
|
||||
def get_certificate(
|
||||
self, order: Order, *, download_all_chains: bool = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]:
|
||||
"""
|
||||
Request a new certificate and downloads it, and optionally all certificate chains.
|
||||
First verifies whether all authorizations are valid; if not, aborts with an error.
|
||||
@@ -250,7 +312,11 @@ class ACMECertificateClient:
|
||||
|
||||
return self.download_certificate(order, download_all_chains=download_all_chains)
|
||||
|
||||
def find_matching_chain(self, chains, select_chain_matcher):
|
||||
def find_matching_chain(
|
||||
self,
|
||||
chains: list[CertificateChain],
|
||||
select_chain_matcher: t.Iterable[ChainMatcher],
|
||||
) -> CertificateChain | None:
|
||||
for criterium_idx, matcher in enumerate(select_chain_matcher):
|
||||
for chain in chains:
|
||||
if matcher.match(chain):
|
||||
@@ -261,9 +327,15 @@ class ACMECertificateClient:
|
||||
return None
|
||||
|
||||
def write_cert_chain(
|
||||
self, cert, cert_dest=None, fullchain_dest=None, chain_dest=None
|
||||
):
|
||||
self,
|
||||
cert: CertificateChain,
|
||||
cert_dest: str | os.PathLike | None = None,
|
||||
fullchain_dest: str | os.PathLike | None = None,
|
||||
chain_dest: str | os.PathLike | None = None,
|
||||
) -> bool:
|
||||
changed = False
|
||||
if cert.cert is None:
|
||||
raise ValueError("Certificate is not present")
|
||||
|
||||
if cert_dest and write_file(self.module, cert_dest, cert.cert.encode("utf8")):
|
||||
changed = True
|
||||
@@ -282,7 +354,7 @@ class ACMECertificateClient:
|
||||
|
||||
return changed
|
||||
|
||||
def deactivate_authzs(self, order):
|
||||
def deactivate_authzs(self, order: Order) -> None:
|
||||
"""
|
||||
Deactivates all valid authz's. Does not raise exceptions.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
ModuleFailException,
|
||||
@@ -19,20 +20,29 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .acme import ACMEClient
|
||||
|
||||
|
||||
_CertificateChain = t.TypeVar("_CertificateChain", bound="CertificateChain")
|
||||
|
||||
|
||||
class CertificateChain:
|
||||
"""
|
||||
Download and parse the certificate chain.
|
||||
https://tools.ietf.org/html/rfc8555#section-7.4.2
|
||||
"""
|
||||
|
||||
def __init__(self, url):
|
||||
def __init__(self, url: str):
|
||||
self.url = url
|
||||
self.cert = None
|
||||
self.chain = []
|
||||
self.alternates = []
|
||||
self.cert: str | None = None
|
||||
self.chain: list[str] = []
|
||||
self.alternates: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def download(cls, client, url):
|
||||
def download(
|
||||
cls: t.Type[_CertificateChain], client: ACMEClient, url: str
|
||||
) -> _CertificateChain:
|
||||
content, info = client.get_request(
|
||||
url,
|
||||
parse_json_result=False,
|
||||
@@ -43,7 +53,7 @@ class CertificateChain:
|
||||
"application/pem-certificate-chain"
|
||||
):
|
||||
raise ModuleFailException(
|
||||
f"Cannot download certificate chain from {url}, as content type is not application/pem-certificate-chain: {content} (headers: {info})"
|
||||
f"Cannot download certificate chain from {url}, as content type is not application/pem-certificate-chain: {content!r} (headers: {info})"
|
||||
)
|
||||
|
||||
result = cls(url)
|
||||
@@ -60,12 +70,12 @@ class CertificateChain:
|
||||
|
||||
if result.cert is None:
|
||||
raise ModuleFailException(
|
||||
f"Failed to parse certificate chain download from {url}: {content} (headers: {info})"
|
||||
f"Failed to parse certificate chain download from {url}: {content!r} (headers: {info})"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _process_links(self, client, link, relation):
|
||||
def _process_links(self, client: ACMEClient, link: str, relation: str) -> None:
|
||||
if relation == "up":
|
||||
# Process link-up headers if there was no chain in reply
|
||||
if not self.chain:
|
||||
@@ -77,7 +87,9 @@ class CertificateChain:
|
||||
elif relation == "alternate":
|
||||
self.alternates.append(link)
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self) -> dict[str, bytes]:
|
||||
if self.cert is None:
|
||||
raise ValueError("Has no certificate")
|
||||
cert = self.cert.encode("utf8")
|
||||
chain = ("\n".join(self.chain)).encode("utf8")
|
||||
return {
|
||||
@@ -88,18 +100,22 @@ class CertificateChain:
|
||||
|
||||
|
||||
class Criterium:
|
||||
def __init__(self, criterium, index=None):
|
||||
def __init__(self, criterium: dict[str, t.Any], index: int):
|
||||
self.index = index
|
||||
self.test_certificates = criterium["test_certificates"]
|
||||
self.subject = criterium["subject"]
|
||||
self.issuer = criterium["issuer"]
|
||||
self.subject_key_identifier = criterium["subject_key_identifier"]
|
||||
self.authority_key_identifier = criterium["authority_key_identifier"]
|
||||
self.test_certificates: t.Literal["first", "last", "all"] = criterium[
|
||||
"test_certificates"
|
||||
]
|
||||
self.subject: dict[str, t.Any] | None = criterium["subject"]
|
||||
self.issuer: dict[str, t.Any] | None = criterium["issuer"]
|
||||
self.subject_key_identifier: str | None = criterium["subject_key_identifier"]
|
||||
self.authority_key_identifier: str | None = criterium[
|
||||
"authority_key_identifier"
|
||||
]
|
||||
|
||||
|
||||
class ChainMatcher(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def match(self, certificate):
|
||||
def match(self, certificate: CertificateChain) -> bool:
|
||||
"""
|
||||
Check whether a certificate chain (CertificateChain instance) matches.
|
||||
"""
|
||||
|
||||
@@ -11,6 +11,7 @@ import ipaddress
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
@@ -23,7 +24,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
def create_key_authorization(client, token):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .acme import ACMEClient
|
||||
|
||||
|
||||
def create_key_authorization(client: ACMEClient, token: str) -> str:
|
||||
"""
|
||||
Returns the key authorization for the given token
|
||||
https://tools.ietf.org/html/rfc8555#section-8.1
|
||||
@@ -35,41 +42,49 @@ def create_key_authorization(client, token):
|
||||
return f"{token}.{thumbprint}"
|
||||
|
||||
|
||||
def combine_identifier(identifier_type, identifier):
|
||||
def combine_identifier(identifier_type: str, identifier: str) -> str:
|
||||
return f"{identifier_type}:{identifier}"
|
||||
|
||||
|
||||
def normalize_combined_identifier(identifier):
|
||||
def normalize_combined_identifier(identifier: str) -> str:
|
||||
identifier_type, identifier = split_identifier(identifier)
|
||||
# Normalize DNS names and IPs
|
||||
identifier = identifier.lower()
|
||||
return combine_identifier(identifier_type, identifier)
|
||||
|
||||
|
||||
def split_identifier(identifier):
|
||||
def split_identifier(identifier: str) -> tuple[str, str]:
|
||||
parts = identifier.split(":", 1)
|
||||
if len(parts) != 2:
|
||||
raise ModuleFailException(
|
||||
f'Identifier "{identifier}" is not of the form <type>:<identifier>'
|
||||
)
|
||||
return parts
|
||||
return parts[0], parts[1]
|
||||
|
||||
|
||||
_Challenge = t.TypeVar("_Challenge", bound="Challenge")
|
||||
|
||||
|
||||
class Challenge:
|
||||
def __init__(self, data, url):
|
||||
def __init__(self, data: dict[str, t.Any], url: str) -> None:
|
||||
self.data = data
|
||||
|
||||
self.type = data["type"]
|
||||
self.type: str = data["type"]
|
||||
self.url = url
|
||||
self.status = data["status"]
|
||||
self.token = data.get("token")
|
||||
self.status: str = data["status"]
|
||||
self.token: str | None = data.get("token")
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, client, data, url=None):
|
||||
def from_json(
|
||||
cls: t.Type[_Challenge],
|
||||
client: ACMEClient,
|
||||
data: dict[str, t.Any],
|
||||
url: str | None = None,
|
||||
) -> _Challenge:
|
||||
return cls(data, url or data["url"])
|
||||
|
||||
def call_validate(self, client):
|
||||
challenge_response = {}
|
||||
def call_validate(self, client: ACMEClient) -> None:
|
||||
challenge_response: dict[str, t.Any] = {}
|
||||
client.send_signed_request(
|
||||
self.url,
|
||||
challenge_response,
|
||||
@@ -77,10 +92,15 @@ class Challenge:
|
||||
expected_status_codes=[200, 202],
|
||||
)
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self) -> dict[str, t.Any]:
|
||||
return self.data.copy()
|
||||
|
||||
def get_validation_data(self, client, identifier_type, identifier):
|
||||
def get_validation_data(
|
||||
self, client: ACMEClient, identifier_type: str, identifier: str
|
||||
) -> dict[str, t.Any] | None:
|
||||
if self.token is None:
|
||||
return None
|
||||
|
||||
token = re.sub(r"[^A-Za-z0-9_\-]", "_", self.token)
|
||||
key_authorization = create_key_authorization(client, token)
|
||||
|
||||
@@ -113,21 +133,33 @@ class Challenge:
|
||||
resource += "."
|
||||
else:
|
||||
resource = identifier
|
||||
value = base64.b64encode(
|
||||
b_value = base64.b64encode(
|
||||
hashlib.sha256(to_bytes(key_authorization)).digest()
|
||||
)
|
||||
return {
|
||||
"resource": resource,
|
||||
"resource_original": combine_identifier(identifier_type, identifier),
|
||||
"resource_value": value,
|
||||
"resource_value": b_value,
|
||||
}
|
||||
|
||||
# Unknown challenge type: ignore
|
||||
return None
|
||||
|
||||
|
||||
_Authorization = t.TypeVar("_Authorization", bound="Authorization")
|
||||
|
||||
|
||||
class Authorization:
|
||||
def _setup(self, client, data):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
|
||||
self.data: dict[str, t.Any] | None = None
|
||||
self.challenges: list[Challenge] = []
|
||||
self.status: str | None = None
|
||||
self.identifier_type: str | None = None
|
||||
self.identifier: str | None = None
|
||||
|
||||
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
|
||||
data["uri"] = self.url
|
||||
self.data = data
|
||||
# While 'challenges' is a required field, apparently not every CA cares
|
||||
@@ -145,29 +177,32 @@ class Authorization:
|
||||
if data.get("wildcard", False):
|
||||
self.identifier = f"*.{self.identifier}"
|
||||
|
||||
def __init__(self, url):
|
||||
self.url = url
|
||||
|
||||
self.data = None
|
||||
self.challenges = []
|
||||
self.status = None
|
||||
self.identifier_type = None
|
||||
self.identifier = None
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, client, data, url):
|
||||
def from_json(
|
||||
cls: t.Type[_Authorization],
|
||||
client: ACMEClient,
|
||||
data: dict[str, t.Any],
|
||||
url: str,
|
||||
) -> _Authorization:
|
||||
result = cls(url)
|
||||
result._setup(client, data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, client, url):
|
||||
def from_url(
|
||||
cls: t.Type[_Authorization], client: ACMEClient, url: str
|
||||
) -> _Authorization:
|
||||
result = cls(url)
|
||||
result.refresh(client)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def create(cls, client, identifier_type, identifier):
|
||||
def create(
|
||||
cls: t.Type[_Authorization],
|
||||
client: ACMEClient,
|
||||
identifier_type: str,
|
||||
identifier: str,
|
||||
) -> _Authorization:
|
||||
"""
|
||||
Create a new authorization for the given identifier.
|
||||
Return the authorization object of the new authorization
|
||||
@@ -194,23 +229,29 @@ class Authorization:
|
||||
return cls.from_json(client, result, info["location"])
|
||||
|
||||
@property
|
||||
def combined_identifier(self):
|
||||
def combined_identifier(self) -> str:
|
||||
if self.identifier_type is None or self.identifier is None:
|
||||
raise ValueError("Data not present")
|
||||
return combine_identifier(self.identifier_type, self.identifier)
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self) -> dict[str, t.Any]:
|
||||
if self.data is None:
|
||||
raise ValueError("Data not present")
|
||||
return self.data.copy()
|
||||
|
||||
def refresh(self, client):
|
||||
def refresh(self, client: ACMEClient) -> bool:
|
||||
result, dummy = client.get_request(self.url)
|
||||
changed = self.data != result
|
||||
self._setup(client, result)
|
||||
return changed
|
||||
|
||||
def get_challenge_data(self, client):
|
||||
def get_challenge_data(self, client: ACMEClient) -> dict[str, t.Any]:
|
||||
"""
|
||||
Returns a dict with the data for all proposed (and supported) challenges
|
||||
of the given authorization.
|
||||
"""
|
||||
if self.identifier_type is None or self.identifier is None:
|
||||
raise ValueError("Data not present")
|
||||
data = {}
|
||||
for challenge in self.challenges:
|
||||
validation_data = challenge.get_validation_data(
|
||||
@@ -220,7 +261,7 @@ class Authorization:
|
||||
data[challenge.type] = validation_data
|
||||
return data
|
||||
|
||||
def raise_error(self, error_msg, module=None):
|
||||
def raise_error(self, error_msg: str, module: AnsibleModule) -> t.NoReturn:
|
||||
"""
|
||||
Aborts with a specific error for a challenge.
|
||||
"""
|
||||
@@ -246,13 +287,13 @@ class Authorization:
|
||||
),
|
||||
)
|
||||
|
||||
def find_challenge(self, challenge_type):
|
||||
def find_challenge(self, challenge_type: str) -> Challenge | None:
|
||||
for challenge in self.challenges:
|
||||
if challenge_type == challenge.type:
|
||||
return challenge
|
||||
return None
|
||||
|
||||
def wait_for_validation(self, client, callenge_type):
|
||||
def wait_for_validation(self, client: ACMEClient, callenge_type: str) -> bool:
|
||||
while True:
|
||||
self.refresh(client)
|
||||
if self.status in ["valid", "invalid", "revoked"]:
|
||||
@@ -264,7 +305,9 @@ class Authorization:
|
||||
|
||||
return self.status == "valid"
|
||||
|
||||
def call_validate(self, client, challenge_type, wait=True):
|
||||
def call_validate(
|
||||
self, client: ACMEClient, challenge_type: str, wait: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the authorization provided in the auth dict. Returns True
|
||||
when the validation was successful and False when it was not.
|
||||
@@ -281,7 +324,7 @@ class Authorization:
|
||||
return self.status == "valid"
|
||||
return self.wait_for_validation(client, challenge_type)
|
||||
|
||||
def can_deactivate(self):
|
||||
def can_deactivate(self) -> bool:
|
||||
"""
|
||||
Deactivates this authorization.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
@@ -289,14 +332,14 @@ class Authorization:
|
||||
"""
|
||||
return self.status in ("valid", "pending")
|
||||
|
||||
def deactivate(self, client):
|
||||
def deactivate(self, client: ACMEClient) -> bool | None:
|
||||
"""
|
||||
Deactivates this authorization.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
https://tools.ietf.org/html/rfc8555#section-7.5.2
|
||||
"""
|
||||
if not self.can_deactivate():
|
||||
return
|
||||
return None
|
||||
authz_deactivate = {"status": "deactivated"}
|
||||
result, info = client.send_signed_request(
|
||||
self.url, authz_deactivate, fail_on_error=False
|
||||
@@ -307,7 +350,9 @@ class Authorization:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def deactivate_url(cls, client, url):
|
||||
def deactivate_url(
|
||||
cls: t.Type[_Authorization], client: ACMEClient, url: str
|
||||
) -> _Authorization:
|
||||
"""
|
||||
Deactivates this authorization.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
@@ -322,7 +367,7 @@ class Authorization:
|
||||
return authz
|
||||
|
||||
|
||||
def wait_for_validation(authzs, client):
|
||||
def wait_for_validation(authzs: t.Iterable[Authorization], client: ACMEClient) -> None:
|
||||
"""
|
||||
Wait until a list of authz is valid. Fail if at least one of them is invalid or revoked.
|
||||
"""
|
||||
|
||||
@@ -5,19 +5,24 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from http.client import responses as http_responses
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_text
|
||||
|
||||
|
||||
def format_http_status(status_code):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def format_http_status(status_code: int) -> str:
|
||||
expl = http_responses.get(status_code)
|
||||
if not expl:
|
||||
return str(status_code)
|
||||
return f"{status_code} {expl}"
|
||||
|
||||
|
||||
def format_error_problem(problem, subproblem_prefix=""):
|
||||
def format_error_problem(problem: dict[str, t.Any], subproblem_prefix: str = "") -> str:
|
||||
error_type = problem.get(
|
||||
"type", "about:blank"
|
||||
) # https://www.rfc-editor.org/rfc/rfc7807#section-3.1
|
||||
@@ -32,8 +37,10 @@ def format_error_problem(problem, subproblem_prefix=""):
|
||||
msg = f"{msg} Subproblems:"
|
||||
for index, problem in enumerate(subproblems):
|
||||
index_str = f"{subproblem_prefix}{index}"
|
||||
problem = format_error_problem(problem, subproblem_prefix=f"{index_str}.")
|
||||
msg = f"{msg}\n({index_str}) {problem}"
|
||||
problem_str = format_error_problem(
|
||||
problem, subproblem_prefix=f"{index_str}."
|
||||
)
|
||||
msg = f"{msg}\n({index_str}) {problem_str}"
|
||||
return msg
|
||||
|
||||
|
||||
@@ -42,25 +49,25 @@ class ModuleFailException(Exception):
|
||||
If raised, module.fail_json() will be called with the given parameters after cleanup.
|
||||
"""
|
||||
|
||||
def __init__(self, msg, **args):
|
||||
def __init__(self, msg: str, **args: t.Any) -> None:
|
||||
super(ModuleFailException, self).__init__(self, msg)
|
||||
self.msg = msg
|
||||
self.module_fail_args = args
|
||||
|
||||
def do_fail(self, module, **arguments):
|
||||
def do_fail(self, module: AnsibleModule, **arguments) -> t.NoReturn:
|
||||
module.fail_json(msg=self.msg, other=self.module_fail_args, **arguments)
|
||||
|
||||
|
||||
class ACMEProtocolException(ModuleFailException):
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
msg=None,
|
||||
info=None,
|
||||
module: AnsibleModule,
|
||||
msg: str | None = None,
|
||||
info: dict[str, t.Any] | None = None,
|
||||
response=None,
|
||||
content=None,
|
||||
content_json=None,
|
||||
extras=None,
|
||||
content: bytes | None = None,
|
||||
content_json: dict[str, t.Any] | None = None,
|
||||
extras: dict[str, t.Any] | None = None,
|
||||
):
|
||||
# Try to get hold of content, if response is given and content is not provided
|
||||
if content is None and content_json is None and response is not None:
|
||||
@@ -71,7 +78,8 @@ class ACMEProtocolException(ModuleFailException):
|
||||
raise TypeError
|
||||
content = response.read()
|
||||
except (AttributeError, TypeError):
|
||||
content = info.pop("body", None)
|
||||
if info is not None:
|
||||
content = info.pop("body", None)
|
||||
|
||||
# Make sure that content_json is None or a dictionary
|
||||
if content_json is not None and not isinstance(content_json, dict):
|
||||
@@ -139,8 +147,8 @@ class ACMEProtocolException(ModuleFailException):
|
||||
add_msg = f" The raw result: {to_text(content)}"
|
||||
|
||||
super(ACMEProtocolException, self).__init__(f"{msg}.{add_msg}", **extras)
|
||||
self.problem = {}
|
||||
self.subproblems = []
|
||||
self.problem: dict[str, t.Any] = {}
|
||||
self.subproblems: list[dict[str, t.Any]] = []
|
||||
self.error_code = error_code
|
||||
self.error_type = error_type
|
||||
for k, v in extras.items():
|
||||
|
||||
@@ -10,22 +10,27 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
ModuleFailException,
|
||||
)
|
||||
|
||||
|
||||
def read_file(fn, mode="b"):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def read_file(fn: str | os.PathLike) -> bytes:
|
||||
try:
|
||||
with open(fn, "r" + mode) as f:
|
||||
with open(fn, "rb") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
raise ModuleFailException(f'Error while reading file "{fn}": {e}')
|
||||
|
||||
|
||||
# This function was adapted from an earlier version of https://github.com/ansible/ansible/blob/devel/lib/ansible/modules/uri.py
|
||||
def write_file(module, dest, content):
|
||||
def write_file(module: AnsibleModule, dest: str | os.PathLike, content: bytes) -> bool:
|
||||
"""
|
||||
Write content to destination file dest, only if the content
|
||||
has changed.
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
|
||||
Authorization,
|
||||
@@ -13,14 +14,35 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.challenges i
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
ACMEProtocolException,
|
||||
ModuleFailException,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.utils import (
|
||||
nopad_b64,
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .acme import ACMEClient
|
||||
|
||||
|
||||
_Order = t.TypeVar("_Order", bound="Order")
|
||||
|
||||
|
||||
class Order:
|
||||
def _setup(self, client, data):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
|
||||
self.data: dict[str, t.Any] | None = None
|
||||
|
||||
self.status = None
|
||||
self.identifiers: list[tuple[str, str]] = []
|
||||
self.replaces_cert_id = None
|
||||
self.finalize_uri = None
|
||||
self.certificate_uri = None
|
||||
self.authorization_uris: list[str] = []
|
||||
self.authorizations: dict[str, Authorization] = {}
|
||||
|
||||
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
|
||||
self.data = data
|
||||
|
||||
self.status = data["status"]
|
||||
@@ -33,33 +55,28 @@ class Order:
|
||||
self.authorization_uris = data["authorizations"]
|
||||
self.authorizations = {}
|
||||
|
||||
def __init__(self, url):
|
||||
self.url = url
|
||||
|
||||
self.data = None
|
||||
|
||||
self.status = None
|
||||
self.identifiers = []
|
||||
self.replaces_cert_id = None
|
||||
self.finalize_uri = None
|
||||
self.certificate_uri = None
|
||||
self.authorization_uris = []
|
||||
self.authorizations = {}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, client, data, url):
|
||||
def from_json(
|
||||
cls: t.Type[_Order], client: ACMEClient, data: dict[str, t.Any], url: str
|
||||
) -> _Order:
|
||||
result = cls(url)
|
||||
result._setup(client, data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, client, url):
|
||||
def from_url(cls: t.Type[_Order], client: ACMEClient, url: str) -> _Order:
|
||||
result = cls(url)
|
||||
result.refresh(client)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def create(cls, client, identifiers, replaces_cert_id=None, profile=None):
|
||||
def create(
|
||||
cls: t.Type[_Order],
|
||||
client: ACMEClient,
|
||||
identifiers: list[tuple[str, str]],
|
||||
replaces_cert_id: str | None = None,
|
||||
profile: str | None = None,
|
||||
) -> _Order:
|
||||
"""
|
||||
Start a new certificate order (ACME v2 protocol).
|
||||
https://tools.ietf.org/html/rfc8555#section-7.4
|
||||
@@ -72,7 +89,7 @@ class Order:
|
||||
"value": identifier,
|
||||
}
|
||||
)
|
||||
new_order = {"identifiers": acme_identifiers}
|
||||
new_order: dict[str, t.Any] = {"identifiers": acme_identifiers}
|
||||
if replaces_cert_id is not None:
|
||||
new_order["replaces"] = replaces_cert_id
|
||||
if profile is not None:
|
||||
@@ -87,15 +104,17 @@ class Order:
|
||||
|
||||
@classmethod
|
||||
def create_with_error_handling(
|
||||
cls,
|
||||
client,
|
||||
identifiers,
|
||||
error_strategy="auto",
|
||||
error_max_retries=3,
|
||||
replaces_cert_id=None,
|
||||
profile=None,
|
||||
message_callback=None,
|
||||
):
|
||||
cls: t.Type[_Order],
|
||||
client: ACMEClient,
|
||||
identifiers: list[tuple[str, str]],
|
||||
error_strategy: t.Literal[
|
||||
"auto", "fail", "always", "retry_without_replaces_cert_id"
|
||||
] = "auto",
|
||||
error_max_retries: int = 3,
|
||||
replaces_cert_id: str | None = None,
|
||||
profile: str | None = None,
|
||||
message_callback: t.Callable[[str], None] | None = None,
|
||||
) -> _Order:
|
||||
"""
|
||||
error_strategy can be one of the following strings:
|
||||
|
||||
@@ -140,20 +159,20 @@ class Order:
|
||||
|
||||
raise
|
||||
|
||||
def refresh(self, client):
|
||||
def refresh(self, client: ACMEClient) -> bool:
|
||||
result, dummy = client.get_request(self.url)
|
||||
changed = self.data != result
|
||||
self._setup(client, result)
|
||||
return changed
|
||||
|
||||
def load_authorizations(self, client):
|
||||
def load_authorizations(self, client: ACMEClient) -> None:
|
||||
for auth_uri in self.authorization_uris:
|
||||
authz = Authorization.from_url(client, auth_uri)
|
||||
self.authorizations[
|
||||
normalize_combined_identifier(authz.combined_identifier)
|
||||
] = authz
|
||||
|
||||
def wait_for_finalization(self, client):
|
||||
def wait_for_finalization(self, client: ACMEClient) -> None:
|
||||
while True:
|
||||
self.refresh(client)
|
||||
if self.status in ["valid", "invalid", "pending", "ready"]:
|
||||
@@ -167,12 +186,14 @@ class Order:
|
||||
content_json=self.data,
|
||||
)
|
||||
|
||||
def finalize(self, client, csr_der, wait=True):
|
||||
def finalize(self, client: ACMEClient, csr_der: bytes, wait: bool = True) -> None:
|
||||
"""
|
||||
Create a new certificate based on the csr.
|
||||
Return the certificate object as dict
|
||||
https://tools.ietf.org/html/rfc8555#section-7.4
|
||||
"""
|
||||
if self.finalize_uri is None:
|
||||
raise ModuleFailException("finalize_uri must be set")
|
||||
new_cert = {
|
||||
"csr": nopad_b64(csr_der),
|
||||
}
|
||||
|
||||
@@ -7,9 +7,11 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import traceback
|
||||
import typing as t
|
||||
from urllib.parse import unquote
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
@@ -23,11 +25,15 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
def nopad_b64(data):
|
||||
if t.TYPE_CHECKING:
|
||||
from .backends import CertificateInformation, CryptoBackend
|
||||
|
||||
|
||||
def nopad_b64(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).decode("utf8").replace("=", "")
|
||||
|
||||
|
||||
def der_to_pem(der_cert):
|
||||
def der_to_pem(der_cert: bytes) -> str:
|
||||
"""
|
||||
Convert the DER format certificate in der_cert to a PEM format certificate and return it.
|
||||
"""
|
||||
@@ -35,7 +41,9 @@ def der_to_pem(der_cert):
|
||||
return f"-----BEGIN CERTIFICATE-----\n{content}\n-----END CERTIFICATE-----\n"
|
||||
|
||||
|
||||
def pem_to_der(pem_filename=None, pem_content=None):
|
||||
def pem_to_der(
|
||||
pem_filename: str | os.PathLike | None = None, pem_content: str | None = None
|
||||
) -> bytes:
|
||||
"""
|
||||
Load PEM file, or use PEM file's content, and convert to DER.
|
||||
|
||||
@@ -70,7 +78,9 @@ def pem_to_der(pem_filename=None, pem_content=None):
|
||||
return base64.b64decode("".join(certificate_lines))
|
||||
|
||||
|
||||
def process_links(info, callback):
|
||||
def process_links(
|
||||
info: dict[str, t.Any], callback: t.Callable[[str, str], None]
|
||||
) -> None:
|
||||
"""
|
||||
Process link header, calls callback for every link header with the URL and relation as options.
|
||||
|
||||
@@ -82,7 +92,11 @@ def process_links(info, callback):
|
||||
callback(unquote(url), relation)
|
||||
|
||||
|
||||
def parse_retry_after(value, relative_with_timezone=True, now=None):
|
||||
def parse_retry_after(
|
||||
value: str,
|
||||
relative_with_timezone: bool = True,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> datetime.datetime:
|
||||
"""
|
||||
Parse the value of a Retry-After header and return a timestamp.
|
||||
|
||||
@@ -106,12 +120,12 @@ def parse_retry_after(value, relative_with_timezone=True, now=None):
|
||||
|
||||
|
||||
def compute_cert_id(
|
||||
backend,
|
||||
cert_info=None,
|
||||
cert_filename=None,
|
||||
cert_content=None,
|
||||
none_if_required_information_is_missing=False,
|
||||
):
|
||||
backend: CryptoBackend,
|
||||
cert_info: CertificateInformation | None = None,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
none_if_required_information_is_missing: bool = False,
|
||||
) -> str | None:
|
||||
# Obtain certificate info if not provided
|
||||
if cert_info is None:
|
||||
cert_info = backend.get_cert_information(
|
||||
|
||||
@@ -4,10 +4,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def _ensure_list(value):
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
def _ensure_list(value: list[_T] | tuple[_T] | None) -> list[_T]:
|
||||
if value is None:
|
||||
return []
|
||||
return list(value)
|
||||
@@ -16,13 +21,19 @@ def _ensure_list(value):
|
||||
class ArgumentSpec:
|
||||
def __init__(
|
||||
self,
|
||||
argument_spec=None,
|
||||
mutually_exclusive=None,
|
||||
required_together=None,
|
||||
required_one_of=None,
|
||||
required_if=None,
|
||||
required_by=None,
|
||||
):
|
||||
argument_spec: dict[str, t.Any] | None = None,
|
||||
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_together: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_if: (
|
||||
list[
|
||||
tuple[str, t.Any, list[str] | tuple[str, ...]]
|
||||
| tuple[str, t.Any, list[str] | tuple[str, ...], bool]
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
|
||||
) -> None:
|
||||
self.argument_spec = argument_spec or {}
|
||||
self.mutually_exclusive = _ensure_list(mutually_exclusive)
|
||||
self.required_together = _ensure_list(required_together)
|
||||
@@ -30,17 +41,23 @@ class ArgumentSpec:
|
||||
self.required_if = _ensure_list(required_if)
|
||||
self.required_by = required_by or {}
|
||||
|
||||
def update_argspec(self, **kwargs):
|
||||
def update_argspec(self, **kwargs) -> t.Self:
|
||||
self.argument_spec.update(kwargs)
|
||||
return self
|
||||
|
||||
def update(
|
||||
self,
|
||||
mutually_exclusive=None,
|
||||
required_together=None,
|
||||
required_one_of=None,
|
||||
required_if=None,
|
||||
required_by=None,
|
||||
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_together: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_if: (
|
||||
list[
|
||||
tuple[str, t.Any, list[str] | tuple[str, ...]]
|
||||
| tuple[str, t.Any, list[str] | tuple[str, ...], bool]
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
|
||||
):
|
||||
if mutually_exclusive:
|
||||
self.mutually_exclusive.extend(mutually_exclusive)
|
||||
@@ -57,7 +74,7 @@ class ArgumentSpec:
|
||||
self.required_by[k] = v
|
||||
return self
|
||||
|
||||
def merge(self, other):
|
||||
def merge(self, other: t.Self) -> t.Self:
|
||||
self.update_argspec(**other.argument_spec)
|
||||
self.update(
|
||||
mutually_exclusive=other.mutually_exclusive,
|
||||
@@ -68,8 +85,22 @@ class ArgumentSpec:
|
||||
)
|
||||
return self
|
||||
|
||||
def create_ansible_module_helper(self, clazz, args, **kwargs):
|
||||
return clazz(
|
||||
def create_ansible_module_helper(
|
||||
self, clazz: type[_T], args: tuple, **kwargs: t.Any
|
||||
) -> _T:
|
||||
for forbidden_name in (
|
||||
"argument_spec",
|
||||
"mutually_exclusive",
|
||||
"required_together",
|
||||
"required_one_of",
|
||||
"required_if",
|
||||
"required_by",
|
||||
):
|
||||
if forbidden_name in kwargs:
|
||||
raise ValueError(
|
||||
f"You must not provide a {forbidden_name} keyword parameter to create_ansible_module_helper()"
|
||||
)
|
||||
instance = clazz( # type: ignore
|
||||
*args,
|
||||
argument_spec=self.argument_spec,
|
||||
mutually_exclusive=self.mutually_exclusive,
|
||||
@@ -79,8 +110,9 @@ class ArgumentSpec:
|
||||
required_by=self.required_by,
|
||||
**kwargs,
|
||||
)
|
||||
return instance
|
||||
|
||||
def create_ansible_module(self, **kwargs):
|
||||
def create_ansible_module(self, **kwargs: t.Any) -> AnsibleModule:
|
||||
return self.create_ansible_module_helper(AnsibleModule, (), **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import re
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
@@ -32,7 +33,7 @@ ASN1_STRING_REGEX = re.compile(
|
||||
)
|
||||
|
||||
|
||||
class TagClass:
|
||||
class TagClass(enum.Enum):
|
||||
universal = 0
|
||||
application = 1
|
||||
context_specific = 2
|
||||
@@ -40,11 +41,11 @@ class TagClass:
|
||||
|
||||
|
||||
# Universal tag numbers that can be encoded.
|
||||
class TagNumber:
|
||||
class TagNumber(enum.Enum):
|
||||
utf8_string = 12
|
||||
|
||||
|
||||
def _pack_octet_integer(value):
|
||||
def _pack_octet_integer(value: int) -> bytes:
|
||||
"""Packs an integer value into 1 or multiple octets."""
|
||||
# NOTE: This is *NOT* the same as packing an ASN.1 INTEGER like value.
|
||||
octets = bytearray()
|
||||
@@ -66,7 +67,7 @@ def _pack_octet_integer(value):
|
||||
return bytes(octets)
|
||||
|
||||
|
||||
def serialize_asn1_string_as_der(value):
|
||||
def serialize_asn1_string_as_der(value: str) -> bytes:
|
||||
"""Deserializes an ASN.1 string to a DER encoded byte string."""
|
||||
asn1_match = ASN1_STRING_REGEX.match(value)
|
||||
if not asn1_match:
|
||||
@@ -92,7 +93,7 @@ def serialize_asn1_string_as_der(value):
|
||||
b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value)
|
||||
|
||||
if tag_type:
|
||||
tag_class = {
|
||||
tag_class_enum = {
|
||||
"U": TagClass.universal,
|
||||
"A": TagClass.application,
|
||||
"P": TagClass.private,
|
||||
@@ -100,13 +101,15 @@ def serialize_asn1_string_as_der(value):
|
||||
}[tag_class]
|
||||
|
||||
# When adding support for more types this should be looked into further. For now it works with UTF8Strings.
|
||||
constructed = tag_type == "EXPLICIT" and tag_class != TagClass.universal
|
||||
b_value = pack_asn1(tag_class, constructed, int(tag_number), b_value)
|
||||
constructed = tag_type == "EXPLICIT" and tag_class_enum != TagClass.universal
|
||||
b_value = pack_asn1(tag_class_enum, constructed, int(tag_number), b_value)
|
||||
|
||||
return b_value
|
||||
|
||||
|
||||
def pack_asn1(tag_class, constructed, tag_number, b_data):
|
||||
def pack_asn1(
|
||||
tag_class: TagClass, constructed: bool, tag_number: TagNumber | int, b_data: bytes
|
||||
) -> bytes:
|
||||
"""Pack the value into an ASN.1 data structure.
|
||||
|
||||
The structure for an ASN.1 element is
|
||||
@@ -115,16 +118,15 @@ def pack_asn1(tag_class, constructed, tag_number, b_data):
|
||||
"""
|
||||
b_asn1_data = bytearray()
|
||||
|
||||
if tag_class < 0 or tag_class > 3:
|
||||
raise ValueError(f"tag_class must be between 0 and 3 not {tag_class}")
|
||||
|
||||
# Bit 8 and 7 denotes the class.
|
||||
identifier_octets = tag_class << 6
|
||||
identifier_octets = tag_class.value << 6
|
||||
# Bit 6 denotes whether the value is primitive or constructed.
|
||||
identifier_octets |= (1 if constructed else 0) << 5
|
||||
|
||||
# Bits 5-1 contain the tag number, if it cannot be encoded in these 5 bits
|
||||
# then they are set and another octet(s) is used to denote the tag number.
|
||||
if isinstance(tag_number, TagNumber):
|
||||
tag_number = tag_number.value
|
||||
if tag_number < 31:
|
||||
identifier_octets |= tag_number
|
||||
b_asn1_data.append(identifier_octets)
|
||||
|
||||
@@ -34,7 +34,7 @@ from __future__ import annotations
|
||||
# cryptography versions!
|
||||
|
||||
|
||||
def obj2txt(openssl_lib, openssl_ffi, obj):
|
||||
def obj2txt(openssl_lib, openssl_ffi, obj) -> str:
|
||||
# Set to 80 on the recommendation of
|
||||
# https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values
|
||||
#
|
||||
|
||||
@@ -7,9 +7,9 @@ from __future__ import annotations
|
||||
from ._objects_data import OID_MAP
|
||||
|
||||
|
||||
OID_LOOKUP = dict()
|
||||
NORMALIZE_NAMES = dict()
|
||||
NORMALIZE_NAMES_SHORT = dict()
|
||||
OID_LOOKUP: dict[str, str] = dict()
|
||||
NORMALIZE_NAMES: dict[str, str] = dict()
|
||||
NORMALIZE_NAMES_SHORT: dict[str, str] = dict()
|
||||
|
||||
for dotted, names in OID_MAP.items():
|
||||
for name in names:
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
LooseVersion as _LooseVersion,
|
||||
)
|
||||
@@ -21,6 +23,10 @@ from .basic import HAS_CRYPTOGRAPHY
|
||||
from .cryptography_support import CRYPTOGRAPHY_TIMEZONE, cryptography_decode_name
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
|
||||
# TODO: once cryptography has a _utc variant of InvalidityDate.invalidity_date, set this
|
||||
# to True and adjust get_invalidity_date() accordingly.
|
||||
# (https://github.com/pyca/cryptography/issues/10818)
|
||||
@@ -55,7 +61,9 @@ else:
|
||||
REVOCATION_REASON_MAP_INVERSE = dict()
|
||||
|
||||
|
||||
def cryptography_decode_revoked_certificate(cert):
|
||||
def cryptography_decode_revoked_certificate(
|
||||
cert: x509.RevokedCertificate,
|
||||
) -> dict[str, t.Any]:
|
||||
result = {
|
||||
"serial_number": cert.serial_number,
|
||||
"revocation_date": get_revocation_date(cert),
|
||||
@@ -67,27 +75,30 @@ def cryptography_decode_revoked_certificate(cert):
|
||||
"invalidity_date_critical": False,
|
||||
}
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(x509.CertificateIssuer)
|
||||
result["issuer"] = list(ext.value)
|
||||
result["issuer_critical"] = ext.critical
|
||||
ext_ci = cert.extensions.get_extension_for_class(x509.CertificateIssuer)
|
||||
result["issuer"] = list(ext_ci.value)
|
||||
result["issuer_critical"] = ext_ci.critical
|
||||
except x509.ExtensionNotFound:
|
||||
pass
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(x509.CRLReason)
|
||||
result["reason"] = ext.value.reason
|
||||
result["reason_critical"] = ext.critical
|
||||
ext_cr = cert.extensions.get_extension_for_class(x509.CRLReason)
|
||||
result["reason"] = ext_cr.value.reason
|
||||
result["reason_critical"] = ext_cr.critical
|
||||
except x509.ExtensionNotFound:
|
||||
pass
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(x509.InvalidityDate)
|
||||
result["invalidity_date"] = get_invalidity_date(ext.value)
|
||||
result["invalidity_date_critical"] = ext.critical
|
||||
ext_id = cert.extensions.get_extension_for_class(x509.InvalidityDate)
|
||||
result["invalidity_date"] = get_invalidity_date(ext_id.value)
|
||||
result["invalidity_date_critical"] = ext_id.critical
|
||||
except x509.ExtensionNotFound:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def cryptography_dump_revoked(entry, idn_rewrite="ignore"):
|
||||
def cryptography_dump_revoked(
|
||||
entry: dict[str, t.Any],
|
||||
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
|
||||
) -> dict[str, t.Any]:
|
||||
return {
|
||||
"serial_number": entry["serial_number"],
|
||||
"revocation_date": entry["revocation_date"].strftime(TIMESTAMP_FORMAT),
|
||||
@@ -115,48 +126,56 @@ def cryptography_dump_revoked(entry, idn_rewrite="ignore"):
|
||||
}
|
||||
|
||||
|
||||
def cryptography_get_signature_algorithm_oid_from_crl(crl):
|
||||
def cryptography_get_signature_algorithm_oid_from_crl(
|
||||
crl: x509.CertificateRevocationList,
|
||||
) -> x509.oid.ObjectIdentifier:
|
||||
try:
|
||||
return crl.signature_algorithm_oid
|
||||
except AttributeError:
|
||||
# Older cryptography versions do not have signature_algorithm_oid yet
|
||||
dotted = obj2txt(
|
||||
crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm
|
||||
crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm # type: ignore
|
||||
)
|
||||
return x509.oid.ObjectIdentifier(dotted)
|
||||
|
||||
|
||||
def get_next_update(obj):
|
||||
def get_next_update(obj: x509.CertificateRevocationList) -> datetime.datetime | None:
|
||||
if CRYPTOGRAPHY_TIMEZONE:
|
||||
return obj.next_update_utc
|
||||
return obj.next_update
|
||||
|
||||
|
||||
def get_last_update(obj):
|
||||
def get_last_update(obj: x509.CertificateRevocationList) -> datetime.datetime:
|
||||
if CRYPTOGRAPHY_TIMEZONE:
|
||||
return obj.last_update_utc
|
||||
return obj.last_update
|
||||
|
||||
|
||||
def get_revocation_date(obj):
|
||||
def get_revocation_date(obj: x509.RevokedCertificate) -> datetime.datetime:
|
||||
if CRYPTOGRAPHY_TIMEZONE:
|
||||
return obj.revocation_date_utc
|
||||
return obj.revocation_date
|
||||
|
||||
|
||||
def get_invalidity_date(obj):
|
||||
def get_invalidity_date(obj: x509.InvalidityDate) -> datetime.datetime:
|
||||
if CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE:
|
||||
return obj.invalidity_date_utc
|
||||
return obj.invalidity_date
|
||||
|
||||
|
||||
def set_next_update(builder, value):
|
||||
def set_next_update(
|
||||
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
|
||||
) -> x509.CertificateRevocationListBuilder:
|
||||
return builder.next_update(value)
|
||||
|
||||
|
||||
def set_last_update(builder, value):
|
||||
def set_last_update(
|
||||
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
|
||||
) -> x509.CertificateRevocationListBuilder:
|
||||
return builder.last_update(value)
|
||||
|
||||
|
||||
def set_revocation_date(builder, value):
|
||||
def set_revocation_date(
|
||||
builder: x509.RevokedCertificateBuilder, value: datetime.datetime
|
||||
) -> x509.RevokedCertificateBuilder:
|
||||
return builder.revocation_date(value)
|
||||
|
||||
@@ -9,6 +9,7 @@ import binascii
|
||||
import ipaddress
|
||||
import re
|
||||
import traceback
|
||||
import typing as t
|
||||
from urllib.parse import (
|
||||
ParseResult,
|
||||
urlparse,
|
||||
@@ -40,6 +41,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import cryptography.hazmat.primitives.asymmetric.dh
|
||||
import cryptography.hazmat.primitives.asymmetric.ed448
|
||||
import cryptography.hazmat.primitives.asymmetric.ed25519
|
||||
import cryptography.hazmat.primitives.asymmetric.rsa
|
||||
@@ -55,7 +57,7 @@ try:
|
||||
)
|
||||
except ImportError:
|
||||
# Error handled in the calling module.
|
||||
_load_pkcs12 = None
|
||||
_load_pkcs12 = None # type: ignore
|
||||
|
||||
try:
|
||||
import idna
|
||||
@@ -74,6 +76,50 @@ from .basic import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric.dh import DHPrivateKey, DHPublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.dsa import (
|
||||
DSAPrivateKey,
|
||||
DSAPublicKey,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
EllipticCurvePrivateKey,
|
||||
EllipticCurvePublicKey,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
||||
RSAPrivateKey,
|
||||
RSAPublicKey,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
CertificateIssuerPublicKeyTypes,
|
||||
CertificatePublicKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
PublicKeyTypes,
|
||||
)
|
||||
from cryptography.hazmat.primitives.serialization.pkcs12 import (
|
||||
PKCS12KeyAndCertificates,
|
||||
)
|
||||
|
||||
CertificatePrivateKeyTypes = (
|
||||
CertificateIssuerPrivateKeyTypes
|
||||
| cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey
|
||||
| cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
|
||||
)
|
||||
PublicKeyTypesWOEdwards = (
|
||||
DHPublicKey | DSAPublicKey | EllipticCurvePublicKey | RSAPublicKey
|
||||
)
|
||||
PrivateKeyTypesWOEdwards = (
|
||||
DHPrivateKey | DSAPrivateKey | EllipticCurvePrivateKey | RSAPrivateKey
|
||||
)
|
||||
else:
|
||||
PublicKeyTypesWOEdwards = None
|
||||
PrivateKeyTypesWOEdwards = None
|
||||
|
||||
|
||||
CRYPTOGRAPHY_TIMEZONE = False
|
||||
_CRYPTOGRAPHY_36_0_OR_NEWER = False
|
||||
if _HAS_CRYPTOGRAPHY:
|
||||
@@ -88,7 +134,9 @@ if _HAS_CRYPTOGRAPHY:
|
||||
DOTTED_OID = re.compile(r"^\d+(?:\.\d+)+$")
|
||||
|
||||
|
||||
def cryptography_get_extensions_from_cert(cert):
|
||||
def cryptography_get_extensions_from_cert(
|
||||
cert: x509.Certificate,
|
||||
) -> dict[str, dict[str, bool | str]]:
|
||||
result = dict()
|
||||
|
||||
if _CRYPTOGRAPHY_36_0_OR_NEWER:
|
||||
@@ -105,7 +153,7 @@ def cryptography_get_extensions_from_cert(cert):
|
||||
|
||||
backend = default_backend()
|
||||
|
||||
x509_obj = cert._x509
|
||||
x509_obj = cert._x509 # type: ignore
|
||||
# With cryptography 35.0.0, we can no longer use obj2txt. Unfortunately it still does
|
||||
# not allow to get the raw value of an extension, so we have to use this ugly hack:
|
||||
exts = list(cert.extensions)
|
||||
@@ -135,7 +183,9 @@ def cryptography_get_extensions_from_cert(cert):
|
||||
return result
|
||||
|
||||
|
||||
def cryptography_get_extensions_from_csr(csr):
|
||||
def cryptography_get_extensions_from_csr(
|
||||
csr: x509.CertificateSigningRequest,
|
||||
) -> dict[str, dict[str, bool | str]]:
|
||||
result = dict()
|
||||
|
||||
if _CRYPTOGRAPHY_36_0_OR_NEWER:
|
||||
@@ -153,7 +203,7 @@ def cryptography_get_extensions_from_csr(csr):
|
||||
|
||||
backend = default_backend()
|
||||
|
||||
extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req)
|
||||
extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req) # type: ignore
|
||||
extensions = backend._ffi.gc(
|
||||
extensions,
|
||||
lambda ext: backend._lib.sk_X509_EXTENSION_pop_free(
|
||||
@@ -175,7 +225,7 @@ def cryptography_get_extensions_from_csr(csr):
|
||||
crit = backend._lib.X509_EXTENSION_get_critical(ext)
|
||||
data = backend._lib.X509_EXTENSION_get_data(ext)
|
||||
backend.openssl_assert(data != backend._ffi.NULL)
|
||||
der = backend._ffi.buffer(data.data, data.length)[:]
|
||||
der: bytes = backend._ffi.buffer(data.data, data.length)[:] # type: ignore
|
||||
entry = dict(
|
||||
critical=(crit == 1),
|
||||
value=base64.b64encode(der).decode("ascii"),
|
||||
@@ -193,7 +243,7 @@ def cryptography_get_extensions_from_csr(csr):
|
||||
return result
|
||||
|
||||
|
||||
def cryptography_name_to_oid(name):
|
||||
def cryptography_name_to_oid(name: str) -> x509.oid.ObjectIdentifier:
|
||||
dotted = OID_LOOKUP.get(name)
|
||||
if dotted is None:
|
||||
if DOTTED_OID.match(name):
|
||||
@@ -202,7 +252,9 @@ def cryptography_name_to_oid(name):
|
||||
return x509.oid.ObjectIdentifier(dotted)
|
||||
|
||||
|
||||
def cryptography_oid_to_name(oid, short=False):
|
||||
def cryptography_oid_to_name(
|
||||
oid: x509.oid.ObjectIdentifier, short: bool = False
|
||||
) -> str:
|
||||
dotted_string = oid.dotted_string
|
||||
names = OID_MAP.get(dotted_string)
|
||||
if names:
|
||||
@@ -217,15 +269,22 @@ def cryptography_oid_to_name(oid, short=False):
|
||||
return NORMALIZE_NAMES.get(name, name)
|
||||
|
||||
|
||||
def _get_hex(bytesstr):
|
||||
def _get_hex(bytesstr: bytes) -> str:
|
||||
if bytesstr is None:
|
||||
return bytesstr
|
||||
data = binascii.hexlify(bytesstr)
|
||||
data = to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2)))
|
||||
return data
|
||||
return to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2)))
|
||||
|
||||
|
||||
def _parse_hex(bytesstr):
|
||||
@t.overload
|
||||
def _parse_hex(bytesstr: bytes | str) -> bytes: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def _parse_hex(bytesstr: bytes | str | None) -> bytes | None: ...
|
||||
|
||||
|
||||
def _parse_hex(bytesstr: bytes | str | None) -> bytes | None:
|
||||
if bytesstr is None:
|
||||
return bytesstr
|
||||
data = "".join(
|
||||
@@ -234,19 +293,20 @@ def _parse_hex(bytesstr):
|
||||
for p in to_text(bytesstr).split(":")
|
||||
]
|
||||
)
|
||||
data = binascii.unhexlify(data)
|
||||
return data
|
||||
return binascii.unhexlify(data)
|
||||
|
||||
|
||||
DN_COMPONENT_START_RE = re.compile(b"^ *([a-zA-z0-9.]+) *= *")
|
||||
DN_HEX_LETTER = b"0123456789abcdef"
|
||||
|
||||
|
||||
def _int_to_byte(value):
|
||||
def _int_to_byte(value: int) -> bytes:
|
||||
return bytes((value,))
|
||||
|
||||
|
||||
def _parse_dn_component(name, sep=b",", decode_remainder=True):
|
||||
def _parse_dn_component(
|
||||
name: bytes, sep: bytes = b",", decode_remainder: bool = True
|
||||
) -> tuple[x509.NameAttribute, bytes]:
|
||||
m = DN_COMPONENT_START_RE.match(name)
|
||||
if not m:
|
||||
raise OpenSSLObjectError(f'cannot start part in "{to_text(name)}"')
|
||||
@@ -305,7 +365,7 @@ def _parse_dn_component(name, sep=b",", decode_remainder=True):
|
||||
return x509.NameAttribute(oid, to_text(b"".join(decoded_name))), name[idx:]
|
||||
|
||||
|
||||
def _parse_dn(name):
|
||||
def _parse_dn(name: bytes) -> list[x509.NameAttribute]:
|
||||
"""
|
||||
Parse a Distinguished Name.
|
||||
|
||||
@@ -323,31 +383,33 @@ def _parse_dn(name):
|
||||
attribute, name = _parse_dn_component(name, sep=sep)
|
||||
except OpenSSLObjectError as e:
|
||||
raise OpenSSLObjectError(
|
||||
f'Error while parsing distinguished name "{to_text(original_name)}": {e}'
|
||||
f"Error while parsing distinguished name {to_text(original_name)!r}: {e}"
|
||||
)
|
||||
result.append(attribute)
|
||||
if name:
|
||||
if name[0:1] != sep or len(name) < 2:
|
||||
raise OpenSSLObjectError(
|
||||
f'Error while parsing distinguished name "{to_text(original_name)}": unexpected end of string'
|
||||
f"Error while parsing distinguished name {to_text(original_name)!r}: unexpected end of string"
|
||||
)
|
||||
name = name[1:]
|
||||
return result
|
||||
|
||||
|
||||
def cryptography_parse_relative_distinguished_name(rdn):
|
||||
def cryptography_parse_relative_distinguished_name(
|
||||
rdn: list[str | bytes],
|
||||
) -> cryptography.x509.RelativeDistinguishedName:
|
||||
names = []
|
||||
for part in rdn:
|
||||
try:
|
||||
names.append(_parse_dn_component(to_bytes(part), decode_remainder=False)[0])
|
||||
except OpenSSLObjectError as e:
|
||||
raise OpenSSLObjectError(
|
||||
f'Error while parsing relative distinguished name "{part}": {e}'
|
||||
f"Error while parsing relative distinguished name {to_text(part)!r}: {e}"
|
||||
)
|
||||
return cryptography.x509.RelativeDistinguishedName(names)
|
||||
|
||||
|
||||
def _is_ascii(value):
|
||||
def _is_ascii(value: str) -> bool:
|
||||
"""Check whether the Unicode string `value` contains only ASCII characters."""
|
||||
try:
|
||||
value.encode("ascii")
|
||||
@@ -356,7 +418,7 @@ def _is_ascii(value):
|
||||
return False
|
||||
|
||||
|
||||
def _adjust_idn(value, idn_rewrite):
|
||||
def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]) -> str:
|
||||
if idn_rewrite == "ignore" or not value:
|
||||
return value
|
||||
if idn_rewrite == "idna" and _is_ascii(value):
|
||||
@@ -399,16 +461,20 @@ def _adjust_idn(value, idn_rewrite):
|
||||
return ".".join(parts)
|
||||
|
||||
|
||||
def _adjust_idn_email(value, idn_rewrite):
|
||||
def _adjust_idn_email(
|
||||
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
|
||||
) -> str:
|
||||
idx = value.find("@")
|
||||
if idx < 0:
|
||||
return value
|
||||
return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite)}"
|
||||
|
||||
|
||||
def _adjust_idn_url(value, idn_rewrite):
|
||||
def _adjust_idn_url(
|
||||
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
|
||||
) -> str:
|
||||
url = urlparse(value)
|
||||
host = _adjust_idn(url.hostname, idn_rewrite)
|
||||
host = _adjust_idn(url.hostname, idn_rewrite) if url.hostname else None
|
||||
if url.username is not None and url.password is not None:
|
||||
host = f"{url.username}:{url.password}@{host}"
|
||||
elif url.username is not None:
|
||||
@@ -418,7 +484,7 @@ def _adjust_idn_url(value, idn_rewrite):
|
||||
return urlunparse(
|
||||
ParseResult(
|
||||
scheme=url.scheme,
|
||||
netloc=host,
|
||||
netloc=host or "",
|
||||
path=url.path,
|
||||
params=url.params,
|
||||
query=url.query,
|
||||
@@ -427,7 +493,9 @@ def _adjust_idn_url(value, idn_rewrite):
|
||||
)
|
||||
|
||||
|
||||
def cryptography_get_name(name, what="Subject Alternative Name"):
|
||||
def cryptography_get_name(
|
||||
name: str, what: str = "Subject Alternative Name"
|
||||
) -> x509.GeneralName:
|
||||
"""
|
||||
Given a name string, returns a cryptography x509.GeneralName object.
|
||||
Raises an OpenSSLObjectError if the name is unknown or cannot be parsed.
|
||||
@@ -490,7 +558,7 @@ def cryptography_get_name(name, what="Subject Alternative Name"):
|
||||
)
|
||||
|
||||
|
||||
def _dn_escape_value(value):
|
||||
def _dn_escape_value(value: str) -> str:
|
||||
"""
|
||||
Escape Distinguished Name's attribute value.
|
||||
"""
|
||||
@@ -505,7 +573,10 @@ def _dn_escape_value(value):
|
||||
return value
|
||||
|
||||
|
||||
def cryptography_decode_name(name, idn_rewrite="ignore"):
|
||||
def cryptography_decode_name(
|
||||
name: x509.GeneralName,
|
||||
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
|
||||
) -> str:
|
||||
"""
|
||||
Given a cryptography x509.GeneralName object, returns a string.
|
||||
Raises an OpenSSLObjectError if the name is not supported.
|
||||
@@ -529,7 +600,7 @@ def cryptography_decode_name(name, idn_rewrite="ignore"):
|
||||
# list needs to be reversed, and joined by commas
|
||||
return "dirName:" + ",".join(
|
||||
[
|
||||
f"{to_text(cryptography_oid_to_name(attribute.oid, short=True))}={_dn_escape_value(attribute.value)}"
|
||||
f"{to_text(cryptography_oid_to_name(attribute.oid, short=True))}={_dn_escape_value(to_text(attribute.value))}"
|
||||
for attribute in reversed(list(name.value))
|
||||
]
|
||||
)
|
||||
@@ -540,7 +611,7 @@ def cryptography_decode_name(name, idn_rewrite="ignore"):
|
||||
raise OpenSSLObjectError(f'Cannot decode name "{name}"')
|
||||
|
||||
|
||||
def _cryptography_get_keyusage(usage):
|
||||
def _cryptography_get_keyusage(usage: str) -> str:
|
||||
"""
|
||||
Given a key usage identifier string, returns the parameter name used by cryptography's x509.KeyUsage().
|
||||
Raises an OpenSSLObjectError if the identifier is unknown.
|
||||
@@ -566,7 +637,7 @@ def _cryptography_get_keyusage(usage):
|
||||
raise OpenSSLObjectError(f'Unknown key usage "{usage}"')
|
||||
|
||||
|
||||
def cryptography_parse_key_usage_params(usages):
|
||||
def cryptography_parse_key_usage_params(usages: t.Iterable[str]) -> dict[str, bool]:
|
||||
"""
|
||||
Given a list of key usage identifier strings, returns the parameters for cryptography's x509.KeyUsage().
|
||||
Raises an OpenSSLObjectError if an identifier is unknown.
|
||||
@@ -587,13 +658,15 @@ def cryptography_parse_key_usage_params(usages):
|
||||
return params
|
||||
|
||||
|
||||
def cryptography_get_basic_constraints(constraints):
|
||||
def cryptography_get_basic_constraints(
|
||||
constraints: t.Iterable[str] | None,
|
||||
) -> tuple[bool, int | None]:
|
||||
"""
|
||||
Given a list of constraints, returns a tuple (ca, path_length).
|
||||
Raises an OpenSSLObjectError if a constraint is unknown or cannot be parsed.
|
||||
"""
|
||||
ca = False
|
||||
path_length = None
|
||||
path_length: int | None = None
|
||||
if constraints:
|
||||
for constraint in constraints:
|
||||
if constraint.startswith("CA:"):
|
||||
@@ -618,7 +691,9 @@ def cryptography_get_basic_constraints(constraints):
|
||||
return ca, path_length
|
||||
|
||||
|
||||
def cryptography_key_needs_digest_for_signing(key):
|
||||
def cryptography_key_needs_digest_for_signing(
|
||||
key: CertificateIssuerPrivateKeyTypes,
|
||||
) -> bool:
|
||||
"""Tests whether the given private key requires a digest algorithm for signing.
|
||||
|
||||
Ed25519 and Ed448 keys do not; they need None to be passed as the digest algorithm.
|
||||
@@ -632,19 +707,27 @@ def cryptography_key_needs_digest_for_signing(key):
|
||||
return True
|
||||
|
||||
|
||||
def _compare_public_keys(key1, key2, clazz):
|
||||
def _compare_public_keys(
|
||||
key1: PublicKeyTypes, key2: PublicKeyTypes, clazz: type[PublicKeyTypes]
|
||||
) -> bool | None:
|
||||
a = isinstance(key1, clazz)
|
||||
b = isinstance(key2, clazz)
|
||||
if not (a or b):
|
||||
return None
|
||||
if not a or not b:
|
||||
return False
|
||||
a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
||||
b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
||||
return a == b
|
||||
a_bytes = key1.public_bytes(
|
||||
serialization.Encoding.Raw, serialization.PublicFormat.Raw
|
||||
)
|
||||
b_bytes = key2.public_bytes(
|
||||
serialization.Encoding.Raw, serialization.PublicFormat.Raw
|
||||
)
|
||||
return a_bytes == b_bytes
|
||||
|
||||
|
||||
def cryptography_compare_public_keys(key1, key2):
|
||||
def cryptography_compare_public_keys(
|
||||
key1: PublicKeyTypes, key2: PublicKeyTypes
|
||||
) -> bool:
|
||||
"""Tests whether two public keys are the same.
|
||||
|
||||
Needs special logic for Ed25519 and Ed448 keys, since they do not have public_numbers().
|
||||
@@ -654,6 +737,13 @@ def cryptography_compare_public_keys(key1, key2):
|
||||
key2,
|
||||
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
res = _compare_public_keys(
|
||||
key1,
|
||||
key2,
|
||||
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
res = _compare_public_keys(
|
||||
@@ -661,10 +751,20 @@ def cryptography_compare_public_keys(key1, key2):
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
return key1.public_numbers() == key2.public_numbers()
|
||||
res = _compare_public_keys(
|
||||
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
return (
|
||||
t.cast(PublicKeyTypesWOEdwards, key1).public_numbers()
|
||||
== t.cast(PublicKeyTypesWOEdwards, key2).public_numbers()
|
||||
)
|
||||
|
||||
|
||||
def _compare_private_keys(key1, key2, clazz):
|
||||
def _compare_private_keys(
|
||||
key1: PrivateKeyTypes, key2: PrivateKeyTypes, clazz: type[PrivateKeyTypes]
|
||||
) -> bool | None:
|
||||
a = isinstance(key1, clazz)
|
||||
b = isinstance(key2, clazz)
|
||||
if not (a or b):
|
||||
@@ -672,20 +772,22 @@ def _compare_private_keys(key1, key2, clazz):
|
||||
if not a or not b:
|
||||
return False
|
||||
encryption_algorithm = cryptography.hazmat.primitives.serialization.NoEncryption()
|
||||
a = key1.private_bytes(
|
||||
a_bytes = key1.private_bytes(
|
||||
serialization.Encoding.Raw,
|
||||
serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
||||
b = key2.private_bytes(
|
||||
b_bytes = key2.private_bytes(
|
||||
serialization.Encoding.Raw,
|
||||
serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
||||
return a == b
|
||||
return a_bytes == b_bytes
|
||||
|
||||
|
||||
def cryptography_compare_private_keys(key1, key2):
|
||||
def cryptography_compare_private_keys(
|
||||
key1: PrivateKeyTypes, key2: PrivateKeyTypes
|
||||
) -> bool:
|
||||
"""Tests whether two private keys are the same.
|
||||
|
||||
Needs special logic for Ed25519, X25519, and Ed448 keys, since they do not have private_numbers().
|
||||
@@ -714,25 +816,39 @@ def cryptography_compare_private_keys(key1, key2):
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
return key1.private_numbers() == key2.private_numbers()
|
||||
return (
|
||||
t.cast(PrivateKeyTypesWOEdwards, key1).private_numbers()
|
||||
== t.cast(PrivateKeyTypesWOEdwards, key2).private_numbers()
|
||||
)
|
||||
|
||||
|
||||
def parse_pkcs12(pkcs12_bytes, passphrase=None):
|
||||
def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
bytes | None,
|
||||
]:
|
||||
"""Returns a tuple (private_key, certificate, additional_certificates, friendly_name)."""
|
||||
passphrase_bytes = None
|
||||
if passphrase is not None:
|
||||
passphrase = to_bytes(passphrase)
|
||||
passphrase_bytes = to_bytes(passphrase)
|
||||
|
||||
# Main code for cryptography 36.0.0 and forward
|
||||
if _load_pkcs12 is not None:
|
||||
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase)
|
||||
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase_bytes)
|
||||
|
||||
if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"):
|
||||
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase)
|
||||
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase_bytes)
|
||||
|
||||
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase)
|
||||
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase_bytes)
|
||||
|
||||
|
||||
def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None):
|
||||
def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
bytes | None,
|
||||
]:
|
||||
# Requires cryptography 36.0.0 or newer
|
||||
pkcs12 = _load_pkcs12(pkcs12_bytes, passphrase)
|
||||
additional_certificates = [cert.certificate for cert in pkcs12.additional_certs]
|
||||
@@ -745,7 +861,12 @@ def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None):
|
||||
return private_key, certificate, additional_certificates, friendly_name
|
||||
|
||||
|
||||
def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None):
|
||||
def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
bytes | None,
|
||||
]:
|
||||
# Backwards compatibility code for cryptography 35.x
|
||||
private_key, certificate, additional_certificates = _load_key_and_certificates(
|
||||
pkcs12_bytes, passphrase
|
||||
@@ -787,7 +908,12 @@ def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None):
|
||||
return private_key, certificate, additional_certificates, friendly_name
|
||||
|
||||
|
||||
def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None):
|
||||
def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
bytes | None,
|
||||
]:
|
||||
# Backwards compatibility code for cryptography < 35.0.0
|
||||
private_key, certificate, additional_certificates = _load_key_and_certificates(
|
||||
pkcs12_bytes, passphrase
|
||||
@@ -796,14 +922,19 @@ def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None):
|
||||
friendly_name = None
|
||||
if certificate:
|
||||
# See https://github.com/pyca/cryptography/issues/5760#issuecomment-842687238
|
||||
backend = certificate._backend
|
||||
maybe_name = backend._lib.X509_alias_get0(certificate._x509, backend._ffi.NULL)
|
||||
backend = certificate._backend # type: ignore
|
||||
maybe_name = backend._lib.X509_alias_get0(certificate._x509, backend._ffi.NULL) # type: ignore
|
||||
if maybe_name != backend._ffi.NULL:
|
||||
friendly_name = backend._ffi.string(maybe_name)
|
||||
return private_key, certificate, additional_certificates, friendly_name
|
||||
|
||||
|
||||
def cryptography_verify_signature(signature, data, hash_algorithm, signer_public_key):
|
||||
def cryptography_verify_signature(
|
||||
signature: bytes,
|
||||
data: bytes,
|
||||
hash_algorithm: hashes.HashAlgorithm | None,
|
||||
signer_public_key: PublicKeyTypes,
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether the given signature of the given data was signed by the given public key object.
|
||||
"""
|
||||
@@ -812,6 +943,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
|
||||
signer_public_key,
|
||||
cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey,
|
||||
):
|
||||
if hash_algorithm is None:
|
||||
raise OpenSSLObjectError("Need hash_algorithm for RSA keys")
|
||||
signer_public_key.verify(
|
||||
signature, data, padding.PKCS1v15(), hash_algorithm
|
||||
)
|
||||
@@ -820,6 +953,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
|
||||
signer_public_key,
|
||||
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey,
|
||||
):
|
||||
if hash_algorithm is None:
|
||||
raise OpenSSLObjectError("Need hash_algorithm for ECC keys")
|
||||
signer_public_key.verify(
|
||||
signature,
|
||||
data,
|
||||
@@ -830,6 +965,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
|
||||
signer_public_key,
|
||||
cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey,
|
||||
):
|
||||
if hash_algorithm is None:
|
||||
raise OpenSSLObjectError("Need hash_algorithm for DSA keys")
|
||||
signer_public_key.verify(signature, data, hash_algorithm)
|
||||
return True
|
||||
if isinstance(
|
||||
@@ -851,7 +988,9 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
|
||||
return False
|
||||
|
||||
|
||||
def cryptography_verify_certificate_signature(certificate, signer_public_key):
|
||||
def cryptography_verify_certificate_signature(
|
||||
certificate: x509.Certificate, signer_public_key: PublicKeyTypes
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether the given X509 certificate object was signed by the given public key object.
|
||||
"""
|
||||
@@ -863,21 +1002,65 @@ def cryptography_verify_certificate_signature(certificate, signer_public_key):
|
||||
)
|
||||
|
||||
|
||||
def get_not_valid_after(obj):
|
||||
def get_not_valid_after(obj: x509.Certificate) -> datetime.datetime:
|
||||
if CRYPTOGRAPHY_TIMEZONE:
|
||||
return obj.not_valid_after_utc
|
||||
return obj.not_valid_after
|
||||
|
||||
|
||||
def get_not_valid_before(obj):
|
||||
def get_not_valid_before(obj: x509.Certificate) -> datetime.datetime:
|
||||
if CRYPTOGRAPHY_TIMEZONE:
|
||||
return obj.not_valid_before_utc
|
||||
return obj.not_valid_before
|
||||
|
||||
|
||||
def set_not_valid_after(builder, value):
|
||||
def set_not_valid_after(
|
||||
builder: x509.CertificateBuilder, value: datetime.datetime
|
||||
) -> x509.CertificateBuilder:
|
||||
return builder.not_valid_after(value)
|
||||
|
||||
|
||||
def set_not_valid_before(builder, value):
|
||||
def set_not_valid_before(
|
||||
builder: x509.CertificateBuilder, value: datetime.datetime
|
||||
) -> x509.CertificateBuilder:
|
||||
return builder.not_valid_before(value)
|
||||
|
||||
|
||||
def is_potential_certificate_private_key(
|
||||
key: PrivateKeyTypes,
|
||||
) -> t.TypeGuard[CertificatePrivateKeyTypes]:
|
||||
return not isinstance(
|
||||
key, cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey
|
||||
)
|
||||
|
||||
|
||||
def is_potential_certificate_issuer_private_key(
|
||||
key: PrivateKeyTypes,
|
||||
) -> t.TypeGuard[CertificateIssuerPrivateKeyTypes]:
|
||||
return not isinstance(
|
||||
key,
|
||||
(
|
||||
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
|
||||
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
|
||||
cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def is_potential_certificate_public_key(
|
||||
key: PublicKeyTypes,
|
||||
) -> t.TypeGuard[CertificatePublicKeyTypes]:
|
||||
return not isinstance(key, DHPublicKey)
|
||||
|
||||
|
||||
def is_potential_certificate_issuer_public_key(
|
||||
key: PublicKeyTypes,
|
||||
) -> t.TypeGuard[CertificateIssuerPublicKeyTypes]:
|
||||
return not isinstance(
|
||||
key,
|
||||
(
|
||||
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
|
||||
cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey,
|
||||
cryptography.hazmat.primitives.asymmetric.dh.DHPublicKey,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def binary_exp_mod(f, e, m):
|
||||
def binary_exp_mod(f: int, e: int, m: int) -> int:
|
||||
"""Computes f^e mod m in O(log e) multiplications modulo m."""
|
||||
# Compute len_e = floor(log_2(e))
|
||||
len_e = -1
|
||||
@@ -22,14 +22,14 @@ def binary_exp_mod(f, e, m):
|
||||
return result
|
||||
|
||||
|
||||
def simple_gcd(a, b):
|
||||
def simple_gcd(a: int, b: int) -> int:
|
||||
"""Compute GCD of its two inputs."""
|
||||
while b != 0:
|
||||
a, b = b, a % b
|
||||
return a
|
||||
|
||||
|
||||
def quick_is_not_prime(n):
|
||||
def quick_is_not_prime(n: int) -> bool:
|
||||
"""Does some quick checks to see if we can poke a hole into the primality of n.
|
||||
|
||||
A result of `False` does **not** mean that the number is prime; it just means
|
||||
@@ -97,7 +97,7 @@ def quick_is_not_prime(n):
|
||||
return False
|
||||
|
||||
|
||||
def count_bytes(no):
|
||||
def count_bytes(no: int) -> int:
|
||||
"""
|
||||
Given an integer, compute the number of bytes necessary to store its absolute value.
|
||||
"""
|
||||
@@ -107,7 +107,7 @@ def count_bytes(no):
|
||||
return (no.bit_length() + 7) // 8
|
||||
|
||||
|
||||
def count_bits(no):
|
||||
def count_bits(no: int) -> int:
|
||||
"""
|
||||
Given an integer, compute the number of bits necessary to store its absolute value.
|
||||
"""
|
||||
@@ -117,19 +117,7 @@ def count_bits(no):
|
||||
return no.bit_length()
|
||||
|
||||
|
||||
def _convert_int_to_bytes(count, no):
|
||||
return no.to_bytes(count, byteorder="big")
|
||||
|
||||
|
||||
def _convert_bytes_to_int(data):
|
||||
return int.from_bytes(data, byteorder="big", signed=False)
|
||||
|
||||
|
||||
def _to_hex(no):
|
||||
return f"{no:x}"
|
||||
|
||||
|
||||
def convert_int_to_bytes(no, count=None):
|
||||
def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
|
||||
"""
|
||||
Convert the absolute value of an integer to a byte string in network byte order.
|
||||
|
||||
@@ -142,10 +130,10 @@ def convert_int_to_bytes(no, count=None):
|
||||
no = abs(no)
|
||||
if count is None:
|
||||
count = count_bytes(no)
|
||||
return _convert_int_to_bytes(count, no)
|
||||
return no.to_bytes(count, byteorder="big")
|
||||
|
||||
|
||||
def convert_int_to_hex(no, digits=None):
|
||||
def convert_int_to_hex(no: int, digits: int | None = None) -> str:
|
||||
"""
|
||||
Convert the absolute value of an integer to a string of hexadecimal digits.
|
||||
|
||||
@@ -154,14 +142,14 @@ def convert_int_to_hex(no, digits=None):
|
||||
the string will be longer.
|
||||
"""
|
||||
no = abs(no)
|
||||
value = _to_hex(no)
|
||||
value = f"{no:x}"
|
||||
if digits is not None and len(value) < digits:
|
||||
value = "0" * (digits - len(value)) + value
|
||||
return value
|
||||
|
||||
|
||||
def convert_bytes_to_int(data):
|
||||
def convert_bytes_to_int(data: bytes) -> int:
|
||||
"""
|
||||
Convert a byte string to an unsigned integer in network byte order.
|
||||
"""
|
||||
return _convert_bytes_to_int(data)
|
||||
return int.from_bytes(data, byteorder="big", signed=False)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
ArgumentSpec,
|
||||
@@ -24,8 +25,8 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
|
||||
load_certificate,
|
||||
load_certificate_privatekey,
|
||||
load_certificate_request,
|
||||
load_privatekey,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
|
||||
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
|
||||
@@ -33,6 +34,17 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ..cryptography_support import CertificatePrivateKeyTypes
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -47,41 +59,45 @@ class CertificateError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
|
||||
self.force = module.params["force"]
|
||||
self.ignore_timestamps = module.params["ignore_timestamps"]
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
if self.privatekey_content is not None:
|
||||
self.privatekey_content = self.privatekey_content.encode("utf-8")
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.csr_path = module.params["csr_path"]
|
||||
self.csr_content = module.params["csr_content"]
|
||||
if self.csr_content is not None:
|
||||
self.csr_content = self.csr_content.encode("utf-8")
|
||||
self.force: bool = module.params["force"]
|
||||
self.ignore_timestamps: bool = module.params["ignore_timestamps"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
if privatekey_content is not None:
|
||||
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
|
||||
else:
|
||||
self.privatekey_content = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
self.csr_path: str | None = module.params["csr_path"]
|
||||
csr_content = module.params["csr_content"]
|
||||
if csr_content is not None:
|
||||
self.csr_content: bytes | None = csr_content.encode("utf-8")
|
||||
else:
|
||||
self.csr_content = None
|
||||
|
||||
# The following are default values which make sure check() works as
|
||||
# before if providers do not explicitly change these properties.
|
||||
self.create_subject_key_identifier = "never_create"
|
||||
self.create_authority_key_identifier = False
|
||||
self.create_subject_key_identifier: str = "never_create"
|
||||
self.create_authority_key_identifier: bool = False
|
||||
|
||||
self.privatekey = None
|
||||
self.csr = None
|
||||
self.cert = None
|
||||
self.existing_certificate = None
|
||||
self.existing_certificate_bytes = None
|
||||
self.privatekey: CertificatePrivateKeyTypes | None = None
|
||||
self.csr: x509.CertificateSigningRequest | None = None
|
||||
self.cert: x509.Certificate | None = None
|
||||
self.existing_certificate: x509.Certificate | None = None
|
||||
self.existing_certificate_bytes: bytes | None = None
|
||||
|
||||
self.check_csr_subject = True
|
||||
self.check_csr_extensions = True
|
||||
self.check_csr_subject: bool = True
|
||||
self.check_csr_extensions: bool = True
|
||||
|
||||
self.diff_before = self._get_info(None)
|
||||
self.diff_after = self._get_info(None)
|
||||
|
||||
def _get_info(self, data):
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return dict()
|
||||
return {}
|
||||
try:
|
||||
result = get_certificate_info(
|
||||
self.module, data, prefer_one_fingerprint=True
|
||||
@@ -92,34 +108,34 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
return dict(can_parse_certificate=False)
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_certificate(self):
|
||||
def generate_certificate(self) -> None:
|
||||
"""(Re-)Generate certificate."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_certificate_data(self):
|
||||
def get_certificate_data(self) -> bytes:
|
||||
"""Return bytes for self.cert."""
|
||||
pass
|
||||
|
||||
def set_existing(self, certificate_bytes):
|
||||
def set_existing(self, certificate_bytes: bytes | None) -> None:
|
||||
"""Set existing certificate bytes. None indicates that the key does not exist."""
|
||||
self.existing_certificate_bytes = certificate_bytes
|
||||
self.diff_after = self.diff_before = self._get_info(
|
||||
self.existing_certificate_bytes
|
||||
)
|
||||
|
||||
def has_existing(self):
|
||||
def has_existing(self) -> bool:
|
||||
"""Query whether an existing certificate is/has been there."""
|
||||
return self.existing_certificate_bytes is not None
|
||||
|
||||
def _ensure_private_key_loaded(self):
|
||||
def _ensure_private_key_loaded(self) -> None:
|
||||
"""Load the provided private key into self.privatekey."""
|
||||
if self.privatekey is not None:
|
||||
return
|
||||
if self.privatekey_path is None and self.privatekey_content is None:
|
||||
return
|
||||
try:
|
||||
self.privatekey = load_privatekey(
|
||||
self.privatekey = load_certificate_privatekey(
|
||||
path=self.privatekey_path,
|
||||
content=self.privatekey_content,
|
||||
passphrase=self.privatekey_passphrase,
|
||||
@@ -127,7 +143,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
except OpenSSLBadPassphraseError as exc:
|
||||
raise CertificateError(exc)
|
||||
|
||||
def _ensure_csr_loaded(self):
|
||||
def _ensure_csr_loaded(self) -> None:
|
||||
"""Load the CSR into self.csr."""
|
||||
if self.csr is not None:
|
||||
return
|
||||
@@ -138,7 +154,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
content=self.csr_content,
|
||||
)
|
||||
|
||||
def _ensure_existing_certificate_loaded(self):
|
||||
def _ensure_existing_certificate_loaded(self) -> None:
|
||||
"""Load the existing certificate into self.existing_certificate."""
|
||||
if self.existing_certificate is not None:
|
||||
return
|
||||
@@ -149,14 +165,28 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
content=self.existing_certificate_bytes,
|
||||
)
|
||||
|
||||
def _check_privatekey(self):
|
||||
def _check_privatekey(self) -> bool:
|
||||
"""Check whether provided parameters match, assuming self.existing_certificate and self.privatekey have been populated."""
|
||||
if self.existing_certificate is None:
|
||||
raise AssertionError(
|
||||
"Contract violation: existing_certificate has not been populated"
|
||||
)
|
||||
if self.privatekey is None:
|
||||
raise AssertionError(
|
||||
"Contract violation: privatekey has not been populated"
|
||||
)
|
||||
return cryptography_compare_public_keys(
|
||||
self.existing_certificate.public_key(), self.privatekey.public_key()
|
||||
)
|
||||
|
||||
def _check_csr(self):
|
||||
def _check_csr(self) -> bool:
|
||||
"""Check whether provided parameters match, assuming self.existing_certificate and self.csr have been populated."""
|
||||
if self.existing_certificate is None:
|
||||
raise AssertionError(
|
||||
"Contract violation: existing_certificate has not been populated"
|
||||
)
|
||||
if self.csr is None:
|
||||
raise AssertionError("Contract violation: csr has not been populated")
|
||||
# Verify that CSR is signed by certificate's private key
|
||||
if not self.csr.is_signature_valid:
|
||||
return False
|
||||
@@ -214,8 +244,14 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_subject_key_identifier(self):
|
||||
"""Check whether Subject Key Identifier matches, assuming self.existing_certificate has been populated."""
|
||||
def _check_subject_key_identifier(self) -> bool:
|
||||
"""Check whether Subject Key Identifier matches, assuming self.existing_certificate and self.csr have been populated."""
|
||||
if self.existing_certificate is None:
|
||||
raise AssertionError(
|
||||
"Contract violation: existing_certificate has not been populated"
|
||||
)
|
||||
if self.csr is None:
|
||||
raise AssertionError("Contract violation: csr has not been populated")
|
||||
# Get hold of certificate's SKI
|
||||
try:
|
||||
ext = self.existing_certificate.extensions.get_extension_for_class(
|
||||
@@ -247,7 +283,11 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
return False
|
||||
return True
|
||||
|
||||
def needs_regeneration(self, not_before=None, not_after=None):
|
||||
def needs_regeneration(
|
||||
self,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
"""Check whether a regeneration is necessary."""
|
||||
if self.force or self.existing_certificate_bytes is None:
|
||||
return True
|
||||
@@ -256,6 +296,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
self._ensure_existing_certificate_loaded()
|
||||
except Exception:
|
||||
return True
|
||||
assert self.existing_certificate is not None
|
||||
|
||||
# Check whether private key matches
|
||||
self._ensure_private_key_loaded()
|
||||
@@ -285,9 +326,12 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
return True
|
||||
return False
|
||||
|
||||
def dump(self, include_certificate):
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = {"privatekey": self.privatekey_path, "csr": self.csr_path}
|
||||
result: dict[str, t.Any] = {
|
||||
"privatekey": self.privatekey_path,
|
||||
"csr": self.csr_path,
|
||||
}
|
||||
# Get hold of certificate bytes
|
||||
certificate_bytes = self.existing_certificate_bytes
|
||||
if self.cert is not None:
|
||||
@@ -299,35 +343,33 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
certificate_bytes.decode("utf-8") if certificate_bytes else None
|
||||
)
|
||||
|
||||
result["diff"] = dict(
|
||||
before=self.diff_before,
|
||||
after=self.diff_after,
|
||||
)
|
||||
result["diff"] = {
|
||||
"before": self.diff_before,
|
||||
"after": self.diff_after,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class CertificateProvider(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def validate_module_args(self, module):
|
||||
def validate_module_args(self, module: AnsibleModule) -> None:
|
||||
"""Check module arguments"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def needs_version_two_certs(self, module):
|
||||
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
|
||||
"""Whether the provider needs to create a version 2 certificate."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_backend(self, module):
|
||||
def create_backend(self, module: AnsibleModule) -> CertificateBackend:
|
||||
"""Create an implementation for a backend.
|
||||
|
||||
Return value must be instance of CertificateBackend.
|
||||
"""
|
||||
|
||||
|
||||
def select_backend(module, provider):
|
||||
"""
|
||||
:type module: AnsibleModule
|
||||
:type provider: CertificateProvider
|
||||
"""
|
||||
def select_backend(
|
||||
module: AnsibleModule, provider: CertificateProvider
|
||||
) -> CertificateBackend:
|
||||
provider.validate_module_args(module)
|
||||
|
||||
assert_required_cryptography_version(
|
||||
@@ -343,7 +385,7 @@ def select_backend(module, provider):
|
||||
return provider.create_backend(module)
|
||||
|
||||
|
||||
def get_certificate_argument_spec():
|
||||
def get_certificate_argument_spec() -> ArgumentSpec:
|
||||
return ArgumentSpec(
|
||||
argument_spec=dict(
|
||||
provider=dict(
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
|
||||
@@ -17,22 +18,30 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
class AcmeCertificateBackend(CertificateBackend):
|
||||
def __init__(self, module):
|
||||
super(AcmeCertificateBackend, self).__init__(module)
|
||||
self.accountkey_path = module.params["acme_accountkey_path"]
|
||||
self.challenge_path = module.params["acme_challenge_path"]
|
||||
self.use_chain = module.params["acme_chain"]
|
||||
self.acme_directory = module.params["acme_directory"]
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
if self.csr_content is None and self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for ownca provider"
|
||||
)
|
||||
if self.csr_content is None and not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
from ...argspec import ArgumentSpec
|
||||
|
||||
|
||||
class AcmeCertificateBackend(CertificateBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(AcmeCertificateBackend, self).__init__(module)
|
||||
self.accountkey_path: str = module.params["acme_accountkey_path"]
|
||||
self.challenge_path: str = module.params["acme_challenge_path"]
|
||||
self.use_chain: bool = module.params["acme_chain"]
|
||||
self.acme_directory: str = module.params["acme_directory"]
|
||||
self.cert_bytes: bytes | None = None
|
||||
|
||||
if self.csr_content is None:
|
||||
if self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for ownca provider"
|
||||
)
|
||||
if not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
|
||||
if not os.path.exists(self.accountkey_path):
|
||||
raise CertificateError(
|
||||
@@ -46,7 +55,7 @@ class AcmeCertificateBackend(CertificateBackend):
|
||||
|
||||
self.acme_tiny_path = self.module.get_bin_path("acme-tiny", required=True)
|
||||
|
||||
def generate_certificate(self):
|
||||
def generate_certificate(self) -> None:
|
||||
"""(Re-)Generate certificate."""
|
||||
|
||||
command = [self.acme_tiny_path]
|
||||
@@ -77,22 +86,26 @@ class AcmeCertificateBackend(CertificateBackend):
|
||||
command.extend(["--directory-url", self.acme_directory])
|
||||
|
||||
try:
|
||||
self.cert = to_bytes(self.module.run_command(command, check_rc=True)[1])
|
||||
self.cert_bytes = to_bytes(
|
||||
self.module.run_command(command, check_rc=True)[1]
|
||||
)
|
||||
except OSError as exc:
|
||||
raise CertificateError(exc)
|
||||
|
||||
def get_certificate_data(self):
|
||||
def get_certificate_data(self) -> bytes:
|
||||
"""Return bytes for self.cert."""
|
||||
return self.cert
|
||||
if self.cert_bytes is None:
|
||||
raise AssertionError("Contract violation: cert_bytes is None")
|
||||
return self.cert_bytes
|
||||
|
||||
def dump(self, include_certificate):
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
result = super(AcmeCertificateBackend, self).dump(include_certificate)
|
||||
result["accountkey"] = self.accountkey_path
|
||||
return result
|
||||
|
||||
|
||||
class AcmeCertificateProvider(CertificateProvider):
|
||||
def validate_module_args(self, module):
|
||||
def validate_module_args(self, module: AnsibleModule) -> None:
|
||||
if module.params["acme_accountkey_path"] is None:
|
||||
module.fail_json(
|
||||
msg="The acme_accountkey_path option must be specified for the acme provider."
|
||||
@@ -102,14 +115,14 @@ class AcmeCertificateProvider(CertificateProvider):
|
||||
msg="The acme_challenge_path option must be specified for the acme provider."
|
||||
)
|
||||
|
||||
def needs_version_two_certs(self, module):
|
||||
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
|
||||
return False
|
||||
|
||||
def create_backend(self, module):
|
||||
def create_backend(self, module: AnsibleModule) -> AcmeCertificateBackend:
|
||||
return AcmeCertificateBackend(module)
|
||||
|
||||
|
||||
def add_acme_provider_to_argument_spec(argument_spec):
|
||||
def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
argument_spec.argument_spec["provider"]["choices"].append("acme")
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
|
||||
@@ -32,6 +33,12 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from ...argspec import ArgumentSpec
|
||||
|
||||
|
||||
try:
|
||||
from cryptography.x509.oid import NameOID
|
||||
except ImportError:
|
||||
@@ -39,7 +46,7 @@ except ImportError:
|
||||
|
||||
|
||||
class EntrustCertificateBackend(CertificateBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(EntrustCertificateBackend, self).__init__(module)
|
||||
self.trackingId = None
|
||||
self.notAfter = get_relative_time_option(
|
||||
@@ -48,16 +55,19 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
|
||||
if self.csr_content is None and self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for entrust provider"
|
||||
)
|
||||
if self.csr_content is None and not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
if self.csr_content is None:
|
||||
if self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for entrust provider"
|
||||
)
|
||||
if not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
|
||||
self._ensure_csr_loaded()
|
||||
if self.csr is None:
|
||||
raise CertificateError("CSR not provided")
|
||||
|
||||
# ECS API defaults to using the validated organization tied to the account.
|
||||
# We want to always force behavior of trying to use the organization provided in the CSR.
|
||||
@@ -93,9 +103,9 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
],
|
||||
)
|
||||
except SessionConfigurationException as e:
|
||||
module.fail_json(msg=f"Failed to initialize Entrust Provider: {e.message}")
|
||||
module.fail_json(msg=f"Failed to initialize Entrust Provider: {e}")
|
||||
|
||||
def generate_certificate(self):
|
||||
def generate_certificate(self) -> None:
|
||||
"""(Re-)Generate certificate."""
|
||||
body = {}
|
||||
|
||||
@@ -104,6 +114,7 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
# csr_content contains bytes
|
||||
body["csr"] = to_native(self.csr_content)
|
||||
else:
|
||||
assert self.csr_path is not None
|
||||
with open(self.csr_path, "r") as csr_file:
|
||||
body["csr"] = csr_file.read()
|
||||
|
||||
@@ -138,11 +149,15 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
content=self.cert_bytes,
|
||||
)
|
||||
|
||||
def get_certificate_data(self):
|
||||
def get_certificate_data(self) -> bytes:
|
||||
"""Return bytes for self.cert."""
|
||||
return self.cert_bytes
|
||||
|
||||
def needs_regeneration(self):
|
||||
def needs_regeneration(
|
||||
self,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
parent_check = super(EntrustCertificateBackend, self).needs_regeneration()
|
||||
|
||||
try:
|
||||
@@ -167,12 +182,12 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
|
||||
return parent_check
|
||||
|
||||
def _get_cert_details(self):
|
||||
cert_details = {}
|
||||
def _get_cert_details(self) -> dict[str, t.Any]:
|
||||
cert_details: dict[str, t.Any] = {}
|
||||
try:
|
||||
self._ensure_existing_certificate_loaded()
|
||||
except Exception:
|
||||
return
|
||||
return cert_details
|
||||
if self.existing_certificate:
|
||||
serial_number = f"{self.existing_certificate.serial_number:X}"
|
||||
expiry = get_not_valid_after(self.existing_certificate)
|
||||
@@ -203,17 +218,17 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
|
||||
|
||||
class EntrustCertificateProvider(CertificateProvider):
|
||||
def validate_module_args(self, module):
|
||||
def validate_module_args(self, module: AnsibleModule) -> None:
|
||||
pass
|
||||
|
||||
def needs_version_two_certs(self, module):
|
||||
def needs_version_two_certs(self, module: AnsibleModule) -> t.Literal[False]:
|
||||
return False
|
||||
|
||||
def create_backend(self, module):
|
||||
def create_backend(self, module: AnsibleModule) -> EntrustCertificateBackend:
|
||||
return EntrustCertificateBackend(module)
|
||||
|
||||
|
||||
def add_entrust_provider_to_argument_spec(argument_spec):
|
||||
def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
argument_spec.argument_spec["provider"]["choices"].append("entrust")
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
@@ -248,7 +263,7 @@ def add_entrust_provider_to_argument_spec(argument_spec):
|
||||
)
|
||||
)
|
||||
argument_spec.required_if.append(
|
||||
[
|
||||
(
|
||||
"provider",
|
||||
"entrust",
|
||||
[
|
||||
@@ -260,5 +275,5 @@ def add_entrust_provider_to_argument_spec(argument_spec):
|
||||
"entrust_api_client_cert_path",
|
||||
"entrust_api_client_cert_key_path",
|
||||
],
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import binascii
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
|
||||
@@ -34,6 +35,19 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
from ....plugin_utils.filter_module import FilterModuleMock
|
||||
from ...argspec import ArgumentSpec
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -48,93 +62,97 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
|
||||
|
||||
|
||||
class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module, content):
|
||||
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
|
||||
# content must be a bytes string
|
||||
self.module = module
|
||||
self.content = content
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_der_bytes(self):
|
||||
def _get_der_bytes(self) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_signature_algorithm(self):
|
||||
def _get_signature_algorithm(self) -> str:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_ordered(self):
|
||||
def _get_subject_ordered(self) -> list[list[str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_issuer_ordered(self):
|
||||
def _get_issuer_ordered(self) -> list[list[str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_version(self):
|
||||
def _get_version(self) -> int | str:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_key_usage(self):
|
||||
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extended_key_usage(self):
|
||||
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_basic_constraints(self):
|
||||
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_ocsp_must_staple(self):
|
||||
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_alt_name(self):
|
||||
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_not_before(self):
|
||||
def get_not_before(self) -> datetime.datetime:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_not_after(self):
|
||||
def get_not_after(self) -> datetime.datetime:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key_pem(self):
|
||||
def _get_public_key_pem(self) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key_object(self):
|
||||
def _get_public_key_object(self) -> PublicKeyTypes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_key_identifier(self):
|
||||
def _get_subject_key_identifier(self) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_authority_key_identifier(self):
|
||||
def _get_authority_key_identifier(
|
||||
self,
|
||||
) -> tuple[bytes | None, list[str] | None, int | None]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_serial_number(self):
|
||||
def _get_serial_number(self) -> int:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_all_extensions(self):
|
||||
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_ocsp_uri(self):
|
||||
def _get_ocsp_uri(self) -> str | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_issuer_uri(self):
|
||||
def _get_issuer_uri(self) -> str | None:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint=False, der_support_enabled=False):
|
||||
result = dict()
|
||||
def get_info(
|
||||
self, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
|
||||
) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
self.cert = load_certificate(
|
||||
None,
|
||||
content=self.content,
|
||||
@@ -194,16 +212,20 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
self._get_der_bytes(), prefer_one=prefer_one_fingerprint
|
||||
)
|
||||
|
||||
ski = self._get_subject_key_identifier()
|
||||
if ski is not None:
|
||||
ski = binascii.hexlify(ski).decode("ascii")
|
||||
ski_bytes = self._get_subject_key_identifier()
|
||||
if ski_bytes is not None:
|
||||
ski = binascii.hexlify(ski_bytes).decode("ascii")
|
||||
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
|
||||
else:
|
||||
ski = None
|
||||
result["subject_key_identifier"] = ski
|
||||
|
||||
aki, aci, acsn = self._get_authority_key_identifier()
|
||||
if aki is not None:
|
||||
aki = binascii.hexlify(aki).decode("ascii")
|
||||
aki_bytes, aci, acsn = self._get_authority_key_identifier()
|
||||
if aki_bytes is not None:
|
||||
aki = binascii.hexlify(aki_bytes).decode("ascii")
|
||||
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
|
||||
else:
|
||||
aki = None
|
||||
result["authority_key_identifier"] = aki
|
||||
result["authority_cert_issuer"] = aci
|
||||
result["authority_cert_serial_number"] = acsn
|
||||
@@ -219,36 +241,40 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
"""Validate the supplied cert, using the cryptography backend"""
|
||||
|
||||
def __init__(self, module, content):
|
||||
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
|
||||
super(CertificateInfoRetrievalCryptography, self).__init__(module, content)
|
||||
self.name_encoding = module.params.get("name_encoding", "ignore")
|
||||
|
||||
def _get_der_bytes(self):
|
||||
def _get_der_bytes(self) -> bytes:
|
||||
return self.cert.public_bytes(serialization.Encoding.DER)
|
||||
|
||||
def _get_signature_algorithm(self):
|
||||
def _get_signature_algorithm(self) -> str:
|
||||
return cryptography_oid_to_name(self.cert.signature_algorithm_oid)
|
||||
|
||||
def _get_subject_ordered(self):
|
||||
result = []
|
||||
def _get_subject_ordered(self) -> list[list[str]]:
|
||||
result: list[list[str]] = []
|
||||
for attribute in self.cert.subject:
|
||||
result.append([cryptography_oid_to_name(attribute.oid), attribute.value])
|
||||
result.append(
|
||||
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_issuer_ordered(self):
|
||||
def _get_issuer_ordered(self) -> list[list[str]]:
|
||||
result = []
|
||||
for attribute in self.cert.issuer:
|
||||
result.append([cryptography_oid_to_name(attribute.oid), attribute.value])
|
||||
result.append(
|
||||
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_version(self):
|
||||
def _get_version(self) -> int | str:
|
||||
if self.cert.version == x509.Version.v1:
|
||||
return 1
|
||||
if self.cert.version == x509.Version.v3:
|
||||
return 3
|
||||
return "unknown"
|
||||
|
||||
def _get_key_usage(self):
|
||||
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
current_key_ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.KeyUsage
|
||||
@@ -297,7 +323,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_extended_key_usage(self):
|
||||
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.ExtendedKeyUsage
|
||||
@@ -311,7 +337,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_basic_constraints(self):
|
||||
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.BasicConstraints
|
||||
@@ -324,7 +350,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_ocsp_must_staple(self):
|
||||
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
|
||||
try:
|
||||
tlsfeature_ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.TLSFeature
|
||||
@@ -336,7 +362,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_subject_alt_name(self):
|
||||
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
san_ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.SubjectAlternativeName
|
||||
@@ -349,22 +375,22 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def get_not_before(self):
|
||||
def get_not_before(self) -> datetime.datetime:
|
||||
return get_not_valid_before(self.cert)
|
||||
|
||||
def get_not_after(self):
|
||||
def get_not_after(self) -> datetime.datetime:
|
||||
return get_not_valid_after(self.cert)
|
||||
|
||||
def _get_public_key_pem(self):
|
||||
def _get_public_key_pem(self) -> bytes:
|
||||
return self.cert.public_key().public_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def _get_public_key_object(self):
|
||||
def _get_public_key_object(self) -> PublicKeyTypes:
|
||||
return self.cert.public_key()
|
||||
|
||||
def _get_subject_key_identifier(self):
|
||||
def _get_subject_key_identifier(self) -> bytes | None:
|
||||
try:
|
||||
ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.SubjectKeyIdentifier
|
||||
@@ -373,7 +399,9 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None
|
||||
|
||||
def _get_authority_key_identifier(self):
|
||||
def _get_authority_key_identifier(
|
||||
self,
|
||||
) -> tuple[bytes | None, list[str] | None, int | None]:
|
||||
try:
|
||||
ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.AuthorityKeyIdentifier
|
||||
@@ -392,13 +420,13 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, None, None
|
||||
|
||||
def _get_serial_number(self):
|
||||
def _get_serial_number(self) -> int:
|
||||
return self.cert.serial_number
|
||||
|
||||
def _get_all_extensions(self):
|
||||
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
|
||||
return cryptography_get_extensions_from_cert(self.cert)
|
||||
|
||||
def _get_ocsp_uri(self):
|
||||
def _get_ocsp_uri(self) -> str | None:
|
||||
try:
|
||||
ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.AuthorityInformationAccess
|
||||
@@ -411,7 +439,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
pass
|
||||
return None
|
||||
|
||||
def _get_issuer_uri(self):
|
||||
def _get_issuer_uri(self) -> str | None:
|
||||
try:
|
||||
ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.AuthorityInformationAccess
|
||||
@@ -428,12 +456,16 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
return None
|
||||
|
||||
|
||||
def get_certificate_info(module, content, prefer_one_fingerprint=False):
|
||||
def get_certificate_info(
|
||||
module: GeneralAnsibleModule, content: bytes, prefer_one_fingerprint: bool = False
|
||||
) -> dict[str, t.Any]:
|
||||
info = CertificateInfoRetrievalCryptography(module, content)
|
||||
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
|
||||
|
||||
|
||||
def select_backend(module, content):
|
||||
def select_backend(
|
||||
module: GeneralAnsibleModule, content: bytes
|
||||
) -> CertificateInfoRetrieval:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
from random import randrange
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -18,6 +19,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
|
||||
cryptography_verify_certificate_signature,
|
||||
get_not_valid_after,
|
||||
get_not_valid_before,
|
||||
is_potential_certificate_issuer_public_key,
|
||||
set_not_valid_after,
|
||||
set_not_valid_before,
|
||||
)
|
||||
@@ -28,7 +30,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
|
||||
load_certificate,
|
||||
load_privatekey,
|
||||
load_certificate_issuer_privatekey,
|
||||
select_message_digest,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
@@ -36,6 +38,17 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ...argspec import ArgumentSpec
|
||||
|
||||
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography import x509
|
||||
@@ -45,13 +58,13 @@ except ImportError:
|
||||
|
||||
|
||||
class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(OwnCACertificateBackendCryptography, self).__init__(module)
|
||||
|
||||
self.create_subject_key_identifier = module.params[
|
||||
"ownca_create_subject_key_identifier"
|
||||
]
|
||||
self.create_authority_key_identifier = module.params[
|
||||
self.create_subject_key_identifier: t.Literal[
|
||||
"create_if_not_provided", "always_create", "never_create"
|
||||
] = module.params["ownca_create_subject_key_identifier"]
|
||||
self.create_authority_key_identifier: bool = module.params[
|
||||
"ownca_create_authority_key_identifier"
|
||||
]
|
||||
self.notBefore = get_relative_time_option(
|
||||
@@ -65,31 +78,40 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
self.digest = select_message_digest(module.params["ownca_digest"])
|
||||
self.version = module.params["ownca_version"]
|
||||
self.version: int = module.params["ownca_version"]
|
||||
self.serial_number = x509.random_serial_number()
|
||||
self.ca_cert_path = module.params["ownca_path"]
|
||||
self.ca_cert_content = module.params["ownca_content"]
|
||||
if self.ca_cert_content is not None:
|
||||
self.ca_cert_content = self.ca_cert_content.encode("utf-8")
|
||||
self.ca_privatekey_path = module.params["ownca_privatekey_path"]
|
||||
self.ca_privatekey_content = module.params["ownca_privatekey_content"]
|
||||
if self.ca_privatekey_content is not None:
|
||||
self.ca_privatekey_content = self.ca_privatekey_content.encode("utf-8")
|
||||
self.ca_privatekey_passphrase = module.params["ownca_privatekey_passphrase"]
|
||||
self.ca_cert_path: str | None = module.params["ownca_path"]
|
||||
ca_cert_content: str | None = module.params["ownca_content"]
|
||||
if ca_cert_content is not None:
|
||||
self.ca_cert_content: bytes | None = ca_cert_content.encode("utf-8")
|
||||
else:
|
||||
self.ca_cert_content = None
|
||||
self.ca_privatekey_path: str | None = module.params["ownca_privatekey_path"]
|
||||
ca_privatekey_content: str | None = module.params["ownca_privatekey_content"]
|
||||
if ca_privatekey_content is not None:
|
||||
self.ca_privatekey_content: bytes | None = ca_privatekey_content.encode(
|
||||
"utf-8"
|
||||
)
|
||||
else:
|
||||
self.ca_privatekey_content = None
|
||||
self.ca_privatekey_passphrase: str | None = module.params[
|
||||
"ownca_privatekey_passphrase"
|
||||
]
|
||||
|
||||
if self.csr_content is None and self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for ownca provider"
|
||||
)
|
||||
if self.csr_content is None and not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
if self.ca_cert_content is None and not os.path.exists(self.ca_cert_path):
|
||||
if self.csr_content is None:
|
||||
if self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for ownca provider"
|
||||
)
|
||||
if not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
if self.ca_cert_path is not None and not os.path.exists(self.ca_cert_path):
|
||||
raise CertificateError(
|
||||
f"The CA certificate file {self.ca_cert_path} does not exist"
|
||||
)
|
||||
if self.ca_privatekey_content is None and not os.path.exists(
|
||||
if self.ca_privatekey_path is not None and not os.path.exists(
|
||||
self.ca_privatekey_path
|
||||
):
|
||||
raise CertificateError(
|
||||
@@ -101,8 +123,12 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
path=self.ca_cert_path,
|
||||
content=self.ca_cert_content,
|
||||
)
|
||||
if not is_potential_certificate_issuer_public_key(self.ca_cert.public_key()):
|
||||
raise CertificateError(
|
||||
"CA certificate's public key cannot be used to sign certificates"
|
||||
)
|
||||
try:
|
||||
self.ca_private_key = load_privatekey(
|
||||
self.ca_private_key = load_certificate_issuer_privatekey(
|
||||
path=self.ca_privatekey_path,
|
||||
content=self.ca_privatekey_content,
|
||||
passphrase=self.ca_privatekey_passphrase,
|
||||
@@ -125,8 +151,10 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
else:
|
||||
self.digest = None
|
||||
|
||||
def generate_certificate(self):
|
||||
def generate_certificate(self) -> None:
|
||||
"""(Re-)Generate certificate."""
|
||||
if self.csr is None:
|
||||
raise AssertionError("Contract violation: csr has not been populated")
|
||||
cert_builder = x509.CertificateBuilder()
|
||||
cert_builder = cert_builder.subject_name(self.csr.subject)
|
||||
cert_builder = cert_builder.issuer_name(self.ca_cert.subject)
|
||||
@@ -166,10 +194,10 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
critical=False,
|
||||
)
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
public_key = self.ca_cert.public_key()
|
||||
assert is_potential_certificate_issuer_public_key(public_key)
|
||||
cert_builder = cert_builder.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(
|
||||
self.ca_cert.public_key()
|
||||
),
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
|
||||
@@ -180,17 +208,24 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
self.cert = certificate
|
||||
|
||||
def get_certificate_data(self):
|
||||
def get_certificate_data(self) -> bytes:
|
||||
"""Return bytes for self.cert."""
|
||||
if self.cert is None:
|
||||
raise AssertionError("Contract violation: cert has not been populated")
|
||||
return self.cert.public_bytes(Encoding.PEM)
|
||||
|
||||
def needs_regeneration(self):
|
||||
def needs_regeneration(
|
||||
self,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
if super(OwnCACertificateBackendCryptography, self).needs_regeneration(
|
||||
not_before=self.notBefore, not_after=self.notAfter
|
||||
):
|
||||
return True
|
||||
|
||||
self._ensure_existing_certificate_loaded()
|
||||
assert self.existing_certificate is not None
|
||||
|
||||
# Check whether certificate is signed by CA certificate
|
||||
if not cryptography_verify_certificate_signature(
|
||||
@@ -205,31 +240,33 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
# Check AuthorityKeyIdentifier
|
||||
if self.create_authority_key_identifier:
|
||||
try:
|
||||
ext = self.ca_cert.extensions.get_extension_for_class(
|
||||
ext_ski = self.ca_cert.extensions.get_extension_for_class(
|
||||
x509.SubjectKeyIdentifier
|
||||
)
|
||||
expected_ext = (
|
||||
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
|
||||
ext.value
|
||||
ext_ski.value
|
||||
)
|
||||
)
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
public_key = self.ca_cert.public_key()
|
||||
assert is_potential_certificate_issuer_public_key(public_key)
|
||||
expected_ext = x509.AuthorityKeyIdentifier.from_issuer_public_key(
|
||||
self.ca_cert.public_key()
|
||||
public_key
|
||||
)
|
||||
|
||||
try:
|
||||
ext = self.existing_certificate.extensions.get_extension_for_class(
|
||||
ext_aki = self.existing_certificate.extensions.get_extension_for_class(
|
||||
x509.AuthorityKeyIdentifier
|
||||
)
|
||||
if ext.value != expected_ext:
|
||||
if ext_aki.value != expected_ext:
|
||||
return True
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def dump(self, include_certificate):
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
result = super(OwnCACertificateBackendCryptography, self).dump(
|
||||
include_certificate
|
||||
)
|
||||
@@ -251,6 +288,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
else:
|
||||
if self.cert is None:
|
||||
self.cert = self.existing_certificate
|
||||
assert self.cert is not None
|
||||
result.update(
|
||||
{
|
||||
"notBefore": get_not_valid_before(self.cert).strftime(
|
||||
@@ -266,7 +304,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
return result
|
||||
|
||||
|
||||
def generate_serial_number():
|
||||
def generate_serial_number() -> int:
|
||||
"""Generate a serial number for a certificate"""
|
||||
while True:
|
||||
result = randrange(0, 1 << 160)
|
||||
@@ -275,7 +313,7 @@ def generate_serial_number():
|
||||
|
||||
|
||||
class OwnCACertificateProvider(CertificateProvider):
|
||||
def validate_module_args(self, module):
|
||||
def validate_module_args(self, module: AnsibleModule) -> None:
|
||||
if (
|
||||
module.params["ownca_path"] is None
|
||||
and module.params["ownca_content"] is None
|
||||
@@ -291,14 +329,16 @@ class OwnCACertificateProvider(CertificateProvider):
|
||||
msg="One of ownca_privatekey_path and ownca_privatekey_content must be specified for the ownca provider."
|
||||
)
|
||||
|
||||
def needs_version_two_certs(self, module):
|
||||
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
|
||||
return module.params["ownca_version"] == 2
|
||||
|
||||
def create_backend(self, module):
|
||||
def create_backend(
|
||||
self, module: AnsibleModule
|
||||
) -> OwnCACertificateBackendCryptography:
|
||||
return OwnCACertificateBackendCryptography(module)
|
||||
|
||||
|
||||
def add_ownca_provider_to_argument_spec(argument_spec):
|
||||
def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
argument_spec.argument_spec["provider"]["choices"].append("ownca")
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
from random import randrange
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
|
||||
@@ -14,6 +15,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
|
||||
cryptography_verify_certificate_signature,
|
||||
get_not_valid_after,
|
||||
get_not_valid_before,
|
||||
is_potential_certificate_issuer_private_key,
|
||||
set_not_valid_after,
|
||||
set_not_valid_before,
|
||||
)
|
||||
@@ -30,6 +32,17 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ...argspec import ArgumentSpec
|
||||
|
||||
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography import x509
|
||||
@@ -39,12 +52,14 @@ except ImportError:
|
||||
|
||||
|
||||
class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
def __init__(self, module):
|
||||
privatekey: CertificateIssuerPrivateKeyTypes
|
||||
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SelfSignedCertificateBackendCryptography, self).__init__(module)
|
||||
|
||||
self.create_subject_key_identifier = module.params[
|
||||
"selfsigned_create_subject_key_identifier"
|
||||
]
|
||||
self.create_subject_key_identifier: t.Literal[
|
||||
"create_if_not_provided", "always_create", "never_create"
|
||||
] = module.params["selfsigned_create_subject_key_identifier"]
|
||||
self.notBefore = get_relative_time_option(
|
||||
module.params["selfsigned_not_before"],
|
||||
"selfsigned_not_before",
|
||||
@@ -56,14 +71,16 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
self.digest = select_message_digest(module.params["selfsigned_digest"])
|
||||
self.version = module.params["selfsigned_version"]
|
||||
self.version: int = module.params["selfsigned_version"]
|
||||
self.serial_number = x509.random_serial_number()
|
||||
|
||||
if self.csr_path is not None and not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
if self.privatekey_content is None and not os.path.exists(self.privatekey_path):
|
||||
if self.privatekey_path is not None and not os.path.exists(
|
||||
self.privatekey_path
|
||||
):
|
||||
raise CertificateError(
|
||||
f"The private key file {self.privatekey_path} does not exist"
|
||||
)
|
||||
@@ -71,20 +88,10 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
self._module = module
|
||||
|
||||
self._ensure_private_key_loaded()
|
||||
|
||||
self._ensure_csr_loaded()
|
||||
if self.csr is None:
|
||||
# Create empty CSR on the fly
|
||||
csr = cryptography.x509.CertificateSigningRequestBuilder()
|
||||
csr = csr.subject_name(cryptography.x509.Name([]))
|
||||
digest = None
|
||||
if cryptography_key_needs_digest_for_signing(self.privatekey):
|
||||
digest = self.digest
|
||||
if digest is None:
|
||||
self.module.fail_json(
|
||||
msg=f'Unsupported digest "{module.params["selfsigned_digest"]}"'
|
||||
)
|
||||
self.csr = csr.sign(self.privatekey, digest)
|
||||
if self.privatekey is None:
|
||||
raise CertificateError("Private key has not been provided")
|
||||
if not is_potential_certificate_issuer_private_key(self.privatekey):
|
||||
raise CertificateError("Private key cannot be used to sign certificates")
|
||||
|
||||
if cryptography_key_needs_digest_for_signing(self.privatekey):
|
||||
if self.digest is None:
|
||||
@@ -94,8 +101,21 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
else:
|
||||
self.digest = None
|
||||
|
||||
def generate_certificate(self):
|
||||
self._ensure_csr_loaded()
|
||||
if self.csr is None:
|
||||
# Create empty CSR on the fly
|
||||
csr = cryptography.x509.CertificateSigningRequestBuilder()
|
||||
csr = csr.subject_name(cryptography.x509.Name([]))
|
||||
self.csr = csr.sign(self.privatekey, self.digest)
|
||||
|
||||
def generate_certificate(self) -> None:
|
||||
"""(Re-)Generate certificate."""
|
||||
if self.csr is None:
|
||||
raise AssertionError("Contract violation: csr has not been populated")
|
||||
if self.privatekey is None:
|
||||
raise AssertionError(
|
||||
"Contract violation: privatekey has not been populated"
|
||||
)
|
||||
try:
|
||||
cert_builder = x509.CertificateBuilder()
|
||||
cert_builder = cert_builder.subject_name(self.csr.subject)
|
||||
@@ -130,17 +150,26 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
self.cert = certificate
|
||||
|
||||
def get_certificate_data(self):
|
||||
def get_certificate_data(self) -> bytes:
|
||||
"""Return bytes for self.cert."""
|
||||
if self.cert is None:
|
||||
raise AssertionError("Contract violation: cert has not been populated")
|
||||
return self.cert.public_bytes(Encoding.PEM)
|
||||
|
||||
def needs_regeneration(self):
|
||||
def needs_regeneration(
|
||||
self,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
assert self.privatekey is not None
|
||||
|
||||
if super(SelfSignedCertificateBackendCryptography, self).needs_regeneration(
|
||||
not_before=self.notBefore, not_after=self.notAfter
|
||||
):
|
||||
return True
|
||||
|
||||
self._ensure_existing_certificate_loaded()
|
||||
assert self.existing_certificate is not None
|
||||
|
||||
# Check whether certificate is signed by private key
|
||||
if not cryptography_verify_certificate_signature(
|
||||
@@ -150,7 +179,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
return False
|
||||
|
||||
def dump(self, include_certificate):
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
result = super(SelfSignedCertificateBackendCryptography, self).dump(
|
||||
include_certificate
|
||||
)
|
||||
@@ -166,6 +195,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
else:
|
||||
if self.cert is None:
|
||||
self.cert = self.existing_certificate
|
||||
assert self.cert is not None
|
||||
result.update(
|
||||
{
|
||||
"notBefore": get_not_valid_before(self.cert).strftime(
|
||||
@@ -181,7 +211,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
return result
|
||||
|
||||
|
||||
def generate_serial_number():
|
||||
def generate_serial_number() -> int:
|
||||
"""Generate a serial number for a certificate"""
|
||||
while True:
|
||||
result = randrange(0, 1 << 160)
|
||||
@@ -190,7 +220,7 @@ def generate_serial_number():
|
||||
|
||||
|
||||
class SelfSignedCertificateProvider(CertificateProvider):
|
||||
def validate_module_args(self, module):
|
||||
def validate_module_args(self, module: AnsibleModule) -> None:
|
||||
if (
|
||||
module.params["privatekey_path"] is None
|
||||
and module.params["privatekey_content"] is None
|
||||
@@ -199,14 +229,16 @@ class SelfSignedCertificateProvider(CertificateProvider):
|
||||
msg="One of privatekey_path and privatekey_content must be specified for the selfsigned provider."
|
||||
)
|
||||
|
||||
def needs_version_two_certs(self, module):
|
||||
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
|
||||
return module.params["selfsigned_version"] == 2
|
||||
|
||||
def create_backend(self, module):
|
||||
def create_backend(
|
||||
self, module: AnsibleModule
|
||||
) -> SelfSignedCertificateBackendCryptography:
|
||||
return SelfSignedCertificateBackendCryptography(module)
|
||||
|
||||
|
||||
def add_selfsigned_provider_to_argument_spec(argument_spec):
|
||||
def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
argument_spec.argument_spec["provider"]["choices"].append("selfsigned")
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_crl import (
|
||||
TIMESTAMP_FORMAT,
|
||||
cryptography_decode_revoked_certificate,
|
||||
@@ -22,6 +24,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
from ....plugin_utils.filter_module import FilterModuleMock
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
# crypto_utils
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
@@ -33,14 +47,19 @@ except ImportError:
|
||||
|
||||
|
||||
class CRLInfoRetrieval:
|
||||
def __init__(self, module, content, list_revoked_certificates=True):
|
||||
def __init__(
|
||||
self,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
list_revoked_certificates: bool = True,
|
||||
) -> None:
|
||||
# content must be a bytes string
|
||||
self.module = module
|
||||
self.content = content
|
||||
self.list_revoked_certificates = list_revoked_certificates
|
||||
self.name_encoding = module.params.get("name_encoding", "ignore")
|
||||
|
||||
def get_info(self):
|
||||
def get_info(self) -> dict[str, t.Any]:
|
||||
self.crl_pem = identify_pem_format(self.content)
|
||||
try:
|
||||
if self.crl_pem:
|
||||
@@ -50,7 +69,7 @@ class CRLInfoRetrieval:
|
||||
except ValueError as e:
|
||||
self.module.fail_json(msg=f"Error while decoding CRL: {e}")
|
||||
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"changed": False,
|
||||
"format": "pem" if self.crl_pem else "der",
|
||||
"last_update": None,
|
||||
@@ -61,7 +80,11 @@ class CRLInfoRetrieval:
|
||||
}
|
||||
|
||||
result["last_update"] = self.crl.last_update.strftime(TIMESTAMP_FORMAT)
|
||||
result["next_update"] = self.crl.next_update.strftime(TIMESTAMP_FORMAT)
|
||||
result["next_update"] = (
|
||||
self.crl.next_update.strftime(TIMESTAMP_FORMAT)
|
||||
if self.crl.next_update
|
||||
else None
|
||||
)
|
||||
result["digest"] = cryptography_oid_to_name(
|
||||
cryptography_get_signature_algorithm_oid_from_crl(self.crl)
|
||||
)
|
||||
@@ -83,7 +106,9 @@ class CRLInfoRetrieval:
|
||||
return result
|
||||
|
||||
|
||||
def get_crl_info(module, content, list_revoked_certificates=True):
|
||||
def get_crl_info(
|
||||
module: GeneralAnsibleModule, content: bytes, list_revoked_certificates: bool = True
|
||||
) -> dict[str, t.Any]:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import binascii
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
@@ -26,13 +27,14 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
|
||||
cryptography_name_to_oid,
|
||||
cryptography_parse_key_usage_params,
|
||||
cryptography_parse_relative_distinguished_name,
|
||||
is_potential_certificate_issuer_public_key,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr_info import (
|
||||
get_csr_info,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
|
||||
load_certificate_issuer_privatekey,
|
||||
load_certificate_request,
|
||||
load_privatekey,
|
||||
parse_name_field,
|
||||
parse_ordered_name_field,
|
||||
select_message_digest,
|
||||
@@ -43,6 +45,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ..cryptography_support import CertificatePrivateKeyTypes
|
||||
|
||||
_ET = t.TypeVar("_ET", bound="cryptography.x509.ExtensionType")
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -69,49 +83,58 @@ class CertificateSigningRequestError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.digest = module.params["digest"]
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
if self.privatekey_content is not None:
|
||||
self.privatekey_content = self.privatekey_content.encode("utf-8")
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.version = module.params["version"]
|
||||
self.subjectAltName = module.params["subject_alt_name"]
|
||||
self.subjectAltName_critical = module.params["subject_alt_name_critical"]
|
||||
self.keyUsage = module.params["key_usage"]
|
||||
self.keyUsage_critical = module.params["key_usage_critical"]
|
||||
self.extendedKeyUsage = module.params["extended_key_usage"]
|
||||
self.extendedKeyUsage_critical = module.params["extended_key_usage_critical"]
|
||||
self.basicConstraints = module.params["basic_constraints"]
|
||||
self.basicConstraints_critical = module.params["basic_constraints_critical"]
|
||||
self.ocspMustStaple = module.params["ocsp_must_staple"]
|
||||
self.ocspMustStaple_critical = module.params["ocsp_must_staple_critical"]
|
||||
self.name_constraints_permitted = (
|
||||
self.digest: str = module.params["digest"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
if privatekey_content is not None:
|
||||
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
|
||||
else:
|
||||
self.privatekey_content = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
self.version: t.Literal[1] = module.params["version"]
|
||||
self.subjectAltName: list[str] | None = module.params["subject_alt_name"]
|
||||
self.subjectAltName_critical: bool = module.params["subject_alt_name_critical"]
|
||||
self.keyUsage: list[str] | None = module.params["key_usage"]
|
||||
self.keyUsage_critical: bool = module.params["key_usage_critical"]
|
||||
self.extendedKeyUsage: list[str] | None = module.params["extended_key_usage"]
|
||||
self.extendedKeyUsage_critical: bool = module.params[
|
||||
"extended_key_usage_critical"
|
||||
]
|
||||
self.basicConstraints: list[str] | None = module.params["basic_constraints"]
|
||||
self.basicConstraints_critical: bool = module.params[
|
||||
"basic_constraints_critical"
|
||||
]
|
||||
self.ocspMustStaple: bool = module.params["ocsp_must_staple"]
|
||||
self.ocspMustStaple_critical: bool = module.params["ocsp_must_staple_critical"]
|
||||
self.name_constraints_permitted: list[str] = (
|
||||
module.params["name_constraints_permitted"] or []
|
||||
)
|
||||
self.name_constraints_excluded = (
|
||||
self.name_constraints_excluded: list[str] = (
|
||||
module.params["name_constraints_excluded"] or []
|
||||
)
|
||||
self.name_constraints_critical = module.params["name_constraints_critical"]
|
||||
self.create_subject_key_identifier = module.params[
|
||||
self.name_constraints_critical: bool = module.params[
|
||||
"name_constraints_critical"
|
||||
]
|
||||
self.create_subject_key_identifier: bool = module.params[
|
||||
"create_subject_key_identifier"
|
||||
]
|
||||
self.subject_key_identifier = module.params["subject_key_identifier"]
|
||||
self.authority_key_identifier = module.params["authority_key_identifier"]
|
||||
self.authority_cert_issuer = module.params["authority_cert_issuer"]
|
||||
self.authority_cert_serial_number = module.params[
|
||||
subject_key_identifier: str | None = module.params["subject_key_identifier"]
|
||||
authority_key_identifier: str | None = module.params["authority_key_identifier"]
|
||||
self.authority_cert_issuer: list[str] | None = module.params[
|
||||
"authority_cert_issuer"
|
||||
]
|
||||
self.authority_cert_serial_number: int = module.params[
|
||||
"authority_cert_serial_number"
|
||||
]
|
||||
self.crl_distribution_points = module.params["crl_distribution_points"]
|
||||
self.csr = None
|
||||
self.privatekey = None
|
||||
self.crl_distribution_points: (
|
||||
list[cryptography.x509.DistributionPoint] | None
|
||||
) = None
|
||||
self.csr: cryptography.x509.CertificateSigningRequest | None = None
|
||||
self.privatekey: CertificateIssuerPrivateKeyTypes | None = None
|
||||
|
||||
if (
|
||||
self.create_subject_key_identifier
|
||||
and self.subject_key_identifier is not None
|
||||
):
|
||||
if self.create_subject_key_identifier and subject_key_identifier is not None:
|
||||
module.fail_json(
|
||||
msg="subject_key_identifier cannot be specified if create_subject_key_identifier is true"
|
||||
)
|
||||
@@ -153,35 +176,37 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
self.using_common_name_for_san = True
|
||||
break
|
||||
|
||||
if self.subject_key_identifier is not None:
|
||||
self.subject_key_identifier: bytes | None = None
|
||||
if subject_key_identifier is not None:
|
||||
try:
|
||||
self.subject_key_identifier = binascii.unhexlify(
|
||||
self.subject_key_identifier.replace(":", "")
|
||||
subject_key_identifier.replace(":", "")
|
||||
)
|
||||
except Exception as e:
|
||||
raise CertificateSigningRequestError(
|
||||
f"Cannot parse subject_key_identifier: {e}"
|
||||
)
|
||||
|
||||
if self.authority_key_identifier is not None:
|
||||
self.authority_key_identifier: bytes | None = None
|
||||
if authority_key_identifier is not None:
|
||||
try:
|
||||
self.authority_key_identifier = binascii.unhexlify(
|
||||
self.authority_key_identifier.replace(":", "")
|
||||
authority_key_identifier.replace(":", "")
|
||||
)
|
||||
except Exception as e:
|
||||
raise CertificateSigningRequestError(
|
||||
f"Cannot parse authority_key_identifier: {e}"
|
||||
)
|
||||
|
||||
self.existing_csr = None
|
||||
self.existing_csr_bytes = None
|
||||
self.existing_csr: cryptography.x509.CertificateSigningRequest | None = None
|
||||
self.existing_csr_bytes: bytes | None = None
|
||||
|
||||
self.diff_before = self._get_info(None)
|
||||
self.diff_after = self._get_info(None)
|
||||
|
||||
def _get_info(self, data):
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return dict()
|
||||
return {}
|
||||
try:
|
||||
result = get_csr_info(
|
||||
self.module,
|
||||
@@ -195,30 +220,28 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
return dict(can_parse_csr=False)
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_csr(self):
|
||||
def generate_csr(self) -> None:
|
||||
"""(Re-)Generate CSR."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_csr_data(self):
|
||||
def get_csr_data(self) -> bytes:
|
||||
"""Return bytes for self.csr."""
|
||||
pass
|
||||
|
||||
def set_existing(self, csr_bytes):
|
||||
def set_existing(self, csr_bytes: bytes | None) -> None:
|
||||
"""Set existing CSR bytes. None indicates that the CSR does not exist."""
|
||||
self.existing_csr_bytes = csr_bytes
|
||||
self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes)
|
||||
|
||||
def has_existing(self):
|
||||
def has_existing(self) -> bool:
|
||||
"""Query whether an existing CSR is/has been there."""
|
||||
return self.existing_csr_bytes is not None
|
||||
|
||||
def _ensure_private_key_loaded(self):
|
||||
def _ensure_private_key_loaded(self) -> None:
|
||||
"""Load the provided private key into self.privatekey."""
|
||||
if self.privatekey is not None:
|
||||
return
|
||||
try:
|
||||
self.privatekey = load_privatekey(
|
||||
self.privatekey = load_certificate_issuer_privatekey(
|
||||
path=self.privatekey_path,
|
||||
content=self.privatekey_content,
|
||||
passphrase=self.privatekey_passphrase,
|
||||
@@ -227,11 +250,10 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
raise CertificateSigningRequestError(exc)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_csr(self):
|
||||
def _check_csr(self) -> bool:
|
||||
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
|
||||
pass
|
||||
|
||||
def needs_regeneration(self):
|
||||
def needs_regeneration(self) -> bool:
|
||||
"""Check whether a regeneration is necessary."""
|
||||
if self.existing_csr_bytes is None:
|
||||
return True
|
||||
@@ -245,9 +267,9 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
self._ensure_private_key_loaded()
|
||||
return not self._check_csr()
|
||||
|
||||
def dump(self, include_csr):
|
||||
def dump(self, include_csr: bool) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"privatekey": self.privatekey_path,
|
||||
"subject": self.subject,
|
||||
"subjectAltName": self.subjectAltName,
|
||||
@@ -274,44 +296,49 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
return result
|
||||
|
||||
|
||||
def parse_crl_distribution_points(module, crl_distribution_points):
|
||||
def parse_crl_distribution_points(
|
||||
module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
|
||||
) -> list[cryptography.x509.DistributionPoint]:
|
||||
result = []
|
||||
for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
|
||||
try:
|
||||
params = dict(
|
||||
full_name=None,
|
||||
relative_name=None,
|
||||
crl_issuer=None,
|
||||
reasons=None,
|
||||
)
|
||||
full_name = None
|
||||
relative_name = None
|
||||
crl_issuer = None
|
||||
reasons = None
|
||||
if parse_crl_distribution_point["full_name"] is not None:
|
||||
if not parse_crl_distribution_point["full_name"]:
|
||||
raise OpenSSLObjectError("full_name must not be empty")
|
||||
params["full_name"] = [
|
||||
full_name = [
|
||||
cryptography_get_name(name, "full name")
|
||||
for name in parse_crl_distribution_point["full_name"]
|
||||
]
|
||||
if parse_crl_distribution_point["relative_name"] is not None:
|
||||
if not parse_crl_distribution_point["relative_name"]:
|
||||
raise OpenSSLObjectError("relative_name must not be empty")
|
||||
params["relative_name"] = (
|
||||
cryptography_parse_relative_distinguished_name(
|
||||
parse_crl_distribution_point["relative_name"]
|
||||
)
|
||||
relative_name = cryptography_parse_relative_distinguished_name(
|
||||
parse_crl_distribution_point["relative_name"]
|
||||
)
|
||||
if parse_crl_distribution_point["crl_issuer"] is not None:
|
||||
if not parse_crl_distribution_point["crl_issuer"]:
|
||||
raise OpenSSLObjectError("crl_issuer must not be empty")
|
||||
params["crl_issuer"] = [
|
||||
crl_issuer = [
|
||||
cryptography_get_name(name, "CRL issuer")
|
||||
for name in parse_crl_distribution_point["crl_issuer"]
|
||||
]
|
||||
if parse_crl_distribution_point["reasons"] is not None:
|
||||
reasons = []
|
||||
reasons_list = []
|
||||
for reason in parse_crl_distribution_point["reasons"]:
|
||||
reasons.append(REVOCATION_REASON_MAP[reason])
|
||||
params["reasons"] = frozenset(reasons)
|
||||
result.append(cryptography.x509.DistributionPoint(**params))
|
||||
reasons_list.append(REVOCATION_REASON_MAP[reason])
|
||||
reasons = frozenset(reasons_list)
|
||||
result.append(
|
||||
cryptography.x509.DistributionPoint(
|
||||
full_name=full_name,
|
||||
relative_name=relative_name,
|
||||
crl_issuer=crl_issuer,
|
||||
reasons=reasons,
|
||||
)
|
||||
)
|
||||
except (OpenSSLObjectError, ValueError) as e:
|
||||
raise OpenSSLObjectError(
|
||||
f"Error while parsing CRL distribution point #{index}: {e}"
|
||||
@@ -321,21 +348,25 @@ def parse_crl_distribution_points(module, crl_distribution_points):
|
||||
|
||||
# Implementation with using cryptography
|
||||
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CertificateSigningRequestCryptographyBackend, self).__init__(module)
|
||||
if self.version != 1:
|
||||
module.warn(
|
||||
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
|
||||
)
|
||||
|
||||
if self.crl_distribution_points:
|
||||
crl_distribution_points: list[dict[str, t.Any]] | None = module.params[
|
||||
"crl_distribution_points"
|
||||
]
|
||||
if crl_distribution_points:
|
||||
self.crl_distribution_points = parse_crl_distribution_points(
|
||||
module, self.crl_distribution_points
|
||||
module, crl_distribution_points
|
||||
)
|
||||
|
||||
def generate_csr(self):
|
||||
def generate_csr(self) -> None:
|
||||
"""(Re-)Generate CSR."""
|
||||
self._ensure_private_key_loaded()
|
||||
assert self.privatekey is not None
|
||||
|
||||
csr = cryptography.x509.CertificateSigningRequestBuilder()
|
||||
try:
|
||||
@@ -412,6 +443,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
raise OpenSSLObjectError(f"Error while parsing name constraint: {e}")
|
||||
|
||||
if self.create_subject_key_identifier:
|
||||
if not is_potential_certificate_issuer_public_key(
|
||||
self.privatekey.public_key()
|
||||
):
|
||||
raise OpenSSLObjectError(
|
||||
"Private key can not be used to create subject key identifier"
|
||||
)
|
||||
csr = csr.add_extension(
|
||||
cryptography.x509.SubjectKeyIdentifier.from_public_key(
|
||||
self.privatekey.public_key()
|
||||
@@ -450,7 +487,10 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
critical=False,
|
||||
)
|
||||
|
||||
digest = None
|
||||
# csr.sign() does not accept some digests we theoretically could have in digest.
|
||||
# For that reason we use type t.Any here. csr.sign() will complain if
|
||||
# the digest is not acceptable.
|
||||
digest: t.Any | None = None
|
||||
if cryptography_key_needs_digest_for_signing(self.privatekey):
|
||||
digest = select_message_digest(self.digest)
|
||||
if digest is None:
|
||||
@@ -482,16 +522,22 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
+ "This is probably caused by an invalid Subject Alternative DNS Name."
|
||||
)
|
||||
|
||||
def get_csr_data(self):
|
||||
def get_csr_data(self) -> bytes:
|
||||
"""Return bytes for self.csr."""
|
||||
if self.csr is None:
|
||||
raise AssertionError("Violated contract: csr is not populated")
|
||||
return self.csr.public_bytes(
|
||||
cryptography.hazmat.primitives.serialization.Encoding.PEM
|
||||
)
|
||||
|
||||
def _check_csr(self):
|
||||
def _check_csr(self) -> bool:
|
||||
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
|
||||
if self.existing_csr is None:
|
||||
raise AssertionError("Violated contract: existing_csr is not populated")
|
||||
if self.privatekey is None:
|
||||
raise AssertionError("Violated contract: privatekey is not populated")
|
||||
|
||||
def _check_subject(csr):
|
||||
def _check_subject(csr: cryptography.x509.CertificateSigningRequest) -> bool:
|
||||
subject = [
|
||||
(cryptography_name_to_oid(entry[0]), to_text(entry[1]))
|
||||
for entry in self.subject
|
||||
@@ -502,12 +548,14 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else:
|
||||
return set(subject) == set(current_subject)
|
||||
|
||||
def _find_extension(extensions, exttype):
|
||||
def _find_extension(
|
||||
extensions: cryptography.x509.Extensions, exttype: type[_ET]
|
||||
) -> cryptography.x509.Extension[_ET] | None:
|
||||
return next(
|
||||
(ext for ext in extensions if isinstance(ext.value, exttype)), None
|
||||
)
|
||||
|
||||
def _check_subjectAltName(extensions):
|
||||
def _check_subjectAltName(extensions: cryptography.x509.Extensions) -> bool:
|
||||
current_altnames_ext = _find_extension(
|
||||
extensions, cryptography.x509.SubjectAlternativeName
|
||||
)
|
||||
@@ -526,12 +574,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
)
|
||||
if set(altnames) != set(current_altnames):
|
||||
return False
|
||||
if altnames:
|
||||
if altnames and current_altnames_ext:
|
||||
if current_altnames_ext.critical != self.subjectAltName_critical:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_keyUsage(extensions):
|
||||
def _check_keyUsage(extensions: cryptography.x509.Extensions) -> bool:
|
||||
current_keyusage_ext = _find_extension(
|
||||
extensions, cryptography.x509.KeyUsage
|
||||
)
|
||||
@@ -547,7 +595,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_extenededKeyUsage(extensions):
|
||||
def _check_extenededKeyUsage(extensions: cryptography.x509.Extensions) -> bool:
|
||||
current_usages_ext = _find_extension(
|
||||
extensions, cryptography.x509.ExtendedKeyUsage
|
||||
)
|
||||
@@ -566,12 +614,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
)
|
||||
if set(current_usages) != set(usages):
|
||||
return False
|
||||
if usages:
|
||||
if usages and current_usages_ext:
|
||||
if current_usages_ext.critical != self.extendedKeyUsage_critical:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_basicConstraints(extensions):
|
||||
def _check_basicConstraints(extensions: cryptography.x509.Extensions) -> bool:
|
||||
bc_ext = _find_extension(extensions, cryptography.x509.BasicConstraints)
|
||||
current_ca = bc_ext.value.ca if bc_ext else False
|
||||
current_path_length = bc_ext.value.path_length if bc_ext else None
|
||||
@@ -591,7 +639,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else:
|
||||
return bc_ext is None
|
||||
|
||||
def _check_ocspMustStaple(extensions):
|
||||
def _check_ocspMustStaple(extensions: cryptography.x509.Extensions) -> bool:
|
||||
tlsfeature_ext = _find_extension(extensions, cryptography.x509.TLSFeature)
|
||||
if self.ocspMustStaple:
|
||||
if (
|
||||
@@ -606,7 +654,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else:
|
||||
return tlsfeature_ext is None
|
||||
|
||||
def _check_nameConstraints(extensions):
|
||||
def _check_nameConstraints(extensions: cryptography.x509.Extensions) -> bool:
|
||||
current_nc_ext = _find_extension(
|
||||
extensions, cryptography.x509.NameConstraints
|
||||
)
|
||||
@@ -638,12 +686,14 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
current_nc_excl
|
||||
):
|
||||
return False
|
||||
if nc_perm or nc_excl:
|
||||
if (nc_perm or nc_excl) and current_nc_ext:
|
||||
if current_nc_ext.critical != self.name_constraints_critical:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_subject_key_identifier(extensions):
|
||||
def _check_subject_key_identifier(
|
||||
extensions: cryptography.x509.Extensions,
|
||||
) -> bool:
|
||||
ext = _find_extension(extensions, cryptography.x509.SubjectKeyIdentifier)
|
||||
if (
|
||||
self.create_subject_key_identifier
|
||||
@@ -652,6 +702,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
if not ext or ext.critical:
|
||||
return False
|
||||
if self.create_subject_key_identifier:
|
||||
assert self.privatekey is not None
|
||||
digest = cryptography.x509.SubjectKeyIdentifier.from_public_key(
|
||||
self.privatekey.public_key()
|
||||
).digest
|
||||
@@ -661,7 +712,9 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else:
|
||||
return ext is None
|
||||
|
||||
def _check_authority_key_identifier(extensions):
|
||||
def _check_authority_key_identifier(
|
||||
extensions: cryptography.x509.Extensions,
|
||||
) -> bool:
|
||||
ext = _find_extension(extensions, cryptography.x509.AuthorityKeyIdentifier)
|
||||
if (
|
||||
self.authority_key_identifier is not None
|
||||
@@ -688,7 +741,9 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else:
|
||||
return ext is None
|
||||
|
||||
def _check_crl_distribution_points(extensions):
|
||||
def _check_crl_distribution_points(
|
||||
extensions: cryptography.x509.Extensions,
|
||||
) -> bool:
|
||||
ext = _find_extension(extensions, cryptography.x509.CRLDistributionPoints)
|
||||
if self.crl_distribution_points is None:
|
||||
return ext is None
|
||||
@@ -696,7 +751,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
return False
|
||||
return list(ext.value) == self.crl_distribution_points
|
||||
|
||||
def _check_extensions(csr):
|
||||
def _check_extensions(csr: cryptography.x509.CertificateSigningRequest) -> bool:
|
||||
extensions = csr.extensions
|
||||
return (
|
||||
_check_subjectAltName(extensions)
|
||||
@@ -710,7 +765,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
and _check_crl_distribution_points(extensions)
|
||||
)
|
||||
|
||||
def _check_signature(csr):
|
||||
def _check_signature(csr: cryptography.x509.CertificateSigningRequest) -> bool:
|
||||
if not csr.is_signature_valid:
|
||||
return False
|
||||
# To check whether public key of CSR belongs to private key,
|
||||
@@ -719,6 +774,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
cryptography.hazmat.primitives.serialization.Encoding.PEM,
|
||||
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
assert self.privatekey is not None
|
||||
key_b = self.privatekey.public_key().public_bytes(
|
||||
cryptography.hazmat.primitives.serialization.Encoding.PEM,
|
||||
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
@@ -732,14 +788,16 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
)
|
||||
|
||||
|
||||
def select_backend(module):
|
||||
def select_backend(
|
||||
module: AnsibleModule,
|
||||
) -> CertificateSigningRequestCryptographyBackend:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return CertificateSigningRequestCryptographyBackend(module)
|
||||
|
||||
|
||||
def get_csr_argument_spec():
|
||||
def get_csr_argument_spec() -> ArgumentSpec:
|
||||
return ArgumentSpec(
|
||||
argument_spec=dict(
|
||||
digest=dict(type="str", default="sha256"),
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import binascii
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
|
||||
@@ -27,6 +28,19 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificatePublicKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
from ....plugin_utils.filter_module import FilterModuleMock
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -41,66 +55,69 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
|
||||
|
||||
|
||||
class CSRInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module, content, validate_signature):
|
||||
# content must be a bytes string
|
||||
def __init__(
|
||||
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
|
||||
) -> None:
|
||||
self.module = module
|
||||
self.content = content
|
||||
self.validate_signature = validate_signature
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_ordered(self):
|
||||
def _get_subject_ordered(self) -> list[list[str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_key_usage(self):
|
||||
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extended_key_usage(self):
|
||||
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_basic_constraints(self):
|
||||
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_ocsp_must_staple(self):
|
||||
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_alt_name(self):
|
||||
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_name_constraints(self):
|
||||
def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key_pem(self):
|
||||
def _get_public_key_pem(self) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key_object(self):
|
||||
def _get_public_key_object(self) -> CertificatePublicKeyTypes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_key_identifier(self):
|
||||
def _get_subject_key_identifier(self) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_authority_key_identifier(self):
|
||||
def _get_authority_key_identifier(
|
||||
self,
|
||||
) -> tuple[bytes | None, list[str] | None, int | None]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_all_extensions(self):
|
||||
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _is_signature_valid(self):
|
||||
def _is_signature_valid(self) -> bool:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint=False):
|
||||
result = dict()
|
||||
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
self.csr = load_certificate_request(
|
||||
None,
|
||||
content=self.content,
|
||||
@@ -145,15 +162,17 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
}
|
||||
)
|
||||
|
||||
ski = self._get_subject_key_identifier()
|
||||
if ski is not None:
|
||||
ski = binascii.hexlify(ski).decode("ascii")
|
||||
ski_bytes = self._get_subject_key_identifier()
|
||||
ski = None
|
||||
if ski_bytes is not None:
|
||||
ski = binascii.hexlify(ski_bytes).decode("ascii")
|
||||
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
|
||||
result["subject_key_identifier"] = ski
|
||||
|
||||
aki, aci, acsn = self._get_authority_key_identifier()
|
||||
if aki is not None:
|
||||
aki = binascii.hexlify(aki).decode("ascii")
|
||||
aki_bytes, aci, acsn = self._get_authority_key_identifier()
|
||||
aki = None
|
||||
if aki_bytes is not None:
|
||||
aki = binascii.hexlify(aki_bytes).decode("ascii")
|
||||
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
|
||||
result["authority_key_identifier"] = aki
|
||||
result["authority_cert_issuer"] = aci
|
||||
@@ -170,19 +189,25 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
"""Validate the supplied CSR, using the cryptography backend"""
|
||||
|
||||
def __init__(self, module, content, validate_signature):
|
||||
def __init__(
|
||||
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
|
||||
) -> None:
|
||||
super(CSRInfoRetrievalCryptography, self).__init__(
|
||||
module, content, validate_signature
|
||||
)
|
||||
self.name_encoding = module.params.get("name_encoding", "ignore")
|
||||
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get(
|
||||
"name_encoding", "ignore"
|
||||
)
|
||||
|
||||
def _get_subject_ordered(self):
|
||||
result = []
|
||||
def _get_subject_ordered(self) -> list[list[str]]:
|
||||
result: list[list[str]] = []
|
||||
for attribute in self.csr.subject:
|
||||
result.append([cryptography_oid_to_name(attribute.oid), attribute.value])
|
||||
result.append(
|
||||
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_key_usage(self):
|
||||
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
current_key_ext = self.csr.extensions.get_extension_for_class(x509.KeyUsage)
|
||||
current_key_usage = current_key_ext.value
|
||||
@@ -229,7 +254,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_extended_key_usage(self):
|
||||
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
|
||||
x509.ExtendedKeyUsage
|
||||
@@ -243,7 +268,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_basic_constraints(self):
|
||||
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
|
||||
x509.BasicConstraints
|
||||
@@ -255,7 +280,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_ocsp_must_staple(self):
|
||||
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
|
||||
try:
|
||||
# This only works with cryptography >= 2.1
|
||||
tlsfeature_ext = self.csr.extensions.get_extension_for_class(
|
||||
@@ -268,7 +293,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_subject_alt_name(self):
|
||||
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
san_ext = self.csr.extensions.get_extension_for_class(
|
||||
x509.SubjectAlternativeName
|
||||
@@ -281,7 +306,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_name_constraints(self):
|
||||
def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
|
||||
try:
|
||||
nc_ext = self.csr.extensions.get_extension_for_class(x509.NameConstraints)
|
||||
permitted = [
|
||||
@@ -296,23 +321,25 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, None, False
|
||||
|
||||
def _get_public_key_pem(self):
|
||||
def _get_public_key_pem(self) -> bytes:
|
||||
return self.csr.public_key().public_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def _get_public_key_object(self):
|
||||
def _get_public_key_object(self) -> CertificatePublicKeyTypes:
|
||||
return self.csr.public_key()
|
||||
|
||||
def _get_subject_key_identifier(self):
|
||||
def _get_subject_key_identifier(self) -> bytes | None:
|
||||
try:
|
||||
ext = self.csr.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
|
||||
return ext.value.digest
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None
|
||||
|
||||
def _get_authority_key_identifier(self):
|
||||
def _get_authority_key_identifier(
|
||||
self,
|
||||
) -> tuple[bytes | None, list[str] | None, int | None]:
|
||||
try:
|
||||
ext = self.csr.extensions.get_extension_for_class(
|
||||
x509.AuthorityKeyIdentifier
|
||||
@@ -331,23 +358,28 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, None, None
|
||||
|
||||
def _get_all_extensions(self):
|
||||
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
|
||||
return cryptography_get_extensions_from_csr(self.csr)
|
||||
|
||||
def _is_signature_valid(self):
|
||||
def _is_signature_valid(self) -> bool:
|
||||
return self.csr.is_signature_valid
|
||||
|
||||
|
||||
def get_csr_info(
|
||||
module, content, validate_signature=True, prefer_one_fingerprint=False
|
||||
):
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
validate_signature: bool = True,
|
||||
prefer_one_fingerprint: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
info = CSRInfoRetrievalCryptography(
|
||||
module, content, validate_signature=validate_signature
|
||||
)
|
||||
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
|
||||
|
||||
|
||||
def select_backend(module, content, validate_signature=True):
|
||||
def select_backend(
|
||||
module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
|
||||
) -> CSRInfoRetrieval:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import base64
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
@@ -33,6 +34,17 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule]
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -64,29 +76,37 @@ class PrivateKeyError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: GeneralAnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.type = module.params["type"]
|
||||
self.size = module.params["size"]
|
||||
self.curve = module.params["curve"]
|
||||
self.passphrase = module.params["passphrase"]
|
||||
self.cipher = module.params["cipher"]
|
||||
self.format = module.params["format"]
|
||||
self.format_mismatch = module.params.get("format_mismatch", "regenerate")
|
||||
self.regenerate = module.params.get("regenerate", "full_idempotence")
|
||||
self.type: t.Literal[
|
||||
"DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"
|
||||
] = module.params["type"]
|
||||
self.size: int = module.params["size"]
|
||||
self.curve: str | None = module.params["curve"]
|
||||
self.passphrase: str | None = module.params["passphrase"]
|
||||
self.cipher: str = module.params["cipher"]
|
||||
self.format: t.Literal["pkcs1", "pkcs8", "raw", "auto", "auto_ignore"] = (
|
||||
module.params["format"]
|
||||
)
|
||||
self.format_mismatch: t.Literal["regenerate", "convert"] = module.params.get(
|
||||
"format_mismatch", "regenerate"
|
||||
)
|
||||
self.regenerate: t.Literal[
|
||||
"never", "fail", "partial_idempotence", "full_idempotence", "always"
|
||||
] = module.params.get("regenerate", "full_idempotence")
|
||||
|
||||
self.private_key = None
|
||||
self.private_key: PrivateKeyTypes | None = None
|
||||
|
||||
self.existing_private_key = None
|
||||
self.existing_private_key_bytes = None
|
||||
self.existing_private_key: PrivateKeyTypes | None = None
|
||||
self.existing_private_key_bytes: bytes | None = None
|
||||
|
||||
self.diff_before = self._get_info(None)
|
||||
self.diff_after = self._get_info(None)
|
||||
|
||||
def _get_info(self, data):
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return dict()
|
||||
result = dict(can_parse_key=False)
|
||||
return {}
|
||||
result: dict[str, t.Any] = {"can_parse_key": False}
|
||||
try:
|
||||
result.update(
|
||||
get_privatekey_info(
|
||||
@@ -106,11 +126,11 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
return result
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_private_key(self):
|
||||
def generate_private_key(self) -> None:
|
||||
"""(Re-)Generate private key."""
|
||||
pass
|
||||
|
||||
def convert_private_key(self):
|
||||
def convert_private_key(self) -> None:
|
||||
"""Convert existing private key (self.existing_private_key) to new private key (self.private_key).
|
||||
|
||||
This is effectively a copy without active conversion. The conversion is done
|
||||
@@ -121,42 +141,37 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
self.private_key = self.existing_private_key
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_private_key_data(self):
|
||||
def get_private_key_data(self) -> bytes:
|
||||
"""Return bytes for self.private_key."""
|
||||
pass
|
||||
|
||||
def set_existing(self, privatekey_bytes):
|
||||
def set_existing(self, privatekey_bytes: bytes | None) -> None:
|
||||
"""Set existing private key bytes. None indicates that the key does not exist."""
|
||||
self.existing_private_key_bytes = privatekey_bytes
|
||||
self.diff_after = self.diff_before = self._get_info(
|
||||
self.existing_private_key_bytes
|
||||
)
|
||||
|
||||
def has_existing(self):
|
||||
def has_existing(self) -> bool:
|
||||
"""Query whether an existing private key is/has been there."""
|
||||
return self.existing_private_key_bytes is not None
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_passphrase(self):
|
||||
def _check_passphrase(self) -> bool:
|
||||
"""Check whether provided passphrase matches, assuming self.existing_private_key_bytes has been populated."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _ensure_existing_private_key_loaded(self):
|
||||
def _ensure_existing_private_key_loaded(self) -> None:
|
||||
"""Make sure that self.existing_private_key is populated from self.existing_private_key_bytes."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_size_and_type(self):
|
||||
def _check_size_and_type(self) -> bool:
|
||||
"""Check whether provided size and type matches, assuming self.existing_private_key has been populated."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_format(self):
|
||||
def _check_format(self) -> bool:
|
||||
"""Check whether the key file format, assuming self.existing_private_key and self.existing_private_key_bytes has been populated."""
|
||||
pass
|
||||
|
||||
def needs_regeneration(self):
|
||||
def needs_regeneration(self) -> bool:
|
||||
"""Check whether a regeneration is necessary."""
|
||||
if self.regenerate == "always":
|
||||
return True
|
||||
@@ -194,7 +209,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
)
|
||||
return False
|
||||
|
||||
def needs_conversion(self):
|
||||
def needs_conversion(self) -> bool:
|
||||
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
|
||||
# During conversion step, convert if format does not match and format_mismatch == 'convert'
|
||||
self._ensure_existing_private_key_loaded()
|
||||
@@ -204,7 +219,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
and not self._check_format()
|
||||
)
|
||||
|
||||
def _get_fingerprint(self):
|
||||
def _get_fingerprint(self) -> dict[str, str] | None:
|
||||
if self.private_key:
|
||||
return get_fingerprint_of_privatekey(self.private_key)
|
||||
try:
|
||||
@@ -214,8 +229,9 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
if self.existing_private_key:
|
||||
return get_fingerprint_of_privatekey(self.existing_private_key)
|
||||
return None
|
||||
|
||||
def dump(self, include_key):
|
||||
def dump(self, include_key: bool) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
if not self.private_key:
|
||||
@@ -224,7 +240,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
except Exception:
|
||||
# Ignore errors
|
||||
pass
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"type": self.type,
|
||||
"size": self.size,
|
||||
"fingerprint": self._get_fingerprint(),
|
||||
@@ -253,38 +269,57 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
return result
|
||||
|
||||
|
||||
# Implementation with using cryptography
|
||||
class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
class _Curve:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
ectype: str,
|
||||
deprecated: bool,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.ectype = ectype
|
||||
self.deprecated = deprecated
|
||||
|
||||
def _get_ec_class(self, ectype):
|
||||
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(ectype)
|
||||
def _get_ec_class(
|
||||
self, module: GeneralAnsibleModule
|
||||
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
|
||||
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore
|
||||
if ecclass is None:
|
||||
self.module.fail_json(
|
||||
msg=f"Your cryptography version does not support {ectype}"
|
||||
module.fail_json(
|
||||
msg=f"Your cryptography version does not support {self.ectype}"
|
||||
)
|
||||
return ecclass
|
||||
|
||||
def _add_curve(self, name, ectype, deprecated=False):
|
||||
def create(size):
|
||||
ecclass = self._get_ec_class(ectype)
|
||||
return ecclass()
|
||||
def create(
|
||||
self, size: int, module: GeneralAnsibleModule
|
||||
) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
|
||||
ecclass = self._get_ec_class(module)
|
||||
return ecclass()
|
||||
|
||||
def verify(privatekey):
|
||||
ecclass = self._get_ec_class(ectype)
|
||||
return isinstance(
|
||||
privatekey.private_numbers().public_numbers.curve, ecclass
|
||||
)
|
||||
def verify(
|
||||
self,
|
||||
privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
|
||||
module: GeneralAnsibleModule,
|
||||
) -> bool:
|
||||
ecclass = self._get_ec_class(module)
|
||||
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
|
||||
|
||||
self.curves[name] = {
|
||||
"create": create,
|
||||
"verify": verify,
|
||||
"deprecated": deprecated,
|
||||
}
|
||||
|
||||
def __init__(self, module):
|
||||
# Implementation with using cryptography
|
||||
class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
|
||||
def _add_curve(
|
||||
self,
|
||||
name: str,
|
||||
ectype: str,
|
||||
deprecated: bool = False,
|
||||
) -> None:
|
||||
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
|
||||
|
||||
def __init__(self, module: GeneralAnsibleModule) -> None:
|
||||
super(PrivateKeyCryptographyBackend, self).__init__(module=module)
|
||||
|
||||
self.curves = dict()
|
||||
self.curves: dict[str, _Curve] = {}
|
||||
self._add_curve("secp224r1", "SECP224R1")
|
||||
self._add_curve("secp256k1", "SECP256K1")
|
||||
self._add_curve("secp256r1", "SECP256R1")
|
||||
@@ -305,15 +340,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
self._add_curve("brainpoolP384r1", "BrainpoolP384R1", deprecated=True)
|
||||
self._add_curve("brainpoolP512r1", "BrainpoolP512R1", deprecated=True)
|
||||
|
||||
def _get_wanted_format(self):
|
||||
def _get_wanted_format(self) -> t.Literal["pkcs1", "pkcs8", "raw"]:
|
||||
if self.format not in ("auto", "auto_ignore"):
|
||||
return self.format
|
||||
return self.format # type: ignore
|
||||
if self.type in ("X25519", "X448", "Ed25519", "Ed448"):
|
||||
return "pkcs8"
|
||||
else:
|
||||
return "pkcs1"
|
||||
|
||||
def generate_private_key(self):
|
||||
def generate_private_key(self) -> None:
|
||||
"""(Re-)Generate private key."""
|
||||
try:
|
||||
if self.type == "RSA":
|
||||
@@ -346,13 +381,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate()
|
||||
)
|
||||
if self.type == "ECC" and self.curve in self.curves:
|
||||
if self.curves[self.curve]["deprecated"]:
|
||||
if self.curves[self.curve].deprecated:
|
||||
self.module.warn(
|
||||
f"Elliptic curves of type {self.curve} should not be used for new keys!"
|
||||
)
|
||||
self.private_key = (
|
||||
cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
|
||||
curve=self.curves[self.curve]["create"](self.size),
|
||||
curve=self.curves[self.curve].create(
|
||||
size=self.size, module=self.module
|
||||
),
|
||||
)
|
||||
)
|
||||
except cryptography.exceptions.UnsupportedAlgorithm:
|
||||
@@ -360,22 +397,24 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
msg=f"Cryptography backend does not support the algorithm required for {self.type}"
|
||||
)
|
||||
|
||||
def get_private_key_data(self):
|
||||
def get_private_key_data(self) -> bytes:
|
||||
"""Return bytes for self.private_key"""
|
||||
if self.private_key is None:
|
||||
raise AssertionError("private_key not set")
|
||||
# Select export format and encoding
|
||||
try:
|
||||
export_format = self._get_wanted_format()
|
||||
export_format_txt = self._get_wanted_format()
|
||||
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
|
||||
if export_format == "pkcs1":
|
||||
if export_format_txt == "pkcs1":
|
||||
# "TraditionalOpenSSL" format is PKCS1
|
||||
export_format = (
|
||||
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
|
||||
)
|
||||
elif export_format == "pkcs8":
|
||||
elif export_format_txt == "pkcs8":
|
||||
export_format = (
|
||||
cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
|
||||
)
|
||||
elif export_format == "raw":
|
||||
elif export_format_txt == "raw":
|
||||
export_format = (
|
||||
cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
|
||||
)
|
||||
@@ -388,9 +427,9 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
)
|
||||
|
||||
# Select key encryption
|
||||
encryption_algorithm = (
|
||||
cryptography.hazmat.primitives.serialization.NoEncryption()
|
||||
)
|
||||
encryption_algorithm: (
|
||||
cryptography.hazmat.primitives.serialization.KeySerializationEncryption
|
||||
) = cryptography.hazmat.primitives.serialization.NoEncryption()
|
||||
if self.cipher and self.passphrase:
|
||||
if self.cipher == "auto":
|
||||
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
|
||||
@@ -418,8 +457,10 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
exception=traceback.format_exc(),
|
||||
)
|
||||
|
||||
def _load_privatekey(self):
|
||||
def _load_privatekey(self) -> PrivateKeyTypes:
|
||||
data = self.existing_private_key_bytes
|
||||
if data is None:
|
||||
raise AssertionError("existing_private_key_bytes not set")
|
||||
try:
|
||||
# Interpret bytes depending on format.
|
||||
format = identify_private_key_format(data)
|
||||
@@ -460,11 +501,13 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
except Exception as e:
|
||||
raise PrivateKeyError(e)
|
||||
|
||||
def _ensure_existing_private_key_loaded(self):
|
||||
def _ensure_existing_private_key_loaded(self) -> None:
|
||||
if self.existing_private_key is None and self.has_existing():
|
||||
self.existing_private_key = self._load_privatekey()
|
||||
|
||||
def _check_passphrase(self):
|
||||
def _check_passphrase(self) -> bool:
|
||||
if self.existing_private_key_bytes is None:
|
||||
raise AssertionError("existing_private_key_bytes not set")
|
||||
try:
|
||||
format = identify_private_key_format(self.existing_private_key_bytes)
|
||||
if format == "raw":
|
||||
@@ -475,7 +518,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
# provided.
|
||||
return self.passphrase is None
|
||||
else:
|
||||
return (
|
||||
return bool(
|
||||
cryptography.hazmat.primitives.serialization.load_pem_private_key(
|
||||
self.existing_private_key_bytes,
|
||||
None if self.passphrase is None else to_bytes(self.passphrase),
|
||||
@@ -484,7 +527,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _check_size_and_type(self):
|
||||
def _check_size_and_type(self) -> bool:
|
||||
if isinstance(
|
||||
self.existing_private_key,
|
||||
cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey,
|
||||
@@ -527,11 +570,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
return False
|
||||
if self.curve not in self.curves:
|
||||
return False
|
||||
return self.curves[self.curve]["verify"](self.existing_private_key)
|
||||
return self.curves[self.curve].verify(
|
||||
self.existing_private_key, module=self.module
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
def _check_format(self):
|
||||
def _check_format(self) -> bool:
|
||||
if self.existing_private_key_bytes is None:
|
||||
raise AssertionError("existing_private_key_bytes not set")
|
||||
if self.format == "auto_ignore":
|
||||
return True
|
||||
try:
|
||||
@@ -541,14 +588,14 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
return False
|
||||
|
||||
|
||||
def select_backend(module):
|
||||
def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return PrivateKeyCryptographyBackend(module)
|
||||
|
||||
|
||||
def get_privatekey_argument_spec():
|
||||
def get_privatekey_argument_spec() -> ArgumentSpec:
|
||||
return ArgumentSpec(
|
||||
argument_spec=dict(
|
||||
size=dict(type="int", default=4096),
|
||||
@@ -607,6 +654,6 @@ def get_privatekey_argument_spec():
|
||||
),
|
||||
),
|
||||
required_if=[
|
||||
["type", "ECC", ["curve"]],
|
||||
("type", "ECC", ["curve"]),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
@@ -27,6 +28,13 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
from ansible_collections.community.crypto.plugins.module_utils.io import load_file
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -58,42 +66,48 @@ class PrivateKeyError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.src_path = module.params["src_path"]
|
||||
self.src_content = module.params["src_content"]
|
||||
self.src_passphrase = module.params["src_passphrase"]
|
||||
self.format = module.params["format"]
|
||||
self.dest_passphrase = module.params["dest_passphrase"]
|
||||
self.src_path: str | None = module.params["src_path"]
|
||||
self.src_content: str | None = module.params["src_content"]
|
||||
self.src_passphrase: str | None = module.params["src_passphrase"]
|
||||
self.format: t.Literal["pkcs1", "pkcs8", "raw"] = module.params["format"]
|
||||
self.dest_passphrase: str | None = module.params["dest_passphrase"]
|
||||
|
||||
self.src_private_key = None
|
||||
self.src_private_key: PrivateKeyTypes | None = None
|
||||
if self.src_path is not None:
|
||||
self.src_private_key_bytes = load_file(self.src_path, module)
|
||||
else:
|
||||
if self.src_content is None:
|
||||
raise AssertionError("src_content is None")
|
||||
self.src_private_key_bytes = self.src_content.encode("utf-8")
|
||||
|
||||
self.dest_private_key = None
|
||||
self.dest_private_key_bytes = None
|
||||
self.dest_private_key: PrivateKeyTypes | None = None
|
||||
self.dest_private_key_bytes: bytes | None = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_private_key_data(self):
|
||||
def get_private_key_data(self) -> bytes:
|
||||
"""Return bytes for self.src_private_key in output format."""
|
||||
pass
|
||||
|
||||
def set_existing_destination(self, privatekey_bytes):
|
||||
def set_existing_destination(self, privatekey_bytes: bytes | None) -> None:
|
||||
"""Set existing private key bytes. None indicates that the key does not exist."""
|
||||
self.dest_private_key_bytes = privatekey_bytes
|
||||
|
||||
def has_existing_destination(self):
|
||||
def has_existing_destination(self) -> bool:
|
||||
"""Query whether an existing private key is/has been there."""
|
||||
return self.dest_private_key_bytes is not None
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_private_key(self, data, passphrase, current_hint=None):
|
||||
def _load_private_key(
|
||||
self,
|
||||
data: bytes,
|
||||
passphrase: str | None,
|
||||
current_hint: PrivateKeyTypes | None = None,
|
||||
) -> tuple[str, PrivateKeyTypes]:
|
||||
"""Check whether data can be loaded as a private key with the provided passphrase. Return tuple (type, private_key)."""
|
||||
pass
|
||||
|
||||
def needs_conversion(self):
|
||||
def needs_conversion(self) -> bool:
|
||||
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
|
||||
dummy, self.src_private_key = self._load_private_key(
|
||||
self.src_private_key_bytes, self.src_passphrase
|
||||
@@ -101,6 +115,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
|
||||
if not self.has_existing_destination():
|
||||
return True
|
||||
assert self.dest_private_key_bytes is not None
|
||||
|
||||
try:
|
||||
format, self.dest_private_key = self._load_private_key(
|
||||
@@ -115,18 +130,20 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
self.dest_private_key, self.src_private_key
|
||||
)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
return {}
|
||||
|
||||
|
||||
# Implementation with using cryptography
|
||||
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module)
|
||||
|
||||
def get_private_key_data(self):
|
||||
def get_private_key_data(self) -> bytes:
|
||||
"""Return bytes for self.src_private_key in output format"""
|
||||
if self.src_private_key is None:
|
||||
raise AssertionError("src_private_key not set")
|
||||
# Select export format and encoding
|
||||
try:
|
||||
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
|
||||
@@ -152,9 +169,9 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
|
||||
)
|
||||
|
||||
# Select key encryption
|
||||
encryption_algorithm = (
|
||||
cryptography.hazmat.primitives.serialization.NoEncryption()
|
||||
)
|
||||
encryption_algorithm: (
|
||||
cryptography.hazmat.primitives.serialization.KeySerializationEncryption
|
||||
) = cryptography.hazmat.primitives.serialization.NoEncryption()
|
||||
if self.dest_passphrase:
|
||||
encryption_algorithm = (
|
||||
cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
|
||||
@@ -179,7 +196,12 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
|
||||
exception=traceback.format_exc(),
|
||||
)
|
||||
|
||||
def _load_private_key(self, data, passphrase, current_hint=None):
|
||||
def _load_private_key(
|
||||
self,
|
||||
data: bytes,
|
||||
passphrase: str | None,
|
||||
current_hint: PrivateKeyTypes | None = None,
|
||||
) -> tuple[str, PrivateKeyTypes]:
|
||||
try:
|
||||
# Interpret bytes depending on format.
|
||||
format = identify_private_key_format(data)
|
||||
@@ -247,14 +269,14 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
|
||||
raise PrivateKeyError(e)
|
||||
|
||||
|
||||
def select_backend(module):
|
||||
def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return PrivateKeyConvertCryptographyBackend(module)
|
||||
|
||||
|
||||
def get_privatekey_argument_spec():
|
||||
def get_privatekey_argument_spec() -> ArgumentSpec:
|
||||
return ArgumentSpec(
|
||||
argument_spec=dict(
|
||||
src_path=dict(type="path"),
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -29,6 +30,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
from ....plugin_utils.filter_module import FilterModuleMock
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -40,38 +53,49 @@ except ImportError:
|
||||
SIGNATURE_TEST_DATA = b"1234"
|
||||
|
||||
|
||||
def _get_cryptography_private_key_info(key, need_private_key_data=False):
|
||||
def _get_cryptography_private_key_info(
|
||||
key: PrivateKeyTypes, need_private_key_data: bool = False
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
key_type, key_public_data = _get_cryptography_public_key_info(key.public_key())
|
||||
key_private_data = dict()
|
||||
key_private_data: dict[str, t.Any] = {}
|
||||
if need_private_key_data:
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
|
||||
private_numbers = key.private_numbers()
|
||||
key_private_data["p"] = private_numbers.p
|
||||
key_private_data["q"] = private_numbers.q
|
||||
key_private_data["exponent"] = private_numbers.d
|
||||
rsa_private_numbers = key.private_numbers()
|
||||
key_private_data["p"] = rsa_private_numbers.p
|
||||
key_private_data["q"] = rsa_private_numbers.q
|
||||
key_private_data["exponent"] = rsa_private_numbers.d
|
||||
elif isinstance(
|
||||
key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey
|
||||
):
|
||||
private_numbers = key.private_numbers()
|
||||
key_private_data["x"] = private_numbers.x
|
||||
dsa_private_numbers = key.private_numbers()
|
||||
key_private_data["x"] = dsa_private_numbers.x
|
||||
elif isinstance(
|
||||
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
|
||||
):
|
||||
private_numbers = key.private_numbers()
|
||||
key_private_data["multiplier"] = private_numbers.private_value
|
||||
ecc_private_numbers = key.private_numbers()
|
||||
key_private_data["multiplier"] = ecc_private_numbers.private_value
|
||||
return key_type, key_public_data, key_private_data
|
||||
|
||||
|
||||
def _check_dsa_consistency(key_public_data, key_private_data):
|
||||
def _check_dsa_consistency(
|
||||
key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
) -> bool | None:
|
||||
# Get parameters
|
||||
p = key_public_data.get("p")
|
||||
q = key_public_data.get("q")
|
||||
g = key_public_data.get("g")
|
||||
y = key_public_data.get("y")
|
||||
x = key_private_data.get("x")
|
||||
for v in (p, q, g, y, x):
|
||||
if v is None:
|
||||
return None
|
||||
p: int | None = key_public_data.get("p")
|
||||
if p is None:
|
||||
return None
|
||||
q: int | None = key_public_data.get("q")
|
||||
if q is None:
|
||||
return None
|
||||
g: int | None = key_public_data.get("g")
|
||||
if g is None:
|
||||
return None
|
||||
y: int | None = key_public_data.get("y")
|
||||
if y is None:
|
||||
return None
|
||||
x: int | None = key_private_data.get("x")
|
||||
if x is None:
|
||||
return None
|
||||
# Make sure that g is not 0, 1 or -1 in Z/pZ
|
||||
if g < 2 or g >= p - 1:
|
||||
return False
|
||||
@@ -94,13 +118,16 @@ def _check_dsa_consistency(key_public_data, key_private_data):
|
||||
|
||||
|
||||
def _is_cryptography_key_consistent(
|
||||
key, key_public_data, key_private_data, warn_func=None
|
||||
):
|
||||
key: PrivateKeyTypes,
|
||||
key_public_data: dict[str, t.Any],
|
||||
key_private_data: dict[str, t.Any],
|
||||
warn_func: t.Callable[[str], None] | None = None,
|
||||
) -> bool | None:
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
|
||||
# key._backend was removed in cryptography 42.0.0
|
||||
backend = getattr(key, "_backend", None)
|
||||
if backend is not None:
|
||||
return bool(backend._lib.RSA_check_key(key._rsa_cdata))
|
||||
return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
|
||||
result = _check_dsa_consistency(key_public_data, key_private_data)
|
||||
if result is not None:
|
||||
@@ -145,9 +172,9 @@ def _is_cryptography_key_consistent(
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey):
|
||||
has_simple_sign_function = True
|
||||
if has_simple_sign_function:
|
||||
signature = key.sign(SIGNATURE_TEST_DATA)
|
||||
signature = key.sign(SIGNATURE_TEST_DATA) # type: ignore
|
||||
try:
|
||||
key.public_key().verify(signature, SIGNATURE_TEST_DATA)
|
||||
key.public_key().verify(signature, SIGNATURE_TEST_DATA) # type: ignore
|
||||
return True
|
||||
except cryptography.exceptions.InvalidSignature:
|
||||
return False
|
||||
@@ -158,14 +185,14 @@ def _is_cryptography_key_consistent(
|
||||
|
||||
|
||||
class PrivateKeyConsistencyError(OpenSSLObjectError):
|
||||
def __init__(self, msg, result):
|
||||
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
|
||||
super(PrivateKeyConsistencyError, self).__init__(msg)
|
||||
self.error_message = msg
|
||||
self.result = result
|
||||
|
||||
|
||||
class PrivateKeyParseError(OpenSSLObjectError):
|
||||
def __init__(self, msg, result):
|
||||
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
|
||||
super(PrivateKeyParseError, self).__init__(msg)
|
||||
self.error_message = msg
|
||||
self.result = result
|
||||
@@ -174,13 +201,12 @@ class PrivateKeyParseError(OpenSSLObjectError):
|
||||
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
content,
|
||||
passphrase=None,
|
||||
return_private_key_data=False,
|
||||
check_consistency=False,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
passphrase: str | None = None,
|
||||
return_private_key_data: bool = False,
|
||||
check_consistency: bool = False,
|
||||
):
|
||||
# content must be a bytes string
|
||||
self.module = module
|
||||
self.content = content
|
||||
self.passphrase = passphrase
|
||||
@@ -188,22 +214,26 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
self.check_consistency = check_consistency
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key(self, binary):
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_key_info(self, need_private_key_data=False):
|
||||
def _get_key_info(
|
||||
self, need_private_key_data: bool = False
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _is_key_consistent(self, key_public_data, key_private_data):
|
||||
def _is_key_consistent(
|
||||
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
) -> bool | None:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint=False):
|
||||
result = dict(
|
||||
can_parse_key=False,
|
||||
key_is_consistent=None,
|
||||
)
|
||||
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {
|
||||
"can_parse_key": False,
|
||||
"key_is_consistent": None,
|
||||
}
|
||||
priv_key_detail = self.content
|
||||
try:
|
||||
self.key = load_privatekey(
|
||||
@@ -252,35 +282,39 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
|
||||
"""Validate the supplied private key, using the cryptography backend"""
|
||||
|
||||
def __init__(self, module, content, **kwargs):
|
||||
def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None:
|
||||
super(PrivateKeyInfoRetrievalCryptography, self).__init__(
|
||||
module, content, **kwargs
|
||||
)
|
||||
|
||||
def _get_public_key(self, binary):
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
return self.key.public_key().public_bytes(
|
||||
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def _get_key_info(self, need_private_key_data=False):
|
||||
def _get_key_info(
|
||||
self, need_private_key_data: bool = False
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return _get_cryptography_private_key_info(
|
||||
self.key, need_private_key_data=need_private_key_data
|
||||
)
|
||||
|
||||
def _is_key_consistent(self, key_public_data, key_private_data):
|
||||
def _is_key_consistent(
|
||||
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
) -> bool | None:
|
||||
return _is_cryptography_key_consistent(
|
||||
self.key, key_public_data, key_private_data, warn_func=self.module.warn
|
||||
)
|
||||
|
||||
|
||||
def get_privatekey_info(
|
||||
module,
|
||||
content,
|
||||
passphrase=None,
|
||||
return_private_key_data=False,
|
||||
prefer_one_fingerprint=False,
|
||||
):
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
passphrase: str | None = None,
|
||||
return_private_key_data: bool = False,
|
||||
prefer_one_fingerprint: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
info = PrivateKeyInfoRetrievalCryptography(
|
||||
module,
|
||||
content,
|
||||
@@ -291,12 +325,12 @@ def get_privatekey_info(
|
||||
|
||||
|
||||
def select_backend(
|
||||
module,
|
||||
content,
|
||||
passphrase=None,
|
||||
return_private_key_data=False,
|
||||
check_consistency=False,
|
||||
):
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
passphrase: str | None = None,
|
||||
return_private_key_data: bool = False,
|
||||
check_consistency: bool = False,
|
||||
) -> PrivateKeyInfoRetrieval:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
@@ -19,6 +20,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PublicKeyTypes,
|
||||
)
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
from ....plugin_utils.filter_module import FilterModuleMock
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -32,23 +45,25 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _get_cryptography_public_key_info(key):
|
||||
key_public_data = dict()
|
||||
def _get_cryptography_public_key_info(
|
||||
key: PublicKeyTypes,
|
||||
) -> tuple[str, dict[str, t.Any]]:
|
||||
key_public_data: dict[str, t.Any] = {}
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey):
|
||||
key_type = "RSA"
|
||||
public_numbers = key.public_numbers()
|
||||
rsa_public_numbers = key.public_numbers()
|
||||
key_public_data["size"] = key.key_size
|
||||
key_public_data["modulus"] = public_numbers.n
|
||||
key_public_data["exponent"] = public_numbers.e
|
||||
key_public_data["modulus"] = rsa_public_numbers.n
|
||||
key_public_data["exponent"] = rsa_public_numbers.e
|
||||
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey):
|
||||
key_type = "DSA"
|
||||
parameter_numbers = key.parameters().parameter_numbers()
|
||||
public_numbers = key.public_numbers()
|
||||
dsa_parameter_numbers = key.parameters().parameter_numbers()
|
||||
dsa_public_numbers = key.public_numbers()
|
||||
key_public_data["size"] = key.key_size
|
||||
key_public_data["p"] = parameter_numbers.p
|
||||
key_public_data["q"] = parameter_numbers.q
|
||||
key_public_data["g"] = parameter_numbers.g
|
||||
key_public_data["y"] = public_numbers.y
|
||||
key_public_data["p"] = dsa_parameter_numbers.p
|
||||
key_public_data["q"] = dsa_parameter_numbers.q
|
||||
key_public_data["g"] = dsa_parameter_numbers.g
|
||||
key_public_data["y"] = dsa_public_numbers.y
|
||||
elif isinstance(
|
||||
key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey
|
||||
):
|
||||
@@ -67,10 +82,10 @@ def _get_cryptography_public_key_info(key):
|
||||
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey
|
||||
):
|
||||
key_type = "ECC"
|
||||
public_numbers = key.public_numbers()
|
||||
ecc_public_numbers = key.public_numbers()
|
||||
key_public_data["curve"] = key.curve.name
|
||||
key_public_data["x"] = public_numbers.x
|
||||
key_public_data["y"] = public_numbers.y
|
||||
key_public_data["x"] = ecc_public_numbers.x
|
||||
key_public_data["y"] = ecc_public_numbers.y
|
||||
key_public_data["exponent_size"] = key.curve.key_size
|
||||
else:
|
||||
key_type = f"unknown ({type(key)})"
|
||||
@@ -78,29 +93,34 @@ def _get_cryptography_public_key_info(key):
|
||||
|
||||
|
||||
class PublicKeyParseError(OpenSSLObjectError):
|
||||
def __init__(self, msg, result):
|
||||
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
|
||||
super(PublicKeyParseError, self).__init__(msg)
|
||||
self.error_message = msg
|
||||
self.result = result
|
||||
|
||||
|
||||
class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module, content=None, key=None):
|
||||
def __init__(
|
||||
self,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
) -> None:
|
||||
# content must be a bytes string
|
||||
self.module = module
|
||||
self.content = content
|
||||
self.key = key
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key(self, binary):
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_key_info(self):
|
||||
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint=False):
|
||||
result = dict()
|
||||
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
if self.key is None:
|
||||
try:
|
||||
self.key = load_publickey(content=self.content)
|
||||
@@ -123,27 +143,45 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
|
||||
"""Validate the supplied public key, using the cryptography backend"""
|
||||
|
||||
def __init__(self, module, content=None, key=None):
|
||||
def __init__(
|
||||
self,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
) -> None:
|
||||
super(PublicKeyInfoRetrievalCryptography, self).__init__(
|
||||
module, content=content, key=key
|
||||
)
|
||||
|
||||
def _get_public_key(self, binary):
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
if self.key is None:
|
||||
raise AssertionError("key must be set")
|
||||
return self.key.public_bytes(
|
||||
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def _get_key_info(self):
|
||||
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
|
||||
if self.key is None:
|
||||
raise AssertionError("key must be set")
|
||||
return _get_cryptography_public_key_info(self.key)
|
||||
|
||||
|
||||
def get_publickey_info(module, content=None, key=None, prefer_one_fingerprint=False):
|
||||
def get_publickey_info(
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
prefer_one_fingerprint: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
|
||||
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
|
||||
|
||||
|
||||
def select_backend(module, content=None, key=None):
|
||||
def select_backend(
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
) -> PublicKeyInfoRetrieval:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
@@ -8,3 +8,6 @@ from __future__ import annotations
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import ( # noqa: F401, pylint: disable=unused-import
|
||||
parse_openssh_version,
|
||||
)
|
||||
|
||||
|
||||
# TODO: delete!
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
|
||||
PEM_START = "-----BEGIN "
|
||||
PEM_END_START = "-----END "
|
||||
@@ -12,7 +14,7 @@ PKCS8_PRIVATEKEY_NAMES = ("PRIVATE KEY", "ENCRYPTED PRIVATE KEY")
|
||||
PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY"
|
||||
|
||||
|
||||
def identify_pem_format(content, encoding="utf-8"):
|
||||
def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
|
||||
"""Given the contents of a binary file, tests whether this could be a PEM file."""
|
||||
try:
|
||||
first_pem = extract_first_pem(content.decode(encoding))
|
||||
@@ -30,7 +32,9 @@ def identify_pem_format(content, encoding="utf-8"):
|
||||
return False
|
||||
|
||||
|
||||
def identify_private_key_format(content, encoding="utf-8"):
|
||||
def identify_private_key_format(
|
||||
content: bytes, encoding: str = "utf-8"
|
||||
) -> t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]:
|
||||
"""Given the contents of a private key file, identifies its format."""
|
||||
# See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85
|
||||
# (PEM_read_bio_PrivateKey)
|
||||
@@ -59,12 +63,12 @@ def identify_private_key_format(content, encoding="utf-8"):
|
||||
return "raw"
|
||||
|
||||
|
||||
def split_pem_list(text, keep_inbetween=False):
|
||||
def split_pem_list(text: str, keep_inbetween: bool = False) -> list[str]:
|
||||
"""
|
||||
Split concatenated PEM objects into a list of strings, where each is one PEM object.
|
||||
"""
|
||||
result = []
|
||||
current = [] if keep_inbetween else None
|
||||
current: list[str] | None = [] if keep_inbetween else None
|
||||
for line in text.splitlines(True):
|
||||
if line.strip():
|
||||
if not keep_inbetween and line.startswith("-----BEGIN "):
|
||||
@@ -77,7 +81,7 @@ def split_pem_list(text, keep_inbetween=False):
|
||||
return result
|
||||
|
||||
|
||||
def extract_first_pem(text):
|
||||
def extract_first_pem(text: str) -> str | None:
|
||||
"""
|
||||
Given one PEM or multiple concatenated PEM objects, return only the first one, or None if there is none.
|
||||
"""
|
||||
@@ -87,7 +91,7 @@ def extract_first_pem(text):
|
||||
return all_pems[0]
|
||||
|
||||
|
||||
def _extract_type(line, start=PEM_START):
|
||||
def _extract_type(line: str, start: str = PEM_START) -> str | None:
|
||||
if not line.startswith(start):
|
||||
return None
|
||||
if not line.endswith(PEM_END):
|
||||
@@ -95,7 +99,7 @@ def _extract_type(line, start=PEM_START):
|
||||
return line[len(start) : -len(PEM_END)]
|
||||
|
||||
|
||||
def extract_pem(content, strict=False):
|
||||
def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
|
||||
lines = content.splitlines()
|
||||
if len(lines) < 3:
|
||||
raise ValueError(f"PEM must have at least 3 lines, have only {len(lines)}")
|
||||
@@ -117,5 +121,4 @@ def extract_pem(content, strict=False):
|
||||
raise ValueError(
|
||||
f"Last line has length {len(lines[-2])}, should be in (0, 64]"
|
||||
)
|
||||
content = lines[1:-1]
|
||||
return header_type, "".join(content)
|
||||
return header_type, "".join(lines[1:-1])
|
||||
|
||||
@@ -8,8 +8,13 @@ import abc
|
||||
import errno
|
||||
import hashlib
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
|
||||
is_potential_certificate_issuer_private_key,
|
||||
is_potential_certificate_private_key,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import (
|
||||
identify_pem_format,
|
||||
)
|
||||
@@ -34,6 +39,17 @@ except ImportError:
|
||||
from .basic import OpenSSLBadPassphraseError, OpenSSLObjectError
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
PublicKeyTypes,
|
||||
)
|
||||
|
||||
from .cryptography_support import CertificatePrivateKeyTypes
|
||||
|
||||
|
||||
# This list of preferred fingerprints is used when prefer_one=True is supplied to the
|
||||
# fingerprinting methods.
|
||||
PREFERRED_FINGERPRINTS = (
|
||||
@@ -48,18 +64,12 @@ PREFERRED_FINGERPRINTS = (
|
||||
)
|
||||
|
||||
|
||||
def get_fingerprint_of_bytes(source, prefer_one=False):
|
||||
def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[str, str]:
|
||||
"""Generate the fingerprint of the given bytes."""
|
||||
|
||||
fingerprint = {}
|
||||
|
||||
try:
|
||||
algorithms = hashlib.algorithms
|
||||
except AttributeError:
|
||||
try:
|
||||
algorithms = hashlib.algorithms_guaranteed
|
||||
except AttributeError:
|
||||
return None
|
||||
algorithms: t.Iterable[str] = hashlib.algorithms_guaranteed
|
||||
|
||||
if prefer_one:
|
||||
# Sort algorithms to have the ones in PREFERRED_FINGERPRINTS at the beginning
|
||||
@@ -97,7 +107,9 @@ def get_fingerprint_of_bytes(source, prefer_one=False):
|
||||
return fingerprint
|
||||
|
||||
|
||||
def get_fingerprint_of_privatekey(privatekey, prefer_one=False):
|
||||
def get_fingerprint_of_privatekey(
|
||||
privatekey: PrivateKeyTypes, prefer_one: bool = False
|
||||
) -> dict[str, str]:
|
||||
"""Generate the fingerprint of the public key."""
|
||||
|
||||
publickey = privatekey.public_key().public_bytes(
|
||||
@@ -107,11 +119,16 @@ def get_fingerprint_of_privatekey(privatekey, prefer_one=False):
|
||||
return get_fingerprint_of_bytes(publickey, prefer_one=prefer_one)
|
||||
|
||||
|
||||
def get_fingerprint(path, passphrase=None, content=None, prefer_one=False):
|
||||
def get_fingerprint(
|
||||
path: os.PathLike | str | None = None,
|
||||
passphrase: str | bytes | None = None,
|
||||
content: bytes | None = None,
|
||||
prefer_one: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Generate the fingerprint of the public key."""
|
||||
|
||||
privatekey = load_privatekey(
|
||||
path,
|
||||
path=path,
|
||||
passphrase=passphrase,
|
||||
content=content,
|
||||
check_passphrase=False,
|
||||
@@ -121,11 +138,11 @@ def get_fingerprint(path, passphrase=None, content=None, prefer_one=False):
|
||||
|
||||
|
||||
def load_privatekey(
|
||||
path,
|
||||
passphrase=None,
|
||||
check_passphrase=True,
|
||||
content=None,
|
||||
):
|
||||
path: os.PathLike | str | None = None,
|
||||
passphrase: str | bytes | None = None,
|
||||
check_passphrase: bool = True,
|
||||
content: bytes | None = None,
|
||||
) -> PrivateKeyTypes:
|
||||
"""Load the specified OpenSSL private key.
|
||||
|
||||
The content can also be specified via content; in that case,
|
||||
@@ -134,6 +151,8 @@ def load_privatekey(
|
||||
|
||||
try:
|
||||
if content is None:
|
||||
if path is None:
|
||||
raise OpenSSLObjectError("Must provide either path or content")
|
||||
with open(path, "rb") as b_priv_key_fh:
|
||||
priv_key_detail = b_priv_key_fh.read()
|
||||
else:
|
||||
@@ -154,7 +173,55 @@ def load_privatekey(
|
||||
raise OpenSSLBadPassphraseError("Wrong passphrase provided for private key")
|
||||
|
||||
|
||||
def load_publickey(path=None, content=None):
|
||||
def load_certificate_privatekey(
|
||||
*,
|
||||
path: os.PathLike | str | None = None,
|
||||
content: bytes | None = None,
|
||||
passphrase: str | bytes | None = None,
|
||||
check_passphrase: bool = True,
|
||||
) -> CertificatePrivateKeyTypes:
|
||||
"""
|
||||
Load the specified OpenSSL private key that can be used as a private key for certificates.
|
||||
"""
|
||||
private_key = load_privatekey(
|
||||
path=path,
|
||||
passphrase=passphrase,
|
||||
check_passphrase=check_passphrase,
|
||||
content=content,
|
||||
)
|
||||
if not is_potential_certificate_private_key(private_key):
|
||||
raise OpenSSLObjectError(
|
||||
f"Key of type {type(private_key)} not supported for certificates"
|
||||
)
|
||||
return private_key
|
||||
|
||||
|
||||
def load_certificate_issuer_privatekey(
|
||||
*,
|
||||
path: os.PathLike | str | None = None,
|
||||
content: bytes | None = None,
|
||||
passphrase: str | bytes | None = None,
|
||||
check_passphrase: bool = True,
|
||||
) -> CertificateIssuerPrivateKeyTypes:
|
||||
"""
|
||||
Load the specified OpenSSL private key that can be used for issuing certificates.
|
||||
"""
|
||||
private_key = load_privatekey(
|
||||
path=path,
|
||||
passphrase=passphrase,
|
||||
check_passphrase=check_passphrase,
|
||||
content=content,
|
||||
)
|
||||
if not is_potential_certificate_issuer_private_key(private_key):
|
||||
raise OpenSSLObjectError(
|
||||
f"Key of type {type(private_key)} not supported for issuing certificates"
|
||||
)
|
||||
return private_key
|
||||
|
||||
|
||||
def load_publickey(
|
||||
path: os.PathLike | str | None = None, content: bytes | None = None
|
||||
) -> PublicKeyTypes:
|
||||
if content is None:
|
||||
if path is None:
|
||||
raise OpenSSLObjectError("Must provide either path or content")
|
||||
@@ -170,11 +237,17 @@ def load_publickey(path=None, content=None):
|
||||
raise OpenSSLObjectError(f"Error while deserializing key: {e}")
|
||||
|
||||
|
||||
def load_certificate(path, content=None, der_support_enabled=False):
|
||||
def load_certificate(
|
||||
path: os.PathLike | str | None = None,
|
||||
content: bytes | None = None,
|
||||
der_support_enabled: bool = False,
|
||||
) -> x509.Certificate:
|
||||
"""Load the specified certificate."""
|
||||
|
||||
try:
|
||||
if content is None:
|
||||
if path is None:
|
||||
raise OpenSSLObjectError("Must provide either path or content")
|
||||
with open(path, "rb") as cert_fh:
|
||||
cert_content = cert_fh.read()
|
||||
else:
|
||||
@@ -193,10 +266,14 @@ def load_certificate(path, content=None, der_support_enabled=False):
|
||||
raise OpenSSLObjectError(f"Cannot parse DER certificate: {exc}")
|
||||
|
||||
|
||||
def load_certificate_request(path, content=None):
|
||||
def load_certificate_request(
|
||||
path: os.PathLike | str | None = None, content: bytes | None = None
|
||||
) -> x509.CertificateSigningRequest:
|
||||
"""Load the specified certificate signing request."""
|
||||
try:
|
||||
if content is None:
|
||||
if path is None:
|
||||
raise OpenSSLObjectError("Must provide either path or content")
|
||||
with open(path, "rb") as csr_fh:
|
||||
csr_content = csr_fh.read()
|
||||
else:
|
||||
@@ -209,45 +286,44 @@ def load_certificate_request(path, content=None):
|
||||
raise OpenSSLObjectError(exc)
|
||||
|
||||
|
||||
def parse_name_field(input_dict, name_field_name=None):
|
||||
def parse_name_field(
|
||||
input_dict: dict[str, list[str | bytes] | str | bytes],
|
||||
name_field_name: str | None = None,
|
||||
) -> list[tuple[str, str | bytes]]:
|
||||
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
|
||||
error_str = "{key}" if name_field_name is None else "{key} in {name}"
|
||||
|
||||
def error_str(key: str) -> str:
|
||||
if name_field_name is None:
|
||||
return f"{key}"
|
||||
return f"{key} in {name_field_name}"
|
||||
|
||||
result = []
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, list):
|
||||
for entry in value:
|
||||
if not isinstance(entry, (str, bytes)):
|
||||
raise TypeError(
|
||||
f"Values {error_str} must be strings".format(
|
||||
key=key, name=name_field_name
|
||||
)
|
||||
)
|
||||
raise TypeError(f"Values {error_str(key)} must be strings")
|
||||
if not entry:
|
||||
raise ValueError(
|
||||
f"Values for {error_str} must not be empty strings".format(
|
||||
key=key, name=name_field_name
|
||||
)
|
||||
f"Values for {error_str(key)} must not be empty strings"
|
||||
)
|
||||
result.append((key, entry))
|
||||
elif isinstance(value, (str, bytes)):
|
||||
if not value:
|
||||
raise ValueError(
|
||||
f"Value for {error_str} must not be an empty string".format(
|
||||
key=key, name=name_field_name
|
||||
)
|
||||
f"Value for {error_str(key)} must not be an empty string"
|
||||
)
|
||||
result.append((key, value))
|
||||
else:
|
||||
raise TypeError(
|
||||
(
|
||||
f"Value for {error_str} must be either a string or a list of strings"
|
||||
).format(key=key, name=name_field_name)
|
||||
f"Value for {error_str(key)} must be either a string or a list of strings"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def parse_ordered_name_field(input_list, name_field_name):
|
||||
def parse_ordered_name_field(
|
||||
input_list: list[dict[str, list[str | bytes] | str | bytes]], name_field_name: str
|
||||
) -> list[tuple[str, str | bytes]]:
|
||||
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
|
||||
|
||||
result = []
|
||||
@@ -265,24 +341,39 @@ def parse_ordered_name_field(input_list, name_field_name):
|
||||
return result
|
||||
|
||||
|
||||
def select_message_digest(digest_string):
|
||||
digest = None
|
||||
@t.overload
|
||||
def select_message_digest(
|
||||
digest_string: t.Literal["sha256", "sha384", "sha512", "sha1", "md5"],
|
||||
) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def select_message_digest(
|
||||
digest_string: str,
|
||||
) -> (
|
||||
hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None
|
||||
): ...
|
||||
|
||||
|
||||
def select_message_digest(
|
||||
digest_string: str,
|
||||
) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None:
|
||||
if digest_string == "sha256":
|
||||
digest = hashes.SHA256()
|
||||
elif digest_string == "sha384":
|
||||
digest = hashes.SHA384()
|
||||
elif digest_string == "sha512":
|
||||
digest = hashes.SHA512()
|
||||
elif digest_string == "sha1":
|
||||
digest = hashes.SHA1()
|
||||
elif digest_string == "md5":
|
||||
digest = hashes.MD5()
|
||||
return digest
|
||||
return hashes.SHA256()
|
||||
if digest_string == "sha384":
|
||||
return hashes.SHA384()
|
||||
if digest_string == "sha512":
|
||||
return hashes.SHA512()
|
||||
if digest_string == "sha1":
|
||||
return hashes.SHA1()
|
||||
if digest_string == "md5":
|
||||
return hashes.MD5()
|
||||
return None
|
||||
|
||||
|
||||
class OpenSSLObject(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, path, state, force, check_mode):
|
||||
def __init__(self, path: str, state: str, force: bool, check_mode: bool) -> None:
|
||||
self.path = path
|
||||
self.state = state
|
||||
self.force = force
|
||||
@@ -290,13 +381,13 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
|
||||
self.changed = False
|
||||
self.check_mode = check_mode
|
||||
|
||||
def check(self, module, perms_required=True):
|
||||
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
|
||||
def _check_state():
|
||||
def _check_state() -> bool:
|
||||
return os.path.exists(self.path)
|
||||
|
||||
def _check_perms(module):
|
||||
def _check_perms(module: AnsibleModule) -> bool:
|
||||
file_args = module.load_file_common_arguments(module.params)
|
||||
if module.check_file_absent_if_check_mode(file_args["path"]):
|
||||
return False
|
||||
@@ -308,18 +399,14 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
|
||||
return _check_state() and _check_perms(module)
|
||||
|
||||
@abc.abstractmethod
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate(self):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate the resource."""
|
||||
|
||||
pass
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
"""Remove the resource from the filesystem."""
|
||||
if self.check_mode:
|
||||
if os.path.exists(self.path):
|
||||
|
||||
@@ -11,6 +11,7 @@ Must be kept in sync with plugins/doc_fragments/cryptography_dep.py.
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import missing_required_lib
|
||||
from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
@@ -18,19 +19,29 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
)
|
||||
|
||||
|
||||
_CRYPTOGRAPHY_IMP_ERR = None
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from ..plugin_utils.action_module import AnsibleActionModule
|
||||
from ..plugin_utils.filter_module import FilterModuleMock
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
_CRYPTOGRAPHY_IMP_ERR: str | None = None
|
||||
_CRYPTOGRAPHY_FILE: str | None = None
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography import x509 # noqa: F401, pylint: disable=unused-import
|
||||
|
||||
_CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
|
||||
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
|
||||
_CRYPTOGRAPHY_FILE = cryptography.__file__
|
||||
except ImportError:
|
||||
_CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
|
||||
_CRYPTOGRAPHY_FOUND = False
|
||||
_CRYPTOGRAPHY_FILE = None
|
||||
CRYPTOGRAPHY_FOUND = False
|
||||
CRYPTOGRAPHY_VERSION = LooseVersion("0.0")
|
||||
else:
|
||||
_CRYPTOGRAPHY_FOUND = True
|
||||
CRYPTOGRAPHY_FOUND = True
|
||||
|
||||
|
||||
# Corresponds to the community.crypto.cryptography_dep.minimum doc fragment
|
||||
@@ -38,25 +49,27 @@ COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION = "3.3"
|
||||
|
||||
|
||||
def assert_required_cryptography_version(
|
||||
module,
|
||||
module: GeneralAnsibleModule,
|
||||
*,
|
||||
minimum_cryptography_version: str = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
|
||||
) -> None:
|
||||
if not _CRYPTOGRAPHY_FOUND:
|
||||
if not CRYPTOGRAPHY_FOUND:
|
||||
module.fail_json(
|
||||
msg=missing_required_lib(f"cryptography >= {minimum_cryptography_version}"),
|
||||
exception=_CRYPTOGRAPHY_IMP_ERR,
|
||||
)
|
||||
if _CRYPTOGRAPHY_VERSION < LooseVersion(minimum_cryptography_version):
|
||||
if CRYPTOGRAPHY_VERSION < LooseVersion(minimum_cryptography_version):
|
||||
module.fail_json(
|
||||
msg=(
|
||||
f"Cannot detect the required Python library cryptography (>= {minimum_cryptography_version})."
|
||||
f" Only found a too old version ({_CRYPTOGRAPHY_VERSION}) at {_CRYPTOGRAPHY_FILE}."
|
||||
f" Only found a too old version ({CRYPTOGRAPHY_VERSION}) at {_CRYPTOGRAPHY_FILE}."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION",
|
||||
"CRYPTOGRAPHY_FOUND",
|
||||
"CRYPTOGRAPHY_VERSION",
|
||||
"assert_required_cryptography_version",
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
import typing as t
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import urlencode
|
||||
|
||||
@@ -34,7 +35,7 @@ else:
|
||||
valid_file_format = re.compile(r".*(\.)(yml|yaml|json)$")
|
||||
|
||||
|
||||
def ecs_client_argument_spec():
|
||||
def ecs_client_argument_spec() -> dict[str, t.Any]:
|
||||
return dict(
|
||||
entrust_api_user=dict(type="str", required=True),
|
||||
entrust_api_key=dict(type="str", required=True, no_log=True),
|
||||
@@ -50,19 +51,17 @@ def ecs_client_argument_spec():
|
||||
class SessionConfigurationException(Exception):
|
||||
"""Raised if we cannot configure a session with the API"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RestOperationException(Exception):
|
||||
"""Encapsulate a REST API error"""
|
||||
|
||||
def __init__(self, error):
|
||||
def __init__(self, error: dict[str, t.Any]) -> None:
|
||||
self.status = to_native(error.get("status", None))
|
||||
self.errors = [to_native(err.get("message")) for err in error.get("errors", {})]
|
||||
self.message = " ".join(self.errors)
|
||||
|
||||
|
||||
def generate_docstring(operation_spec):
|
||||
def generate_docstring(operation_spec: dict[str, t.Any]) -> str:
|
||||
"""Generate a docstring for an operation defined in operation_spec (swagger)"""
|
||||
# Description of the operation
|
||||
docs = operation_spec.get("description", "No Description")
|
||||
|
||||
@@ -14,7 +14,9 @@ class GPGError(Exception):
|
||||
|
||||
class GPGRunner(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def run_command(self, command, check_rc=True, data=None):
|
||||
def run_command(
|
||||
self, command: list[str], check_rc: bool = True, data: bytes | None = None
|
||||
) -> tuple[int, str, str]:
|
||||
"""
|
||||
Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``.
|
||||
|
||||
@@ -29,7 +31,7 @@ class GPGRunner(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
|
||||
def get_fingerprint_from_stdout(stdout):
|
||||
def get_fingerprint_from_stdout(stdout: str) -> str:
|
||||
lines = stdout.splitlines(False)
|
||||
for line in lines:
|
||||
if line.startswith("fpr:"):
|
||||
@@ -42,7 +44,7 @@ def get_fingerprint_from_stdout(stdout):
|
||||
raise GPGError(f'Cannot extract fingerprint from stdout "{stdout}"')
|
||||
|
||||
|
||||
def get_fingerprint_from_file(gpg_runner, path):
|
||||
def get_fingerprint_from_file(gpg_runner: GPGRunner, path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
raise GPGError(f"{path} does not exist")
|
||||
stdout = gpg_runner.run_command(
|
||||
@@ -59,7 +61,7 @@ def get_fingerprint_from_file(gpg_runner, path):
|
||||
return get_fingerprint_from_stdout(stdout)
|
||||
|
||||
|
||||
def get_fingerprint_from_bytes(gpg_runner, content):
|
||||
def get_fingerprint_from_bytes(gpg_runner: GPGRunner, content: bytes) -> str:
|
||||
stdout = gpg_runner.run_command(
|
||||
[
|
||||
"--no-keyring",
|
||||
|
||||
@@ -7,9 +7,14 @@ from __future__ import annotations
|
||||
import errno
|
||||
import os
|
||||
import tempfile
|
||||
import typing as t
|
||||
|
||||
|
||||
def load_file(path, module=None):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def load_file(path: str | os.PathLike, module: AnsibleModule | None = None) -> bytes:
|
||||
"""
|
||||
Load the file as a bytes string.
|
||||
"""
|
||||
@@ -22,7 +27,11 @@ def load_file(path, module=None):
|
||||
module.fail_json(f"Error while loading {path} - {exc}")
|
||||
|
||||
|
||||
def load_file_if_exists(path, module=None, ignore_errors=False):
|
||||
def load_file_if_exists(
|
||||
path: str | os.PathLike,
|
||||
module: AnsibleModule | None = None,
|
||||
ignore_errors: bool = False,
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Load the file as a bytes string. If the file does not exist, ``None`` is returned.
|
||||
|
||||
@@ -49,7 +58,12 @@ def load_file_if_exists(path, module=None, ignore_errors=False):
|
||||
module.fail_json(f"Error while loading {path} - {exc}")
|
||||
|
||||
|
||||
def write_file(module, content, default_mode=None, path=None):
|
||||
def write_file(
|
||||
module: AnsibleModule,
|
||||
content: bytes,
|
||||
default_mode: str | int | None = None,
|
||||
path: str | os.PathLike | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Writes content into destination file as securely as possible.
|
||||
Uses file arguments from module.
|
||||
|
||||
@@ -8,14 +8,31 @@ import abc
|
||||
import os
|
||||
import stat
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import (
|
||||
parse_openssh_version,
|
||||
)
|
||||
|
||||
|
||||
def restore_on_failure(f):
|
||||
def backup_and_restore(module, path, *args, **kwargs):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ..certificate import OpensshCertificateTimeParameters
|
||||
|
||||
Param = t.ParamSpec("Param")
|
||||
|
||||
|
||||
def restore_on_failure(
|
||||
f: t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None],
|
||||
) -> t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None]:
|
||||
def backup_and_restore(
|
||||
module: AnsibleModule, path: str | os.PathLike, *args, **kwargs
|
||||
) -> None:
|
||||
backup_file = module.backup_local(path) if os.path.exists(path) else None
|
||||
|
||||
try:
|
||||
@@ -31,12 +48,31 @@ def restore_on_failure(f):
|
||||
|
||||
|
||||
@restore_on_failure
|
||||
def safe_atomic_move(module, path, destination):
|
||||
def safe_atomic_move(
|
||||
module: AnsibleModule, path: str | os.PathLike, destination: str | os.PathLike
|
||||
) -> None:
|
||||
module.atomic_move(os.path.abspath(path), os.path.abspath(destination))
|
||||
|
||||
|
||||
def _restore_all_on_failure(f):
|
||||
def backup_and_restore(self, sources_and_destinations, *args, **kwargs):
|
||||
def _restore_all_on_failure(
|
||||
f: t.Callable[
|
||||
t.Concatenate[
|
||||
OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param
|
||||
],
|
||||
None,
|
||||
],
|
||||
) -> t.Callable[
|
||||
t.Concatenate[
|
||||
OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param
|
||||
],
|
||||
None,
|
||||
]:
|
||||
def backup_and_restore(
|
||||
self: OpensshModule,
|
||||
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
backups = [
|
||||
(d, self.module.backup_local(d))
|
||||
for s, d in sources_and_destinations
|
||||
@@ -59,13 +95,13 @@ def _restore_all_on_failure(f):
|
||||
|
||||
|
||||
class OpensshModule(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
|
||||
self.changed = False
|
||||
self.check_mode = self.module.check_mode
|
||||
self.changed: bool = False
|
||||
self.check_mode: bool = self.module.check_mode
|
||||
|
||||
def execute(self):
|
||||
def execute(self) -> t.NoReturn:
|
||||
try:
|
||||
self._execute()
|
||||
except Exception as e:
|
||||
@@ -77,11 +113,11 @@ class OpensshModule(metaclass=abc.ABCMeta):
|
||||
self.module.exit_json(**self.result)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _execute(self):
|
||||
def _execute(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
def result(self) -> dict[str, t.Any]:
|
||||
result = self._result
|
||||
|
||||
result["changed"] = self.changed
|
||||
@@ -93,31 +129,31 @@ class OpensshModule(metaclass=abc.ABCMeta):
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _result(self):
|
||||
def _result(self) -> dict[str, t.Any]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def diff(self):
|
||||
def diff(self) -> dict[str, t.Any]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def skip_if_check_mode(f):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
def skip_if_check_mode(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
|
||||
def wrapper(self, *args, **kwargs) -> None:
|
||||
if not self.check_mode:
|
||||
f(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return wrapper # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def trigger_change(f):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
def trigger_change(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
|
||||
def wrapper(self, *args, **kwargs) -> None:
|
||||
f(self, *args, **kwargs)
|
||||
self.changed = True
|
||||
|
||||
return wrapper
|
||||
return wrapper # type: ignore
|
||||
|
||||
def _check_if_base_dir(self, path):
|
||||
def _check_if_base_dir(self, path: str | os.PathLike) -> None:
|
||||
base_dir = os.path.dirname(path) or "."
|
||||
if not os.path.isdir(base_dir):
|
||||
self.module.fail_json(
|
||||
@@ -125,16 +161,19 @@ class OpensshModule(metaclass=abc.ABCMeta):
|
||||
msg=f"The directory {base_dir} does not exist or the file is not a directory",
|
||||
)
|
||||
|
||||
def _get_ssh_version(self):
|
||||
def _get_ssh_version(self) -> str | None:
|
||||
ssh_bin = self.module.get_bin_path("ssh")
|
||||
if not ssh_bin:
|
||||
return ""
|
||||
return None
|
||||
return parse_openssh_version(
|
||||
self.module.run_command([ssh_bin, "-V", "-q"], check_rc=True)[2].strip()
|
||||
)
|
||||
|
||||
@_restore_all_on_failure
|
||||
def _safe_secure_move(self, sources_and_destinations):
|
||||
def _safe_secure_move(
|
||||
self,
|
||||
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
|
||||
) -> None:
|
||||
"""Moves a list of files from 'source' to 'destination' and restores 'destination' from backup upon failure.
|
||||
If 'destination' does not already exist, then 'source' permissions are preserved to prevent
|
||||
exposing protected data ('atomic_move' uses the 'destination' base directory mask for
|
||||
@@ -148,7 +187,7 @@ class OpensshModule(metaclass=abc.ABCMeta):
|
||||
else:
|
||||
self.module.preserved_copy(source, destination)
|
||||
|
||||
def _update_permissions(self, path):
|
||||
def _update_permissions(self, path: str | os.PathLike) -> None:
|
||||
file_args = self.module.load_file_common_arguments(self.module.params)
|
||||
file_args["path"] = path
|
||||
|
||||
@@ -161,25 +200,25 @@ class OpensshModule(metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class KeygenCommand:
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self._bin_path = module.get_bin_path("ssh-keygen", True)
|
||||
self._run_command = module.run_command
|
||||
|
||||
def generate_certificate(
|
||||
self,
|
||||
certificate_path,
|
||||
identifier,
|
||||
options,
|
||||
pkcs11_provider,
|
||||
principals,
|
||||
serial_number,
|
||||
signature_algorithm,
|
||||
signing_key_path,
|
||||
type,
|
||||
time_parameters,
|
||||
use_agent,
|
||||
certificate_path: str,
|
||||
identifier: str,
|
||||
options: list[str] | None,
|
||||
pkcs11_provider: str | None,
|
||||
principals: list[str] | None,
|
||||
serial_number: int | None,
|
||||
signature_algorithm: str | None,
|
||||
signing_key_path: str,
|
||||
type: t.Literal["host", "user"] | None,
|
||||
time_parameters: OpensshCertificateTimeParameters,
|
||||
use_agent: bool,
|
||||
**kwargs,
|
||||
):
|
||||
) -> tuple[int, str, str]:
|
||||
args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier]
|
||||
|
||||
if options:
|
||||
@@ -203,7 +242,9 @@ class KeygenCommand:
|
||||
|
||||
return self._run_command(args, **kwargs)
|
||||
|
||||
def generate_keypair(self, private_key_path, size, type, comment, **kwargs):
|
||||
def generate_keypair(
|
||||
self, private_key_path: str, size: int, type: str, comment: str | None, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
args = [
|
||||
self._bin_path,
|
||||
"-q",
|
||||
@@ -224,32 +265,40 @@ class KeygenCommand:
|
||||
|
||||
return self._run_command(args, data=data, **kwargs)
|
||||
|
||||
def get_certificate_info(self, certificate_path, **kwargs):
|
||||
def get_certificate_info(
|
||||
self, certificate_path: str, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-L", "-f", certificate_path], **kwargs
|
||||
)
|
||||
|
||||
def get_matching_public_key(self, private_key_path, **kwargs):
|
||||
def get_matching_public_key(
|
||||
self, private_key_path: str, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
|
||||
)
|
||||
|
||||
def get_private_key(self, private_key_path, **kwargs):
|
||||
def get_private_key(self, private_key_path: str, **kwargs) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-l", "-f", private_key_path], **kwargs
|
||||
)
|
||||
|
||||
def update_comment(
|
||||
self, private_key_path, comment, force_new_format=True, **kwargs
|
||||
):
|
||||
self,
|
||||
private_key_path: str,
|
||||
comment: str,
|
||||
force_new_format: bool = True,
|
||||
**kwargs,
|
||||
) -> tuple[int, str, str]:
|
||||
if os.path.exists(private_key_path) and not os.access(
|
||||
private_key_path, os.W_OK
|
||||
):
|
||||
try:
|
||||
os.chmod(private_key_path, stat.S_IWUSR + stat.S_IRUSR)
|
||||
except (IOError, OSError) as e:
|
||||
raise e(
|
||||
f"The private key at {private_key_path} is not writeable preventing a comment update"
|
||||
raise ValueError(
|
||||
f"The private key at {private_key_path} is not writeable preventing a comment update ({e})"
|
||||
)
|
||||
|
||||
command = [self._bin_path, "-q"]
|
||||
@@ -259,31 +308,36 @@ class KeygenCommand:
|
||||
return self._run_command(command, **kwargs)
|
||||
|
||||
|
||||
_PrivateKey = t.TypeVar("_PrivateKey", bound="PrivateKey")
|
||||
|
||||
|
||||
class PrivateKey:
|
||||
def __init__(self, size, key_type, fingerprint, format=""):
|
||||
def __init__(
|
||||
self, size: int, key_type: str, fingerprint: str, format: str = ""
|
||||
) -> None:
|
||||
self._size = size
|
||||
self._type = key_type
|
||||
self._fingerprint = fingerprint
|
||||
self._format = format
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
def size(self) -> int:
|
||||
return self._size
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
def type(self) -> str:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def fingerprint(self):
|
||||
def fingerprint(self) -> str:
|
||||
return self._fingerprint
|
||||
|
||||
@property
|
||||
def format(self):
|
||||
def format(self) -> str:
|
||||
return self._format
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, string):
|
||||
def from_string(cls: t.Type[_PrivateKey], string: str) -> _PrivateKey:
|
||||
properties = string.split()
|
||||
|
||||
return cls(
|
||||
@@ -292,7 +346,7 @@ class PrivateKey:
|
||||
fingerprint=properties[1],
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, t.Any]:
|
||||
return {
|
||||
"size": self._size,
|
||||
"type": self._type,
|
||||
@@ -301,13 +355,16 @@ class PrivateKey:
|
||||
}
|
||||
|
||||
|
||||
_PublicKey = t.TypeVar("_PublicKey", bound="PublicKey")
|
||||
|
||||
|
||||
class PublicKey:
|
||||
def __init__(self, type_string, data, comment):
|
||||
def __init__(self, type_string: str, data: str, comment: str | None) -> None:
|
||||
self._type_string = type_string
|
||||
self._data = data
|
||||
self._comment = comment
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
@@ -323,30 +380,30 @@ class PublicKey:
|
||||
]
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self._type_string} {self._data}"
|
||||
|
||||
@property
|
||||
def comment(self):
|
||||
def comment(self) -> str | None:
|
||||
return self._comment
|
||||
|
||||
@comment.setter
|
||||
def comment(self, value):
|
||||
def comment(self, value: str | None) -> None:
|
||||
self._comment = value
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> str:
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def type_string(self):
|
||||
def type_string(self) -> str:
|
||||
return self._type_string
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, string):
|
||||
def from_string(cls: t.Type[_PublicKey], string: str) -> _PublicKey:
|
||||
properties = string.strip("\n").split(" ", 2)
|
||||
|
||||
return cls(
|
||||
@@ -356,7 +413,7 @@ class PublicKey:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
def load(cls: t.Type[_PublicKey], path: str | os.PathLike) -> _PublicKey | None:
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
properties = f.read().strip(" \n").split(" ", 2)
|
||||
@@ -372,14 +429,16 @@ class PublicKey:
|
||||
comment="" if len(properties) <= 2 else properties[2],
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, t.Any]:
|
||||
return {
|
||||
"comment": self._comment,
|
||||
"public_key": self._data,
|
||||
}
|
||||
|
||||
|
||||
def parse_private_key_format(path):
|
||||
def parse_private_key_format(
|
||||
path: str | os.PathLike,
|
||||
) -> t.Literal["SSH", "PKCS8", "PKCS1", ""]:
|
||||
with open(path, "r") as file:
|
||||
header = file.readline().strip()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import missing_required_lib
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
@@ -39,31 +40,43 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
|
||||
class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackend, self).__init__(module)
|
||||
|
||||
self.comment = self.module.params["comment"]
|
||||
self.private_key_path = self.module.params["path"]
|
||||
self.comment: str | None = self.module.params["comment"]
|
||||
self.private_key_path: str = self.module.params["path"]
|
||||
self.public_key_path = self.private_key_path + ".pub"
|
||||
self.regenerate = (
|
||||
self.regenerate: t.Literal[
|
||||
"never", "fail", "partial_idempotence", "full_idempotence", "always"
|
||||
] = (
|
||||
self.module.params["regenerate"]
|
||||
if not self.module.params["force"]
|
||||
else "always"
|
||||
)
|
||||
self.state = self.module.params["state"]
|
||||
self.type = self.module.params["type"]
|
||||
self.state: t.Literal["present", "absent"] = self.module.params["state"]
|
||||
self.type: t.Literal["rsa", "dsa", "rsa1", "ecdsa", "ed25519"] = (
|
||||
self.module.params["type"]
|
||||
)
|
||||
|
||||
self.size = self._get_size(self.module.params["size"])
|
||||
self.size: int = self._get_size(self.module.params["size"])
|
||||
self._validate_path()
|
||||
|
||||
self.original_private_key = None
|
||||
self.original_public_key = None
|
||||
self.private_key = None
|
||||
self.public_key = None
|
||||
self.original_private_key: PrivateKey | None = None
|
||||
self.original_public_key: PublicKey | None = None
|
||||
self.private_key: PrivateKey | None = None
|
||||
self.public_key: PublicKey | None = None
|
||||
|
||||
def _get_size(self, size):
|
||||
def _get_size(self, size: int | None) -> int:
|
||||
if self.type in ("rsa", "rsa1"):
|
||||
result = 4096 if size is None else size
|
||||
if result < 1024:
|
||||
@@ -96,7 +109,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
return result
|
||||
|
||||
def _validate_path(self):
|
||||
def _validate_path(self) -> None:
|
||||
self._check_if_base_dir(self.private_key_path)
|
||||
|
||||
if os.path.isdir(self.private_key_path):
|
||||
@@ -104,7 +117,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
msg=f"{self.private_key_path} is a directory. Please specify a path to a file."
|
||||
)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self) -> None:
|
||||
self.original_private_key = self._load_private_key()
|
||||
self.original_public_key = self._load_public_key()
|
||||
|
||||
@@ -125,7 +138,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
if self._should_remove():
|
||||
self._remove()
|
||||
|
||||
def _load_private_key(self):
|
||||
def _load_private_key(self) -> PrivateKey | None:
|
||||
result = None
|
||||
if self._private_key_exists():
|
||||
try:
|
||||
@@ -135,14 +148,14 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
return result
|
||||
|
||||
def _private_key_exists(self):
|
||||
def _private_key_exists(self) -> bool:
|
||||
return os.path.exists(self.private_key_path)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_private_key(self):
|
||||
def _get_private_key(self) -> PrivateKey:
|
||||
pass
|
||||
|
||||
def _load_public_key(self):
|
||||
def _load_public_key(self) -> PublicKey | None:
|
||||
result = None
|
||||
if self._public_key_exists():
|
||||
try:
|
||||
@@ -151,10 +164,10 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
pass
|
||||
return result
|
||||
|
||||
def _public_key_exists(self):
|
||||
def _public_key_exists(self) -> bool:
|
||||
return os.path.exists(self.public_key_path)
|
||||
|
||||
def _validate_key_load(self):
|
||||
def _validate_key_load(self) -> None:
|
||||
if (
|
||||
self._private_key_exists()
|
||||
and self.regenerate in ("never", "fail", "partial_idempotence")
|
||||
@@ -167,10 +180,10 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _private_key_readable(self):
|
||||
def _private_key_readable(self) -> bool:
|
||||
pass
|
||||
|
||||
def _should_generate(self):
|
||||
def _should_generate(self) -> bool:
|
||||
if self.original_private_key is None:
|
||||
return True
|
||||
elif self.regenerate == "never":
|
||||
@@ -188,7 +201,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _private_key_valid(self):
|
||||
def _private_key_valid(self) -> bool:
|
||||
if self.original_private_key is None:
|
||||
return False
|
||||
|
||||
@@ -196,17 +209,17 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
[
|
||||
self.size == self.original_private_key.size,
|
||||
self.type == self.original_private_key.type,
|
||||
self._private_key_valid_backend(),
|
||||
self._private_key_valid_backend(self.original_private_key),
|
||||
]
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _private_key_valid_backend(self):
|
||||
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
|
||||
pass
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _generate(self):
|
||||
def _generate(self) -> None:
|
||||
temp_private_key, temp_public_key = self._generate_temp_keypair()
|
||||
|
||||
try:
|
||||
@@ -219,7 +232,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
except OSError as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
def _generate_temp_keypair(self):
|
||||
def _generate_temp_keypair(self) -> tuple[str, str]:
|
||||
temp_private_key = os.path.join(
|
||||
self.module.tmpdir, os.path.basename(self.private_key_path)
|
||||
)
|
||||
@@ -236,25 +249,26 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
return temp_private_key, temp_public_key
|
||||
|
||||
@abc.abstractmethod
|
||||
def _generate_keypair(self, private_key_path):
|
||||
def _generate_keypair(self, private_key_path: str) -> None:
|
||||
pass
|
||||
|
||||
def _public_key_valid(self):
|
||||
def _public_key_valid(self) -> bool:
|
||||
if self.original_public_key is None:
|
||||
return False
|
||||
|
||||
valid_public_key = self._get_public_key()
|
||||
valid_public_key.comment = self.comment
|
||||
if valid_public_key:
|
||||
valid_public_key.comment = self.comment
|
||||
|
||||
return self.original_public_key == valid_public_key
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key(self):
|
||||
def _get_public_key(self) -> PublicKey | t.Literal[""]:
|
||||
pass
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _restore_public_key(self):
|
||||
def _restore_public_key(self) -> None:
|
||||
try:
|
||||
temp_public_key = self._create_temp_public_key(
|
||||
str(self._get_public_key()) + "\n"
|
||||
@@ -269,7 +283,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
if self.comment:
|
||||
self._update_comment()
|
||||
|
||||
def _create_temp_public_key(self, content):
|
||||
def _create_temp_public_key(self, content: str | bytes) -> str:
|
||||
temp_public_key = os.path.join(
|
||||
self.module.tmpdir, os.path.basename(self.public_key_path)
|
||||
)
|
||||
@@ -290,15 +304,15 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
return temp_public_key
|
||||
|
||||
@abc.abstractmethod
|
||||
def _update_comment(self):
|
||||
def _update_comment(self) -> None:
|
||||
pass
|
||||
|
||||
def _should_remove(self):
|
||||
def _should_remove(self) -> bool:
|
||||
return self._private_key_exists() or self._public_key_exists()
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _remove(self):
|
||||
def _remove(self) -> None:
|
||||
try:
|
||||
if self._private_key_exists():
|
||||
os.remove(self.private_key_path)
|
||||
@@ -308,7 +322,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
@property
|
||||
def _result(self):
|
||||
def _result(self) -> dict[str, t.Any]:
|
||||
private_key = self.private_key or self.original_private_key
|
||||
public_key = self.public_key or self.original_public_key
|
||||
|
||||
@@ -322,7 +336,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
}
|
||||
|
||||
@property
|
||||
def diff(self):
|
||||
def diff(self) -> dict[str, t.Any]:
|
||||
before = (
|
||||
self.original_private_key.to_dict() if self.original_private_key else {}
|
||||
)
|
||||
@@ -340,7 +354,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class KeypairBackendOpensshBin(KeypairBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendOpensshBin, self).__init__(module)
|
||||
|
||||
if self.module.params["private_key_format"] != "auto":
|
||||
@@ -350,12 +364,12 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
|
||||
self.ssh_keygen = KeygenCommand(self.module)
|
||||
|
||||
def _generate_keypair(self, private_key_path):
|
||||
def _generate_keypair(self, private_key_path: str) -> None:
|
||||
self.ssh_keygen.generate_keypair(
|
||||
private_key_path, self.size, self.type, self.comment, check_rc=True
|
||||
)
|
||||
|
||||
def _get_private_key(self):
|
||||
def _get_private_key(self) -> PrivateKey:
|
||||
rc, private_key_content, err = self.ssh_keygen.get_private_key(
|
||||
self.private_key_path, check_rc=False
|
||||
)
|
||||
@@ -363,13 +377,13 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
raise ValueError(err)
|
||||
return PrivateKey.from_string(private_key_content)
|
||||
|
||||
def _get_public_key(self):
|
||||
def _get_public_key(self) -> PublicKey | t.Literal[""]:
|
||||
public_key_content = self.ssh_keygen.get_matching_public_key(
|
||||
self.private_key_path, check_rc=True
|
||||
)[1]
|
||||
return PublicKey.from_string(public_key_content)
|
||||
|
||||
def _private_key_readable(self):
|
||||
def _private_key_readable(self) -> bool:
|
||||
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
|
||||
self.private_key_path, check_rc=False
|
||||
)
|
||||
@@ -383,7 +397,7 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
)
|
||||
)
|
||||
|
||||
def _update_comment(self):
|
||||
def _update_comment(self) -> None:
|
||||
try:
|
||||
ssh_version = self._get_ssh_version() or "7.8"
|
||||
force_new_format = (
|
||||
@@ -391,19 +405,19 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
)
|
||||
self.ssh_keygen.update_comment(
|
||||
self.private_key_path,
|
||||
self.comment,
|
||||
self.comment or "",
|
||||
force_new_format=force_new_format,
|
||||
check_rc=True,
|
||||
)
|
||||
except (IOError, OSError) as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
def _private_key_valid_backend(self):
|
||||
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class KeypairBackendCryptography(KeypairBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendCryptography, self).__init__(module)
|
||||
|
||||
if self.type == "rsa1":
|
||||
@@ -416,12 +430,15 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
if module.params["passphrase"]
|
||||
else None
|
||||
)
|
||||
self.private_key_format = self._get_key_format(
|
||||
module.params["private_key_format"]
|
||||
)
|
||||
key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"] = module.params[
|
||||
"private_key_format"
|
||||
]
|
||||
self.private_key_format = self._get_key_format(key_format)
|
||||
|
||||
def _get_key_format(self, key_format):
|
||||
result = "SSH"
|
||||
def _get_key_format(
|
||||
self, key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"]
|
||||
) -> t.Literal["SSH", "PKCS1", "PKCS8"]:
|
||||
result: t.Literal["SSH", "PKCS1", "PKCS8"] = "SSH"
|
||||
|
||||
if key_format == "auto":
|
||||
# Default to OpenSSH 7.8 compatibility when OpenSSH is not installed
|
||||
@@ -435,11 +452,12 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
# but still defaulted to PKCS1 format with the exception of ed25519 keys
|
||||
result = "PKCS1"
|
||||
else:
|
||||
result = key_format.upper()
|
||||
result = key_format.upper() # type: ignore
|
||||
|
||||
return result
|
||||
|
||||
def _generate_keypair(self, private_key_path):
|
||||
def _generate_keypair(self, private_key_path: str) -> None:
|
||||
assert self.type != "rsa1"
|
||||
keypair = OpensshKeypair.generate(
|
||||
keytype=self.type,
|
||||
size=self.size,
|
||||
@@ -455,7 +473,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
public_key_path = private_key_path + ".pub"
|
||||
secure_write(public_key_path, 0o644, keypair.public_key)
|
||||
|
||||
def _get_private_key(self):
|
||||
def _get_private_key(self) -> PrivateKey:
|
||||
keypair = OpensshKeypair.load(
|
||||
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
|
||||
)
|
||||
@@ -467,7 +485,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
format=parse_private_key_format(self.private_key_path),
|
||||
)
|
||||
|
||||
def _get_public_key(self):
|
||||
def _get_public_key(self) -> PublicKey | t.Literal[""]:
|
||||
try:
|
||||
keypair = OpensshKeypair.load(
|
||||
path=self.private_key_path,
|
||||
@@ -480,7 +498,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
|
||||
return PublicKey.from_string(to_text(keypair.public_key))
|
||||
|
||||
def _private_key_readable(self):
|
||||
def _private_key_readable(self) -> bool:
|
||||
try:
|
||||
OpensshKeypair.load(
|
||||
path=self.private_key_path,
|
||||
@@ -504,7 +522,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
|
||||
return True
|
||||
|
||||
def _update_comment(self):
|
||||
def _update_comment(self) -> None:
|
||||
keypair = OpensshKeypair.load(
|
||||
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
|
||||
)
|
||||
@@ -519,16 +537,18 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
except (IOError, OSError) as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
def _private_key_valid_backend(self):
|
||||
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
|
||||
# avoids breaking behavior and prevents
|
||||
# automatic conversions with OpenSSH upgrades
|
||||
if self.module.params["private_key_format"] == "auto":
|
||||
return True
|
||||
|
||||
return self.private_key_format == self.original_private_key.format
|
||||
return self.private_key_format == original_private_key.format
|
||||
|
||||
|
||||
def select_backend(module, backend):
|
||||
def select_backend(
|
||||
module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
|
||||
) -> KeypairBackend:
|
||||
can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion(
|
||||
CRYPTOGRAPHY_VERSION
|
||||
) >= LooseVersion(COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION)
|
||||
|
||||
@@ -8,6 +8,7 @@ import abc
|
||||
import binascii
|
||||
import datetime as _datetime
|
||||
import os
|
||||
import typing as t
|
||||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
@@ -26,6 +27,16 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .cryptography import KeyType
|
||||
|
||||
DateFormat = t.Literal["human_readable", "openssh", "timestamp"]
|
||||
DateFormatStr = t.Literal["human_readable", "openssh"]
|
||||
DateFormatInt = t.Literal["timestamp"]
|
||||
else:
|
||||
KeyType = None
|
||||
|
||||
|
||||
# Protocol References
|
||||
# -------------------
|
||||
# https://datatracker.ietf.org/doc/html/rfc4251
|
||||
@@ -44,7 +55,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
_USER_TYPE = 1
|
||||
_HOST_TYPE = 2
|
||||
|
||||
_SSH_TYPE_STRINGS = {
|
||||
_SSH_TYPE_STRINGS: dict[KeyType | str, bytes] = {
|
||||
"rsa": b"ssh-rsa",
|
||||
"dsa": b"ssh-dss",
|
||||
"ecdsa-nistp256": b"ecdsa-sha2-nistp256",
|
||||
@@ -94,16 +105,18 @@ _EXTENSIONS = (
|
||||
|
||||
|
||||
class OpensshCertificateTimeParameters:
|
||||
def __init__(self, valid_from, valid_to):
|
||||
def __init__(
|
||||
self, valid_from: str | bytes | int, valid_to: str | bytes | int
|
||||
) -> None:
|
||||
self._valid_from = self.to_datetime(valid_from)
|
||||
self._valid_to = self.to_datetime(valid_to)
|
||||
|
||||
if self._valid_from > self._valid_to:
|
||||
raise ValueError(
|
||||
f"Valid from: {valid_from} must not be greater than Valid to: {valid_to}"
|
||||
f"Valid from: {valid_from!r} must not be greater than Valid to: {valid_to!r}"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
else:
|
||||
@@ -112,55 +125,83 @@ class OpensshCertificateTimeParameters:
|
||||
and self._valid_to == other._valid_to
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self == other
|
||||
|
||||
@property
|
||||
def validity_string(self):
|
||||
def validity_string(self) -> str:
|
||||
if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER):
|
||||
return f"{self.valid_from(date_format='openssh')}:{self.valid_to(date_format='openssh')}"
|
||||
return ""
|
||||
|
||||
def valid_from(self, date_format):
|
||||
@t.overload
|
||||
def valid_from(self, date_format: DateFormatStr) -> str: ...
|
||||
|
||||
@t.overload
|
||||
def valid_from(self, date_format: DateFormatInt) -> int: ...
|
||||
|
||||
@t.overload
|
||||
def valid_from(self, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
def valid_from(self, date_format: DateFormat) -> str | int:
|
||||
return self.format_datetime(self._valid_from, date_format)
|
||||
|
||||
def valid_to(self, date_format):
|
||||
@t.overload
|
||||
def valid_to(self, date_format: DateFormatStr) -> str: ...
|
||||
|
||||
@t.overload
|
||||
def valid_to(self, date_format: DateFormatInt) -> int: ...
|
||||
|
||||
@t.overload
|
||||
def valid_to(self, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
def valid_to(self, date_format: DateFormat) -> str | int:
|
||||
return self.format_datetime(self._valid_to, date_format)
|
||||
|
||||
def within_range(self, valid_at):
|
||||
def within_range(self, valid_at: str | bytes | int | None) -> bool:
|
||||
if valid_at is not None:
|
||||
valid_at_datetime = self.to_datetime(valid_at)
|
||||
return self._valid_from <= valid_at_datetime <= self._valid_to
|
||||
return True
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt, date_format):
|
||||
def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ...
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ...
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int:
|
||||
if date_format in ("human_readable", "openssh"):
|
||||
if dt == _ALWAYS:
|
||||
result = "always"
|
||||
elif dt == _FOREVER:
|
||||
result = "forever"
|
||||
return "always"
|
||||
if dt == _FOREVER:
|
||||
return "forever"
|
||||
else:
|
||||
result = (
|
||||
return (
|
||||
dt.isoformat().replace("+00:00", "")
|
||||
if date_format == "human_readable"
|
||||
else dt.strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
elif date_format == "timestamp":
|
||||
if date_format == "timestamp":
|
||||
td = dt - _ALWAYS
|
||||
result = int(
|
||||
return int(
|
||||
(td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{date_format} is not a valid format")
|
||||
return result
|
||||
raise ValueError(f"{date_format} is not a valid format")
|
||||
|
||||
@staticmethod
|
||||
def to_datetime(time_string_or_timestamp):
|
||||
def to_datetime(time_string_or_timestamp: str | bytes | int) -> datetime:
|
||||
try:
|
||||
if isinstance(time_string_or_timestamp, (str, bytes)):
|
||||
result = OpensshCertificateTimeParameters._time_string_to_datetime(
|
||||
time_string_or_timestamp.strip()
|
||||
to_text(time_string_or_timestamp.strip())
|
||||
)
|
||||
elif isinstance(time_string_or_timestamp, int):
|
||||
result = OpensshCertificateTimeParameters._timestamp_to_datetime(
|
||||
@@ -175,43 +216,53 @@ class OpensshCertificateTimeParameters:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _timestamp_to_datetime(timestamp):
|
||||
def _timestamp_to_datetime(timestamp: int) -> datetime:
|
||||
if timestamp == 0x0:
|
||||
result = _ALWAYS
|
||||
elif timestamp == 0xFFFFFFFFFFFFFFFF:
|
||||
result = _FOREVER
|
||||
else:
|
||||
try:
|
||||
result = datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
|
||||
except OverflowError:
|
||||
raise ValueError
|
||||
return result
|
||||
return _ALWAYS
|
||||
if timestamp == 0xFFFFFFFFFFFFFFFF:
|
||||
return _FOREVER
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
|
||||
except OverflowError:
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _time_string_to_datetime(time_string):
|
||||
result = None
|
||||
def _time_string_to_datetime(time_string: str) -> datetime:
|
||||
if time_string == "always":
|
||||
result = _ALWAYS
|
||||
elif time_string == "forever":
|
||||
result = _FOREVER
|
||||
elif is_relative_time_string(time_string):
|
||||
return _ALWAYS
|
||||
if time_string == "forever":
|
||||
return _FOREVER
|
||||
if is_relative_time_string(time_string):
|
||||
result = convert_relative_to_datetime(time_string, with_timezone=True)
|
||||
else:
|
||||
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
|
||||
try:
|
||||
result = _add_or_remove_timezone(
|
||||
datetime.strptime(time_string, time_format),
|
||||
with_timezone=True,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
if result is None:
|
||||
raise ValueError
|
||||
return result
|
||||
result = None
|
||||
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
|
||||
try:
|
||||
result = _add_or_remove_timezone(
|
||||
datetime.strptime(time_string, time_format),
|
||||
with_timezone=True,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
if result is None:
|
||||
raise ValueError
|
||||
return result
|
||||
|
||||
|
||||
_OpensshCertificateOption = t.TypeVar(
|
||||
"_OpensshCertificateOption", bound="OpensshCertificateOption"
|
||||
)
|
||||
|
||||
|
||||
class OpensshCertificateOption:
|
||||
def __init__(self, option_type, name, data):
|
||||
def __init__(
|
||||
self,
|
||||
option_type: t.Literal["critical", "extension"],
|
||||
name: str | bytes,
|
||||
data: str | bytes,
|
||||
):
|
||||
if option_type not in ("critical", "extension"):
|
||||
raise ValueError("type must be either 'critical' or 'extension'")
|
||||
|
||||
@@ -225,7 +276,7 @@ class OpensshCertificateOption:
|
||||
self._name = name.lower()
|
||||
self._data = data
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
@@ -237,32 +288,34 @@ class OpensshCertificateOption:
|
||||
]
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._option_type, self._name, self._data))
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self._data:
|
||||
return f"{self._name}={self._data}"
|
||||
return self._name
|
||||
return f"{self._name!r}={self._data!r}"
|
||||
return f"{self._name!r}"
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> str | bytes:
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str | bytes:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
def type(self) -> t.Literal["critical", "extension"]:
|
||||
return self._option_type
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, option_string):
|
||||
if not isinstance(option_string, (str, bytes)):
|
||||
def from_string(
|
||||
cls: t.Type[_OpensshCertificateOption], option_string: str
|
||||
) -> _OpensshCertificateOption:
|
||||
if not isinstance(option_string, str):
|
||||
raise ValueError(
|
||||
f"option_string must be a string not {type(option_string)}"
|
||||
)
|
||||
@@ -280,7 +333,8 @@ class OpensshCertificateOption:
|
||||
name, data = option_string.strip(), ""
|
||||
|
||||
return cls(
|
||||
option_type=option_type or get_option_type(name.lower()),
|
||||
# We have str, but we're expecting a specific literal:
|
||||
option_type=option_type or get_option_type(name.lower()), # type: ignore
|
||||
name=name,
|
||||
data=data,
|
||||
)
|
||||
@@ -291,21 +345,21 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nonce=None,
|
||||
serial=None,
|
||||
cert_type=None,
|
||||
key_id=None,
|
||||
principals=None,
|
||||
valid_after=None,
|
||||
valid_before=None,
|
||||
critical_options=None,
|
||||
extensions=None,
|
||||
reserved=None,
|
||||
signing_key=None,
|
||||
nonce: bytes | None = None,
|
||||
serial: int | None = None,
|
||||
cert_type: int | None = None,
|
||||
key_id: bytes | None = None,
|
||||
principals: list[bytes] | None = None,
|
||||
valid_after: int | None = None,
|
||||
valid_before: int | None = None,
|
||||
critical_options: list[tuple[bytes, bytes]] | None = None,
|
||||
extensions: list[tuple[bytes, bytes]] | None = None,
|
||||
reserved: bytes | None = None,
|
||||
signing_key: bytes | None = None,
|
||||
):
|
||||
self.nonce = nonce
|
||||
self.serial = serial
|
||||
self._cert_type = cert_type
|
||||
self._cert_type: int | None = cert_type
|
||||
self.key_id = key_id
|
||||
self.principals = principals
|
||||
self.valid_after = valid_after
|
||||
@@ -315,10 +369,10 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
self.reserved = reserved
|
||||
self.signing_key = signing_key
|
||||
|
||||
self.type_string = None
|
||||
self.type_string: bytes | None = None
|
||||
|
||||
@property
|
||||
def cert_type(self):
|
||||
def cert_type(self) -> t.Literal["user", "host", ""]:
|
||||
if self._cert_type == _USER_TYPE:
|
||||
return "user"
|
||||
elif self._cert_type == _HOST_TYPE:
|
||||
@@ -327,7 +381,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
return ""
|
||||
|
||||
@cert_type.setter
|
||||
def cert_type(self, cert_type):
|
||||
def cert_type(self, cert_type: t.Literal["user", "host"] | int) -> None:
|
||||
if cert_type == "user" or cert_type == _USER_TYPE:
|
||||
self._cert_type = _USER_TYPE
|
||||
elif cert_type == "host" or cert_type == _HOST_TYPE:
|
||||
@@ -335,28 +389,30 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
else:
|
||||
raise ValueError(f"{cert_type} is not a valid certificate type")
|
||||
|
||||
def signing_key_fingerprint(self):
|
||||
def signing_key_fingerprint(self) -> bytes:
|
||||
if self.signing_key is None:
|
||||
raise ValueError("signing_key not present")
|
||||
return fingerprint(self.signing_key)
|
||||
|
||||
@abc.abstractmethod
|
||||
def public_key_fingerprint(self):
|
||||
def public_key_fingerprint(self) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_public_numbers(self, parser):
|
||||
def parse_public_numbers(self, parser: OpensshParser) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class OpensshRSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, e=None, n=None, **kwargs):
|
||||
def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None:
|
||||
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
|
||||
self.e = e
|
||||
self.n = n
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
def public_key_fingerprint(self):
|
||||
if any([self.e is None, self.n is None]):
|
||||
def public_key_fingerprint(self) -> bytes:
|
||||
if self.e is None or self.n is None:
|
||||
return b""
|
||||
|
||||
writer = _OpensshWriter()
|
||||
@@ -366,13 +422,20 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
return fingerprint(writer.bytes())
|
||||
|
||||
def parse_public_numbers(self, parser):
|
||||
def parse_public_numbers(self, parser: OpensshParser) -> None:
|
||||
self.e = parser.mpint()
|
||||
self.n = parser.mpint()
|
||||
|
||||
|
||||
class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, p=None, q=None, g=None, y=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
p: int | None = None,
|
||||
q: int | None = None,
|
||||
g: int | None = None,
|
||||
y: int | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super(OpensshDSACertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01
|
||||
self.p = p
|
||||
@@ -381,8 +444,8 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
self.y = y
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
def public_key_fingerprint(self):
|
||||
if any([self.p is None, self.q is None, self.g is None, self.y is None]):
|
||||
def public_key_fingerprint(self) -> bytes:
|
||||
if self.p is None or self.q is None or self.g is None or self.y is None:
|
||||
return b""
|
||||
|
||||
writer = _OpensshWriter()
|
||||
@@ -394,7 +457,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
return fingerprint(writer.bytes())
|
||||
|
||||
def parse_public_numbers(self, parser):
|
||||
def parse_public_numbers(self, parser: OpensshParser) -> None:
|
||||
self.p = parser.mpint()
|
||||
self.q = parser.mpint()
|
||||
self.g = parser.mpint()
|
||||
@@ -402,7 +465,9 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
|
||||
class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, curve=None, public_key=None, **kwargs):
|
||||
def __init__(
|
||||
self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
|
||||
):
|
||||
super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
|
||||
self._curve = None
|
||||
if curve is not None:
|
||||
@@ -411,11 +476,11 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
self.public_key = public_key
|
||||
|
||||
@property
|
||||
def curve(self):
|
||||
def curve(self) -> bytes | None:
|
||||
return self._curve
|
||||
|
||||
@curve.setter
|
||||
def curve(self, curve):
|
||||
def curve(self, curve: bytes) -> None:
|
||||
if curve in _ECDSA_CURVE_IDENTIFIERS.values():
|
||||
self._curve = curve
|
||||
self.type_string = (
|
||||
@@ -428,8 +493,8 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
)
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
def public_key_fingerprint(self):
|
||||
if any([self.curve is None, self.public_key is None]):
|
||||
def public_key_fingerprint(self) -> bytes:
|
||||
if self.curve is None or self.public_key is None:
|
||||
return b""
|
||||
|
||||
writer = _OpensshWriter()
|
||||
@@ -439,18 +504,18 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
return fingerprint(writer.bytes())
|
||||
|
||||
def parse_public_numbers(self, parser):
|
||||
def parse_public_numbers(self, parser: OpensshParser) -> None:
|
||||
self.curve = parser.string()
|
||||
self.public_key = parser.string()
|
||||
|
||||
|
||||
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, pk=None, **kwargs):
|
||||
def __init__(self, pk: bytes | None = None, **kwargs) -> None:
|
||||
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
|
||||
self.pk = pk
|
||||
|
||||
def public_key_fingerprint(self):
|
||||
def public_key_fingerprint(self) -> bytes:
|
||||
if self.pk is None:
|
||||
return b""
|
||||
|
||||
@@ -460,21 +525,26 @@ class OpensshED25519CertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
return fingerprint(writer.bytes())
|
||||
|
||||
def parse_public_numbers(self, parser):
|
||||
def parse_public_numbers(self, parser: OpensshParser) -> None:
|
||||
self.pk = parser.string()
|
||||
|
||||
|
||||
_OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate")
|
||||
|
||||
|
||||
# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
|
||||
class OpensshCertificate:
|
||||
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
|
||||
|
||||
def __init__(self, cert_info, signature):
|
||||
def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes):
|
||||
|
||||
self._cert_info = cert_info
|
||||
self.signature = signature
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
def load(
|
||||
cls: t.Type[_OpensshCertificate], path: str | os.PathLike
|
||||
) -> _OpensshCertificate:
|
||||
if not os.path.exists(path):
|
||||
raise ValueError(f"{path} is not a valid path.")
|
||||
|
||||
@@ -492,11 +562,11 @@ class OpensshCertificate:
|
||||
|
||||
for key_type, string in _SSH_TYPE_STRINGS.items():
|
||||
if format_identifier == string + _CERT_SUFFIX_V01:
|
||||
pub_key_type = key_type
|
||||
pub_key_type = t.cast(KeyType, key_type)
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid certificate format identifier: {format_identifier}"
|
||||
f"Invalid certificate format identifier: {format_identifier!r}"
|
||||
)
|
||||
|
||||
parser = OpensshParser(cert)
|
||||
@@ -521,75 +591,97 @@ class OpensshCertificate:
|
||||
)
|
||||
|
||||
@property
|
||||
def type_string(self):
|
||||
def type_string(self) -> str:
|
||||
return to_text(self._cert_info.type_string)
|
||||
|
||||
@property
|
||||
def nonce(self):
|
||||
def nonce(self) -> bytes:
|
||||
if self._cert_info.nonce is None:
|
||||
raise ValueError
|
||||
return self._cert_info.nonce
|
||||
|
||||
@property
|
||||
def public_key(self):
|
||||
def public_key(self) -> str:
|
||||
return to_text(self._cert_info.public_key_fingerprint())
|
||||
|
||||
@property
|
||||
def serial(self):
|
||||
def serial(self) -> int:
|
||||
if self._cert_info.serial is None:
|
||||
raise ValueError
|
||||
return self._cert_info.serial
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self._cert_info.cert_type
|
||||
def type(self) -> t.Literal["user", "host"]:
|
||||
result = self._cert_info.cert_type
|
||||
if result == "":
|
||||
raise ValueError
|
||||
return result
|
||||
|
||||
@property
|
||||
def key_id(self):
|
||||
def key_id(self) -> str:
|
||||
return to_text(self._cert_info.key_id)
|
||||
|
||||
@property
|
||||
def principals(self):
|
||||
def principals(self) -> list[str]:
|
||||
if self._cert_info.principals is None:
|
||||
raise ValueError
|
||||
return [to_text(p) for p in self._cert_info.principals]
|
||||
|
||||
@property
|
||||
def valid_after(self):
|
||||
def valid_after(self) -> int:
|
||||
if self._cert_info.valid_after is None:
|
||||
raise ValueError
|
||||
return self._cert_info.valid_after
|
||||
|
||||
@property
|
||||
def valid_before(self):
|
||||
def valid_before(self) -> int:
|
||||
if self._cert_info.valid_before is None:
|
||||
raise ValueError
|
||||
return self._cert_info.valid_before
|
||||
|
||||
@property
|
||||
def critical_options(self):
|
||||
def critical_options(self) -> list[OpensshCertificateOption]:
|
||||
if self._cert_info.critical_options is None:
|
||||
raise ValueError
|
||||
return [
|
||||
OpensshCertificateOption("critical", to_text(n), to_text(d))
|
||||
for n, d in self._cert_info.critical_options
|
||||
]
|
||||
|
||||
@property
|
||||
def extensions(self):
|
||||
def extensions(self) -> list[OpensshCertificateOption]:
|
||||
if self._cert_info.extensions is None:
|
||||
raise ValueError
|
||||
return [
|
||||
OpensshCertificateOption("extension", to_text(n), to_text(d))
|
||||
for n, d in self._cert_info.extensions
|
||||
]
|
||||
|
||||
@property
|
||||
def reserved(self):
|
||||
def reserved(self) -> bytes:
|
||||
if self._cert_info.reserved is None:
|
||||
raise ValueError
|
||||
return self._cert_info.reserved
|
||||
|
||||
@property
|
||||
def signing_key(self):
|
||||
def signing_key(self) -> str:
|
||||
return to_text(self._cert_info.signing_key_fingerprint())
|
||||
|
||||
@property
|
||||
def signature_type(self):
|
||||
def signature_type(self) -> str:
|
||||
signature_data = OpensshParser.signature_data(self.signature)
|
||||
return to_text(signature_data["signature_type"])
|
||||
|
||||
@staticmethod
|
||||
def _parse_cert_info(pub_key_type, parser):
|
||||
def _parse_cert_info(
|
||||
pub_key_type: KeyType, parser: OpensshParser
|
||||
) -> OpensshCertificateInfo:
|
||||
cert_info = get_cert_info_object(pub_key_type)
|
||||
cert_info.nonce = parser.string()
|
||||
cert_info.parse_public_numbers(parser)
|
||||
cert_info.serial = parser.uint64()
|
||||
cert_info.cert_type = parser.uint32()
|
||||
# mypy doesn't understand that the setter accepts other types than the getter:
|
||||
cert_info.cert_type = parser.uint32() # type: ignore
|
||||
cert_info.key_id = parser.string()
|
||||
cert_info.principals = parser.string_list()
|
||||
cert_info.valid_after = parser.uint64()
|
||||
@@ -601,7 +693,7 @@ class OpensshCertificate:
|
||||
|
||||
return cert_info
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, t.Any]:
|
||||
time_parameters = OpensshCertificateTimeParameters(
|
||||
valid_from=self.valid_after, valid_to=self.valid_before
|
||||
)
|
||||
@@ -624,7 +716,7 @@ class OpensshCertificate:
|
||||
}
|
||||
|
||||
|
||||
def apply_directives(directives):
|
||||
def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOption]:
|
||||
if any(d not in _DIRECTIVES for d in directives):
|
||||
raise ValueError(f"directives must be one of {', '.join(_DIRECTIVES)}")
|
||||
|
||||
@@ -650,50 +742,47 @@ def apply_directives(directives):
|
||||
)
|
||||
|
||||
|
||||
def default_options():
|
||||
def default_options() -> list[OpensshCertificateOption]:
|
||||
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
|
||||
|
||||
|
||||
def fingerprint(public_key):
|
||||
def fingerprint(public_key: bytes) -> bytes:
|
||||
"""Generates a SHA256 hash and formats output to resemble ``ssh-keygen``"""
|
||||
h = sha256()
|
||||
h.update(public_key)
|
||||
return b"SHA256:" + b64encode(h.digest()).rstrip(b"=")
|
||||
|
||||
|
||||
def get_cert_info_object(key_type):
|
||||
def get_cert_info_object(key_type: KeyType) -> OpensshCertificateInfo:
|
||||
if key_type == "rsa":
|
||||
cert_info = OpensshRSACertificateInfo()
|
||||
elif key_type == "dsa":
|
||||
cert_info = OpensshDSACertificateInfo()
|
||||
elif key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
|
||||
cert_info = OpensshECDSACertificateInfo()
|
||||
elif key_type == "ed25519":
|
||||
cert_info = OpensshED25519CertificateInfo()
|
||||
else:
|
||||
raise ValueError(f"{key_type} is not a valid key type")
|
||||
|
||||
return cert_info
|
||||
return OpensshRSACertificateInfo()
|
||||
if key_type == "dsa":
|
||||
return OpensshDSACertificateInfo()
|
||||
if key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
|
||||
return OpensshECDSACertificateInfo()
|
||||
if key_type == "ed25519":
|
||||
return OpensshED25519CertificateInfo()
|
||||
raise ValueError(f"{key_type} is not a valid key type")
|
||||
|
||||
|
||||
def get_option_type(name):
|
||||
def get_option_type(name: str) -> t.Literal["critical", "extension"]:
|
||||
if name in _CRITICAL_OPTIONS:
|
||||
result = "critical"
|
||||
elif name in _EXTENSIONS:
|
||||
result = "extension"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{name} is not a valid option. "
|
||||
"Custom options must start with 'critical:' or 'extension:' to indicate type"
|
||||
)
|
||||
return result
|
||||
return "critical"
|
||||
if name in _EXTENSIONS:
|
||||
return "extension"
|
||||
raise ValueError(
|
||||
f"{name} is not a valid option. "
|
||||
"Custom options must start with 'critical:' or 'extension:' to indicate type"
|
||||
)
|
||||
|
||||
|
||||
def is_relative_time_string(time_string):
|
||||
def is_relative_time_string(time_string: str) -> bool:
|
||||
return time_string.startswith("+") or time_string.startswith("-")
|
||||
|
||||
|
||||
def parse_option_list(option_list):
|
||||
def parse_option_list(
|
||||
option_list: t.Iterable[str],
|
||||
) -> tuple[list[OpensshCertificateOption], list[OpensshCertificateOption]]:
|
||||
critical_options = []
|
||||
directives = []
|
||||
extensions = []
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
from base64 import b64decode, b64encode
|
||||
from getpass import getuser
|
||||
from socket import gethostname
|
||||
@@ -64,6 +65,27 @@ except ImportError:
|
||||
CRYPTOGRAPHY_VERSION = "0.0"
|
||||
_ALGORITHM_PARAMETERS = {}
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
KeyFormat = t.Literal["SSH", "PKCS8", "PKCS1"]
|
||||
KeySerializationFormat = t.Literal["PEM", "DER", "SSH"]
|
||||
KeyType = t.Literal["rsa", "dsa", "ed25519", "ecdsa"]
|
||||
|
||||
PrivateKeyTypes = t.Union[
|
||||
rsa.RSAPrivateKey,
|
||||
dsa.DSAPrivateKey,
|
||||
ec.EllipticCurvePrivateKey,
|
||||
Ed25519PrivateKey,
|
||||
]
|
||||
PublicKeyTypes = t.Union[
|
||||
rsa.RSAPublicKey, dsa.DSAPublicKey, ec.EllipticCurvePublicKey, Ed25519PublicKey
|
||||
]
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PublicKeyTypes as AllPublicKeyTypes,
|
||||
)
|
||||
|
||||
|
||||
_TEXT_ENCODING = "UTF-8"
|
||||
|
||||
|
||||
@@ -111,11 +133,19 @@ class InvalidSignatureError(OpenSSHError):
|
||||
pass
|
||||
|
||||
|
||||
_AsymmetricKeypair = t.TypeVar("_AsymmetricKeypair", bound="AsymmetricKeypair")
|
||||
|
||||
|
||||
class AsymmetricKeypair:
|
||||
"""Container for newly generated asymmetric key pairs or those loaded from existing files"""
|
||||
|
||||
@classmethod
|
||||
def generate(cls, keytype="rsa", size=None, passphrase=None):
|
||||
def generate(
|
||||
cls: t.Type[_AsymmetricKeypair],
|
||||
keytype: KeyType = "rsa",
|
||||
size: int | None = None,
|
||||
passphrase: bytes | None = None,
|
||||
) -> _AsymmetricKeypair:
|
||||
"""Returns an Asymmetric_Keypair object generated with the supplied parameters
|
||||
or defaults to an unencrypted RSA-2048 key
|
||||
|
||||
@@ -124,19 +154,21 @@ class AsymmetricKeypair:
|
||||
:passphrase: Secret of type Bytes used to encrypt the private key being generated
|
||||
"""
|
||||
|
||||
if keytype not in _ALGORITHM_PARAMETERS.keys():
|
||||
if keytype not in _ALGORITHM_PARAMETERS:
|
||||
raise InvalidKeyTypeError(
|
||||
f"{keytype} is not a valid keytype. Valid keytypes are {', '.join(_ALGORITHM_PARAMETERS)}"
|
||||
)
|
||||
|
||||
if not size:
|
||||
size = _ALGORITHM_PARAMETERS[keytype]["default_size"]
|
||||
size = _ALGORITHM_PARAMETERS[keytype]["default_size"] # type: ignore
|
||||
else:
|
||||
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]:
|
||||
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]: # type: ignore
|
||||
raise InvalidKeySizeError(
|
||||
f"{size} is not a valid key size for {keytype} keys"
|
||||
)
|
||||
size = t.cast(int, size)
|
||||
|
||||
privatekey: PrivateKeyTypes
|
||||
if passphrase:
|
||||
encryption_algorithm = get_encryption_algorithm(passphrase)
|
||||
else:
|
||||
@@ -157,7 +189,7 @@ class AsymmetricKeypair:
|
||||
privatekey = Ed25519PrivateKey.generate()
|
||||
elif keytype == "ecdsa":
|
||||
privatekey = ec.generate_private_key(
|
||||
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size],
|
||||
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size], # type: ignore
|
||||
)
|
||||
|
||||
publickey = privatekey.public_key()
|
||||
@@ -172,13 +204,13 @@ class AsymmetricKeypair:
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
path,
|
||||
passphrase=None,
|
||||
private_key_format="PEM",
|
||||
public_key_format="PEM",
|
||||
no_public_key=False,
|
||||
):
|
||||
cls: t.Type[_AsymmetricKeypair],
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None = None,
|
||||
private_key_format: KeySerializationFormat = "PEM",
|
||||
public_key_format: KeySerializationFormat = "PEM",
|
||||
no_public_key: bool = False,
|
||||
) -> _AsymmetricKeypair:
|
||||
"""Returns an Asymmetric_Keypair object loaded from the supplied file path
|
||||
|
||||
:path: A path to an existing private key to be loaded
|
||||
@@ -197,14 +229,17 @@ class AsymmetricKeypair:
|
||||
if no_public_key:
|
||||
publickey = privatekey.public_key()
|
||||
else:
|
||||
publickey = load_publickey(path + ".pub", public_key_format)
|
||||
# TODO: BUG: load_publickey() can return unsupported key types
|
||||
# (Also we should check whether the public key fits the private key...)
|
||||
publickey = load_publickey(path + ".pub", public_key_format) # type: ignore
|
||||
|
||||
# Ed25519 keys are always of size 256 and do not have a key_size attribute
|
||||
if isinstance(privatekey, Ed25519PrivateKey):
|
||||
size = _ALGORITHM_PARAMETERS["ed25519"]["default_size"]
|
||||
size: int = _ALGORITHM_PARAMETERS["ed25519"]["default_size"] # type: ignore
|
||||
else:
|
||||
size = privatekey.key_size
|
||||
|
||||
keytype: KeyType
|
||||
if isinstance(privatekey, rsa.RSAPrivateKey):
|
||||
keytype = "rsa"
|
||||
elif isinstance(privatekey, dsa.DSAPrivateKey):
|
||||
@@ -224,7 +259,14 @@ class AsymmetricKeypair:
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
||||
|
||||
def __init__(self, keytype, size, privatekey, publickey, encryption_algorithm):
|
||||
def __init__(
|
||||
self,
|
||||
keytype: KeyType,
|
||||
size: int,
|
||||
privatekey: PrivateKeyTypes,
|
||||
publickey: PublicKeyTypes,
|
||||
encryption_algorithm: serialization.KeySerializationEncryption,
|
||||
) -> None:
|
||||
"""
|
||||
:keytype: One of rsa, dsa, ecdsa, ed25519
|
||||
:size: The key length for the private key of this key pair
|
||||
@@ -246,7 +288,7 @@ class AsymmetricKeypair:
|
||||
"The private key and public key of this keypair do not match"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, AsymmetricKeypair):
|
||||
return NotImplemented
|
||||
|
||||
@@ -256,55 +298,53 @@ class AsymmetricKeypair:
|
||||
self.encryption_algorithm, other.encryption_algorithm
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self == other
|
||||
|
||||
@property
|
||||
def private_key(self):
|
||||
def private_key(self) -> PrivateKeyTypes:
|
||||
"""Returns the private key of this key pair"""
|
||||
|
||||
return self.__privatekey
|
||||
|
||||
@property
|
||||
def public_key(self):
|
||||
def public_key(self) -> PublicKeyTypes:
|
||||
"""Returns the public key of this key pair"""
|
||||
|
||||
return self.__publickey
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
def size(self) -> int:
|
||||
"""Returns the size of the private key of this key pair"""
|
||||
|
||||
return self.__size
|
||||
|
||||
@property
|
||||
def key_type(self):
|
||||
def key_type(self) -> KeyType:
|
||||
"""Returns the key type of this key pair"""
|
||||
|
||||
return self.__keytype
|
||||
|
||||
@property
|
||||
def encryption_algorithm(self):
|
||||
def encryption_algorithm(self) -> serialization.KeySerializationEncryption:
|
||||
"""Returns the key encryption algorithm of this key pair"""
|
||||
|
||||
return self.__encryption_algorithm
|
||||
|
||||
def sign(self, data):
|
||||
def sign(self, data: bytes) -> bytes:
|
||||
"""Returns signature of data signed with the private key of this key pair
|
||||
|
||||
:data: byteslike data to sign
|
||||
"""
|
||||
|
||||
try:
|
||||
signature = self.__privatekey.sign(
|
||||
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"]
|
||||
return self.__privatekey.sign(
|
||||
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"] # type: ignore
|
||||
)
|
||||
except TypeError as e:
|
||||
raise InvalidDataError(e)
|
||||
|
||||
return signature
|
||||
|
||||
def verify(self, signature, data):
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
"""Verifies that the signature associated with the provided data was signed
|
||||
by the private key of this key pair.
|
||||
|
||||
@@ -312,15 +352,15 @@ class AsymmetricKeypair:
|
||||
:data: byteslike data signed by the provided signature
|
||||
"""
|
||||
try:
|
||||
return self.__publickey.verify(
|
||||
self.__publickey.verify(
|
||||
signature,
|
||||
data,
|
||||
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"],
|
||||
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"], # type: ignore
|
||||
)
|
||||
except InvalidSignature:
|
||||
raise InvalidSignatureError
|
||||
|
||||
def update_passphrase(self, passphrase=None):
|
||||
def update_passphrase(self, passphrase: bytes | None = None) -> None:
|
||||
"""Updates the encryption algorithm of this key pair
|
||||
|
||||
:passphrase: Byte secret used to encrypt this key pair
|
||||
@@ -332,11 +372,20 @@ class AsymmetricKeypair:
|
||||
self.__encryption_algorithm = serialization.NoEncryption()
|
||||
|
||||
|
||||
_OpensshKeypair = t.TypeVar("_OpensshKeypair", bound="OpensshKeypair")
|
||||
|
||||
|
||||
class OpensshKeypair:
|
||||
"""Container for OpenSSH encoded asymmetric key pairs"""
|
||||
|
||||
@classmethod
|
||||
def generate(cls, keytype="rsa", size=None, passphrase=None, comment=None):
|
||||
def generate(
|
||||
cls: t.Type[_OpensshKeypair],
|
||||
keytype: KeyType = "rsa",
|
||||
size: int | None = None,
|
||||
passphrase: bytes | None = None,
|
||||
comment: str | None = None,
|
||||
) -> _OpensshKeypair:
|
||||
"""Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key
|
||||
|
||||
:keytype: One of rsa, dsa, ecdsa, ed25519
|
||||
@@ -362,7 +411,12 @@ class OpensshKeypair:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path, passphrase=None, no_public_key=False):
|
||||
def load(
|
||||
cls: t.Type[_OpensshKeypair],
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None = None,
|
||||
no_public_key: bool = False,
|
||||
) -> _OpensshKeypair:
|
||||
"""Returns an Openssh_Keypair object loaded from the supplied file path
|
||||
|
||||
:path: A path to an existing private key to be loaded
|
||||
@@ -373,7 +427,7 @@ class OpensshKeypair:
|
||||
if no_public_key:
|
||||
comment = ""
|
||||
else:
|
||||
comment = extract_comment(path + ".pub")
|
||||
comment = extract_comment(str(path) + ".pub")
|
||||
|
||||
asym_keypair = AsymmetricKeypair.load(
|
||||
path, passphrase, "SSH", "SSH", no_public_key
|
||||
@@ -391,7 +445,9 @@ class OpensshKeypair:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def encode_openssh_privatekey(asym_keypair, key_format):
|
||||
def encode_openssh_privatekey(
|
||||
asym_keypair: AsymmetricKeypair, key_format: KeyFormat
|
||||
) -> bytes:
|
||||
"""Returns an OpenSSH encoded private key for a given keypair
|
||||
|
||||
:asym_keypair: Asymmetric_Keypair from the private key is extracted
|
||||
@@ -422,7 +478,9 @@ class OpensshKeypair:
|
||||
return encoded_privatekey
|
||||
|
||||
@staticmethod
|
||||
def encode_openssh_publickey(asym_keypair, comment):
|
||||
def encode_openssh_publickey(
|
||||
asym_keypair: AsymmetricKeypair, comment: str
|
||||
) -> bytes:
|
||||
"""Returns an OpenSSH encoded public key for a given keypair
|
||||
|
||||
:asym_keypair: Asymmetric_Keypair from the public key is extracted
|
||||
@@ -436,14 +494,19 @@ class OpensshKeypair:
|
||||
validate_comment(comment)
|
||||
|
||||
encoded_publickey += (
|
||||
f" {comment}".encode(encoding=_TEXT_ENCODING) if comment else b""
|
||||
(b" " + comment.encode(encoding=_TEXT_ENCODING)) if comment else b""
|
||||
)
|
||||
|
||||
return encoded_publickey
|
||||
|
||||
def __init__(
|
||||
self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment
|
||||
):
|
||||
self,
|
||||
asym_keypair: AsymmetricKeypair,
|
||||
openssh_privatekey: bytes,
|
||||
openssh_publickey: bytes,
|
||||
fingerprint: str,
|
||||
comment: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
|
||||
:openssh_privatekey: An OpenSSH encoded private key
|
||||
@@ -458,7 +521,7 @@ class OpensshKeypair:
|
||||
self.__fingerprint = fingerprint
|
||||
self.__comment = comment
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, OpensshKeypair):
|
||||
return NotImplemented
|
||||
|
||||
@@ -468,49 +531,49 @@ class OpensshKeypair:
|
||||
)
|
||||
|
||||
@property
|
||||
def asymmetric_keypair(self):
|
||||
def asymmetric_keypair(self) -> AsymmetricKeypair:
|
||||
"""Returns the underlying asymmetric key pair of this OpenSSH encoded key pair"""
|
||||
|
||||
return self.__asym_keypair
|
||||
|
||||
@property
|
||||
def private_key(self):
|
||||
def private_key(self) -> bytes:
|
||||
"""Returns the OpenSSH formatted private key of this key pair"""
|
||||
|
||||
return self.__openssh_privatekey
|
||||
|
||||
@property
|
||||
def public_key(self):
|
||||
def public_key(self) -> bytes:
|
||||
"""Returns the OpenSSH formatted public key of this key pair"""
|
||||
|
||||
return self.__openssh_publickey
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
def size(self) -> int:
|
||||
"""Returns the size of the private key of this key pair"""
|
||||
|
||||
return self.__asym_keypair.size
|
||||
|
||||
@property
|
||||
def key_type(self):
|
||||
def key_type(self) -> KeyType:
|
||||
"""Returns the key type of this key pair"""
|
||||
|
||||
return self.__asym_keypair.key_type
|
||||
|
||||
@property
|
||||
def fingerprint(self):
|
||||
def fingerprint(self) -> str:
|
||||
"""Returns the fingerprint (SHA256 Hash) of the public key of this key pair"""
|
||||
|
||||
return self.__fingerprint
|
||||
|
||||
@property
|
||||
def comment(self):
|
||||
def comment(self) -> str | None:
|
||||
"""Returns the comment applied to the OpenSSH formatted public key of this key pair"""
|
||||
|
||||
return self.__comment
|
||||
|
||||
@comment.setter
|
||||
def comment(self, comment):
|
||||
def comment(self, comment: str) -> bytes:
|
||||
"""Updates the comment applied to the OpenSSH formatted public key of this key pair
|
||||
|
||||
:comment: Text to update the OpenSSH public key comment
|
||||
@@ -529,7 +592,7 @@ class OpensshKeypair:
|
||||
)
|
||||
return self.__openssh_publickey
|
||||
|
||||
def update_passphrase(self, passphrase):
|
||||
def update_passphrase(self, passphrase: bytes | None) -> None:
|
||||
"""Updates the passphrase used to encrypt the private key of this keypair
|
||||
|
||||
:passphrase: Text secret used for encryption
|
||||
@@ -541,18 +604,17 @@ class OpensshKeypair:
|
||||
)
|
||||
|
||||
|
||||
def load_privatekey(path, passphrase, key_format):
|
||||
def load_privatekey(
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None,
|
||||
key_format: KeySerializationFormat,
|
||||
) -> PrivateKeyTypes:
|
||||
privatekey_loaders = {
|
||||
"PEM": serialization.load_pem_private_key,
|
||||
"DER": serialization.load_der_private_key,
|
||||
"SSH": serialization.load_ssh_private_key,
|
||||
}
|
||||
|
||||
# OpenSSH formatted private keys are not available in Cryptography <3.0
|
||||
if hasattr(serialization, "load_ssh_private_key"):
|
||||
privatekey_loaders["SSH"] = serialization.load_ssh_private_key
|
||||
else:
|
||||
privatekey_loaders["SSH"] = serialization.load_pem_private_key
|
||||
|
||||
try:
|
||||
privatekey_loader = privatekey_loaders[key_format]
|
||||
except KeyError:
|
||||
@@ -567,16 +629,16 @@ def load_privatekey(path, passphrase, key_format):
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
privatekey = privatekey_loader(
|
||||
privatekey = privatekey_loader( # type: ignore
|
||||
data=content,
|
||||
password=passphrase,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
except ValueError as exc:
|
||||
# Revert to PEM if key could not be loaded in SSH format
|
||||
if key_format == "SSH":
|
||||
try:
|
||||
privatekey = privatekey_loaders["PEM"](
|
||||
privatekey = privatekey_loaders["PEM"]( # type: ignore
|
||||
data=content,
|
||||
password=passphrase,
|
||||
)
|
||||
@@ -587,7 +649,7 @@ def load_privatekey(path, passphrase, key_format):
|
||||
except UnsupportedAlgorithm as e:
|
||||
raise InvalidAlgorithmError(e)
|
||||
else:
|
||||
raise InvalidPrivateKeyFileError(e)
|
||||
raise InvalidPrivateKeyFileError(exc)
|
||||
except TypeError as e:
|
||||
raise InvalidPassphraseError(e)
|
||||
except UnsupportedAlgorithm as e:
|
||||
@@ -596,7 +658,9 @@ def load_privatekey(path, passphrase, key_format):
|
||||
return privatekey
|
||||
|
||||
|
||||
def load_publickey(path, key_format):
|
||||
def load_publickey(
|
||||
path: str | os.PathLike, key_format: KeySerializationFormat
|
||||
) -> AllPublicKeyTypes:
|
||||
publickey_loaders = {
|
||||
"PEM": serialization.load_pem_public_key,
|
||||
"DER": serialization.load_der_public_key,
|
||||
@@ -628,20 +692,27 @@ def load_publickey(path, key_format):
|
||||
return publickey
|
||||
|
||||
|
||||
def compare_publickeys(pk1, pk2):
|
||||
def compare_publickeys(pk1: PublicKeyTypes, pk2: PublicKeyTypes) -> bool:
|
||||
a = isinstance(pk1, Ed25519PublicKey)
|
||||
b = isinstance(pk2, Ed25519PublicKey)
|
||||
if a or b:
|
||||
if not a or not b:
|
||||
return False
|
||||
a = pk1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
||||
b = pk2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
||||
return a == b
|
||||
a_bytes = pk1.public_bytes(
|
||||
serialization.Encoding.Raw, serialization.PublicFormat.Raw
|
||||
)
|
||||
b_bytes = pk2.public_bytes(
|
||||
serialization.Encoding.Raw, serialization.PublicFormat.Raw
|
||||
)
|
||||
return a_bytes == b_bytes
|
||||
else:
|
||||
return pk1.public_numbers() == pk2.public_numbers()
|
||||
return pk1.public_numbers() == pk2.public_numbers() # type: ignore
|
||||
|
||||
|
||||
def compare_encryption_algorithms(ea1, ea2):
|
||||
def compare_encryption_algorithms(
|
||||
ea1: serialization.KeySerializationEncryption,
|
||||
ea2: serialization.KeySerializationEncryption,
|
||||
) -> bool:
|
||||
if isinstance(ea1, serialization.NoEncryption) and isinstance(
|
||||
ea2, serialization.NoEncryption
|
||||
):
|
||||
@@ -654,19 +725,21 @@ def compare_encryption_algorithms(ea1, ea2):
|
||||
return False
|
||||
|
||||
|
||||
def get_encryption_algorithm(passphrase):
|
||||
def get_encryption_algorithm(
|
||||
passphrase: bytes,
|
||||
) -> serialization.KeySerializationEncryption:
|
||||
try:
|
||||
return serialization.BestAvailableEncryption(passphrase)
|
||||
except ValueError as e:
|
||||
raise InvalidPassphraseError(e)
|
||||
|
||||
|
||||
def validate_comment(comment):
|
||||
def validate_comment(comment: str) -> None:
|
||||
if not hasattr(comment, "encode"):
|
||||
raise InvalidCommentError(f"{comment} cannot be encoded to text")
|
||||
|
||||
|
||||
def extract_comment(path):
|
||||
def extract_comment(path: str | os.PathLike) -> str:
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise InvalidPublicKeyFileError(f"No file was found at {path}")
|
||||
@@ -684,7 +757,7 @@ def extract_comment(path):
|
||||
return comment
|
||||
|
||||
|
||||
def calculate_fingerprint(openssh_publickey):
|
||||
def calculate_fingerprint(openssh_publickey: bytes) -> str:
|
||||
digest = hashes.Hash(hashes.SHA256())
|
||||
decoded_pubkey = b64decode(openssh_publickey.split(b" ")[1])
|
||||
digest.update(decoded_pubkey)
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from struct import Struct
|
||||
|
||||
@@ -38,17 +39,20 @@ _UINT64 = Struct(b"!Q")
|
||||
_UINT64_MAX = 0xFFFFFFFFFFFFFFFF
|
||||
|
||||
|
||||
def any_in(sequence, *elements):
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
def any_in(sequence: t.Iterable[_T], *elements: _T) -> bool:
|
||||
return any(e in sequence for e in elements)
|
||||
|
||||
|
||||
def file_mode(path):
|
||||
def file_mode(path: str | os.PathLike) -> int:
|
||||
if not os.path.exists(path):
|
||||
return 0o000
|
||||
return os.stat(path).st_mode & 0o777
|
||||
|
||||
|
||||
def parse_openssh_version(version_string):
|
||||
def parse_openssh_version(version_string: str) -> str | None:
|
||||
"""Parse the version output of ssh -V and return version numbers that can be compared"""
|
||||
|
||||
parsed_result = re.match(
|
||||
@@ -63,7 +67,7 @@ def parse_openssh_version(version_string):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def secure_open(path, mode):
|
||||
def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
|
||||
fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode)
|
||||
try:
|
||||
yield fd
|
||||
@@ -71,7 +75,7 @@ def secure_open(path, mode):
|
||||
os.close(fd)
|
||||
|
||||
|
||||
def secure_write(path, mode, content):
|
||||
def secure_write(path: str | os.PathLike, mode: int, content: bytes) -> None:
|
||||
with secure_open(path, mode) as fd:
|
||||
os.write(fd, content)
|
||||
|
||||
@@ -84,35 +88,35 @@ class OpensshParser:
|
||||
UINT32_OFFSET = 4
|
||||
UINT64_OFFSET = 8
|
||||
|
||||
def __init__(self, data):
|
||||
def __init__(self, data: bytes | bytearray) -> None:
|
||||
if not isinstance(data, (bytes, bytearray)):
|
||||
raise TypeError(f"Data must be bytes-like not {type(data)}")
|
||||
|
||||
self._data = memoryview(data)
|
||||
self._pos = 0
|
||||
|
||||
def boolean(self):
|
||||
def boolean(self) -> bool:
|
||||
next_pos = self._check_position(self.BOOLEAN_OFFSET)
|
||||
|
||||
value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0]
|
||||
self._pos = next_pos
|
||||
return value
|
||||
|
||||
def uint32(self):
|
||||
def uint32(self) -> int:
|
||||
next_pos = self._check_position(self.UINT32_OFFSET)
|
||||
|
||||
value = _UINT32.unpack(self._data[self._pos : next_pos])[0]
|
||||
self._pos = next_pos
|
||||
return value
|
||||
|
||||
def uint64(self):
|
||||
def uint64(self) -> int:
|
||||
next_pos = self._check_position(self.UINT64_OFFSET)
|
||||
|
||||
value = _UINT64.unpack(self._data[self._pos : next_pos])[0]
|
||||
self._pos = next_pos
|
||||
return value
|
||||
|
||||
def string(self):
|
||||
def string(self) -> bytes:
|
||||
length = self.uint32()
|
||||
|
||||
next_pos = self._check_position(length)
|
||||
@@ -122,15 +126,15 @@ class OpensshParser:
|
||||
# Cast to bytes is required as a memoryview slice is itself a memoryview
|
||||
return bytes(value)
|
||||
|
||||
def mpint(self):
|
||||
def mpint(self) -> int:
|
||||
return self._big_int(self.string(), "big", signed=True)
|
||||
|
||||
def name_list(self):
|
||||
def name_list(self) -> list[str]:
|
||||
raw_string = self.string()
|
||||
return raw_string.decode("ASCII").split(",")
|
||||
|
||||
# Convenience function, but not an official data type from SSH
|
||||
def string_list(self):
|
||||
def string_list(self) -> list[bytes]:
|
||||
result = []
|
||||
raw_string = self.string()
|
||||
|
||||
@@ -142,7 +146,7 @@ class OpensshParser:
|
||||
return result
|
||||
|
||||
# Convenience function, but not an official data type from SSH
|
||||
def option_list(self):
|
||||
def option_list(self) -> list[tuple[bytes, bytes]]:
|
||||
result = []
|
||||
raw_string = self.string()
|
||||
|
||||
@@ -159,15 +163,15 @@ class OpensshParser:
|
||||
|
||||
return result
|
||||
|
||||
def seek(self, offset):
|
||||
def seek(self, offset: int) -> int:
|
||||
self._pos = self._check_position(offset)
|
||||
|
||||
return self._pos
|
||||
|
||||
def remaining_bytes(self):
|
||||
def remaining_bytes(self) -> int:
|
||||
return len(self._data) - self._pos
|
||||
|
||||
def _check_position(self, offset):
|
||||
def _check_position(self, offset: int) -> int:
|
||||
if self._pos + offset > len(self._data):
|
||||
raise ValueError(f"Insufficient data remaining at position: {self._pos}")
|
||||
elif self._pos + offset < 0:
|
||||
@@ -176,8 +180,8 @@ class OpensshParser:
|
||||
return self._pos + offset
|
||||
|
||||
@classmethod
|
||||
def signature_data(cls, signature_string):
|
||||
signature_data = {}
|
||||
def signature_data(cls, signature_string: bytes) -> dict[str, bytes | int]:
|
||||
signature_data: dict[str, bytes | int] = {}
|
||||
|
||||
parser = cls(signature_string)
|
||||
signature_type = parser.string()
|
||||
@@ -205,14 +209,19 @@ class OpensshParser:
|
||||
signature_data["R"] = cls._big_int(signature_blob[:32], "little")
|
||||
signature_data["S"] = cls._big_int(signature_blob[32:], "little")
|
||||
else:
|
||||
raise ValueError(f"{signature_type} is not a valid signature type")
|
||||
raise ValueError(f"{signature_type!r} is not a valid signature type")
|
||||
|
||||
signature_data["signature_type"] = signature_type
|
||||
|
||||
return signature_data
|
||||
|
||||
@classmethod
|
||||
def _big_int(cls, raw_string, byte_order, signed=False):
|
||||
def _big_int(
|
||||
cls,
|
||||
raw_string: bytes,
|
||||
byte_order: t.Literal["big", "little"],
|
||||
signed: bool = False,
|
||||
) -> int:
|
||||
if byte_order not in ("big", "little"):
|
||||
raise ValueError(
|
||||
f"Byte_order must be one of (big, little) not {byte_order}"
|
||||
@@ -230,18 +239,16 @@ class _OpensshWriter:
|
||||
in validating parsed material.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer=None):
|
||||
def __init__(self, buffer: bytearray | None = None):
|
||||
if buffer is not None:
|
||||
if not isinstance(buffer, (bytes, bytearray)):
|
||||
raise TypeError(
|
||||
f"Buffer must be a bytes-like object not {type(buffer)}"
|
||||
)
|
||||
if not isinstance(buffer, bytearray):
|
||||
raise TypeError(f"Buffer must be a bytearray, not {type(buffer)}")
|
||||
else:
|
||||
buffer = bytearray()
|
||||
|
||||
self._buff = buffer
|
||||
self._buff: bytearray = buffer
|
||||
|
||||
def boolean(self, value):
|
||||
def boolean(self, value: bool) -> t.Self:
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError(f"Value must be of type bool not {type(value)}")
|
||||
|
||||
@@ -249,7 +256,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def uint32(self, value):
|
||||
def uint32(self, value: int) -> t.Self:
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f"Value must be of type int not {type(value)}")
|
||||
if value < 0 or value > _UINT32_MAX:
|
||||
@@ -261,7 +268,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def uint64(self, value):
|
||||
def uint64(self, value: int) -> t.Self:
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f"Value must be of type int not {type(value)}")
|
||||
if value < 0 or value > _UINT64_MAX:
|
||||
@@ -273,7 +280,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def string(self, value):
|
||||
def string(self, value: bytes | bytearray) -> t.Self:
|
||||
if not isinstance(value, (bytes, bytearray)):
|
||||
raise TypeError(f"Value must be bytes-like not {type(value)}")
|
||||
self.uint32(len(value))
|
||||
@@ -281,7 +288,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def mpint(self, value):
|
||||
def mpint(self, value: int) -> t.Self:
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f"Value must be of type int not {type(value)}")
|
||||
|
||||
@@ -289,7 +296,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def name_list(self, value):
|
||||
def name_list(self, value: list[str]) -> t.Self:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Value must be a list of byte strings not {type(value)}")
|
||||
|
||||
@@ -300,7 +307,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def string_list(self, value):
|
||||
def string_list(self, value: list[bytes]) -> t.Self:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Value must be a list of byte string not {type(value)}")
|
||||
|
||||
@@ -312,7 +319,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def option_list(self, value):
|
||||
def option_list(self, value: list[tuple[bytes, bytes]]) -> t.Self:
|
||||
if not isinstance(value, list) or (value and not isinstance(value[0], tuple)):
|
||||
raise TypeError("Value must be a list of tuples")
|
||||
|
||||
@@ -327,7 +334,7 @@ class _OpensshWriter:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _int_to_mpint(num):
|
||||
def _int_to_mpint(num: int) -> bytes:
|
||||
byte_length = (num.bit_length() + 7) // 8
|
||||
try:
|
||||
return num.to_bytes(byte_length, "big", signed=True)
|
||||
@@ -335,5 +342,5 @@ class _OpensshWriter:
|
||||
except OverflowError:
|
||||
return num.to_bytes(byte_length + 1, "big", signed=True)
|
||||
|
||||
def bytes(self):
|
||||
def bytes(self) -> bytes:
|
||||
return bytes(self._buff)
|
||||
|
||||
@@ -10,7 +10,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.math impor
|
||||
)
|
||||
|
||||
|
||||
def th(number):
|
||||
def th(number: int) -> str:
|
||||
abs_number = abs(number)
|
||||
mod_10 = abs_number % 10
|
||||
mod_100 = abs_number % 100
|
||||
@@ -24,13 +24,13 @@ def th(number):
|
||||
return "th"
|
||||
|
||||
|
||||
def parse_serial(value):
|
||||
def parse_serial(value: str | bytes) -> int:
|
||||
"""
|
||||
Given a colon-separated string of hexadecimal byte values, converts it to an integer.
|
||||
"""
|
||||
value = to_native(value)
|
||||
value_str = to_native(value)
|
||||
result = 0
|
||||
for i, part in enumerate(value.split(":")):
|
||||
for i, part in enumerate(value_str.split(":")):
|
||||
try:
|
||||
part_value = int(part, 16)
|
||||
if part_value < 0 or part_value > 255:
|
||||
@@ -43,11 +43,11 @@ def parse_serial(value):
|
||||
return result
|
||||
|
||||
|
||||
def to_serial(value):
|
||||
def to_serial(value: int) -> str:
|
||||
"""
|
||||
Given an integer, converts its absolute value to a colon-separated string of hexadecimal byte values.
|
||||
"""
|
||||
value = convert_int_to_hex(value).upper()
|
||||
if len(value) % 2 != 0:
|
||||
value = "0" + value
|
||||
return ":".join(value[i : i + 2] for i in range(0, len(value), 2))
|
||||
value_str = convert_int_to_hex(value).upper()
|
||||
if len(value_str) % 2 != 0:
|
||||
value_str = f"0{value_str}"
|
||||
return ":".join(value_str[i : i + 2] for i in range(0, len(value_str), 2))
|
||||
|
||||
@@ -13,37 +13,16 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.basic impo
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
UTC = datetime.timezone.utc
|
||||
except AttributeError:
|
||||
_DURATION_ZERO = datetime.timedelta(0)
|
||||
|
||||
class _UTCClass(datetime.tzinfo):
|
||||
def utcoffset(self, dt):
|
||||
return _DURATION_ZERO
|
||||
|
||||
def dst(self, dt):
|
||||
return _DURATION_ZERO
|
||||
|
||||
def tzname(self, dt):
|
||||
return "UTC"
|
||||
|
||||
def fromutc(self, dt):
|
||||
return dt
|
||||
|
||||
def __repr__(self):
|
||||
return "UTC"
|
||||
|
||||
UTC = _UTCClass()
|
||||
UTC = datetime.timezone.utc
|
||||
|
||||
|
||||
def get_now_datetime(with_timezone):
|
||||
def get_now_datetime(with_timezone: bool) -> datetime.datetime:
|
||||
if with_timezone:
|
||||
return datetime.datetime.now(tz=UTC)
|
||||
return datetime.datetime.utcnow()
|
||||
|
||||
|
||||
def ensure_utc_timezone(timestamp):
|
||||
def ensure_utc_timezone(timestamp: datetime.datetime) -> datetime.datetime:
|
||||
if timestamp.tzinfo is UTC:
|
||||
return timestamp
|
||||
if timestamp.tzinfo is None:
|
||||
@@ -52,7 +31,7 @@ def ensure_utc_timezone(timestamp):
|
||||
return timestamp.astimezone(UTC)
|
||||
|
||||
|
||||
def remove_timezone(timestamp):
|
||||
def remove_timezone(timestamp: datetime.datetime) -> datetime.datetime:
|
||||
# Convert to native datetime object
|
||||
if timestamp.tzinfo is None:
|
||||
return timestamp
|
||||
@@ -61,26 +40,34 @@ def remove_timezone(timestamp):
|
||||
return timestamp.replace(tzinfo=None)
|
||||
|
||||
|
||||
def add_or_remove_timezone(timestamp, with_timezone):
|
||||
def add_or_remove_timezone(
|
||||
timestamp: datetime.datetime, with_timezone: bool
|
||||
) -> datetime.datetime:
|
||||
return (
|
||||
ensure_utc_timezone(timestamp) if with_timezone else remove_timezone(timestamp)
|
||||
)
|
||||
|
||||
|
||||
def get_epoch_seconds(timestamp):
|
||||
def get_epoch_seconds(timestamp: datetime.datetime) -> float:
|
||||
if timestamp.tzinfo is None:
|
||||
# timestamp.timestamp() is offset by the local timezone if timestamp has no timezone
|
||||
timestamp = ensure_utc_timezone(timestamp)
|
||||
return timestamp.timestamp()
|
||||
|
||||
|
||||
def from_epoch_seconds(timestamp, with_timezone):
|
||||
def from_epoch_seconds(
|
||||
timestamp: int | float, with_timezone: bool
|
||||
) -> datetime.datetime:
|
||||
if with_timezone:
|
||||
return datetime.datetime.fromtimestamp(timestamp, UTC)
|
||||
return datetime.datetime.utcfromtimestamp(timestamp)
|
||||
|
||||
|
||||
def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=None):
|
||||
def convert_relative_to_datetime(
|
||||
relative_time_string: str,
|
||||
with_timezone: bool = False,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> datetime.datetime | None:
|
||||
"""Get a datetime.datetime or None from a string in the time format described in sshd_config(5)"""
|
||||
|
||||
parsed_result = re.match(
|
||||
@@ -115,7 +102,12 @@ def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=
|
||||
return now - offset
|
||||
|
||||
|
||||
def get_relative_time_option(input_string, input_name, with_timezone=False, now=None):
|
||||
def get_relative_time_option(
|
||||
input_string: str,
|
||||
input_name: str,
|
||||
with_timezone: bool = False,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> datetime.datetime:
|
||||
"""
|
||||
Return an absolute timespec if a relative timespec or an ASN1 formatted
|
||||
string is provided.
|
||||
@@ -129,9 +121,12 @@ def get_relative_time_option(input_string, input_name, with_timezone=False, now=
|
||||
)
|
||||
# Relative time
|
||||
if result.startswith("+") or result.startswith("-"):
|
||||
return convert_relative_to_datetime(
|
||||
result, with_timezone=with_timezone, now=now
|
||||
)
|
||||
res = convert_relative_to_datetime(result, with_timezone=with_timezone, now=now)
|
||||
if res is None:
|
||||
raise OpenSSLObjectError(
|
||||
f'The timespec "{input_string}" for {input_name} is invalid'
|
||||
)
|
||||
return res
|
||||
# Absolute time
|
||||
for date_fmt, length in [
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user