Add type hints and type checking (#885)

* Enable basic type checking.

* Fix first errors.

* Add changelog fragment.

* Add types to module_utils and plugin_utils (without module backends).

* Add typing hints for acme_* modules.

* Add typing to X.509 certificate modules, and add more helpers.

* Add typing to remaining module backends.

* Add typing for action, filter, and lookup plugins.

* Bump ansible-core 2.19 beta requirement for typing.

* Add more typing definitions.

* Add typing to some unit tests.
This commit is contained in:
Felix Fontein
2025-05-11 18:00:11 +02:00
committed by GitHub
parent 82f0176773
commit f758d94fba
124 changed files with 4986 additions and 2662 deletions

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import base64
import typing as t
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -19,25 +20,41 @@ from ansible_collections.community.crypto.plugins.plugin_utils.action_module imp
)
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
ArgumentSpec,
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey import (
PrivateKeyBackend,
)
from ansible_collections.community.crypto.plugins.plugin_utils.action_module import (
AnsibleActionModule,
)
class PrivateKeyModule:
def __init__(self, module, module_backend):
def __init__(
self, module: AnsibleActionModule, module_backend: PrivateKeyBackend
) -> None:
self.module = module
self.module_backend = module_backend
self.check_mode = module.check_mode
self.changed = False
self.return_current_key = module.params["return_current_key"]
self.return_current_key: bool = module.params["return_current_key"]
if module.params["content"] is not None:
if module.params["content_base64"]:
content: str | None = module.params["content"]
content_base64: bool = module.params["content_base64"]
if content is not None:
if content_base64:
try:
data = base64.b64decode(module.params["content"])
data = base64.b64decode(content)
except Exception as e:
module.fail_json(msg=f"Cannot decode Base64 encoded data: {e}")
else:
data = to_bytes(module.params["content"])
data = to_bytes(content)
module_backend.set_existing(data)
def generate(self, module):
def generate(self, module: AnsibleActionModule) -> None:
"""Generate a keypair."""
if self.module_backend.needs_regeneration():
@@ -53,7 +70,7 @@ class PrivateKeyModule:
self.privatekey_bytes = privatekey_data
self.changed = True
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = self.module_backend.dump(
include_key=self.changed or self.return_current_key
@@ -64,7 +81,7 @@ class PrivateKeyModule:
class ActionModule(ActionModuleBase):
@staticmethod
def setup_module():
def setup_module() -> tuple[ArgumentSpec, dict[str, t.Any]]:
argument_spec = get_privatekey_argument_spec()
argument_spec.argument_spec.update(
dict(
@@ -78,7 +95,7 @@ class ActionModule(ActionModuleBase):
)
@staticmethod
def run_module(module):
def run_module(module: AnsibleActionModule) -> None:
module_backend = select_backend(module=module)
try:

View File

@@ -39,6 +39,8 @@ _value:
type: string
"""
import typing as t
from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import (
@@ -50,7 +52,7 @@ from ansible_collections.community.crypto.plugins.plugin_utils.gnupg import (
)
def gpg_fingerprint(input):
def gpg_fingerprint(input: str | bytes) -> str:
if not isinstance(input, (str, bytes)):
raise AnsibleFilterError(
f"The input for the community.crypto.gpg_fingerprint filter must be a string; got {type(input)} instead"
@@ -65,7 +67,7 @@ def gpg_fingerprint(input):
class FilterModule:
"""Ansible jinja2 filters"""
def filters(self):
def filters(self) -> dict[str, t.Callable]:
return {
"gpg_fingerprint": gpg_fingerprint,
}

View File

@@ -274,6 +274,8 @@ _value:
sample: 12345
"""
import typing as t
from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -287,7 +289,9 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
)
def openssl_csr_info_filter(data, name_encoding="ignore"):
def openssl_csr_info_filter(
data: str | bytes, name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore"
) -> dict[str, t.Any]:
"""Extract information from X.509 PEM certificate."""
if not isinstance(data, (str, bytes)):
raise AnsibleFilterError(
@@ -313,7 +317,7 @@ def openssl_csr_info_filter(data, name_encoding="ignore"):
class FilterModule:
"""Ansible jinja2 filters"""
def filters(self):
def filters(self) -> dict[str, t.Callable]:
return {
"openssl_csr_info": openssl_csr_info_filter,
}

View File

@@ -146,8 +146,10 @@ _value:
type: dict
"""
import typing as t
from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes
from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
)
@@ -161,8 +163,10 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
def openssl_privatekey_info_filter(
data, passphrase=None, return_private_key_data=False
):
data: str | bytes,
passphrase: str | bytes | None = None,
return_private_key_data: bool = False,
) -> dict[str, t.Any]:
"""Extract information from X.509 PEM certificate."""
if not isinstance(data, (str, bytes)):
raise AnsibleFilterError(
@@ -182,7 +186,7 @@ def openssl_privatekey_info_filter(
result = get_privatekey_info(
module,
content=to_bytes(data),
passphrase=passphrase,
passphrase=to_text(passphrase) if passphrase is not None else None,
return_private_key_data=return_private_key_data,
)
result.pop("can_parse_key", None)
@@ -197,7 +201,7 @@ def openssl_privatekey_info_filter(
class FilterModule:
"""Ansible jinja2 filters"""
def filters(self):
def filters(self) -> dict[str, t.Callable]:
return {
"openssl_privatekey_info": openssl_privatekey_info_filter,
}

View File

@@ -123,6 +123,8 @@ _value:
returned: When RV(_value.type=DSA) or RV(_value.type=ECC)
"""
import typing as t
from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -137,7 +139,7 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
)
def openssl_publickey_info_filter(data):
def openssl_publickey_info_filter(data: str | bytes) -> dict[str, t.Any]:
"""Extract information from OpenSSL PEM public key."""
if not isinstance(data, (str, bytes)):
raise AnsibleFilterError(
@@ -156,7 +158,7 @@ def openssl_publickey_info_filter(data):
class FilterModule:
"""Ansible jinja2 filters"""
def filters(self):
def filters(self) -> dict[str, t.Callable]:
return {
"openssl_publickey_info": openssl_publickey_info_filter,
}

View File

@@ -39,6 +39,8 @@ _value:
type: int
"""
import typing as t
from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_native
from ansible_collections.community.crypto.plugins.module_utils.serial import (
@@ -46,7 +48,7 @@ from ansible_collections.community.crypto.plugins.module_utils.serial import (
)
def parse_serial_filter(input):
def parse_serial_filter(input: str | bytes) -> int:
if not isinstance(input, (str, bytes)):
raise AnsibleFilterError(
f"The input for the community.crypto.parse_serial filter must be a string; got {type(input)} instead"
@@ -60,7 +62,7 @@ def parse_serial_filter(input):
class FilterModule:
"""Ansible jinja2 filters"""
def filters(self):
def filters(self) -> dict[str, t.Callable]:
return {
"parse_serial": parse_serial_filter,
}

View File

@@ -38,6 +38,8 @@ _value:
elements: string
"""
import typing as t
from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import (
@@ -45,21 +47,20 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
)
def split_pem_filter(data):
def split_pem_filter(data: str | bytes) -> list[str]:
"""Split PEM file."""
if not isinstance(data, (str, bytes)):
raise AnsibleFilterError(
f"The community.crypto.split_pem input must be a text type, not {type(data)}"
)
data = to_text(data)
return split_pem_list(data)
return split_pem_list(to_text(data))
class FilterModule:
"""Ansible jinja2 filters"""
def filters(self):
def filters(self) -> dict[str, t.Callable]:
return {
"split_pem": split_pem_filter,
}

View File

@@ -39,11 +39,13 @@ _value:
type: string
"""
import typing as t
from ansible.errors import AnsibleFilterError
from ansible_collections.community.crypto.plugins.module_utils.serial import to_serial
def to_serial_filter(input):
def to_serial_filter(input: int) -> str:
if not isinstance(input, int):
raise AnsibleFilterError(
f"The input for the community.crypto.to_serial filter must be an integer; got {type(input)} instead"
@@ -61,7 +63,7 @@ def to_serial_filter(input):
class FilterModule:
"""Ansible jinja2 filters"""
def filters(self):
def filters(self) -> dict[str, t.Callable]:
return {
"to_serial": to_serial_filter,
}

View File

@@ -308,6 +308,8 @@ _value:
type: str
"""
import typing as t
from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -321,7 +323,9 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
)
def x509_certificate_info_filter(data, name_encoding="ignore"):
def x509_certificate_info_filter(
data: str | bytes, name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore"
) -> dict[str, t.Any]:
"""Extract information from X.509 PEM certificate."""
if not isinstance(data, (str, bytes)):
raise AnsibleFilterError(
@@ -347,7 +351,7 @@ def x509_certificate_info_filter(data, name_encoding="ignore"):
class FilterModule:
"""Ansible jinja2 filters"""
def filters(self):
def filters(self) -> dict[str, t.Callable]:
return {
"x509_certificate_info": x509_certificate_info_filter,
}

View File

@@ -155,6 +155,7 @@ _value:
import base64
import binascii
import typing as t
from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes, to_native
@@ -172,7 +173,11 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
)
def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates=True):
def x509_crl_info_filter(
data: str | bytes,
name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore",
list_revoked_certificates: bool = True,
) -> dict[str, t.Any]:
"""Extract information from X.509 PEM certificate."""
if not isinstance(data, (str, bytes)):
raise AnsibleFilterError(
@@ -192,17 +197,19 @@ def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates
f'The name_encoding option must be one of the values "ignore", "idna", or "unicode", not "{name_encoding}"'
)
data = to_bytes(data)
if not identify_pem_format(data):
data_bytes = to_bytes(data)
if not identify_pem_format(data_bytes):
try:
data = base64.b64decode(to_native(data))
data_bytes = base64.b64decode(to_native(data_bytes))
except (binascii.Error, TypeError, ValueError, UnicodeEncodeError):
pass
module = FilterModuleMock({"name_encoding": name_encoding})
try:
return get_crl_info(
module, content=data, list_revoked_certificates=list_revoked_certificates
module,
content=data_bytes,
list_revoked_certificates=list_revoked_certificates,
)
except OpenSSLObjectError as exc:
raise AnsibleFilterError(str(exc))
@@ -211,7 +218,7 @@ def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates
class FilterModule:
"""Ansible jinja2 filters"""
def filters(self):
def filters(self) -> dict[str, t.Callable]:
return {
"x509_crl_info": x509_crl_info_filter,
}

View File

@@ -42,7 +42,11 @@ _value:
elements: string
"""
import os
import typing as t
from ansible.errors import AnsibleLookupError
from ansible.module_utils.common.text.converters import to_native
from ansible.plugins.lookup import LookupBase
from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import (
GPGError,
@@ -54,14 +58,20 @@ from ansible_collections.community.crypto.plugins.plugin_utils.gnupg import (
class LookupModule(LookupBase):
def run(self, terms, variables=None, **kwargs):
def run(self, terms: list[t.Any], variables=None, **kwargs) -> list[str]:
self.set_options(direct=kwargs)
if self._loader is None:
raise AssertionError("Contract violation: self._loader is None")
try:
gpg = PluginGPGRunner(cwd=self._loader.get_basedir())
result = []
for path in terms:
result.append(get_fingerprint_from_file(gpg, path))
for i, path in enumerate(terms):
if not isinstance(path, (str, bytes, os.PathLike)):
raise AnsibleLookupError(
f"Lookup parameter #{i} should be string or a path object, but got {type(path)}"
)
result.append(get_fingerprint_from_file(gpg, to_native(path)))
return result
except GPGError as exc:
raise AnsibleLookupError(str(exc))

View File

@@ -5,6 +5,8 @@
from __future__ import annotations
import typing as t
from ansible.module_utils.common._collections_compat import Mapping
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
ACMEProtocolException,
@@ -12,26 +14,29 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
if t.TYPE_CHECKING:
from .acme import ACMEClient
class ACMEAccount:
"""
ACME account object. Allows to create new accounts, check for existence of accounts,
retrieve account data.
"""
def __init__(self, client):
def __init__(self, client: ACMEClient) -> None:
# Set to true to enable logging of all signed requests
self._debug = False
self._debug: bool = False
self.client = client
def _new_reg(
self,
contact=None,
agreement=None,
terms_agreed=False,
allow_creation=True,
external_account_binding=None,
):
contact: list[str] | None = None,
terms_agreed: bool = False,
allow_creation: bool = True,
external_account_binding: dict[str, t.Any] | None = None,
) -> tuple[bool, dict[str, t.Any] | None]:
"""
Registers a new ACME account. Returns a pair ``(created, data)``.
Here, ``created`` is ``True`` if the account was created and
@@ -63,7 +68,7 @@ class ACMEAccount:
return created, data
# An account does not yet exist. Try to create one next.
new_reg = {"contact": contact}
new_reg: dict[str, t.Any] = {"contact": contact}
if not allow_creation:
# https://tools.ietf.org/html/rfc8555#section-7.3.1
new_reg["onlyReturnExisting"] = True
@@ -99,7 +104,7 @@ class ACMEAccount:
self.client.module,
msg="Invalid account creation reply from ACME server",
info=info,
content=result,
content_json=result,
)
if info["status"] == 201:
@@ -152,7 +157,7 @@ class ACMEAccount:
content_json=result,
)
def get_account_data(self):
def get_account_data(self) -> dict[str, t.Any] | None:
"""
Retrieve account information. Can only be called when the account
URI is already known (such as after calling setup_account).
@@ -161,7 +166,7 @@ class ACMEAccount:
if self.client.account_uri is None:
raise ModuleFailException("Account URI unknown")
# try POST-as-GET first (draft-15 or newer)
data = None
data: dict[str, t.Any] | None = None
result, info = self.client.send_signed_request(
self.client.account_uri, data, fail_on_error=False
)
@@ -180,7 +185,7 @@ class ACMEAccount:
self.client.module,
msg="Invalid account data retrieved from ACME server",
info=info,
content=result,
content_json=result,
)
if (
info["status"] in (400, 403)
@@ -203,15 +208,34 @@ class ACMEAccount:
)
return result
@t.overload
def setup_account(
self,
contact=None,
agreement=None,
terms_agreed=False,
allow_creation=True,
remove_account_uri_if_not_exists=False,
external_account_binding=None,
):
contact: list[str] | None = None,
terms_agreed: bool = False,
allow_creation: t.Literal[True] = True,
remove_account_uri_if_not_exists: bool = False,
external_account_binding: dict[str, t.Any] | None = None,
) -> tuple[bool, dict[str, t.Any]]: ...
@t.overload
def setup_account(
self,
contact: list[str] | None = None,
terms_agreed: bool = False,
allow_creation: bool = True,
remove_account_uri_if_not_exists: bool = False,
external_account_binding: dict[str, t.Any] | None = None,
) -> tuple[bool, dict[str, t.Any] | None]: ...
def setup_account(
self,
contact: list[str] | None = None,
terms_agreed: bool = False,
allow_creation: bool = True,
remove_account_uri_if_not_exists: bool = False,
external_account_binding: dict[str, t.Any] | None = None,
) -> tuple[bool, dict[str, t.Any] | None]:
"""
Detect or create an account on the ACME server. For ACME v1,
as the only way (without knowing an account URI) to test if an
@@ -253,7 +277,6 @@ class ACMEAccount:
else:
created, account_data = self._new_reg(
contact,
agreement=agreement,
terms_agreed=terms_agreed,
allow_creation=allow_creation and not self.client.module.check_mode,
external_account_binding=external_account_binding,
@@ -267,7 +290,9 @@ class ACMEAccount:
account_data = {"contact": contact or []}
return created, account_data
def update_account(self, account_data, contact=None):
def update_account(
self, account_data: dict[str, t.Any], contact: list[str] | None = None
) -> tuple[bool, dict[str, t.Any]]:
"""
Update an account on the ACME server. Check mode is fully respected.
@@ -280,8 +305,11 @@ class ACMEAccount:
https://tools.ietf.org/html/rfc8555#section-7.3.2
"""
if self.client.account_uri is None:
raise ModuleFailException("Cannot update account without account URI")
# Create request
update_request = {}
update_request: dict[str, t.Any] = {}
if contact is not None and account_data.get("contact", []) != contact:
update_request["contact"] = list(contact)
@@ -302,7 +330,7 @@ class ACMEAccount:
self.client.module,
msg="Invalid account updating reply from ACME server",
info=info,
content=account_data,
content_json=account_data,
)
return True, account_data

View File

@@ -10,6 +10,7 @@ import datetime
import json
import locale
import time
import typing as t
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.common.text.converters import to_bytes
@@ -41,13 +42,24 @@ from ansible_collections.community.crypto.plugins.module_utils.argspec import (
)
if t.TYPE_CHECKING:
import os
from ansible.module_utils.basic import AnsibleModule
from .account import ACMEAccount
from .backends import CertificateInformation, CryptoBackend
# -1 usually means connection problems
RETRY_STATUS_CODES = (-1, 408, 429, 503)
RETRY_COUNT = 10
def _decode_retry(module, response, info, retry_count):
def _decode_retry(
module: AnsibleModule, response: t.Any, info: dict[str, t.Any], retry_count: int
) -> bool:
if info["status"] not in RETRY_STATUS_CODES:
return False
@@ -61,7 +73,8 @@ def _decode_retry(module, response, info, retry_count):
# 429 and 503 should have a Retry-After header (https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
try:
retry_after = min(max(1, int(info.get("retry-after"))), 60)
# TODO: use utils.parse_retry_after()
retry_after = min(max(1, int(info.get("retry-after", "10"))), 60)
except (TypeError, ValueError):
retry_after = 10
module.log(
@@ -73,13 +86,13 @@ def _decode_retry(module, response, info, retry_count):
def _assert_fetch_url_success(
module,
response,
info,
allow_redirect=False,
allow_client_error=True,
allow_server_error=True,
):
module: AnsibleModule,
response: t.Any,
info: dict[str, t.Any],
allow_redirect: bool = False,
allow_client_error: bool = True,
allow_server_error: bool = True,
) -> None:
if info["status"] < 0:
raise NetworkException(msg=f"Failure downloading {info['url']}, {info['msg']}")
@@ -91,7 +104,9 @@ def _assert_fetch_url_success(
raise ACMEProtocolException(module, info=info, response=response)
def _is_failed(info, expected_status_codes=None):
def _is_failed(
info: dict[str, t.Any], expected_status_codes: t.Iterable[int] | None = None
) -> bool:
if info["status"] < 200 or info["status"] >= 400:
return True
if (
@@ -111,12 +126,12 @@ class ACMEDirectory:
https://tools.ietf.org/html/rfc8555#section-7.1.1
"""
def __init__(self, module, account):
def __init__(self, module: AnsibleModule, client: ACMEClient) -> None:
self.module = module
self.directory_root = module.params["acme_directory"]
self.version = module.params["acme_version"]
self.directory, dummy = account.get_request(self.directory_root, get_only=True)
self.directory, dummy = client.get_request(self.directory_root, get_only=True)
self.request_timeout = module.params["request_timeout"]
@@ -131,16 +146,16 @@ class ACMEDirectory:
if "meta" not in self.directory:
self.directory["meta"] = {}
def __getitem__(self, key):
def __getitem__(self, key: str) -> t.Any:
return self.directory[key]
def __contains__(self, key):
def __contains__(self, key: str) -> bool:
return key in self.directory
def get(self, key, default_value=None):
def get(self, key: str, default_value: t.Any = None) -> t.Any:
return self.directory.get(key, default_value)
def get_nonce(self, resource=None):
def get_nonce(self, resource: str | None = None) -> str:
url = self.directory["newNonce"]
if resource is not None:
url = resource
@@ -170,7 +185,7 @@ class ACMEDirectory:
)
retry_count += 1
def has_renewal_info_endpoint(self):
def has_renewal_info_endpoint(self) -> bool:
return "renewalInfo" in self.directory
@@ -180,7 +195,7 @@ class ACMEClient:
ACME server.
"""
def __init__(self, module, backend):
def __init__(self, module: AnsibleModule, backend: CryptoBackend) -> None:
# Set to true to enable logging of all signed requests
self._debug = False
@@ -221,16 +236,22 @@ class ACMEClient:
self.directory = ACMEDirectory(module, self)
def set_account_uri(self, uri):
def set_account_uri(self, uri: str) -> None:
"""
Set account URI. For ACME v2, it needs to be used to sending signed
requests.
"""
self.account_uri = uri
self.account_jws_header.pop("jwk")
self.account_jws_header["kid"] = self.account_uri
if self.account_jws_header:
self.account_jws_header.pop("jwk", None)
self.account_jws_header["kid"] = self.account_uri
def parse_key(self, key_file=None, key_content=None, passphrase=None):
def parse_key(
self,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
) -> dict[str, t.Any]:
"""
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
In case of an error, raises KeyParsingError.
@@ -239,7 +260,13 @@ class ACMEClient:
raise AssertionError("One of key_file and key_content must be specified!")
return self.backend.parse_key(key_file, key_content, passphrase=passphrase)
def sign_request(self, protected, payload, key_data, encode_payload=True):
def sign_request(
self,
protected: dict[str, t.Any],
payload: str | dict[str, t.Any] | None,
key_data: dict[str, t.Any],
encode_payload: bool = True,
) -> dict[str, t.Any]:
"""
Signs an ACME request.
"""
@@ -260,7 +287,7 @@ class ACMEClient:
return self.backend.sign(payload64, protected64, key_data)
def _log(self, msg, data=None):
def _log(self, msg: str, data: t.Any = None) -> None:
"""
Write arguments to acme.log when logging is enabled.
"""
@@ -275,18 +302,49 @@ class ACMEClient:
)
)
@t.overload
def send_signed_request(
self,
url,
payload,
key_data=None,
jws_header=None,
parse_json_result=True,
encode_payload=True,
fail_on_error=True,
error_msg=None,
expected_status_codes=None,
):
url: str,
payload: str | dict[str, t.Any] | None,
*,
key_data: dict[str, t.Any] | None = None,
jws_header: dict[str, t.Any] | None = None,
parse_json_result: t.Literal[True] = True,
encode_payload: bool = True,
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ...
@t.overload
def send_signed_request(
self,
url: str,
payload: str | dict[str, t.Any] | None,
*,
key_data: dict[str, t.Any] | None = None,
jws_header: dict[str, t.Any] | None = None,
parse_json_result: t.Literal[False],
encode_payload: bool = True,
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[bytes, dict[str, t.Any]]: ...
def send_signed_request(
self,
url: str,
payload: str | dict[str, t.Any] | None,
*,
key_data: dict[str, t.Any] | None = None,
jws_header: dict[str, t.Any] | None = None,
parse_json_result: bool = True,
encode_payload: bool = True,
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]:
"""
Sends a JWS signed HTTP POST request to the ACME server and returns
the response as dictionary (if parse_json_result is True) or in raw form
@@ -297,7 +355,11 @@ class ACMEClient:
(https://tools.ietf.org/html/rfc8555#section-6.3)
"""
key_data = key_data or self.account_key_data
if key_data is None:
raise ModuleFailException("Missing key data")
jws_header = jws_header or self.account_jws_header
if jws_header is None:
raise ModuleFailException("Missing JWS header")
failed_tries = 0
while True:
protected = copy.deepcopy(jws_header)
@@ -382,16 +444,43 @@ class ACMEClient:
)
return result, info
@t.overload
def get_request(
self,
uri,
parse_json_result=True,
headers=None,
get_only=False,
fail_on_error=True,
error_msg=None,
expected_status_codes=None,
):
uri: str,
*,
parse_json_result: t.Literal[True] = True,
headers: dict[str, str] | None = None,
get_only: bool = False,
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ...
@t.overload
def get_request(
self,
uri: str,
*,
parse_json_result: t.Literal[False],
headers: dict[str, str] | None = None,
get_only: bool = False,
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[bytes, dict[str, t.Any]]: ...
def get_request(
self,
uri: str,
*,
parse_json_result: bool = True,
headers: dict[str, str] | None = None,
get_only: bool = False,
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]:
"""
Perform a GET-like request. Will try POST-as-GET for ACMEv2, with fallback
to GET if server replies with a status code of 405.
@@ -436,6 +525,7 @@ class ACMEClient:
# Process result
parsed_json_result = False
result: dict[str, t.Any] | bytes
if parse_json_result:
result = {}
if content:
@@ -445,7 +535,7 @@ class ACMEClient:
parsed_json_result = True
except ValueError:
raise NetworkException(
f"Failed to parse the ACME response: {uri} {content}"
f"Failed to parse the ACME response: {uri} {content!r}"
)
else:
result = content
@@ -460,19 +550,21 @@ class ACMEClient:
msg=error_msg,
info=info,
content=content,
content_json=result if parsed_json_result else None,
content_json=(
t.cast(dict[str, t.Any], result) if parsed_json_result else None
),
)
return result, info
def get_renewal_info(
self,
cert_id=None,
cert_info=None,
cert_filename=None,
cert_content=None,
include_retry_after=False,
retry_after_relative_with_timezone=True,
):
cert_id: str | None = None,
cert_info: CertificateInformation | None = None,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
include_retry_after: bool = False,
retry_after_relative_with_timezone: bool = True,
) -> dict[str, t.Any]:
if not self.directory.has_renewal_info_endpoint():
raise ModuleFailException(
"The ACME endpoint does not support ACME Renewal Information retrieval"
@@ -504,10 +596,10 @@ class ACMEClient:
def create_default_argspec(
with_account=True,
require_account_key=True,
with_certificate=False,
):
with_account: bool = True,
require_account_key: bool = True,
with_certificate: bool = False,
) -> ArgumentSpec:
"""
Provides default argument spec for the options documented in the acme doc fragment.
"""
@@ -544,7 +636,7 @@ def create_default_argspec(
return result
def create_backend(module, needs_acme_v2=True):
def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoBackend:
backend = module.params["select_crypto_backend"]
# Backend autodetect
@@ -552,6 +644,7 @@ def create_backend(module, needs_acme_v2=True):
backend = "cryptography" if HAS_CURRENT_CRYPTOGRAPHY else "openssl"
# Create backend object
module_backend: CryptoBackend
if backend == "cryptography":
if CRYPTOGRAPHY_ERROR is not None:
# Either we could not import cryptography at all, or there was an unexpected error

View File

@@ -9,6 +9,7 @@ import base64
import binascii
import os
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
@@ -75,10 +76,19 @@ else:
CRYPTOGRAPHY_MINIMAL_VERSION
)
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from .certificates import CertificateChain, Criterium
class CryptographyChainMatcher(ChainMatcher):
@staticmethod
def _parse_key_identifier(key_identifier, name, criterium_idx, module):
def _parse_key_identifier(
key_identifier: str | None, name: str, criterium_idx: int, module: AnsibleModule
) -> bytes | None:
if key_identifier:
try:
return binascii.unhexlify(key_identifier.replace(":", ""))
@@ -94,11 +104,11 @@ class CryptographyChainMatcher(ChainMatcher):
)
return None
def __init__(self, criterium, module):
def __init__(self, criterium: Criterium, module: AnsibleModule) -> None:
self.criterium = criterium
self.test_certificates = criterium.test_certificates
self.subject = []
self.issuer = []
self.subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = []
self.issuer: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = []
if criterium.subject:
self.subject = [
(cryptography_name_to_oid(k), to_native(v))
@@ -121,8 +131,13 @@ class CryptographyChainMatcher(ChainMatcher):
criterium.index,
module,
)
self.module = module
def _match_subject(self, x509_subject, match_subject):
def _match_subject(
self,
x509_subject: cryptography.x509.Name,
match_subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]],
) -> bool:
for oid, value in match_subject:
found = False
for attribute in x509_subject:
@@ -133,7 +148,7 @@ class CryptographyChainMatcher(ChainMatcher):
return False
return True
def match(self, certificate):
def match(self, certificate: CertificateChain) -> bool:
"""
Check whether an alternate chain matches the specified criterium.
"""
@@ -152,19 +167,22 @@ class CryptographyChainMatcher(ChainMatcher):
matches = False
if self.subject_key_identifier:
try:
ext = x509.extensions.get_extension_for_class(
ext_ski = x509.extensions.get_extension_for_class(
cryptography.x509.SubjectKeyIdentifier
)
if self.subject_key_identifier != ext.value.digest:
if self.subject_key_identifier != ext_ski.value.digest:
matches = False
except cryptography.x509.ExtensionNotFound:
matches = False
if self.authority_key_identifier:
try:
ext = x509.extensions.get_extension_for_class(
ext_aki = x509.extensions.get_extension_for_class(
cryptography.x509.AuthorityKeyIdentifier
)
if self.authority_key_identifier != ext.value.key_identifier:
if (
self.authority_key_identifier
!= ext_aki.value.key_identifier
):
matches = False
except cryptography.x509.ExtensionNotFound:
matches = False
@@ -176,59 +194,68 @@ class CryptographyChainMatcher(ChainMatcher):
class CryptographyBackend(CryptoBackend):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(CryptographyBackend, self).__init__(
module, with_timezone=CRYPTOGRAPHY_TIMEZONE
)
def parse_key(self, key_file=None, key_content=None, passphrase=None):
def parse_key(
self,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
) -> dict[str, t.Any]:
"""
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
Raises KeyParsingError in case of errors.
"""
# If key_content is not given, read key_file
if key_content is None:
key_content = read_file(key_file)
if key_file is None:
raise KeyParsingError(
"one of key_file and key_content must be specified"
)
b_key_content = read_file(key_file)
else:
key_content = to_bytes(key_content)
b_key_content = to_bytes(key_content)
# Parse key
try:
key = cryptography.hazmat.primitives.serialization.load_pem_private_key(
key_content,
b_key_content,
password=to_bytes(passphrase) if passphrase is not None else None,
)
except Exception as e:
raise KeyParsingError(f"error while loading key: {e}")
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
pk = key.public_key().public_numbers()
rsa_pk = key.public_key().public_numbers()
return {
"key_obj": key,
"type": "rsa",
"alg": "RS256",
"jwk": {
"kty": "RSA",
"e": nopad_b64(convert_int_to_bytes(pk.e)),
"n": nopad_b64(convert_int_to_bytes(pk.n)),
"e": nopad_b64(convert_int_to_bytes(rsa_pk.e)),
"n": nopad_b64(convert_int_to_bytes(rsa_pk.n)),
},
"hash": "sha256",
}
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
):
pk = key.public_key().public_numbers()
if pk.curve.name == "secp256r1":
ec_pk = key.public_key().public_numbers()
if ec_pk.curve.name == "secp256r1":
bits = 256
alg = "ES256"
hashalg = "sha256"
point_size = 32
curve = "P-256"
elif pk.curve.name == "secp384r1":
elif ec_pk.curve.name == "secp384r1":
bits = 384
alg = "ES384"
hashalg = "sha384"
point_size = 48
curve = "P-384"
elif pk.curve.name == "secp521r1":
elif ec_pk.curve.name == "secp521r1":
# Not yet supported on Let's Encrypt side, see
# https://github.com/letsencrypt/boulder/issues/2217
bits = 521
@@ -237,7 +264,7 @@ class CryptographyBackend(CryptoBackend):
point_size = 66
curve = "P-521"
else:
raise KeyParsingError(f"unknown elliptic curve: {pk.curve.name}")
raise KeyParsingError(f"unknown elliptic curve: {ec_pk.curve.name}")
num_bytes = (bits + 7) // 8
return {
"key_obj": key,
@@ -246,8 +273,8 @@ class CryptographyBackend(CryptoBackend):
"jwk": {
"kty": "EC",
"crv": curve,
"x": nopad_b64(convert_int_to_bytes(pk.x, count=num_bytes)),
"y": nopad_b64(convert_int_to_bytes(pk.y, count=num_bytes)),
"x": nopad_b64(convert_int_to_bytes(ec_pk.x, count=num_bytes)),
"y": nopad_b64(convert_int_to_bytes(ec_pk.y, count=num_bytes)),
},
"hash": hashalg,
"point_size": point_size,
@@ -255,8 +282,11 @@ class CryptographyBackend(CryptoBackend):
else:
raise KeyParsingError(f'unknown key type "{type(key)}"')
def sign(self, payload64, protected64, key_data):
def sign(
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
) -> dict[str, t.Any]:
sign_payload = f"{protected64}.{payload64}".encode("utf8")
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
if "mac_obj" in key_data:
mac = key_data["mac_obj"]()
mac.update(sign_payload)
@@ -292,8 +322,9 @@ class CryptographyBackend(CryptoBackend):
"signature": nopad_b64(signature),
}
def create_mac_key(self, alg, key):
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
"""Create a MAC key."""
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
if alg == "HS256":
hashalg = cryptography.hazmat.primitives.hashes.SHA256
hashbytes = 32
@@ -324,7 +355,11 @@ class CryptographyBackend(CryptoBackend):
},
}
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
def get_ordered_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> list[tuple[str, str]]:
"""
Return a list of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
@@ -334,15 +369,19 @@ class CryptographyBackend(CryptoBackend):
as the first element in the result.
"""
if csr_content is None:
csr_content = read_file(csr_filename)
if csr_filename is None:
raise BackendException(
"One of csr_content and csr_filename has to be provided"
)
b_csr_content = read_file(csr_filename)
else:
csr_content = to_bytes(csr_content)
csr = cryptography.x509.load_pem_x509_csr(csr_content)
b_csr_content = to_bytes(csr_content)
csr = cryptography.x509.load_pem_x509_csr(b_csr_content)
identifiers = set()
result = []
def add_identifier(identifier):
def add_identifier(identifier: tuple[str, str]) -> None:
if identifier in identifiers:
return
identifiers.add(identifier)
@@ -350,7 +389,7 @@ class CryptographyBackend(CryptoBackend):
for sub in csr.subject:
if sub.oid == cryptography.x509.oid.NameOID.COMMON_NAME:
add_identifier(("dns", sub.value))
add_identifier(("dns", t.cast(str, sub.value)))
for extension in csr.extensions:
if (
extension.oid
@@ -367,7 +406,11 @@ class CryptographyBackend(CryptoBackend):
)
return result
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
def get_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | bytes | None = None,
) -> set[tuple[str, str]]:
"""
Return a set of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
@@ -379,7 +422,12 @@ class CryptographyBackend(CryptoBackend):
)
)
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
def get_cert_days(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
now: datetime.datetime | None = None,
) -> int:
"""
Return the days the certificate in cert_filename remains valid and -1
if the file was not found. If cert_filename contains more than one
@@ -398,10 +446,10 @@ class CryptographyBackend(CryptoBackend):
return -1
# Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf.
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
b_cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
try:
cert = cryptography.x509.load_pem_x509_certificate(cert_content)
cert = cryptography.x509.load_pem_x509_certificate(b_cert_content)
except Exception as e:
if cert_filename is None:
raise BackendException(f"Cannot parse certificate: {e}")
@@ -413,13 +461,17 @@ class CryptographyBackend(CryptoBackend):
now = add_or_remove_timezone(now, with_timezone=CRYPTOGRAPHY_TIMEZONE)
return (get_not_valid_after(cert) - now).days
def create_chain_matcher(self, criterium):
def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
"""
Given a Criterium object, creates a ChainMatcher object.
"""
return CryptographyChainMatcher(criterium, self.module)
def get_cert_information(self, cert_filename=None, cert_content=None):
def get_cert_information(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
) -> CertificateInformation:
"""
Return some information on a X.509 certificate as a CertificateInformation object.
"""
@@ -429,10 +481,10 @@ class CryptographyBackend(CryptoBackend):
cert_content = to_bytes(cert_content)
# Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf.
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
b_cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
try:
cert = cryptography.x509.load_pem_x509_certificate(cert_content)
cert = cryptography.x509.load_pem_x509_certificate(b_cert_content)
except Exception as e:
if cert_filename is None:
raise BackendException(f"Cannot parse certificate: {e}")
@@ -440,19 +492,19 @@ class CryptographyBackend(CryptoBackend):
ski = None
try:
ext = cert.extensions.get_extension_for_class(
ext_ski = cert.extensions.get_extension_for_class(
cryptography.x509.SubjectKeyIdentifier
)
ski = ext.value.digest
ski = ext_ski.value.digest
except cryptography.x509.ExtensionNotFound:
pass
aki = None
try:
ext = cert.extensions.get_extension_for_class(
ext_aki = cert.extensions.get_extension_for_class(
cryptography.x509.AuthorityKeyIdentifier
)
aki = ext.value.key_identifier
aki = ext_aki.value.key_identifier
except cryptography.x509.ExtensionNotFound:
pass

View File

@@ -13,6 +13,7 @@ import os
import re
import tempfile
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
@@ -34,12 +35,23 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from .certificates import Criterium
_OPENSSL_ENVIRONMENT_UPDATE = dict(LANG="C", LC_ALL="C", LC_MESSAGES="C", LC_CTYPE="C")
def _extract_date(out_text, name, cert_filename_suffix=""):
def _extract_date(
out_text: str, name: str, cert_filename_suffix: str = ""
) -> datetime.datetime:
matcher = re.search(rf"\s+{name}\s*:\s+(.*)", out_text)
if matcher is None:
raise BackendException(f"No '{name}' date found{cert_filename_suffix}")
date_str = matcher.group(1)
try:
date_str = re.search(rf"\s+{name}\s*:\s+(.*)", out_text).group(1)
# For some reason Python's strptime() does not return any timezone information,
# even though the information is there and a supported timezone for all supported
# Python implementations (GMT). So we have to modify the datetime object by
@@ -47,19 +59,40 @@ def _extract_date(out_text, name, cert_filename_suffix=""):
return ensure_utc_timezone(
datetime.datetime.strptime(date_str, "%b %d %H:%M:%S %Y %Z")
)
except AttributeError:
raise BackendException(f"No '{name}' date found{cert_filename_suffix}")
except ValueError as exc:
raise BackendException(
f"Failed to parse '{name}' date{cert_filename_suffix}: {exc}"
)
def _decode_octets(octets_text):
def _decode_octets(octets_text: str) -> bytes:
return binascii.unhexlify(re.sub(r"(\s|:)", "", octets_text).encode("utf-8"))
def _extract_octets(out_text, name, required=True, potential_prefixes=None):
@t.overload
def _extract_octets(
out_text: str,
name: str,
required: t.Literal[False],
potential_prefixes: t.Iterable[str] | None = None,
) -> bytes | None: ...
@t.overload
def _extract_octets(
out_text: str,
name: str,
required: t.Literal[True],
potential_prefixes: t.Iterable[str] | None = None,
) -> bytes: ...
def _extract_octets(
out_text: str,
name: str,
required: bool = True,
potential_prefixes: t.Iterable[str] | None = None,
) -> bytes | None:
part = (
f"(?:{'|'.join(re.escape(pp) for pp in potential_prefixes)})"
if potential_prefixes
@@ -75,13 +108,20 @@ def _extract_octets(out_text, name, required=True, potential_prefixes=None):
class OpenSSLCLIBackend(CryptoBackend):
def __init__(self, module, openssl_binary=None):
def __init__(
self, module: AnsibleModule, openssl_binary: str | None = None
) -> None:
super(OpenSSLCLIBackend, self).__init__(module, with_timezone=True)
if openssl_binary is None:
openssl_binary = module.get_bin_path("openssl", True)
self.openssl_binary = openssl_binary
def parse_key(self, key_file=None, key_content=None, passphrase=None):
def parse_key(
self,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
) -> dict[str, t.Any]:
"""
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
Raises KeyParsingError in case of errors.
@@ -90,6 +130,10 @@ class OpenSSLCLIBackend(CryptoBackend):
raise KeyParsingError("openssl backend does not support key passphrases")
# If key_file is not given, but key_content, write that to a temporary file
if key_file is None:
if key_content is None:
raise KeyParsingError(
"one of key_file and key_content must be specified"
)
fd, tmpsrc = tempfile.mkstemp()
self.module.add_cleanup_file(tmpsrc) # Ansible will delete the file on exit
f = os.fdopen(fd, "wb")
@@ -108,8 +152,8 @@ class OpenSSLCLIBackend(CryptoBackend):
f.close()
# Parse key
account_key_type = None
with open(key_file, "rt") as f:
for line in f:
with open(key_file, "rt") as fi:
for line in fi:
m = re.match(
r"^\s*-{5,}BEGIN\s+(EC|RSA)\s+PRIVATE\s+KEY-{5,}\s*$", line
)
@@ -129,38 +173,44 @@ class OpenSSLCLIBackend(CryptoBackend):
self.openssl_binary,
account_key_type,
"-in",
key_file,
str(key_file),
"-noout",
"-text",
]
rc, out, err = self.module.run_command(
rc, out, stderr = self.module.run_command(
openssl_keydump_cmd,
check_rc=False,
environ_update=_OPENSSL_ENVIRONMENT_UPDATE,
)
if rc != 0:
raise BackendException(
f"Error while running {' '.join(openssl_keydump_cmd)}: {err}"
f"Error while running {' '.join(openssl_keydump_cmd)}: {stderr}"
)
out_text = to_text(out, errors="surrogate_or_strict")
if account_key_type == "rsa":
pub_hex = re.search(
matcher = re.search(
r"modulus:\n\s+00:([a-f0-9\:\s]+?)\npublicExponent",
out_text,
re.MULTILINE | re.DOTALL,
).group(1)
)
if matcher is None:
raise KeyParsingError("cannot parse RSA key: modulus not found")
pub_hex = matcher.group(1)
pub_exp = re.search(
matcher = re.search(
r"\npublicExponent: ([0-9]+)", out_text, re.MULTILINE | re.DOTALL
).group(1)
)
if matcher is None:
raise KeyParsingError("cannot parse RSA key: public exponent not found")
pub_exp = matcher.group(1)
pub_exp = f"{int(pub_exp):x}"
if len(pub_exp) % 2:
pub_exp = f"0{pub_exp}"
return {
"key_file": key_file,
"key_file": str(key_file),
"type": "rsa",
"alg": "RS256",
"jwk": {
@@ -223,8 +273,13 @@ class OpenSSLCLIBackend(CryptoBackend):
"hash": hashalg,
"point_size": point_size,
}
raise KeyParsingError(
f"Internal error: unexpected account_key_type = {account_key_type!r}"
)
def sign(self, payload64, protected64, key_data):
def sign(
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
) -> dict[str, t.Any]:
sign_payload = f"{protected64}.{payload64}".encode("utf8")
if key_data["type"] == "hmac":
hex_key = (
@@ -284,7 +339,7 @@ class OpenSSLCLIBackend(CryptoBackend):
"signature": nopad_b64(to_bytes(out)),
}
def create_mac_key(self, alg, key):
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
"""Create a MAC key."""
if alg == "HS256":
hashalg = "sha256"
@@ -315,14 +370,18 @@ class OpenSSLCLIBackend(CryptoBackend):
}
@staticmethod
def _normalize_ip(ip):
def _normalize_ip(ip: str) -> str:
try:
return ipaddress.ip_address(to_text(ip)).compressed
return ipaddress.ip_address(ip).compressed
except ValueError:
# We do not want to error out on something IPAddress() cannot parse
return ip
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
def get_ordered_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> list[tuple[str, str]]:
"""
Return a list of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
@@ -335,13 +394,13 @@ class OpenSSLCLIBackend(CryptoBackend):
data = None
if csr_content is not None:
filename = "/dev/stdin"
data = csr_content.encode("utf-8")
data = to_bytes(csr_content)
openssl_csr_cmd = [
self.openssl_binary,
"req",
"-in",
filename,
str(filename),
"-noout",
"-text",
]
@@ -360,7 +419,7 @@ class OpenSSLCLIBackend(CryptoBackend):
identifiers = set()
result = []
def add_identifier(identifier):
def add_identifier(identifier: tuple[str, str]) -> None:
if identifier in identifiers:
return
identifiers.add(identifier)
@@ -389,7 +448,11 @@ class OpenSSLCLIBackend(CryptoBackend):
raise BackendException(f'Found unsupported SAN identifier "{san}"')
return result
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
def get_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> set[tuple[str, str]]:
"""
Return a set of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
@@ -401,7 +464,12 @@ class OpenSSLCLIBackend(CryptoBackend):
)
)
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
def get_cert_days(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
now: datetime.datetime | None = None,
) -> int:
"""
Return the days the certificate in cert_filename remains valid and -1
if the file was not found. If cert_filename contains more than one
@@ -413,7 +481,7 @@ class OpenSSLCLIBackend(CryptoBackend):
data = None
if cert_content is not None:
filename = "/dev/stdin"
data = cert_content.encode("utf-8")
data = to_bytes(cert_content)
cert_filename_suffix = ""
elif cert_filename is not None:
if not os.path.exists(cert_filename):
@@ -426,7 +494,7 @@ class OpenSSLCLIBackend(CryptoBackend):
self.openssl_binary,
"x509",
"-in",
filename,
str(filename),
"-noout",
"-text",
]
@@ -452,7 +520,7 @@ class OpenSSLCLIBackend(CryptoBackend):
now = ensure_utc_timezone(now)
return (not_after - now).days
def create_chain_matcher(self, criterium):
def create_chain_matcher(self, criterium: Criterium) -> t.NoReturn:
"""
Given a Criterium object, creates a ChainMatcher object.
"""
@@ -460,7 +528,11 @@ class OpenSSLCLIBackend(CryptoBackend):
'Alternate chain matching can only be used with the "cryptography" backend.'
)
def get_cert_information(self, cert_filename=None, cert_content=None):
def get_cert_information(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
) -> CertificateInformation:
"""
Return some information on a X.509 certificate as a CertificateInformation object.
"""
@@ -477,7 +549,7 @@ class OpenSSLCLIBackend(CryptoBackend):
self.openssl_binary,
"x509",
"-in",
filename,
str(filename),
"-noout",
"-text",
]

View File

@@ -8,7 +8,7 @@ from __future__ import annotations
import abc
import datetime
import re
from collections import namedtuple
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
BackendException,
@@ -27,16 +27,20 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
)
CertificateInformation = namedtuple(
"CertificateInformation",
(
"not_valid_after",
"not_valid_before",
"serial_number",
"subject_key_identifier",
"authority_key_identifier",
),
)
if t.TYPE_CHECKING:
import os
from ansible.module_utils.basic import AnsibleModule
from .certificates import ChainMatcher, Criterium
class CertificateInformation(t.NamedTuple):
not_valid_after: datetime.datetime
not_valid_before: datetime.datetime
serial_number: int
subject_key_identifier: bytes | None
authority_key_identifier: bytes | None
_FRACTIONAL_MATCHER = re.compile(
@@ -44,7 +48,7 @@ _FRACTIONAL_MATCHER = re.compile(
)
def _reduce_fractional_digits(timestamp_str):
def _reduce_fractional_digits(timestamp_str: str) -> str:
"""
Given a RFC 3339 timestamp that includes too many digits for the fractional seconds part, reduces these to at most 6.
"""
@@ -60,7 +64,7 @@ def _reduce_fractional_digits(timestamp_str):
return f"{timestamp}{fractional}{timezone}"
def _parse_acme_timestamp(timestamp_str, with_timezone):
def _parse_acme_timestamp(timestamp_str: str, with_timezone: bool) -> datetime.datetime:
"""
Parses a RFC 3339 timestamp.
"""
@@ -86,34 +90,42 @@ def _parse_acme_timestamp(timestamp_str, with_timezone):
class CryptoBackend(metaclass=abc.ABCMeta):
def __init__(self, module, with_timezone=False):
def __init__(self, module: AnsibleModule, with_timezone: bool = False) -> None:
self.module = module
self._with_timezone = with_timezone
def get_now(self):
def get_now(self) -> datetime.datetime:
return get_now_datetime(with_timezone=self._with_timezone)
def parse_acme_timestamp(self, timestamp_str):
def parse_acme_timestamp(self, timestamp_str: str) -> datetime.datetime:
# RFC 3339 (https://www.rfc-editor.org/info/rfc3339)
return _parse_acme_timestamp(timestamp_str, with_timezone=self._with_timezone)
def parse_module_parameter(self, value, name):
def parse_module_parameter(self, value: str, name: str) -> datetime.datetime:
try:
return get_relative_time_option(
result = get_relative_time_option(
value, name, with_timezone=self._with_timezone
)
if result is None:
raise BackendException(f"Invalid value for {name}: {value!r}")
return result
except OpenSSLObjectError as exc:
raise BackendException(str(exc))
def interpolate_timestamp(self, timestamp_start, timestamp_end, percentage):
def interpolate_timestamp(
self,
timestamp_start: datetime.datetime,
timestamp_end: datetime.datetime,
percentage: float,
) -> datetime.datetime:
start = get_epoch_seconds(timestamp_start)
end = get_epoch_seconds(timestamp_end)
return from_epoch_seconds(
start + percentage * (end - start), with_timezone=self._with_timezone
)
def get_utc_datetime(self, *args, **kwargs):
kwargs_ext = dict(kwargs)
def get_utc_datetime(self, *args, **kwargs) -> datetime.datetime:
kwargs_ext: dict[str, t.Any] = dict(kwargs)
if self._with_timezone and ("tzinfo" not in kwargs_ext and len(args) < 8):
kwargs_ext["tzinfo"] = UTC
result = datetime.datetime(*args, **kwargs_ext)
@@ -122,22 +134,33 @@ class CryptoBackend(metaclass=abc.ABCMeta):
return result
@abc.abstractmethod
def parse_key(self, key_file=None, key_content=None, passphrase=None):
def parse_key(
self,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
) -> dict[str, t.Any]:
"""
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
Raises KeyParsingError in case of errors.
"""
@abc.abstractmethod
def sign(self, payload64, protected64, key_data):
def sign(
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
) -> dict[str, t.Any]:
pass
@abc.abstractmethod
def create_mac_key(self, alg, key):
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
"""Create a MAC key."""
@abc.abstractmethod
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
def get_ordered_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> list[tuple[str, str]]:
"""
Return a list of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
@@ -148,7 +171,11 @@ class CryptoBackend(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
def get_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> set[tuple[str, str]]:
"""
Return a set of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
@@ -156,7 +183,12 @@ class CryptoBackend(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
def get_cert_days(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
now: datetime.datetime | None = None,
) -> int:
"""
Return the days the certificate in cert_filename remains valid and -1
if the file was not found. If cert_filename contains more than one
@@ -166,13 +198,17 @@ class CryptoBackend(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def create_chain_matcher(self, criterium):
def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
"""
Given a Criterium object, creates a ChainMatcher object.
"""
@abc.abstractmethod
def get_cert_information(self, cert_filename=None, cert_content=None):
def get_cert_information(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
) -> CertificateInformation:
"""
Return some information on a X.509 certificate as a CertificateInformation object.
"""

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
@@ -30,6 +31,14 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from .backends import CryptoBackend
from .certificates import ChainMatcher
from .challenges import Challenge
class ACMECertificateClient:
"""
ACME v2 client class. Uses an ACME account object and a CSR to
@@ -37,7 +46,13 @@ class ACMECertificateClient:
certificates.
"""
def __init__(self, module, backend, client=None, account=None):
def __init__(
self,
module: AnsibleModule,
backend: CryptoBackend,
client: ACMEClient | None = None,
account: ACMEAccount | None = None,
) -> None:
self.module = module
self.version = module.params["acme_version"]
self.csr = module.params.get("csr")
@@ -66,13 +81,17 @@ class ACMECertificateClient:
# Extract list of identifiers from CSR
if self.csr is not None or self.csr_content is not None:
self.identifiers = self.client.backend.get_ordered_csr_identifiers(
csr_filename=self.csr, csr_content=self.csr_content
self.identifiers: list[tuple[str, str]] | None = (
self.client.backend.get_ordered_csr_identifiers(
csr_filename=self.csr, csr_content=self.csr_content
)
)
else:
self.identifiers = None
def parse_select_chain(self, select_chain):
def parse_select_chain(
self, select_chain: list[dict[str, t.Any]] | None
) -> list[ChainMatcher]:
select_chain_matcher = []
if select_chain:
for criterium_idx, criterium in enumerate(select_chain):
@@ -88,14 +107,16 @@ class ACMECertificateClient:
)
return select_chain_matcher
def load_order(self):
def load_order(self) -> Order:
if not self.order_uri:
raise ModuleFailException("The order URI has not been provided")
order = Order.from_url(self.client, self.order_uri)
order.load_authorizations(self.client)
return order
def create_order(self, replaces_cert_id=None, profile=None):
def create_order(
self, replaces_cert_id: str | None = None, profile: str | None = None
) -> Order:
"""
Create a new order.
"""
@@ -114,31 +135,31 @@ class ACMECertificateClient:
order.load_authorizations(self.client)
return order
def get_challenges_data(self, order):
def get_challenges_data(
self, order: Order
) -> tuple[list[dict[str, t.Any]], dict[str, list[str]]]:
"""
Get challenge details.
Return a tuple of generic challenge details, and specialized DNS challenge details.
"""
# Get general challenge data
data = []
data: list[dict[str, t.Any]] = []
data_dns: dict[str, list[str]] = {}
dns_challenge_type = "dns-01"
for authz in order.authorizations.values():
# Skip valid authentications: their challenges are already valid
# and do not need to be returned
if authz.status == "valid":
continue
challenge_data = authz.get_challenge_data(self.client)
data.append(
dict(
identifier=authz.identifier,
identifier_type=authz.identifier_type,
challenges=authz.get_challenge_data(self.client),
challenges=challenge_data,
)
)
# Get DNS challenge data
data_dns = {}
dns_challenge_type = "dns-01"
for entry in data:
dns_challenge = entry["challenges"].get(dns_challenge_type)
dns_challenge = challenge_data.get(dns_challenge_type)
if dns_challenge:
values = data_dns.get(dns_challenge["record"])
if values is None:
@@ -147,7 +168,7 @@ class ACMECertificateClient:
values.append(dns_challenge["resource_value"])
return data, data_dns
def check_that_authorizations_can_be_used(self, order):
def check_that_authorizations_can_be_used(self, order: Order) -> None:
bad_authzs = []
for authz in order.authorizations.values():
if authz.status not in ("valid", "pending"):
@@ -155,27 +176,32 @@ class ACMECertificateClient:
f"{authz.combined_identifier} (status={authz.status!r})"
)
if bad_authzs:
bad_authzs = ", ".join(sorted(bad_authzs))
bad_authzs_str = ", ".join(sorted(bad_authzs))
raise ModuleFailException(
"Some of the authorizations for the order are in a bad state, so the order"
f" can no longer be satisfied: {bad_authzs}",
f" can no longer be satisfied: {bad_authzs_str}",
)
def collect_invalid_authzs(self, order):
def collect_invalid_authzs(self, order: Order) -> list[Authorization]:
return [
authz
for authz in order.authorizations.values()
if authz.status == "invalid"
]
def collect_pending_authzs(self, order):
def collect_pending_authzs(self, order: Order) -> list[Authorization]:
return [
authz
for authz in order.authorizations.values()
if authz.status == "pending"
]
def call_validate(self, pending_authzs, get_challenge, wait=True):
def call_validate(
self,
pending_authzs: list[Authorization],
get_challenge: t.Callable[[Authorization], str],
wait: bool = True,
) -> list[tuple[Authorization, str, Challenge | None]]:
authzs_with_challenges_to_wait_for = []
for authz in pending_authzs:
challenge_type = get_challenge(authz)
@@ -185,10 +211,12 @@ class ACMECertificateClient:
)
return authzs_with_challenges_to_wait_for
def wait_for_validation(self, authzs_to_wait_for):
def wait_for_validation(self, authzs_to_wait_for: list[Authorization]) -> None:
wait_for_validation(authzs_to_wait_for, self.client)
def _download_alternate_chains(self, cert):
def _download_alternate_chains(
self, cert: CertificateChain
) -> list[CertificateChain]:
alternate_chains = []
for alternate in cert.alternates:
try:
@@ -206,13 +234,30 @@ class ACMECertificateClient:
)
return alternate_chains
def download_certificate(self, order, download_all_chains=True):
@t.overload
def download_certificate(
self, order: Order, *, download_all_chains: t.Literal[True] = True
) -> tuple[CertificateChain, list[CertificateChain]]: ...
@t.overload
def download_certificate(
self, order: Order, *, download_all_chains: t.Literal[False]
) -> tuple[CertificateChain, None]: ...
@t.overload
def download_certificate(
self, order: Order, *, download_all_chains: bool = True
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
def download_certificate(
self, order: Order, *, download_all_chains: bool = True
) -> tuple[CertificateChain, list[CertificateChain] | None]:
"""
Download certificate from a valid oder.
"""
if order.status != "valid":
raise ModuleFailException(
f"The order must be valid, but has state {order.state!r}!"
f"The order must be valid, but has state {order.status!r}!"
)
if not order.certificate_uri:
@@ -232,7 +277,24 @@ class ACMECertificateClient:
return cert, alternate_chains
def get_certificate(self, order, download_all_chains=True):
@t.overload
def get_certificate(
self, order: Order, *, download_all_chains: t.Literal[True] = True
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
@t.overload
def get_certificate(
self, order: Order, *, download_all_chains: t.Literal[False]
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
@t.overload
def get_certificate(
self, order: Order, *, download_all_chains: bool = True
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
def get_certificate(
self, order: Order, *, download_all_chains: bool = True
) -> tuple[CertificateChain, list[CertificateChain] | None]:
"""
Request a new certificate and downloads it, and optionally all certificate chains.
First verifies whether all authorizations are valid; if not, aborts with an error.
@@ -250,7 +312,11 @@ class ACMECertificateClient:
return self.download_certificate(order, download_all_chains=download_all_chains)
def find_matching_chain(self, chains, select_chain_matcher):
def find_matching_chain(
self,
chains: list[CertificateChain],
select_chain_matcher: t.Iterable[ChainMatcher],
) -> CertificateChain | None:
for criterium_idx, matcher in enumerate(select_chain_matcher):
for chain in chains:
if matcher.match(chain):
@@ -261,9 +327,15 @@ class ACMECertificateClient:
return None
def write_cert_chain(
self, cert, cert_dest=None, fullchain_dest=None, chain_dest=None
):
self,
cert: CertificateChain,
cert_dest: str | os.PathLike | None = None,
fullchain_dest: str | os.PathLike | None = None,
chain_dest: str | os.PathLike | None = None,
) -> bool:
changed = False
if cert.cert is None:
raise ValueError("Certificate is not present")
if cert_dest and write_file(self.module, cert_dest, cert.cert.encode("utf8")):
changed = True
@@ -282,7 +354,7 @@ class ACMECertificateClient:
return changed
def deactivate_authzs(self, order):
def deactivate_authzs(self, order: Order) -> None:
"""
Deactivates all valid authz's. Does not raise exceptions.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2

View File

@@ -6,6 +6,7 @@
from __future__ import annotations
import abc
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
ModuleFailException,
@@ -19,20 +20,29 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
)
if t.TYPE_CHECKING:
from .acme import ACMEClient
_CertificateChain = t.TypeVar("_CertificateChain", bound="CertificateChain")
class CertificateChain:
"""
Download and parse the certificate chain.
https://tools.ietf.org/html/rfc8555#section-7.4.2
"""
def __init__(self, url):
def __init__(self, url: str):
self.url = url
self.cert = None
self.chain = []
self.alternates = []
self.cert: str | None = None
self.chain: list[str] = []
self.alternates: list[str] = []
@classmethod
def download(cls, client, url):
def download(
cls: t.Type[_CertificateChain], client: ACMEClient, url: str
) -> _CertificateChain:
content, info = client.get_request(
url,
parse_json_result=False,
@@ -43,7 +53,7 @@ class CertificateChain:
"application/pem-certificate-chain"
):
raise ModuleFailException(
f"Cannot download certificate chain from {url}, as content type is not application/pem-certificate-chain: {content} (headers: {info})"
f"Cannot download certificate chain from {url}, as content type is not application/pem-certificate-chain: {content!r} (headers: {info})"
)
result = cls(url)
@@ -60,12 +70,12 @@ class CertificateChain:
if result.cert is None:
raise ModuleFailException(
f"Failed to parse certificate chain download from {url}: {content} (headers: {info})"
f"Failed to parse certificate chain download from {url}: {content!r} (headers: {info})"
)
return result
def _process_links(self, client, link, relation):
def _process_links(self, client: ACMEClient, link: str, relation: str) -> None:
if relation == "up":
# Process link-up headers if there was no chain in reply
if not self.chain:
@@ -77,7 +87,9 @@ class CertificateChain:
elif relation == "alternate":
self.alternates.append(link)
def to_json(self):
def to_json(self) -> dict[str, bytes]:
if self.cert is None:
raise ValueError("Has no certificate")
cert = self.cert.encode("utf8")
chain = ("\n".join(self.chain)).encode("utf8")
return {
@@ -88,18 +100,22 @@ class CertificateChain:
class Criterium:
def __init__(self, criterium, index=None):
def __init__(self, criterium: dict[str, t.Any], index: int):
self.index = index
self.test_certificates = criterium["test_certificates"]
self.subject = criterium["subject"]
self.issuer = criterium["issuer"]
self.subject_key_identifier = criterium["subject_key_identifier"]
self.authority_key_identifier = criterium["authority_key_identifier"]
self.test_certificates: t.Literal["first", "last", "all"] = criterium[
"test_certificates"
]
self.subject: dict[str, t.Any] | None = criterium["subject"]
self.issuer: dict[str, t.Any] | None = criterium["issuer"]
self.subject_key_identifier: str | None = criterium["subject_key_identifier"]
self.authority_key_identifier: str | None = criterium[
"authority_key_identifier"
]
class ChainMatcher(metaclass=abc.ABCMeta):
@abc.abstractmethod
def match(self, certificate):
def match(self, certificate: CertificateChain) -> bool:
"""
Check whether a certificate chain (CertificateChain instance) matches.
"""

View File

@@ -11,6 +11,7 @@ import ipaddress
import json
import re
import time
import typing as t
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
@@ -23,7 +24,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
)
def create_key_authorization(client, token):
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from .acme import ACMEClient
def create_key_authorization(client: ACMEClient, token: str) -> str:
"""
Returns the key authorization for the given token
https://tools.ietf.org/html/rfc8555#section-8.1
@@ -35,41 +42,49 @@ def create_key_authorization(client, token):
return f"{token}.{thumbprint}"
def combine_identifier(identifier_type, identifier):
def combine_identifier(identifier_type: str, identifier: str) -> str:
return f"{identifier_type}:{identifier}"
def normalize_combined_identifier(identifier):
def normalize_combined_identifier(identifier: str) -> str:
identifier_type, identifier = split_identifier(identifier)
# Normalize DNS names and IPs
identifier = identifier.lower()
return combine_identifier(identifier_type, identifier)
def split_identifier(identifier):
def split_identifier(identifier: str) -> tuple[str, str]:
parts = identifier.split(":", 1)
if len(parts) != 2:
raise ModuleFailException(
f'Identifier "{identifier}" is not of the form <type>:<identifier>'
)
return parts
return parts[0], parts[1]
_Challenge = t.TypeVar("_Challenge", bound="Challenge")
class Challenge:
def __init__(self, data, url):
def __init__(self, data: dict[str, t.Any], url: str) -> None:
self.data = data
self.type = data["type"]
self.type: str = data["type"]
self.url = url
self.status = data["status"]
self.token = data.get("token")
self.status: str = data["status"]
self.token: str | None = data.get("token")
@classmethod
def from_json(cls, client, data, url=None):
def from_json(
cls: t.Type[_Challenge],
client: ACMEClient,
data: dict[str, t.Any],
url: str | None = None,
) -> _Challenge:
return cls(data, url or data["url"])
def call_validate(self, client):
challenge_response = {}
def call_validate(self, client: ACMEClient) -> None:
challenge_response: dict[str, t.Any] = {}
client.send_signed_request(
self.url,
challenge_response,
@@ -77,10 +92,15 @@ class Challenge:
expected_status_codes=[200, 202],
)
def to_json(self):
def to_json(self) -> dict[str, t.Any]:
return self.data.copy()
def get_validation_data(self, client, identifier_type, identifier):
def get_validation_data(
self, client: ACMEClient, identifier_type: str, identifier: str
) -> dict[str, t.Any] | None:
if self.token is None:
return None
token = re.sub(r"[^A-Za-z0-9_\-]", "_", self.token)
key_authorization = create_key_authorization(client, token)
@@ -113,21 +133,33 @@ class Challenge:
resource += "."
else:
resource = identifier
value = base64.b64encode(
b_value = base64.b64encode(
hashlib.sha256(to_bytes(key_authorization)).digest()
)
return {
"resource": resource,
"resource_original": combine_identifier(identifier_type, identifier),
"resource_value": value,
"resource_value": b_value,
}
# Unknown challenge type: ignore
return None
_Authorization = t.TypeVar("_Authorization", bound="Authorization")
class Authorization:
def _setup(self, client, data):
def __init__(self, url: str) -> None:
self.url = url
self.data: dict[str, t.Any] | None = None
self.challenges: list[Challenge] = []
self.status: str | None = None
self.identifier_type: str | None = None
self.identifier: str | None = None
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
data["uri"] = self.url
self.data = data
# While 'challenges' is a required field, apparently not every CA cares
@@ -145,29 +177,32 @@ class Authorization:
if data.get("wildcard", False):
self.identifier = f"*.{self.identifier}"
def __init__(self, url):
self.url = url
self.data = None
self.challenges = []
self.status = None
self.identifier_type = None
self.identifier = None
@classmethod
def from_json(cls, client, data, url):
def from_json(
cls: t.Type[_Authorization],
client: ACMEClient,
data: dict[str, t.Any],
url: str,
) -> _Authorization:
result = cls(url)
result._setup(client, data)
return result
@classmethod
def from_url(cls, client, url):
def from_url(
cls: t.Type[_Authorization], client: ACMEClient, url: str
) -> _Authorization:
result = cls(url)
result.refresh(client)
return result
@classmethod
def create(cls, client, identifier_type, identifier):
def create(
cls: t.Type[_Authorization],
client: ACMEClient,
identifier_type: str,
identifier: str,
) -> _Authorization:
"""
Create a new authorization for the given identifier.
Return the authorization object of the new authorization
@@ -194,23 +229,29 @@ class Authorization:
return cls.from_json(client, result, info["location"])
@property
def combined_identifier(self):
def combined_identifier(self) -> str:
if self.identifier_type is None or self.identifier is None:
raise ValueError("Data not present")
return combine_identifier(self.identifier_type, self.identifier)
def to_json(self):
def to_json(self) -> dict[str, t.Any]:
if self.data is None:
raise ValueError("Data not present")
return self.data.copy()
def refresh(self, client):
def refresh(self, client: ACMEClient) -> bool:
result, dummy = client.get_request(self.url)
changed = self.data != result
self._setup(client, result)
return changed
def get_challenge_data(self, client):
def get_challenge_data(self, client: ACMEClient) -> dict[str, t.Any]:
"""
Returns a dict with the data for all proposed (and supported) challenges
of the given authorization.
"""
if self.identifier_type is None or self.identifier is None:
raise ValueError("Data not present")
data = {}
for challenge in self.challenges:
validation_data = challenge.get_validation_data(
@@ -220,7 +261,7 @@ class Authorization:
data[challenge.type] = validation_data
return data
def raise_error(self, error_msg, module=None):
def raise_error(self, error_msg: str, module: AnsibleModule) -> t.NoReturn:
"""
Aborts with a specific error for a challenge.
"""
@@ -246,13 +287,13 @@ class Authorization:
),
)
def find_challenge(self, challenge_type):
def find_challenge(self, challenge_type: str) -> Challenge | None:
for challenge in self.challenges:
if challenge_type == challenge.type:
return challenge
return None
def wait_for_validation(self, client, callenge_type):
def wait_for_validation(self, client: ACMEClient, callenge_type: str) -> bool:
while True:
self.refresh(client)
if self.status in ["valid", "invalid", "revoked"]:
@@ -264,7 +305,9 @@ class Authorization:
return self.status == "valid"
def call_validate(self, client, challenge_type, wait=True):
def call_validate(
self, client: ACMEClient, challenge_type: str, wait: bool = True
) -> bool:
"""
Validate the authorization provided in the auth dict. Returns True
when the validation was successful and False when it was not.
@@ -281,7 +324,7 @@ class Authorization:
return self.status == "valid"
return self.wait_for_validation(client, challenge_type)
def can_deactivate(self):
def can_deactivate(self) -> bool:
"""
Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
@@ -289,14 +332,14 @@ class Authorization:
"""
return self.status in ("valid", "pending")
def deactivate(self, client):
def deactivate(self, client: ACMEClient) -> bool | None:
"""
Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
https://tools.ietf.org/html/rfc8555#section-7.5.2
"""
if not self.can_deactivate():
return
return None
authz_deactivate = {"status": "deactivated"}
result, info = client.send_signed_request(
self.url, authz_deactivate, fail_on_error=False
@@ -307,7 +350,9 @@ class Authorization:
return False
@classmethod
def deactivate_url(cls, client, url):
def deactivate_url(
cls: t.Type[_Authorization], client: ACMEClient, url: str
) -> _Authorization:
"""
Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
@@ -322,7 +367,7 @@ class Authorization:
return authz
def wait_for_validation(authzs, client):
def wait_for_validation(authzs: t.Iterable[Authorization], client: ACMEClient) -> None:
"""
Wait until a list of authz is valid. Fail if at least one of them is invalid or revoked.
"""

View File

@@ -5,19 +5,24 @@
from __future__ import annotations
import typing as t
from http.client import responses as http_responses
from ansible.module_utils.common.text.converters import to_text
def format_http_status(status_code):
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def format_http_status(status_code: int) -> str:
expl = http_responses.get(status_code)
if not expl:
return str(status_code)
return f"{status_code} {expl}"
def format_error_problem(problem, subproblem_prefix=""):
def format_error_problem(problem: dict[str, t.Any], subproblem_prefix: str = "") -> str:
error_type = problem.get(
"type", "about:blank"
) # https://www.rfc-editor.org/rfc/rfc7807#section-3.1
@@ -32,8 +37,10 @@ def format_error_problem(problem, subproblem_prefix=""):
msg = f"{msg} Subproblems:"
for index, problem in enumerate(subproblems):
index_str = f"{subproblem_prefix}{index}"
problem = format_error_problem(problem, subproblem_prefix=f"{index_str}.")
msg = f"{msg}\n({index_str}) {problem}"
problem_str = format_error_problem(
problem, subproblem_prefix=f"{index_str}."
)
msg = f"{msg}\n({index_str}) {problem_str}"
return msg
@@ -42,25 +49,25 @@ class ModuleFailException(Exception):
If raised, module.fail_json() will be called with the given parameters after cleanup.
"""
def __init__(self, msg, **args):
def __init__(self, msg: str, **args: t.Any) -> None:
super(ModuleFailException, self).__init__(self, msg)
self.msg = msg
self.module_fail_args = args
def do_fail(self, module, **arguments):
def do_fail(self, module: AnsibleModule, **arguments) -> t.NoReturn:
module.fail_json(msg=self.msg, other=self.module_fail_args, **arguments)
class ACMEProtocolException(ModuleFailException):
def __init__(
self,
module,
msg=None,
info=None,
module: AnsibleModule,
msg: str | None = None,
info: dict[str, t.Any] | None = None,
response=None,
content=None,
content_json=None,
extras=None,
content: bytes | None = None,
content_json: dict[str, t.Any] | None = None,
extras: dict[str, t.Any] | None = None,
):
# Try to get hold of content, if response is given and content is not provided
if content is None and content_json is None and response is not None:
@@ -71,7 +78,8 @@ class ACMEProtocolException(ModuleFailException):
raise TypeError
content = response.read()
except (AttributeError, TypeError):
content = info.pop("body", None)
if info is not None:
content = info.pop("body", None)
# Make sure that content_json is None or a dictionary
if content_json is not None and not isinstance(content_json, dict):
@@ -139,8 +147,8 @@ class ACMEProtocolException(ModuleFailException):
add_msg = f" The raw result: {to_text(content)}"
super(ACMEProtocolException, self).__init__(f"{msg}.{add_msg}", **extras)
self.problem = {}
self.subproblems = []
self.problem: dict[str, t.Any] = {}
self.subproblems: list[dict[str, t.Any]] = []
self.error_code = error_code
self.error_type = error_type
for k, v in extras.items():

View File

@@ -10,22 +10,27 @@ import os
import shutil
import tempfile
import traceback
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
ModuleFailException,
)
def read_file(fn, mode="b"):
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def read_file(fn: str | os.PathLike) -> bytes:
try:
with open(fn, "r" + mode) as f:
with open(fn, "rb") as f:
return f.read()
except Exception as e:
raise ModuleFailException(f'Error while reading file "{fn}": {e}')
# This function was adapted from an earlier version of https://github.com/ansible/ansible/blob/devel/lib/ansible/modules/uri.py
def write_file(module, dest, content):
def write_file(module: AnsibleModule, dest: str | os.PathLike, content: bytes) -> bool:
"""
Write content to destination file dest, only if the content
has changed.

View File

@@ -6,6 +6,7 @@
from __future__ import annotations
import time
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
Authorization,
@@ -13,14 +14,35 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.challenges i
)
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
ACMEProtocolException,
ModuleFailException,
)
from ansible_collections.community.crypto.plugins.module_utils.acme.utils import (
nopad_b64,
)
if t.TYPE_CHECKING:
from .acme import ACMEClient
_Order = t.TypeVar("_Order", bound="Order")
class Order:
def _setup(self, client, data):
def __init__(self, url: str) -> None:
self.url = url
self.data: dict[str, t.Any] | None = None
self.status = None
self.identifiers: list[tuple[str, str]] = []
self.replaces_cert_id = None
self.finalize_uri = None
self.certificate_uri = None
self.authorization_uris: list[str] = []
self.authorizations: dict[str, Authorization] = {}
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
self.data = data
self.status = data["status"]
@@ -33,33 +55,28 @@ class Order:
self.authorization_uris = data["authorizations"]
self.authorizations = {}
def __init__(self, url):
self.url = url
self.data = None
self.status = None
self.identifiers = []
self.replaces_cert_id = None
self.finalize_uri = None
self.certificate_uri = None
self.authorization_uris = []
self.authorizations = {}
@classmethod
def from_json(cls, client, data, url):
def from_json(
cls: t.Type[_Order], client: ACMEClient, data: dict[str, t.Any], url: str
) -> _Order:
result = cls(url)
result._setup(client, data)
return result
@classmethod
def from_url(cls, client, url):
def from_url(cls: t.Type[_Order], client: ACMEClient, url: str) -> _Order:
result = cls(url)
result.refresh(client)
return result
@classmethod
def create(cls, client, identifiers, replaces_cert_id=None, profile=None):
def create(
cls: t.Type[_Order],
client: ACMEClient,
identifiers: list[tuple[str, str]],
replaces_cert_id: str | None = None,
profile: str | None = None,
) -> _Order:
"""
Start a new certificate order (ACME v2 protocol).
https://tools.ietf.org/html/rfc8555#section-7.4
@@ -72,7 +89,7 @@ class Order:
"value": identifier,
}
)
new_order = {"identifiers": acme_identifiers}
new_order: dict[str, t.Any] = {"identifiers": acme_identifiers}
if replaces_cert_id is not None:
new_order["replaces"] = replaces_cert_id
if profile is not None:
@@ -87,15 +104,17 @@ class Order:
@classmethod
def create_with_error_handling(
cls,
client,
identifiers,
error_strategy="auto",
error_max_retries=3,
replaces_cert_id=None,
profile=None,
message_callback=None,
):
cls: t.Type[_Order],
client: ACMEClient,
identifiers: list[tuple[str, str]],
error_strategy: t.Literal[
"auto", "fail", "always", "retry_without_replaces_cert_id"
] = "auto",
error_max_retries: int = 3,
replaces_cert_id: str | None = None,
profile: str | None = None,
message_callback: t.Callable[[str], None] | None = None,
) -> _Order:
"""
error_strategy can be one of the following strings:
@@ -140,20 +159,20 @@ class Order:
raise
def refresh(self, client):
def refresh(self, client: ACMEClient) -> bool:
result, dummy = client.get_request(self.url)
changed = self.data != result
self._setup(client, result)
return changed
def load_authorizations(self, client):
def load_authorizations(self, client: ACMEClient) -> None:
for auth_uri in self.authorization_uris:
authz = Authorization.from_url(client, auth_uri)
self.authorizations[
normalize_combined_identifier(authz.combined_identifier)
] = authz
def wait_for_finalization(self, client):
def wait_for_finalization(self, client: ACMEClient) -> None:
while True:
self.refresh(client)
if self.status in ["valid", "invalid", "pending", "ready"]:
@@ -167,12 +186,14 @@ class Order:
content_json=self.data,
)
def finalize(self, client, csr_der, wait=True):
def finalize(self, client: ACMEClient, csr_der: bytes, wait: bool = True) -> None:
"""
Create a new certificate based on the csr.
Return the certificate object as dict
https://tools.ietf.org/html/rfc8555#section-7.4
"""
if self.finalize_uri is None:
raise ModuleFailException("finalize_uri must be set")
new_cert = {
"csr": nopad_b64(csr_der),
}

View File

@@ -7,9 +7,11 @@ from __future__ import annotations
import base64
import datetime
import os
import re
import textwrap
import traceback
import typing as t
from urllib.parse import unquote
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
@@ -23,11 +25,15 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
)
def nopad_b64(data):
if t.TYPE_CHECKING:
from .backends import CertificateInformation, CryptoBackend
def nopad_b64(data: bytes) -> str:
return base64.urlsafe_b64encode(data).decode("utf8").replace("=", "")
def der_to_pem(der_cert):
def der_to_pem(der_cert: bytes) -> str:
"""
Convert the DER format certificate in der_cert to a PEM format certificate and return it.
"""
@@ -35,7 +41,9 @@ def der_to_pem(der_cert):
return f"-----BEGIN CERTIFICATE-----\n{content}\n-----END CERTIFICATE-----\n"
def pem_to_der(pem_filename=None, pem_content=None):
def pem_to_der(
pem_filename: str | os.PathLike | None = None, pem_content: str | None = None
) -> bytes:
"""
Load PEM file, or use PEM file's content, and convert to DER.
@@ -70,7 +78,9 @@ def pem_to_der(pem_filename=None, pem_content=None):
return base64.b64decode("".join(certificate_lines))
def process_links(info, callback):
def process_links(
info: dict[str, t.Any], callback: t.Callable[[str, str], None]
) -> None:
"""
Process link header, calls callback for every link header with the URL and relation as options.
@@ -82,7 +92,11 @@ def process_links(info, callback):
callback(unquote(url), relation)
def parse_retry_after(value, relative_with_timezone=True, now=None):
def parse_retry_after(
value: str,
relative_with_timezone: bool = True,
now: datetime.datetime | None = None,
) -> datetime.datetime:
"""
Parse the value of a Retry-After header and return a timestamp.
@@ -106,12 +120,12 @@ def parse_retry_after(value, relative_with_timezone=True, now=None):
def compute_cert_id(
backend,
cert_info=None,
cert_filename=None,
cert_content=None,
none_if_required_information_is_missing=False,
):
backend: CryptoBackend,
cert_info: CertificateInformation | None = None,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
none_if_required_information_is_missing: bool = False,
) -> str | None:
# Obtain certificate info if not provided
if cert_info is None:
cert_info = backend.get_cert_information(

View File

@@ -4,10 +4,15 @@
from __future__ import annotations
import typing as t
from ansible.module_utils.basic import AnsibleModule
def _ensure_list(value):
_T = t.TypeVar("_T")
def _ensure_list(value: list[_T] | tuple[_T] | None) -> list[_T]:
if value is None:
return []
return list(value)
@@ -16,13 +21,19 @@ def _ensure_list(value):
class ArgumentSpec:
def __init__(
self,
argument_spec=None,
mutually_exclusive=None,
required_together=None,
required_one_of=None,
required_if=None,
required_by=None,
):
argument_spec: dict[str, t.Any] | None = None,
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
required_together: list[list[str] | tuple[str, ...]] | None = None,
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
required_if: (
list[
tuple[str, t.Any, list[str] | tuple[str, ...]]
| tuple[str, t.Any, list[str] | tuple[str, ...], bool]
]
| None
) = None,
required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
) -> None:
self.argument_spec = argument_spec or {}
self.mutually_exclusive = _ensure_list(mutually_exclusive)
self.required_together = _ensure_list(required_together)
@@ -30,17 +41,23 @@ class ArgumentSpec:
self.required_if = _ensure_list(required_if)
self.required_by = required_by or {}
def update_argspec(self, **kwargs):
def update_argspec(self, **kwargs) -> t.Self:
self.argument_spec.update(kwargs)
return self
def update(
self,
mutually_exclusive=None,
required_together=None,
required_one_of=None,
required_if=None,
required_by=None,
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
required_together: list[list[str] | tuple[str, ...]] | None = None,
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
required_if: (
list[
tuple[str, t.Any, list[str] | tuple[str, ...]]
| tuple[str, t.Any, list[str] | tuple[str, ...], bool]
]
| None
) = None,
required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
):
if mutually_exclusive:
self.mutually_exclusive.extend(mutually_exclusive)
@@ -57,7 +74,7 @@ class ArgumentSpec:
self.required_by[k] = v
return self
def merge(self, other):
def merge(self, other: t.Self) -> t.Self:
self.update_argspec(**other.argument_spec)
self.update(
mutually_exclusive=other.mutually_exclusive,
@@ -68,8 +85,22 @@ class ArgumentSpec:
)
return self
def create_ansible_module_helper(self, clazz, args, **kwargs):
return clazz(
def create_ansible_module_helper(
self, clazz: type[_T], args: tuple, **kwargs: t.Any
) -> _T:
for forbidden_name in (
"argument_spec",
"mutually_exclusive",
"required_together",
"required_one_of",
"required_if",
"required_by",
):
if forbidden_name in kwargs:
raise ValueError(
f"You must not provide a {forbidden_name} keyword parameter to create_ansible_module_helper()"
)
instance = clazz( # type: ignore
*args,
argument_spec=self.argument_spec,
mutually_exclusive=self.mutually_exclusive,
@@ -79,8 +110,9 @@ class ArgumentSpec:
required_by=self.required_by,
**kwargs,
)
return instance
def create_ansible_module(self, **kwargs):
def create_ansible_module(self, **kwargs: t.Any) -> AnsibleModule:
return self.create_ansible_module_helper(AnsibleModule, (), **kwargs)

View File

@@ -4,6 +4,7 @@
from __future__ import annotations
import enum
import re
from ansible.module_utils.common.text.converters import to_bytes
@@ -32,7 +33,7 @@ ASN1_STRING_REGEX = re.compile(
)
class TagClass:
class TagClass(enum.Enum):
universal = 0
application = 1
context_specific = 2
@@ -40,11 +41,11 @@ class TagClass:
# Universal tag numbers that can be encoded.
class TagNumber:
class TagNumber(enum.Enum):
utf8_string = 12
def _pack_octet_integer(value):
def _pack_octet_integer(value: int) -> bytes:
"""Packs an integer value into 1 or multiple octets."""
# NOTE: This is *NOT* the same as packing an ASN.1 INTEGER like value.
octets = bytearray()
@@ -66,7 +67,7 @@ def _pack_octet_integer(value):
return bytes(octets)
def serialize_asn1_string_as_der(value):
def serialize_asn1_string_as_der(value: str) -> bytes:
"""Deserializes an ASN.1 string to a DER encoded byte string."""
asn1_match = ASN1_STRING_REGEX.match(value)
if not asn1_match:
@@ -92,7 +93,7 @@ def serialize_asn1_string_as_der(value):
b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value)
if tag_type:
tag_class = {
tag_class_enum = {
"U": TagClass.universal,
"A": TagClass.application,
"P": TagClass.private,
@@ -100,13 +101,15 @@ def serialize_asn1_string_as_der(value):
}[tag_class]
# When adding support for more types this should be looked into further. For now it works with UTF8Strings.
constructed = tag_type == "EXPLICIT" and tag_class != TagClass.universal
b_value = pack_asn1(tag_class, constructed, int(tag_number), b_value)
constructed = tag_type == "EXPLICIT" and tag_class_enum != TagClass.universal
b_value = pack_asn1(tag_class_enum, constructed, int(tag_number), b_value)
return b_value
def pack_asn1(tag_class, constructed, tag_number, b_data):
def pack_asn1(
tag_class: TagClass, constructed: bool, tag_number: TagNumber | int, b_data: bytes
) -> bytes:
"""Pack the value into an ASN.1 data structure.
The structure for an ASN.1 element is
@@ -115,16 +118,15 @@ def pack_asn1(tag_class, constructed, tag_number, b_data):
"""
b_asn1_data = bytearray()
if tag_class < 0 or tag_class > 3:
raise ValueError(f"tag_class must be between 0 and 3 not {tag_class}")
# Bit 8 and 7 denotes the class.
identifier_octets = tag_class << 6
identifier_octets = tag_class.value << 6
# Bit 6 denotes whether the value is primitive or constructed.
identifier_octets |= (1 if constructed else 0) << 5
# Bits 5-1 contain the tag number, if it cannot be encoded in these 5 bits
# then they are set and another octet(s) is used to denote the tag number.
if isinstance(tag_number, TagNumber):
tag_number = tag_number.value
if tag_number < 31:
identifier_octets |= tag_number
b_asn1_data.append(identifier_octets)

View File

@@ -34,7 +34,7 @@ from __future__ import annotations
# cryptography versions!
def obj2txt(openssl_lib, openssl_ffi, obj):
def obj2txt(openssl_lib, openssl_ffi, obj) -> str:
# Set to 80 on the recommendation of
# https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values
#

View File

@@ -7,9 +7,9 @@ from __future__ import annotations
from ._objects_data import OID_MAP
OID_LOOKUP = dict()
NORMALIZE_NAMES = dict()
NORMALIZE_NAMES_SHORT = dict()
OID_LOOKUP: dict[str, str] = dict()
NORMALIZE_NAMES: dict[str, str] = dict()
NORMALIZE_NAMES_SHORT: dict[str, str] = dict()
for dotted, names in OID_MAP.items():
for name in names:

View File

@@ -4,6 +4,8 @@
from __future__ import annotations
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.version import (
LooseVersion as _LooseVersion,
)
@@ -21,6 +23,10 @@ from .basic import HAS_CRYPTOGRAPHY
from .cryptography_support import CRYPTOGRAPHY_TIMEZONE, cryptography_decode_name
if t.TYPE_CHECKING:
import datetime
# TODO: once cryptography has a _utc variant of InvalidityDate.invalidity_date, set this
# to True and adjust get_invalidity_date() accordingly.
# (https://github.com/pyca/cryptography/issues/10818)
@@ -55,7 +61,9 @@ else:
REVOCATION_REASON_MAP_INVERSE = dict()
def cryptography_decode_revoked_certificate(cert):
def cryptography_decode_revoked_certificate(
cert: x509.RevokedCertificate,
) -> dict[str, t.Any]:
result = {
"serial_number": cert.serial_number,
"revocation_date": get_revocation_date(cert),
@@ -67,27 +75,30 @@ def cryptography_decode_revoked_certificate(cert):
"invalidity_date_critical": False,
}
try:
ext = cert.extensions.get_extension_for_class(x509.CertificateIssuer)
result["issuer"] = list(ext.value)
result["issuer_critical"] = ext.critical
ext_ci = cert.extensions.get_extension_for_class(x509.CertificateIssuer)
result["issuer"] = list(ext_ci.value)
result["issuer_critical"] = ext_ci.critical
except x509.ExtensionNotFound:
pass
try:
ext = cert.extensions.get_extension_for_class(x509.CRLReason)
result["reason"] = ext.value.reason
result["reason_critical"] = ext.critical
ext_cr = cert.extensions.get_extension_for_class(x509.CRLReason)
result["reason"] = ext_cr.value.reason
result["reason_critical"] = ext_cr.critical
except x509.ExtensionNotFound:
pass
try:
ext = cert.extensions.get_extension_for_class(x509.InvalidityDate)
result["invalidity_date"] = get_invalidity_date(ext.value)
result["invalidity_date_critical"] = ext.critical
ext_id = cert.extensions.get_extension_for_class(x509.InvalidityDate)
result["invalidity_date"] = get_invalidity_date(ext_id.value)
result["invalidity_date_critical"] = ext_id.critical
except x509.ExtensionNotFound:
pass
return result
def cryptography_dump_revoked(entry, idn_rewrite="ignore"):
def cryptography_dump_revoked(
entry: dict[str, t.Any],
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
) -> dict[str, t.Any]:
return {
"serial_number": entry["serial_number"],
"revocation_date": entry["revocation_date"].strftime(TIMESTAMP_FORMAT),
@@ -115,48 +126,56 @@ def cryptography_dump_revoked(entry, idn_rewrite="ignore"):
}
def cryptography_get_signature_algorithm_oid_from_crl(crl):
def cryptography_get_signature_algorithm_oid_from_crl(
crl: x509.CertificateRevocationList,
) -> x509.oid.ObjectIdentifier:
try:
return crl.signature_algorithm_oid
except AttributeError:
# Older cryptography versions do not have signature_algorithm_oid yet
dotted = obj2txt(
crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm
crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm # type: ignore
)
return x509.oid.ObjectIdentifier(dotted)
def get_next_update(obj):
def get_next_update(obj: x509.CertificateRevocationList) -> datetime.datetime | None:
if CRYPTOGRAPHY_TIMEZONE:
return obj.next_update_utc
return obj.next_update
def get_last_update(obj):
def get_last_update(obj: x509.CertificateRevocationList) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE:
return obj.last_update_utc
return obj.last_update
def get_revocation_date(obj):
def get_revocation_date(obj: x509.RevokedCertificate) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE:
return obj.revocation_date_utc
return obj.revocation_date
def get_invalidity_date(obj):
def get_invalidity_date(obj: x509.InvalidityDate) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE:
return obj.invalidity_date_utc
return obj.invalidity_date
def set_next_update(builder, value):
def set_next_update(
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
) -> x509.CertificateRevocationListBuilder:
return builder.next_update(value)
def set_last_update(builder, value):
def set_last_update(
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
) -> x509.CertificateRevocationListBuilder:
return builder.last_update(value)
def set_revocation_date(builder, value):
def set_revocation_date(
builder: x509.RevokedCertificateBuilder, value: datetime.datetime
) -> x509.RevokedCertificateBuilder:
return builder.revocation_date(value)

View File

@@ -9,6 +9,7 @@ import binascii
import ipaddress
import re
import traceback
import typing as t
from urllib.parse import (
ParseResult,
urlparse,
@@ -40,6 +41,7 @@ except ImportError:
pass
try:
import cryptography.hazmat.primitives.asymmetric.dh
import cryptography.hazmat.primitives.asymmetric.ed448
import cryptography.hazmat.primitives.asymmetric.ed25519
import cryptography.hazmat.primitives.asymmetric.rsa
@@ -55,7 +57,7 @@ try:
)
except ImportError:
# Error handled in the calling module.
_load_pkcs12 = None
_load_pkcs12 = None # type: ignore
try:
import idna
@@ -74,6 +76,50 @@ from .basic import (
)
if t.TYPE_CHECKING:
import datetime
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric.dh import DHPrivateKey, DHPublicKey
from cryptography.hazmat.primitives.asymmetric.dsa import (
DSAPrivateKey,
DSAPublicKey,
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKey,
EllipticCurvePublicKey,
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey,
RSAPublicKey,
)
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
CertificateIssuerPublicKeyTypes,
CertificatePublicKeyTypes,
PrivateKeyTypes,
PublicKeyTypes,
)
from cryptography.hazmat.primitives.serialization.pkcs12 import (
PKCS12KeyAndCertificates,
)
CertificatePrivateKeyTypes = (
CertificateIssuerPrivateKeyTypes
| cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey
| cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
)
PublicKeyTypesWOEdwards = (
DHPublicKey | DSAPublicKey | EllipticCurvePublicKey | RSAPublicKey
)
PrivateKeyTypesWOEdwards = (
DHPrivateKey | DSAPrivateKey | EllipticCurvePrivateKey | RSAPrivateKey
)
else:
PublicKeyTypesWOEdwards = None
PrivateKeyTypesWOEdwards = None
CRYPTOGRAPHY_TIMEZONE = False
_CRYPTOGRAPHY_36_0_OR_NEWER = False
if _HAS_CRYPTOGRAPHY:
@@ -88,7 +134,9 @@ if _HAS_CRYPTOGRAPHY:
DOTTED_OID = re.compile(r"^\d+(?:\.\d+)+$")
def cryptography_get_extensions_from_cert(cert):
def cryptography_get_extensions_from_cert(
cert: x509.Certificate,
) -> dict[str, dict[str, bool | str]]:
result = dict()
if _CRYPTOGRAPHY_36_0_OR_NEWER:
@@ -105,7 +153,7 @@ def cryptography_get_extensions_from_cert(cert):
backend = default_backend()
x509_obj = cert._x509
x509_obj = cert._x509 # type: ignore
# With cryptography 35.0.0, we can no longer use obj2txt. Unfortunately it still does
# not allow to get the raw value of an extension, so we have to use this ugly hack:
exts = list(cert.extensions)
@@ -135,7 +183,9 @@ def cryptography_get_extensions_from_cert(cert):
return result
def cryptography_get_extensions_from_csr(csr):
def cryptography_get_extensions_from_csr(
csr: x509.CertificateSigningRequest,
) -> dict[str, dict[str, bool | str]]:
result = dict()
if _CRYPTOGRAPHY_36_0_OR_NEWER:
@@ -153,7 +203,7 @@ def cryptography_get_extensions_from_csr(csr):
backend = default_backend()
extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req)
extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req) # type: ignore
extensions = backend._ffi.gc(
extensions,
lambda ext: backend._lib.sk_X509_EXTENSION_pop_free(
@@ -175,7 +225,7 @@ def cryptography_get_extensions_from_csr(csr):
crit = backend._lib.X509_EXTENSION_get_critical(ext)
data = backend._lib.X509_EXTENSION_get_data(ext)
backend.openssl_assert(data != backend._ffi.NULL)
der = backend._ffi.buffer(data.data, data.length)[:]
der: bytes = backend._ffi.buffer(data.data, data.length)[:] # type: ignore
entry = dict(
critical=(crit == 1),
value=base64.b64encode(der).decode("ascii"),
@@ -193,7 +243,7 @@ def cryptography_get_extensions_from_csr(csr):
return result
def cryptography_name_to_oid(name):
def cryptography_name_to_oid(name: str) -> x509.oid.ObjectIdentifier:
dotted = OID_LOOKUP.get(name)
if dotted is None:
if DOTTED_OID.match(name):
@@ -202,7 +252,9 @@ def cryptography_name_to_oid(name):
return x509.oid.ObjectIdentifier(dotted)
def cryptography_oid_to_name(oid, short=False):
def cryptography_oid_to_name(
oid: x509.oid.ObjectIdentifier, short: bool = False
) -> str:
dotted_string = oid.dotted_string
names = OID_MAP.get(dotted_string)
if names:
@@ -217,15 +269,22 @@ def cryptography_oid_to_name(oid, short=False):
return NORMALIZE_NAMES.get(name, name)
def _get_hex(bytesstr):
def _get_hex(bytesstr: bytes) -> str:
if bytesstr is None:
return bytesstr
data = binascii.hexlify(bytesstr)
data = to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2)))
return data
return to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2)))
def _parse_hex(bytesstr):
@t.overload
def _parse_hex(bytesstr: bytes | str) -> bytes: ...
@t.overload
def _parse_hex(bytesstr: bytes | str | None) -> bytes | None: ...
def _parse_hex(bytesstr: bytes | str | None) -> bytes | None:
if bytesstr is None:
return bytesstr
data = "".join(
@@ -234,19 +293,20 @@ def _parse_hex(bytesstr):
for p in to_text(bytesstr).split(":")
]
)
data = binascii.unhexlify(data)
return data
return binascii.unhexlify(data)
DN_COMPONENT_START_RE = re.compile(b"^ *([a-zA-z0-9.]+) *= *")
DN_HEX_LETTER = b"0123456789abcdef"
def _int_to_byte(value):
def _int_to_byte(value: int) -> bytes:
return bytes((value,))
def _parse_dn_component(name, sep=b",", decode_remainder=True):
def _parse_dn_component(
name: bytes, sep: bytes = b",", decode_remainder: bool = True
) -> tuple[x509.NameAttribute, bytes]:
m = DN_COMPONENT_START_RE.match(name)
if not m:
raise OpenSSLObjectError(f'cannot start part in "{to_text(name)}"')
@@ -305,7 +365,7 @@ def _parse_dn_component(name, sep=b",", decode_remainder=True):
return x509.NameAttribute(oid, to_text(b"".join(decoded_name))), name[idx:]
def _parse_dn(name):
def _parse_dn(name: bytes) -> list[x509.NameAttribute]:
"""
Parse a Distinguished Name.
@@ -323,31 +383,33 @@ def _parse_dn(name):
attribute, name = _parse_dn_component(name, sep=sep)
except OpenSSLObjectError as e:
raise OpenSSLObjectError(
f'Error while parsing distinguished name "{to_text(original_name)}": {e}'
f"Error while parsing distinguished name {to_text(original_name)!r}: {e}"
)
result.append(attribute)
if name:
if name[0:1] != sep or len(name) < 2:
raise OpenSSLObjectError(
f'Error while parsing distinguished name "{to_text(original_name)}": unexpected end of string'
f"Error while parsing distinguished name {to_text(original_name)!r}: unexpected end of string"
)
name = name[1:]
return result
def cryptography_parse_relative_distinguished_name(rdn):
def cryptography_parse_relative_distinguished_name(
rdn: list[str | bytes],
) -> cryptography.x509.RelativeDistinguishedName:
names = []
for part in rdn:
try:
names.append(_parse_dn_component(to_bytes(part), decode_remainder=False)[0])
except OpenSSLObjectError as e:
raise OpenSSLObjectError(
f'Error while parsing relative distinguished name "{part}": {e}'
f"Error while parsing relative distinguished name {to_text(part)!r}: {e}"
)
return cryptography.x509.RelativeDistinguishedName(names)
def _is_ascii(value):
def _is_ascii(value: str) -> bool:
"""Check whether the Unicode string `value` contains only ASCII characters."""
try:
value.encode("ascii")
@@ -356,7 +418,7 @@ def _is_ascii(value):
return False
def _adjust_idn(value, idn_rewrite):
def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]) -> str:
if idn_rewrite == "ignore" or not value:
return value
if idn_rewrite == "idna" and _is_ascii(value):
@@ -399,16 +461,20 @@ def _adjust_idn(value, idn_rewrite):
return ".".join(parts)
def _adjust_idn_email(value, idn_rewrite):
def _adjust_idn_email(
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
) -> str:
idx = value.find("@")
if idx < 0:
return value
return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite)}"
def _adjust_idn_url(value, idn_rewrite):
def _adjust_idn_url(
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
) -> str:
url = urlparse(value)
host = _adjust_idn(url.hostname, idn_rewrite)
host = _adjust_idn(url.hostname, idn_rewrite) if url.hostname else None
if url.username is not None and url.password is not None:
host = f"{url.username}:{url.password}@{host}"
elif url.username is not None:
@@ -418,7 +484,7 @@ def _adjust_idn_url(value, idn_rewrite):
return urlunparse(
ParseResult(
scheme=url.scheme,
netloc=host,
netloc=host or "",
path=url.path,
params=url.params,
query=url.query,
@@ -427,7 +493,9 @@ def _adjust_idn_url(value, idn_rewrite):
)
def cryptography_get_name(name, what="Subject Alternative Name"):
def cryptography_get_name(
name: str, what: str = "Subject Alternative Name"
) -> x509.GeneralName:
"""
Given a name string, returns a cryptography x509.GeneralName object.
Raises an OpenSSLObjectError if the name is unknown or cannot be parsed.
@@ -490,7 +558,7 @@ def cryptography_get_name(name, what="Subject Alternative Name"):
)
def _dn_escape_value(value):
def _dn_escape_value(value: str) -> str:
"""
Escape Distinguished Name's attribute value.
"""
@@ -505,7 +573,10 @@ def _dn_escape_value(value):
return value
def cryptography_decode_name(name, idn_rewrite="ignore"):
def cryptography_decode_name(
name: x509.GeneralName,
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
) -> str:
"""
Given a cryptography x509.GeneralName object, returns a string.
Raises an OpenSSLObjectError if the name is not supported.
@@ -529,7 +600,7 @@ def cryptography_decode_name(name, idn_rewrite="ignore"):
# list needs to be reversed, and joined by commas
return "dirName:" + ",".join(
[
f"{to_text(cryptography_oid_to_name(attribute.oid, short=True))}={_dn_escape_value(attribute.value)}"
f"{to_text(cryptography_oid_to_name(attribute.oid, short=True))}={_dn_escape_value(to_text(attribute.value))}"
for attribute in reversed(list(name.value))
]
)
@@ -540,7 +611,7 @@ def cryptography_decode_name(name, idn_rewrite="ignore"):
raise OpenSSLObjectError(f'Cannot decode name "{name}"')
def _cryptography_get_keyusage(usage):
def _cryptography_get_keyusage(usage: str) -> str:
"""
Given a key usage identifier string, returns the parameter name used by cryptography's x509.KeyUsage().
Raises an OpenSSLObjectError if the identifier is unknown.
@@ -566,7 +637,7 @@ def _cryptography_get_keyusage(usage):
raise OpenSSLObjectError(f'Unknown key usage "{usage}"')
def cryptography_parse_key_usage_params(usages):
def cryptography_parse_key_usage_params(usages: t.Iterable[str]) -> dict[str, bool]:
"""
Given a list of key usage identifier strings, returns the parameters for cryptography's x509.KeyUsage().
Raises an OpenSSLObjectError if an identifier is unknown.
@@ -587,13 +658,15 @@ def cryptography_parse_key_usage_params(usages):
return params
def cryptography_get_basic_constraints(constraints):
def cryptography_get_basic_constraints(
constraints: t.Iterable[str] | None,
) -> tuple[bool, int | None]:
"""
Given a list of constraints, returns a tuple (ca, path_length).
Raises an OpenSSLObjectError if a constraint is unknown or cannot be parsed.
"""
ca = False
path_length = None
path_length: int | None = None
if constraints:
for constraint in constraints:
if constraint.startswith("CA:"):
@@ -618,7 +691,9 @@ def cryptography_get_basic_constraints(constraints):
return ca, path_length
def cryptography_key_needs_digest_for_signing(key):
def cryptography_key_needs_digest_for_signing(
key: CertificateIssuerPrivateKeyTypes,
) -> bool:
"""Tests whether the given private key requires a digest algorithm for signing.
Ed25519 and Ed448 keys do not; they need None to be passed as the digest algorithm.
@@ -632,19 +707,27 @@ def cryptography_key_needs_digest_for_signing(key):
return True
def _compare_public_keys(key1, key2, clazz):
def _compare_public_keys(
key1: PublicKeyTypes, key2: PublicKeyTypes, clazz: type[PublicKeyTypes]
) -> bool | None:
a = isinstance(key1, clazz)
b = isinstance(key2, clazz)
if not (a or b):
return None
if not a or not b:
return False
a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
return a == b
a_bytes = key1.public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)
b_bytes = key2.public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)
return a_bytes == b_bytes
def cryptography_compare_public_keys(key1, key2):
def cryptography_compare_public_keys(
key1: PublicKeyTypes, key2: PublicKeyTypes
) -> bool:
"""Tests whether two public keys are the same.
Needs special logic for Ed25519 and Ed448 keys, since they do not have public_numbers().
@@ -654,6 +737,13 @@ def cryptography_compare_public_keys(key1, key2):
key2,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
)
if res is not None:
return res
res = _compare_public_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
)
if res is not None:
return res
res = _compare_public_keys(
@@ -661,10 +751,20 @@ def cryptography_compare_public_keys(key1, key2):
)
if res is not None:
return res
return key1.public_numbers() == key2.public_numbers()
res = _compare_public_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
)
if res is not None:
return res
return (
t.cast(PublicKeyTypesWOEdwards, key1).public_numbers()
== t.cast(PublicKeyTypesWOEdwards, key2).public_numbers()
)
def _compare_private_keys(key1, key2, clazz):
def _compare_private_keys(
key1: PrivateKeyTypes, key2: PrivateKeyTypes, clazz: type[PrivateKeyTypes]
) -> bool | None:
a = isinstance(key1, clazz)
b = isinstance(key2, clazz)
if not (a or b):
@@ -672,20 +772,22 @@ def _compare_private_keys(key1, key2, clazz):
if not a or not b:
return False
encryption_algorithm = cryptography.hazmat.primitives.serialization.NoEncryption()
a = key1.private_bytes(
a_bytes = key1.private_bytes(
serialization.Encoding.Raw,
serialization.PrivateFormat.Raw,
encryption_algorithm=encryption_algorithm,
)
b = key2.private_bytes(
b_bytes = key2.private_bytes(
serialization.Encoding.Raw,
serialization.PrivateFormat.Raw,
encryption_algorithm=encryption_algorithm,
)
return a == b
return a_bytes == b_bytes
def cryptography_compare_private_keys(key1, key2):
def cryptography_compare_private_keys(
key1: PrivateKeyTypes, key2: PrivateKeyTypes
) -> bool:
"""Tests whether two private keys are the same.
Needs special logic for Ed25519, X25519, and Ed448 keys, since they do not have private_numbers().
@@ -714,25 +816,39 @@ def cryptography_compare_private_keys(key1, key2):
)
if res is not None:
return res
return key1.private_numbers() == key2.private_numbers()
return (
t.cast(PrivateKeyTypesWOEdwards, key1).private_numbers()
== t.cast(PrivateKeyTypesWOEdwards, key2).private_numbers()
)
def parse_pkcs12(pkcs12_bytes, passphrase=None):
def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
bytes | None,
]:
"""Returns a tuple (private_key, certificate, additional_certificates, friendly_name)."""
passphrase_bytes = None
if passphrase is not None:
passphrase = to_bytes(passphrase)
passphrase_bytes = to_bytes(passphrase)
# Main code for cryptography 36.0.0 and forward
if _load_pkcs12 is not None:
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase)
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase_bytes)
if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"):
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase)
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase_bytes)
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase)
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase_bytes)
def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None):
def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
bytes | None,
]:
# Requires cryptography 36.0.0 or newer
pkcs12 = _load_pkcs12(pkcs12_bytes, passphrase)
additional_certificates = [cert.certificate for cert in pkcs12.additional_certs]
@@ -745,7 +861,12 @@ def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None):
return private_key, certificate, additional_certificates, friendly_name
def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None):
def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
bytes | None,
]:
# Backwards compatibility code for cryptography 35.x
private_key, certificate, additional_certificates = _load_key_and_certificates(
pkcs12_bytes, passphrase
@@ -787,7 +908,12 @@ def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None):
return private_key, certificate, additional_certificates, friendly_name
def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None):
def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
bytes | None,
]:
# Backwards compatibility code for cryptography < 35.0.0
private_key, certificate, additional_certificates = _load_key_and_certificates(
pkcs12_bytes, passphrase
@@ -796,14 +922,19 @@ def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None):
friendly_name = None
if certificate:
# See https://github.com/pyca/cryptography/issues/5760#issuecomment-842687238
backend = certificate._backend
maybe_name = backend._lib.X509_alias_get0(certificate._x509, backend._ffi.NULL)
backend = certificate._backend # type: ignore
maybe_name = backend._lib.X509_alias_get0(certificate._x509, backend._ffi.NULL) # type: ignore
if maybe_name != backend._ffi.NULL:
friendly_name = backend._ffi.string(maybe_name)
return private_key, certificate, additional_certificates, friendly_name
def cryptography_verify_signature(signature, data, hash_algorithm, signer_public_key):
def cryptography_verify_signature(
signature: bytes,
data: bytes,
hash_algorithm: hashes.HashAlgorithm | None,
signer_public_key: PublicKeyTypes,
) -> bool:
"""
Check whether the given signature of the given data was signed by the given public key object.
"""
@@ -812,6 +943,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
signer_public_key,
cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey,
):
if hash_algorithm is None:
raise OpenSSLObjectError("Need hash_algorithm for RSA keys")
signer_public_key.verify(
signature, data, padding.PKCS1v15(), hash_algorithm
)
@@ -820,6 +953,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
signer_public_key,
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey,
):
if hash_algorithm is None:
raise OpenSSLObjectError("Need hash_algorithm for ECC keys")
signer_public_key.verify(
signature,
data,
@@ -830,6 +965,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
signer_public_key,
cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey,
):
if hash_algorithm is None:
raise OpenSSLObjectError("Need hash_algorithm for DSA keys")
signer_public_key.verify(signature, data, hash_algorithm)
return True
if isinstance(
@@ -851,7 +988,9 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
return False
def cryptography_verify_certificate_signature(certificate, signer_public_key):
def cryptography_verify_certificate_signature(
certificate: x509.Certificate, signer_public_key: PublicKeyTypes
) -> bool:
"""
Check whether the given X509 certificate object was signed by the given public key object.
"""
@@ -863,21 +1002,65 @@ def cryptography_verify_certificate_signature(certificate, signer_public_key):
)
def get_not_valid_after(obj):
def get_not_valid_after(obj: x509.Certificate) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE:
return obj.not_valid_after_utc
return obj.not_valid_after
def get_not_valid_before(obj):
def get_not_valid_before(obj: x509.Certificate) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE:
return obj.not_valid_before_utc
return obj.not_valid_before
def set_not_valid_after(builder, value):
def set_not_valid_after(
builder: x509.CertificateBuilder, value: datetime.datetime
) -> x509.CertificateBuilder:
return builder.not_valid_after(value)
def set_not_valid_before(builder, value):
def set_not_valid_before(
builder: x509.CertificateBuilder, value: datetime.datetime
) -> x509.CertificateBuilder:
return builder.not_valid_before(value)
def is_potential_certificate_private_key(
key: PrivateKeyTypes,
) -> t.TypeGuard[CertificatePrivateKeyTypes]:
return not isinstance(
key, cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey
)
def is_potential_certificate_issuer_private_key(
key: PrivateKeyTypes,
) -> t.TypeGuard[CertificateIssuerPrivateKeyTypes]:
return not isinstance(
key,
(
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey,
),
)
def is_potential_certificate_public_key(
key: PublicKeyTypes,
) -> t.TypeGuard[CertificatePublicKeyTypes]:
return not isinstance(key, DHPublicKey)
def is_potential_certificate_issuer_public_key(
key: PublicKeyTypes,
) -> t.TypeGuard[CertificateIssuerPublicKeyTypes]:
return not isinstance(
key,
(
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey,
cryptography.hazmat.primitives.asymmetric.dh.DHPublicKey,
),
)

View File

@@ -5,7 +5,7 @@
from __future__ import annotations
def binary_exp_mod(f, e, m):
def binary_exp_mod(f: int, e: int, m: int) -> int:
"""Computes f^e mod m in O(log e) multiplications modulo m."""
# Compute len_e = floor(log_2(e))
len_e = -1
@@ -22,14 +22,14 @@ def binary_exp_mod(f, e, m):
return result
def simple_gcd(a, b):
def simple_gcd(a: int, b: int) -> int:
"""Compute GCD of its two inputs."""
while b != 0:
a, b = b, a % b
return a
def quick_is_not_prime(n):
def quick_is_not_prime(n: int) -> bool:
"""Does some quick checks to see if we can poke a hole into the primality of n.
A result of `False` does **not** mean that the number is prime; it just means
@@ -97,7 +97,7 @@ def quick_is_not_prime(n):
return False
def count_bytes(no):
def count_bytes(no: int) -> int:
"""
Given an integer, compute the number of bytes necessary to store its absolute value.
"""
@@ -107,7 +107,7 @@ def count_bytes(no):
return (no.bit_length() + 7) // 8
def count_bits(no):
def count_bits(no: int) -> int:
"""
Given an integer, compute the number of bits necessary to store its absolute value.
"""
@@ -117,19 +117,7 @@ def count_bits(no):
return no.bit_length()
def _convert_int_to_bytes(count, no):
return no.to_bytes(count, byteorder="big")
def _convert_bytes_to_int(data):
return int.from_bytes(data, byteorder="big", signed=False)
def _to_hex(no):
return f"{no:x}"
def convert_int_to_bytes(no, count=None):
def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
"""
Convert the absolute value of an integer to a byte string in network byte order.
@@ -142,10 +130,10 @@ def convert_int_to_bytes(no, count=None):
no = abs(no)
if count is None:
count = count_bytes(no)
return _convert_int_to_bytes(count, no)
return no.to_bytes(count, byteorder="big")
def convert_int_to_hex(no, digits=None):
def convert_int_to_hex(no: int, digits: int | None = None) -> str:
"""
Convert the absolute value of an integer to a string of hexadecimal digits.
@@ -154,14 +142,14 @@ def convert_int_to_hex(no, digits=None):
the string will be longer.
"""
no = abs(no)
value = _to_hex(no)
value = f"{no:x}"
if digits is not None and len(value) < digits:
value = "0" * (digits - len(value)) + value
return value
def convert_bytes_to_int(data):
def convert_bytes_to_int(data: bytes) -> int:
"""
Convert a byte string to an unsigned integer in network byte order.
"""
return _convert_bytes_to_int(data)
return int.from_bytes(data, byteorder="big", signed=False)

View File

@@ -6,6 +6,7 @@
from __future__ import annotations
import abc
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
ArgumentSpec,
@@ -24,8 +25,8 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
load_certificate,
load_certificate_privatekey,
load_certificate_request,
load_privatekey,
)
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -33,6 +34,17 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
)
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
)
from ..cryptography_support import CertificatePrivateKeyTypes
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
@@ -47,41 +59,45 @@ class CertificateError(OpenSSLObjectError):
class CertificateBackend(metaclass=abc.ABCMeta):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.force = module.params["force"]
self.ignore_timestamps = module.params["ignore_timestamps"]
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
if self.privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.csr_path = module.params["csr_path"]
self.csr_content = module.params["csr_content"]
if self.csr_content is not None:
self.csr_content = self.csr_content.encode("utf-8")
self.force: bool = module.params["force"]
self.ignore_timestamps: bool = module.params["ignore_timestamps"]
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
if privatekey_content is not None:
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.csr_path: str | None = module.params["csr_path"]
csr_content = module.params["csr_content"]
if csr_content is not None:
self.csr_content: bytes | None = csr_content.encode("utf-8")
else:
self.csr_content = None
# The following are default values which make sure check() works as
# before if providers do not explicitly change these properties.
self.create_subject_key_identifier = "never_create"
self.create_authority_key_identifier = False
self.create_subject_key_identifier: str = "never_create"
self.create_authority_key_identifier: bool = False
self.privatekey = None
self.csr = None
self.cert = None
self.existing_certificate = None
self.existing_certificate_bytes = None
self.privatekey: CertificatePrivateKeyTypes | None = None
self.csr: x509.CertificateSigningRequest | None = None
self.cert: x509.Certificate | None = None
self.existing_certificate: x509.Certificate | None = None
self.existing_certificate_bytes: bytes | None = None
self.check_csr_subject = True
self.check_csr_extensions = True
self.check_csr_subject: bool = True
self.check_csr_extensions: bool = True
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
def _get_info(self, data):
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return dict()
return {}
try:
result = get_certificate_info(
self.module, data, prefer_one_fingerprint=True
@@ -92,34 +108,34 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return dict(can_parse_certificate=False)
@abc.abstractmethod
def generate_certificate(self):
def generate_certificate(self) -> None:
"""(Re-)Generate certificate."""
pass
@abc.abstractmethod
def get_certificate_data(self):
def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert."""
pass
def set_existing(self, certificate_bytes):
def set_existing(self, certificate_bytes: bytes | None) -> None:
"""Set existing certificate bytes. None indicates that the key does not exist."""
self.existing_certificate_bytes = certificate_bytes
self.diff_after = self.diff_before = self._get_info(
self.existing_certificate_bytes
)
def has_existing(self):
def has_existing(self) -> bool:
"""Query whether an existing certificate is/has been there."""
return self.existing_certificate_bytes is not None
def _ensure_private_key_loaded(self):
def _ensure_private_key_loaded(self) -> None:
"""Load the provided private key into self.privatekey."""
if self.privatekey is not None:
return
if self.privatekey_path is None and self.privatekey_content is None:
return
try:
self.privatekey = load_privatekey(
self.privatekey = load_certificate_privatekey(
path=self.privatekey_path,
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
@@ -127,7 +143,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
except OpenSSLBadPassphraseError as exc:
raise CertificateError(exc)
def _ensure_csr_loaded(self):
def _ensure_csr_loaded(self) -> None:
"""Load the CSR into self.csr."""
if self.csr is not None:
return
@@ -138,7 +154,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
content=self.csr_content,
)
def _ensure_existing_certificate_loaded(self):
def _ensure_existing_certificate_loaded(self) -> None:
"""Load the existing certificate into self.existing_certificate."""
if self.existing_certificate is not None:
return
@@ -149,14 +165,28 @@ class CertificateBackend(metaclass=abc.ABCMeta):
content=self.existing_certificate_bytes,
)
def _check_privatekey(self):
def _check_privatekey(self) -> bool:
"""Check whether provided parameters match, assuming self.existing_certificate and self.privatekey have been populated."""
if self.existing_certificate is None:
raise AssertionError(
"Contract violation: existing_certificate has not been populated"
)
if self.privatekey is None:
raise AssertionError(
"Contract violation: privatekey has not been populated"
)
return cryptography_compare_public_keys(
self.existing_certificate.public_key(), self.privatekey.public_key()
)
def _check_csr(self):
def _check_csr(self) -> bool:
"""Check whether provided parameters match, assuming self.existing_certificate and self.csr have been populated."""
if self.existing_certificate is None:
raise AssertionError(
"Contract violation: existing_certificate has not been populated"
)
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
# Verify that CSR is signed by certificate's private key
if not self.csr.is_signature_valid:
return False
@@ -214,8 +244,14 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return False
return True
def _check_subject_key_identifier(self):
"""Check whether Subject Key Identifier matches, assuming self.existing_certificate has been populated."""
def _check_subject_key_identifier(self) -> bool:
"""Check whether Subject Key Identifier matches, assuming self.existing_certificate and self.csr have been populated."""
if self.existing_certificate is None:
raise AssertionError(
"Contract violation: existing_certificate has not been populated"
)
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
# Get hold of certificate's SKI
try:
ext = self.existing_certificate.extensions.get_extension_for_class(
@@ -247,7 +283,11 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return False
return True
def needs_regeneration(self, not_before=None, not_after=None):
def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
"""Check whether a regeneration is necessary."""
if self.force or self.existing_certificate_bytes is None:
return True
@@ -256,6 +296,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
self._ensure_existing_certificate_loaded()
except Exception:
return True
assert self.existing_certificate is not None
# Check whether private key matches
self._ensure_private_key_loaded()
@@ -285,9 +326,12 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return True
return False
def dump(self, include_certificate):
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = {"privatekey": self.privatekey_path, "csr": self.csr_path}
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
"csr": self.csr_path,
}
# Get hold of certificate bytes
certificate_bytes = self.existing_certificate_bytes
if self.cert is not None:
@@ -299,35 +343,33 @@ class CertificateBackend(metaclass=abc.ABCMeta):
certificate_bytes.decode("utf-8") if certificate_bytes else None
)
result["diff"] = dict(
before=self.diff_before,
after=self.diff_after,
)
result["diff"] = {
"before": self.diff_before,
"after": self.diff_after,
}
return result
class CertificateProvider(metaclass=abc.ABCMeta):
@abc.abstractmethod
def validate_module_args(self, module):
def validate_module_args(self, module: AnsibleModule) -> None:
"""Check module arguments"""
@abc.abstractmethod
def needs_version_two_certs(self, module):
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
"""Whether the provider needs to create a version 2 certificate."""
@abc.abstractmethod
def create_backend(self, module):
def create_backend(self, module: AnsibleModule) -> CertificateBackend:
"""Create an implementation for a backend.
Return value must be instance of CertificateBackend.
"""
def select_backend(module, provider):
"""
:type module: AnsibleModule
:type provider: CertificateProvider
"""
def select_backend(
module: AnsibleModule, provider: CertificateProvider
) -> CertificateBackend:
provider.validate_module_args(module)
assert_required_cryptography_version(
@@ -343,7 +385,7 @@ def select_backend(module, provider):
return provider.create_backend(module)
def get_certificate_argument_spec():
def get_certificate_argument_spec() -> ArgumentSpec:
return ArgumentSpec(
argument_spec=dict(
provider=dict(

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import os
import tempfile
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
@@ -17,22 +18,30 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
class AcmeCertificateBackend(CertificateBackend):
def __init__(self, module):
super(AcmeCertificateBackend, self).__init__(module)
self.accountkey_path = module.params["acme_accountkey_path"]
self.challenge_path = module.params["acme_challenge_path"]
self.use_chain = module.params["acme_chain"]
self.acme_directory = module.params["acme_directory"]
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
if self.csr_content is None and self.csr_path is None:
raise CertificateError(
"csr_path or csr_content is required for ownca provider"
)
if self.csr_content is None and not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
from ...argspec import ArgumentSpec
class AcmeCertificateBackend(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(AcmeCertificateBackend, self).__init__(module)
self.accountkey_path: str = module.params["acme_accountkey_path"]
self.challenge_path: str = module.params["acme_challenge_path"]
self.use_chain: bool = module.params["acme_chain"]
self.acme_directory: str = module.params["acme_directory"]
self.cert_bytes: bytes | None = None
if self.csr_content is None:
if self.csr_path is None:
raise CertificateError(
"csr_path or csr_content is required for ownca provider"
)
if not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
if not os.path.exists(self.accountkey_path):
raise CertificateError(
@@ -46,7 +55,7 @@ class AcmeCertificateBackend(CertificateBackend):
self.acme_tiny_path = self.module.get_bin_path("acme-tiny", required=True)
def generate_certificate(self):
def generate_certificate(self) -> None:
"""(Re-)Generate certificate."""
command = [self.acme_tiny_path]
@@ -77,22 +86,26 @@ class AcmeCertificateBackend(CertificateBackend):
command.extend(["--directory-url", self.acme_directory])
try:
self.cert = to_bytes(self.module.run_command(command, check_rc=True)[1])
self.cert_bytes = to_bytes(
self.module.run_command(command, check_rc=True)[1]
)
except OSError as exc:
raise CertificateError(exc)
def get_certificate_data(self):
def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert."""
return self.cert
if self.cert_bytes is None:
raise AssertionError("Contract violation: cert_bytes is None")
return self.cert_bytes
def dump(self, include_certificate):
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(AcmeCertificateBackend, self).dump(include_certificate)
result["accountkey"] = self.accountkey_path
return result
class AcmeCertificateProvider(CertificateProvider):
def validate_module_args(self, module):
def validate_module_args(self, module: AnsibleModule) -> None:
if module.params["acme_accountkey_path"] is None:
module.fail_json(
msg="The acme_accountkey_path option must be specified for the acme provider."
@@ -102,14 +115,14 @@ class AcmeCertificateProvider(CertificateProvider):
msg="The acme_challenge_path option must be specified for the acme provider."
)
def needs_version_two_certs(self, module):
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
return False
def create_backend(self, module):
def create_backend(self, module: AnsibleModule) -> AcmeCertificateBackend:
return AcmeCertificateBackend(module)
def add_acme_provider_to_argument_spec(argument_spec):
def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("acme")
argument_spec.argument_spec.update(
dict(

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import datetime
import os
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
@@ -32,6 +33,12 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ...argspec import ArgumentSpec
try:
from cryptography.x509.oid import NameOID
except ImportError:
@@ -39,7 +46,7 @@ except ImportError:
class EntrustCertificateBackend(CertificateBackend):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(EntrustCertificateBackend, self).__init__(module)
self.trackingId = None
self.notAfter = get_relative_time_option(
@@ -48,16 +55,19 @@ class EntrustCertificateBackend(CertificateBackend):
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
if self.csr_content is None and self.csr_path is None:
raise CertificateError(
"csr_path or csr_content is required for entrust provider"
)
if self.csr_content is None and not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
if self.csr_content is None:
if self.csr_path is None:
raise CertificateError(
"csr_path or csr_content is required for entrust provider"
)
if not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
self._ensure_csr_loaded()
if self.csr is None:
raise CertificateError("CSR not provided")
# ECS API defaults to using the validated organization tied to the account.
# We want to always force behavior of trying to use the organization provided in the CSR.
@@ -93,9 +103,9 @@ class EntrustCertificateBackend(CertificateBackend):
],
)
except SessionConfigurationException as e:
module.fail_json(msg=f"Failed to initialize Entrust Provider: {e.message}")
module.fail_json(msg=f"Failed to initialize Entrust Provider: {e}")
def generate_certificate(self):
def generate_certificate(self) -> None:
"""(Re-)Generate certificate."""
body = {}
@@ -104,6 +114,7 @@ class EntrustCertificateBackend(CertificateBackend):
# csr_content contains bytes
body["csr"] = to_native(self.csr_content)
else:
assert self.csr_path is not None
with open(self.csr_path, "r") as csr_file:
body["csr"] = csr_file.read()
@@ -138,11 +149,15 @@ class EntrustCertificateBackend(CertificateBackend):
content=self.cert_bytes,
)
def get_certificate_data(self):
def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert."""
return self.cert_bytes
def needs_regeneration(self):
def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
parent_check = super(EntrustCertificateBackend, self).needs_regeneration()
try:
@@ -167,12 +182,12 @@ class EntrustCertificateBackend(CertificateBackend):
return parent_check
def _get_cert_details(self):
cert_details = {}
def _get_cert_details(self) -> dict[str, t.Any]:
cert_details: dict[str, t.Any] = {}
try:
self._ensure_existing_certificate_loaded()
except Exception:
return
return cert_details
if self.existing_certificate:
serial_number = f"{self.existing_certificate.serial_number:X}"
expiry = get_not_valid_after(self.existing_certificate)
@@ -203,17 +218,17 @@ class EntrustCertificateBackend(CertificateBackend):
class EntrustCertificateProvider(CertificateProvider):
def validate_module_args(self, module):
def validate_module_args(self, module: AnsibleModule) -> None:
pass
def needs_version_two_certs(self, module):
def needs_version_two_certs(self, module: AnsibleModule) -> t.Literal[False]:
return False
def create_backend(self, module):
def create_backend(self, module: AnsibleModule) -> EntrustCertificateBackend:
return EntrustCertificateBackend(module)
def add_entrust_provider_to_argument_spec(argument_spec):
def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("entrust")
argument_spec.argument_spec.update(
dict(
@@ -248,7 +263,7 @@ def add_entrust_provider_to_argument_spec(argument_spec):
)
)
argument_spec.required_if.append(
[
(
"provider",
"entrust",
[
@@ -260,5 +275,5 @@ def add_entrust_provider_to_argument_spec(argument_spec):
"entrust_api_client_cert_path",
"entrust_api_client_cert_key_path",
],
]
)
)

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import abc
import binascii
import typing as t
from ansible.module_utils.common.text.converters import to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
@@ -34,6 +35,19 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
)
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes
from ....plugin_utils.action_module import AnsibleActionModule
from ....plugin_utils.filter_module import FilterModuleMock
from ...argspec import ArgumentSpec
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
@@ -48,93 +62,97 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(self, module, content):
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
# content must be a bytes string
self.module = module
self.content = content
@abc.abstractmethod
def _get_der_bytes(self):
def _get_der_bytes(self) -> bytes:
pass
@abc.abstractmethod
def _get_signature_algorithm(self):
def _get_signature_algorithm(self) -> str:
pass
@abc.abstractmethod
def _get_subject_ordered(self):
def _get_subject_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_issuer_ordered(self):
def _get_issuer_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_version(self):
def _get_version(self) -> int | str:
pass
@abc.abstractmethod
def _get_key_usage(self):
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_extended_key_usage(self):
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_basic_constraints(self):
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_ocsp_must_staple(self):
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
pass
@abc.abstractmethod
def _get_subject_alt_name(self):
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def get_not_before(self):
def get_not_before(self) -> datetime.datetime:
pass
@abc.abstractmethod
def get_not_after(self):
def get_not_after(self) -> datetime.datetime:
pass
@abc.abstractmethod
def _get_public_key_pem(self):
def _get_public_key_pem(self) -> bytes:
pass
@abc.abstractmethod
def _get_public_key_object(self):
def _get_public_key_object(self) -> PublicKeyTypes:
pass
@abc.abstractmethod
def _get_subject_key_identifier(self):
def _get_subject_key_identifier(self) -> bytes | None:
pass
@abc.abstractmethod
def _get_authority_key_identifier(self):
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
pass
@abc.abstractmethod
def _get_serial_number(self):
def _get_serial_number(self) -> int:
pass
@abc.abstractmethod
def _get_all_extensions(self):
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
pass
@abc.abstractmethod
def _get_ocsp_uri(self):
def _get_ocsp_uri(self) -> str | None:
pass
@abc.abstractmethod
def _get_issuer_uri(self):
def _get_issuer_uri(self) -> str | None:
pass
def get_info(self, prefer_one_fingerprint=False, der_support_enabled=False):
result = dict()
def get_info(
self, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.cert = load_certificate(
None,
content=self.content,
@@ -194,16 +212,20 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
self._get_der_bytes(), prefer_one=prefer_one_fingerprint
)
ski = self._get_subject_key_identifier()
if ski is not None:
ski = binascii.hexlify(ski).decode("ascii")
ski_bytes = self._get_subject_key_identifier()
if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
else:
ski = None
result["subject_key_identifier"] = ski
aki, aci, acsn = self._get_authority_key_identifier()
if aki is not None:
aki = binascii.hexlify(aki).decode("ascii")
aki_bytes, aci, acsn = self._get_authority_key_identifier()
if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
else:
aki = None
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
@@ -219,36 +241,40 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
"""Validate the supplied cert, using the cryptography backend"""
def __init__(self, module, content):
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
super(CertificateInfoRetrievalCryptography, self).__init__(module, content)
self.name_encoding = module.params.get("name_encoding", "ignore")
def _get_der_bytes(self):
def _get_der_bytes(self) -> bytes:
return self.cert.public_bytes(serialization.Encoding.DER)
def _get_signature_algorithm(self):
def _get_signature_algorithm(self) -> str:
return cryptography_oid_to_name(self.cert.signature_algorithm_oid)
def _get_subject_ordered(self):
result = []
def _get_subject_ordered(self) -> list[list[str]]:
result: list[list[str]] = []
for attribute in self.cert.subject:
result.append([cryptography_oid_to_name(attribute.oid), attribute.value])
result.append(
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
)
return result
def _get_issuer_ordered(self):
def _get_issuer_ordered(self) -> list[list[str]]:
result = []
for attribute in self.cert.issuer:
result.append([cryptography_oid_to_name(attribute.oid), attribute.value])
result.append(
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
)
return result
def _get_version(self):
def _get_version(self) -> int | str:
if self.cert.version == x509.Version.v1:
return 1
if self.cert.version == x509.Version.v3:
return 3
return "unknown"
def _get_key_usage(self):
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
try:
current_key_ext = self.cert.extensions.get_extension_for_class(
x509.KeyUsage
@@ -297,7 +323,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_extended_key_usage(self):
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
try:
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
x509.ExtendedKeyUsage
@@ -311,7 +337,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_basic_constraints(self):
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
try:
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
x509.BasicConstraints
@@ -324,7 +350,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_ocsp_must_staple(self):
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
try:
tlsfeature_ext = self.cert.extensions.get_extension_for_class(
x509.TLSFeature
@@ -336,7 +362,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_subject_alt_name(self):
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
try:
san_ext = self.cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName
@@ -349,22 +375,22 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, False
def get_not_before(self):
def get_not_before(self) -> datetime.datetime:
return get_not_valid_before(self.cert)
def get_not_after(self):
def get_not_after(self) -> datetime.datetime:
return get_not_valid_after(self.cert)
def _get_public_key_pem(self):
def _get_public_key_pem(self) -> bytes:
return self.cert.public_key().public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_public_key_object(self):
def _get_public_key_object(self) -> PublicKeyTypes:
return self.cert.public_key()
def _get_subject_key_identifier(self):
def _get_subject_key_identifier(self) -> bytes | None:
try:
ext = self.cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
@@ -373,7 +399,9 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None
def _get_authority_key_identifier(self):
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
try:
ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier
@@ -392,13 +420,13 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, None, None
def _get_serial_number(self):
def _get_serial_number(self) -> int:
return self.cert.serial_number
def _get_all_extensions(self):
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
return cryptography_get_extensions_from_cert(self.cert)
def _get_ocsp_uri(self):
def _get_ocsp_uri(self) -> str | None:
try:
ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityInformationAccess
@@ -411,7 +439,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
pass
return None
def _get_issuer_uri(self):
def _get_issuer_uri(self) -> str | None:
try:
ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityInformationAccess
@@ -428,12 +456,16 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
return None
def get_certificate_info(module, content, prefer_one_fingerprint=False):
def get_certificate_info(
module: GeneralAnsibleModule, content: bytes, prefer_one_fingerprint: bool = False
) -> dict[str, t.Any]:
info = CertificateInfoRetrievalCryptography(module, content)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(module, content):
def select_backend(
module: GeneralAnsibleModule, content: bytes
) -> CertificateInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)

View File

@@ -6,6 +6,7 @@
from __future__ import annotations
import os
import typing as t
from random import randrange
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -18,6 +19,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
cryptography_verify_certificate_signature,
get_not_valid_after,
get_not_valid_before,
is_potential_certificate_issuer_public_key,
set_not_valid_after,
set_not_valid_before,
)
@@ -28,7 +30,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
load_certificate,
load_privatekey,
load_certificate_issuer_privatekey,
select_message_digest,
)
from ansible_collections.community.crypto.plugins.module_utils.time import (
@@ -36,6 +38,17 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
)
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
)
from ...argspec import ArgumentSpec
try:
import cryptography
from cryptography import x509
@@ -45,13 +58,13 @@ except ImportError:
class OwnCACertificateBackendCryptography(CertificateBackend):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(OwnCACertificateBackendCryptography, self).__init__(module)
self.create_subject_key_identifier = module.params[
"ownca_create_subject_key_identifier"
]
self.create_authority_key_identifier = module.params[
self.create_subject_key_identifier: t.Literal[
"create_if_not_provided", "always_create", "never_create"
] = module.params["ownca_create_subject_key_identifier"]
self.create_authority_key_identifier: bool = module.params[
"ownca_create_authority_key_identifier"
]
self.notBefore = get_relative_time_option(
@@ -65,31 +78,40 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["ownca_digest"])
self.version = module.params["ownca_version"]
self.version: int = module.params["ownca_version"]
self.serial_number = x509.random_serial_number()
self.ca_cert_path = module.params["ownca_path"]
self.ca_cert_content = module.params["ownca_content"]
if self.ca_cert_content is not None:
self.ca_cert_content = self.ca_cert_content.encode("utf-8")
self.ca_privatekey_path = module.params["ownca_privatekey_path"]
self.ca_privatekey_content = module.params["ownca_privatekey_content"]
if self.ca_privatekey_content is not None:
self.ca_privatekey_content = self.ca_privatekey_content.encode("utf-8")
self.ca_privatekey_passphrase = module.params["ownca_privatekey_passphrase"]
self.ca_cert_path: str | None = module.params["ownca_path"]
ca_cert_content: str | None = module.params["ownca_content"]
if ca_cert_content is not None:
self.ca_cert_content: bytes | None = ca_cert_content.encode("utf-8")
else:
self.ca_cert_content = None
self.ca_privatekey_path: str | None = module.params["ownca_privatekey_path"]
ca_privatekey_content: str | None = module.params["ownca_privatekey_content"]
if ca_privatekey_content is not None:
self.ca_privatekey_content: bytes | None = ca_privatekey_content.encode(
"utf-8"
)
else:
self.ca_privatekey_content = None
self.ca_privatekey_passphrase: str | None = module.params[
"ownca_privatekey_passphrase"
]
if self.csr_content is None and self.csr_path is None:
raise CertificateError(
"csr_path or csr_content is required for ownca provider"
)
if self.csr_content is None and not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
if self.ca_cert_content is None and not os.path.exists(self.ca_cert_path):
if self.csr_content is None:
if self.csr_path is None:
raise CertificateError(
"csr_path or csr_content is required for ownca provider"
)
if not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
if self.ca_cert_path is not None and not os.path.exists(self.ca_cert_path):
raise CertificateError(
f"The CA certificate file {self.ca_cert_path} does not exist"
)
if self.ca_privatekey_content is None and not os.path.exists(
if self.ca_privatekey_path is not None and not os.path.exists(
self.ca_privatekey_path
):
raise CertificateError(
@@ -101,8 +123,12 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
path=self.ca_cert_path,
content=self.ca_cert_content,
)
if not is_potential_certificate_issuer_public_key(self.ca_cert.public_key()):
raise CertificateError(
"CA certificate's public key cannot be used to sign certificates"
)
try:
self.ca_private_key = load_privatekey(
self.ca_private_key = load_certificate_issuer_privatekey(
path=self.ca_privatekey_path,
content=self.ca_privatekey_content,
passphrase=self.ca_privatekey_passphrase,
@@ -125,8 +151,10 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
else:
self.digest = None
def generate_certificate(self):
def generate_certificate(self) -> None:
"""(Re-)Generate certificate."""
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
cert_builder = x509.CertificateBuilder()
cert_builder = cert_builder.subject_name(self.csr.subject)
cert_builder = cert_builder.issuer_name(self.ca_cert.subject)
@@ -166,10 +194,10 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
critical=False,
)
except cryptography.x509.ExtensionNotFound:
public_key = self.ca_cert.public_key()
assert is_potential_certificate_issuer_public_key(public_key)
cert_builder = cert_builder.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_public_key(
self.ca_cert.public_key()
),
x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key),
critical=False,
)
@@ -180,17 +208,24 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
self.cert = certificate
def get_certificate_data(self):
def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert."""
if self.cert is None:
raise AssertionError("Contract violation: cert has not been populated")
return self.cert.public_bytes(Encoding.PEM)
def needs_regeneration(self):
def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
if super(OwnCACertificateBackendCryptography, self).needs_regeneration(
not_before=self.notBefore, not_after=self.notAfter
):
return True
self._ensure_existing_certificate_loaded()
assert self.existing_certificate is not None
# Check whether certificate is signed by CA certificate
if not cryptography_verify_certificate_signature(
@@ -205,31 +240,33 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
# Check AuthorityKeyIdentifier
if self.create_authority_key_identifier:
try:
ext = self.ca_cert.extensions.get_extension_for_class(
ext_ski = self.ca_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
expected_ext = (
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ext.value
ext_ski.value
)
)
except cryptography.x509.ExtensionNotFound:
public_key = self.ca_cert.public_key()
assert is_potential_certificate_issuer_public_key(public_key)
expected_ext = x509.AuthorityKeyIdentifier.from_issuer_public_key(
self.ca_cert.public_key()
public_key
)
try:
ext = self.existing_certificate.extensions.get_extension_for_class(
ext_aki = self.existing_certificate.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier
)
if ext.value != expected_ext:
if ext_aki.value != expected_ext:
return True
except cryptography.x509.ExtensionNotFound:
return True
return False
def dump(self, include_certificate):
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(OwnCACertificateBackendCryptography, self).dump(
include_certificate
)
@@ -251,6 +288,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
else:
if self.cert is None:
self.cert = self.existing_certificate
assert self.cert is not None
result.update(
{
"notBefore": get_not_valid_before(self.cert).strftime(
@@ -266,7 +304,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
return result
def generate_serial_number():
def generate_serial_number() -> int:
"""Generate a serial number for a certificate"""
while True:
result = randrange(0, 1 << 160)
@@ -275,7 +313,7 @@ def generate_serial_number():
class OwnCACertificateProvider(CertificateProvider):
def validate_module_args(self, module):
def validate_module_args(self, module: AnsibleModule) -> None:
if (
module.params["ownca_path"] is None
and module.params["ownca_content"] is None
@@ -291,14 +329,16 @@ class OwnCACertificateProvider(CertificateProvider):
msg="One of ownca_privatekey_path and ownca_privatekey_content must be specified for the ownca provider."
)
def needs_version_two_certs(self, module):
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
return module.params["ownca_version"] == 2
def create_backend(self, module):
def create_backend(
self, module: AnsibleModule
) -> OwnCACertificateBackendCryptography:
return OwnCACertificateBackendCryptography(module)
def add_ownca_provider_to_argument_spec(argument_spec):
def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("ownca")
argument_spec.argument_spec.update(
dict(

View File

@@ -6,6 +6,7 @@
from __future__ import annotations
import os
import typing as t
from random import randrange
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
@@ -14,6 +15,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
cryptography_verify_certificate_signature,
get_not_valid_after,
get_not_valid_before,
is_potential_certificate_issuer_private_key,
set_not_valid_after,
set_not_valid_before,
)
@@ -30,6 +32,17 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
)
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
)
from ...argspec import ArgumentSpec
try:
import cryptography
from cryptography import x509
@@ -39,12 +52,14 @@ except ImportError:
class SelfSignedCertificateBackendCryptography(CertificateBackend):
def __init__(self, module):
privatekey: CertificateIssuerPrivateKeyTypes
def __init__(self, module: AnsibleModule) -> None:
super(SelfSignedCertificateBackendCryptography, self).__init__(module)
self.create_subject_key_identifier = module.params[
"selfsigned_create_subject_key_identifier"
]
self.create_subject_key_identifier: t.Literal[
"create_if_not_provided", "always_create", "never_create"
] = module.params["selfsigned_create_subject_key_identifier"]
self.notBefore = get_relative_time_option(
module.params["selfsigned_not_before"],
"selfsigned_not_before",
@@ -56,14 +71,16 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["selfsigned_digest"])
self.version = module.params["selfsigned_version"]
self.version: int = module.params["selfsigned_version"]
self.serial_number = x509.random_serial_number()
if self.csr_path is not None and not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
if self.privatekey_content is None and not os.path.exists(self.privatekey_path):
if self.privatekey_path is not None and not os.path.exists(
self.privatekey_path
):
raise CertificateError(
f"The private key file {self.privatekey_path} does not exist"
)
@@ -71,20 +88,10 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
self._module = module
self._ensure_private_key_loaded()
self._ensure_csr_loaded()
if self.csr is None:
# Create empty CSR on the fly
csr = cryptography.x509.CertificateSigningRequestBuilder()
csr = csr.subject_name(cryptography.x509.Name([]))
digest = None
if cryptography_key_needs_digest_for_signing(self.privatekey):
digest = self.digest
if digest is None:
self.module.fail_json(
msg=f'Unsupported digest "{module.params["selfsigned_digest"]}"'
)
self.csr = csr.sign(self.privatekey, digest)
if self.privatekey is None:
raise CertificateError("Private key has not been provided")
if not is_potential_certificate_issuer_private_key(self.privatekey):
raise CertificateError("Private key cannot be used to sign certificates")
if cryptography_key_needs_digest_for_signing(self.privatekey):
if self.digest is None:
@@ -94,8 +101,21 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
else:
self.digest = None
def generate_certificate(self):
self._ensure_csr_loaded()
if self.csr is None:
# Create empty CSR on the fly
csr = cryptography.x509.CertificateSigningRequestBuilder()
csr = csr.subject_name(cryptography.x509.Name([]))
self.csr = csr.sign(self.privatekey, self.digest)
def generate_certificate(self) -> None:
"""(Re-)Generate certificate."""
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
if self.privatekey is None:
raise AssertionError(
"Contract violation: privatekey has not been populated"
)
try:
cert_builder = x509.CertificateBuilder()
cert_builder = cert_builder.subject_name(self.csr.subject)
@@ -130,17 +150,26 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
self.cert = certificate
def get_certificate_data(self):
def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert."""
if self.cert is None:
raise AssertionError("Contract violation: cert has not been populated")
return self.cert.public_bytes(Encoding.PEM)
def needs_regeneration(self):
def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
assert self.privatekey is not None
if super(SelfSignedCertificateBackendCryptography, self).needs_regeneration(
not_before=self.notBefore, not_after=self.notAfter
):
return True
self._ensure_existing_certificate_loaded()
assert self.existing_certificate is not None
# Check whether certificate is signed by private key
if not cryptography_verify_certificate_signature(
@@ -150,7 +179,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
return False
def dump(self, include_certificate):
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(SelfSignedCertificateBackendCryptography, self).dump(
include_certificate
)
@@ -166,6 +195,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
else:
if self.cert is None:
self.cert = self.existing_certificate
assert self.cert is not None
result.update(
{
"notBefore": get_not_valid_before(self.cert).strftime(
@@ -181,7 +211,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
return result
def generate_serial_number():
def generate_serial_number() -> int:
"""Generate a serial number for a certificate"""
while True:
result = randrange(0, 1 << 160)
@@ -190,7 +220,7 @@ def generate_serial_number():
class SelfSignedCertificateProvider(CertificateProvider):
def validate_module_args(self, module):
def validate_module_args(self, module: AnsibleModule) -> None:
if (
module.params["privatekey_path"] is None
and module.params["privatekey_content"] is None
@@ -199,14 +229,16 @@ class SelfSignedCertificateProvider(CertificateProvider):
msg="One of privatekey_path and privatekey_content must be specified for the selfsigned provider."
)
def needs_version_two_certs(self, module):
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
return module.params["selfsigned_version"] == 2
def create_backend(self, module):
def create_backend(
self, module: AnsibleModule
) -> SelfSignedCertificateBackendCryptography:
return SelfSignedCertificateBackendCryptography(module)
def add_selfsigned_provider_to_argument_spec(argument_spec):
def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("selfsigned")
argument_spec.argument_spec.update(
dict(

View File

@@ -4,6 +4,8 @@
from __future__ import annotations
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_crl import (
TIMESTAMP_FORMAT,
cryptography_decode_revoked_certificate,
@@ -22,6 +24,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
from ....plugin_utils.action_module import AnsibleActionModule
from ....plugin_utils.filter_module import FilterModuleMock
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
# crypto_utils
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
@@ -33,14 +47,19 @@ except ImportError:
class CRLInfoRetrieval:
def __init__(self, module, content, list_revoked_certificates=True):
def __init__(
self,
module: GeneralAnsibleModule,
content: bytes,
list_revoked_certificates: bool = True,
) -> None:
# content must be a bytes string
self.module = module
self.content = content
self.list_revoked_certificates = list_revoked_certificates
self.name_encoding = module.params.get("name_encoding", "ignore")
def get_info(self):
def get_info(self) -> dict[str, t.Any]:
self.crl_pem = identify_pem_format(self.content)
try:
if self.crl_pem:
@@ -50,7 +69,7 @@ class CRLInfoRetrieval:
except ValueError as e:
self.module.fail_json(msg=f"Error while decoding CRL: {e}")
result = {
result: dict[str, t.Any] = {
"changed": False,
"format": "pem" if self.crl_pem else "der",
"last_update": None,
@@ -61,7 +80,11 @@ class CRLInfoRetrieval:
}
result["last_update"] = self.crl.last_update.strftime(TIMESTAMP_FORMAT)
result["next_update"] = self.crl.next_update.strftime(TIMESTAMP_FORMAT)
result["next_update"] = (
self.crl.next_update.strftime(TIMESTAMP_FORMAT)
if self.crl.next_update
else None
)
result["digest"] = cryptography_oid_to_name(
cryptography_get_signature_algorithm_oid_from_crl(self.crl)
)
@@ -83,7 +106,9 @@ class CRLInfoRetrieval:
return result
def get_crl_info(module, content, list_revoked_certificates=True):
def get_crl_info(
module: GeneralAnsibleModule, content: bytes, list_revoked_certificates: bool = True
) -> dict[str, t.Any]:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import abc
import binascii
import typing as t
from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
@@ -26,13 +27,14 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
cryptography_name_to_oid,
cryptography_parse_key_usage_params,
cryptography_parse_relative_distinguished_name,
is_potential_certificate_issuer_public_key,
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr_info import (
get_csr_info,
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
load_certificate_issuer_privatekey,
load_certificate_request,
load_privatekey,
parse_name_field,
parse_ordered_name_field,
select_message_digest,
@@ -43,6 +45,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)
from ..cryptography_support import CertificatePrivateKeyTypes
_ET = t.TypeVar("_ET", bound="cryptography.x509.ExtensionType")
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
@@ -69,49 +83,58 @@ class CertificateSigningRequestError(OpenSSLObjectError):
class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.digest = module.params["digest"]
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
if self.privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.version = module.params["version"]
self.subjectAltName = module.params["subject_alt_name"]
self.subjectAltName_critical = module.params["subject_alt_name_critical"]
self.keyUsage = module.params["key_usage"]
self.keyUsage_critical = module.params["key_usage_critical"]
self.extendedKeyUsage = module.params["extended_key_usage"]
self.extendedKeyUsage_critical = module.params["extended_key_usage_critical"]
self.basicConstraints = module.params["basic_constraints"]
self.basicConstraints_critical = module.params["basic_constraints_critical"]
self.ocspMustStaple = module.params["ocsp_must_staple"]
self.ocspMustStaple_critical = module.params["ocsp_must_staple_critical"]
self.name_constraints_permitted = (
self.digest: str = module.params["digest"]
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
if privatekey_content is not None:
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.version: t.Literal[1] = module.params["version"]
self.subjectAltName: list[str] | None = module.params["subject_alt_name"]
self.subjectAltName_critical: bool = module.params["subject_alt_name_critical"]
self.keyUsage: list[str] | None = module.params["key_usage"]
self.keyUsage_critical: bool = module.params["key_usage_critical"]
self.extendedKeyUsage: list[str] | None = module.params["extended_key_usage"]
self.extendedKeyUsage_critical: bool = module.params[
"extended_key_usage_critical"
]
self.basicConstraints: list[str] | None = module.params["basic_constraints"]
self.basicConstraints_critical: bool = module.params[
"basic_constraints_critical"
]
self.ocspMustStaple: bool = module.params["ocsp_must_staple"]
self.ocspMustStaple_critical: bool = module.params["ocsp_must_staple_critical"]
self.name_constraints_permitted: list[str] = (
module.params["name_constraints_permitted"] or []
)
self.name_constraints_excluded = (
self.name_constraints_excluded: list[str] = (
module.params["name_constraints_excluded"] or []
)
self.name_constraints_critical = module.params["name_constraints_critical"]
self.create_subject_key_identifier = module.params[
self.name_constraints_critical: bool = module.params[
"name_constraints_critical"
]
self.create_subject_key_identifier: bool = module.params[
"create_subject_key_identifier"
]
self.subject_key_identifier = module.params["subject_key_identifier"]
self.authority_key_identifier = module.params["authority_key_identifier"]
self.authority_cert_issuer = module.params["authority_cert_issuer"]
self.authority_cert_serial_number = module.params[
subject_key_identifier: str | None = module.params["subject_key_identifier"]
authority_key_identifier: str | None = module.params["authority_key_identifier"]
self.authority_cert_issuer: list[str] | None = module.params[
"authority_cert_issuer"
]
self.authority_cert_serial_number: int = module.params[
"authority_cert_serial_number"
]
self.crl_distribution_points = module.params["crl_distribution_points"]
self.csr = None
self.privatekey = None
self.crl_distribution_points: (
list[cryptography.x509.DistributionPoint] | None
) = None
self.csr: cryptography.x509.CertificateSigningRequest | None = None
self.privatekey: CertificateIssuerPrivateKeyTypes | None = None
if (
self.create_subject_key_identifier
and self.subject_key_identifier is not None
):
if self.create_subject_key_identifier and subject_key_identifier is not None:
module.fail_json(
msg="subject_key_identifier cannot be specified if create_subject_key_identifier is true"
)
@@ -153,35 +176,37 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self.using_common_name_for_san = True
break
if self.subject_key_identifier is not None:
self.subject_key_identifier: bytes | None = None
if subject_key_identifier is not None:
try:
self.subject_key_identifier = binascii.unhexlify(
self.subject_key_identifier.replace(":", "")
subject_key_identifier.replace(":", "")
)
except Exception as e:
raise CertificateSigningRequestError(
f"Cannot parse subject_key_identifier: {e}"
)
if self.authority_key_identifier is not None:
self.authority_key_identifier: bytes | None = None
if authority_key_identifier is not None:
try:
self.authority_key_identifier = binascii.unhexlify(
self.authority_key_identifier.replace(":", "")
authority_key_identifier.replace(":", "")
)
except Exception as e:
raise CertificateSigningRequestError(
f"Cannot parse authority_key_identifier: {e}"
)
self.existing_csr = None
self.existing_csr_bytes = None
self.existing_csr: cryptography.x509.CertificateSigningRequest | None = None
self.existing_csr_bytes: bytes | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
def _get_info(self, data):
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return dict()
return {}
try:
result = get_csr_info(
self.module,
@@ -195,30 +220,28 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
return dict(can_parse_csr=False)
@abc.abstractmethod
def generate_csr(self):
def generate_csr(self) -> None:
"""(Re-)Generate CSR."""
pass
@abc.abstractmethod
def get_csr_data(self):
def get_csr_data(self) -> bytes:
"""Return bytes for self.csr."""
pass
def set_existing(self, csr_bytes):
def set_existing(self, csr_bytes: bytes | None) -> None:
"""Set existing CSR bytes. None indicates that the CSR does not exist."""
self.existing_csr_bytes = csr_bytes
self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes)
def has_existing(self):
def has_existing(self) -> bool:
"""Query whether an existing CSR is/has been there."""
return self.existing_csr_bytes is not None
def _ensure_private_key_loaded(self):
def _ensure_private_key_loaded(self) -> None:
"""Load the provided private key into self.privatekey."""
if self.privatekey is not None:
return
try:
self.privatekey = load_privatekey(
self.privatekey = load_certificate_issuer_privatekey(
path=self.privatekey_path,
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
@@ -227,11 +250,10 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
raise CertificateSigningRequestError(exc)
@abc.abstractmethod
def _check_csr(self):
def _check_csr(self) -> bool:
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
pass
def needs_regeneration(self):
def needs_regeneration(self) -> bool:
"""Check whether a regeneration is necessary."""
if self.existing_csr_bytes is None:
return True
@@ -245,9 +267,9 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self._ensure_private_key_loaded()
return not self._check_csr()
def dump(self, include_csr):
def dump(self, include_csr: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = {
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
"subject": self.subject,
"subjectAltName": self.subjectAltName,
@@ -274,44 +296,49 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
return result
def parse_crl_distribution_points(module, crl_distribution_points):
def parse_crl_distribution_points(
module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
) -> list[cryptography.x509.DistributionPoint]:
result = []
for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
try:
params = dict(
full_name=None,
relative_name=None,
crl_issuer=None,
reasons=None,
)
full_name = None
relative_name = None
crl_issuer = None
reasons = None
if parse_crl_distribution_point["full_name"] is not None:
if not parse_crl_distribution_point["full_name"]:
raise OpenSSLObjectError("full_name must not be empty")
params["full_name"] = [
full_name = [
cryptography_get_name(name, "full name")
for name in parse_crl_distribution_point["full_name"]
]
if parse_crl_distribution_point["relative_name"] is not None:
if not parse_crl_distribution_point["relative_name"]:
raise OpenSSLObjectError("relative_name must not be empty")
params["relative_name"] = (
cryptography_parse_relative_distinguished_name(
parse_crl_distribution_point["relative_name"]
)
relative_name = cryptography_parse_relative_distinguished_name(
parse_crl_distribution_point["relative_name"]
)
if parse_crl_distribution_point["crl_issuer"] is not None:
if not parse_crl_distribution_point["crl_issuer"]:
raise OpenSSLObjectError("crl_issuer must not be empty")
params["crl_issuer"] = [
crl_issuer = [
cryptography_get_name(name, "CRL issuer")
for name in parse_crl_distribution_point["crl_issuer"]
]
if parse_crl_distribution_point["reasons"] is not None:
reasons = []
reasons_list = []
for reason in parse_crl_distribution_point["reasons"]:
reasons.append(REVOCATION_REASON_MAP[reason])
params["reasons"] = frozenset(reasons)
result.append(cryptography.x509.DistributionPoint(**params))
reasons_list.append(REVOCATION_REASON_MAP[reason])
reasons = frozenset(reasons_list)
result.append(
cryptography.x509.DistributionPoint(
full_name=full_name,
relative_name=relative_name,
crl_issuer=crl_issuer,
reasons=reasons,
)
)
except (OpenSSLObjectError, ValueError) as e:
raise OpenSSLObjectError(
f"Error while parsing CRL distribution point #{index}: {e}"
@@ -321,21 +348,25 @@ def parse_crl_distribution_points(module, crl_distribution_points):
# Implementation with using cryptography
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(CertificateSigningRequestCryptographyBackend, self).__init__(module)
if self.version != 1:
module.warn(
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
)
if self.crl_distribution_points:
crl_distribution_points: list[dict[str, t.Any]] | None = module.params[
"crl_distribution_points"
]
if crl_distribution_points:
self.crl_distribution_points = parse_crl_distribution_points(
module, self.crl_distribution_points
module, crl_distribution_points
)
def generate_csr(self):
def generate_csr(self) -> None:
"""(Re-)Generate CSR."""
self._ensure_private_key_loaded()
assert self.privatekey is not None
csr = cryptography.x509.CertificateSigningRequestBuilder()
try:
@@ -412,6 +443,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
raise OpenSSLObjectError(f"Error while parsing name constraint: {e}")
if self.create_subject_key_identifier:
if not is_potential_certificate_issuer_public_key(
self.privatekey.public_key()
):
raise OpenSSLObjectError(
"Private key can not be used to create subject key identifier"
)
csr = csr.add_extension(
cryptography.x509.SubjectKeyIdentifier.from_public_key(
self.privatekey.public_key()
@@ -450,7 +487,10 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
critical=False,
)
digest = None
# csr.sign() does not accept some digests we theoretically could have in digest.
# For that reason we use type t.Any here. csr.sign() will complain if
# the digest is not acceptable.
digest: t.Any | None = None
if cryptography_key_needs_digest_for_signing(self.privatekey):
digest = select_message_digest(self.digest)
if digest is None:
@@ -482,16 +522,22 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
+ "This is probably caused by an invalid Subject Alternative DNS Name."
)
def get_csr_data(self):
def get_csr_data(self) -> bytes:
"""Return bytes for self.csr."""
if self.csr is None:
raise AssertionError("Violated contract: csr is not populated")
return self.csr.public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM
)
def _check_csr(self):
def _check_csr(self) -> bool:
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
if self.existing_csr is None:
raise AssertionError("Violated contract: existing_csr is not populated")
if self.privatekey is None:
raise AssertionError("Violated contract: privatekey is not populated")
def _check_subject(csr):
def _check_subject(csr: cryptography.x509.CertificateSigningRequest) -> bool:
subject = [
(cryptography_name_to_oid(entry[0]), to_text(entry[1]))
for entry in self.subject
@@ -502,12 +548,14 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else:
return set(subject) == set(current_subject)
def _find_extension(extensions, exttype):
def _find_extension(
extensions: cryptography.x509.Extensions, exttype: type[_ET]
) -> cryptography.x509.Extension[_ET] | None:
return next(
(ext for ext in extensions if isinstance(ext.value, exttype)), None
)
def _check_subjectAltName(extensions):
def _check_subjectAltName(extensions: cryptography.x509.Extensions) -> bool:
current_altnames_ext = _find_extension(
extensions, cryptography.x509.SubjectAlternativeName
)
@@ -526,12 +574,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
)
if set(altnames) != set(current_altnames):
return False
if altnames:
if altnames and current_altnames_ext:
if current_altnames_ext.critical != self.subjectAltName_critical:
return False
return True
def _check_keyUsage(extensions):
def _check_keyUsage(extensions: cryptography.x509.Extensions) -> bool:
current_keyusage_ext = _find_extension(
extensions, cryptography.x509.KeyUsage
)
@@ -547,7 +595,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
return False
return True
def _check_extenededKeyUsage(extensions):
def _check_extenededKeyUsage(extensions: cryptography.x509.Extensions) -> bool:
current_usages_ext = _find_extension(
extensions, cryptography.x509.ExtendedKeyUsage
)
@@ -566,12 +614,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
)
if set(current_usages) != set(usages):
return False
if usages:
if usages and current_usages_ext:
if current_usages_ext.critical != self.extendedKeyUsage_critical:
return False
return True
def _check_basicConstraints(extensions):
def _check_basicConstraints(extensions: cryptography.x509.Extensions) -> bool:
bc_ext = _find_extension(extensions, cryptography.x509.BasicConstraints)
current_ca = bc_ext.value.ca if bc_ext else False
current_path_length = bc_ext.value.path_length if bc_ext else None
@@ -591,7 +639,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else:
return bc_ext is None
def _check_ocspMustStaple(extensions):
def _check_ocspMustStaple(extensions: cryptography.x509.Extensions) -> bool:
tlsfeature_ext = _find_extension(extensions, cryptography.x509.TLSFeature)
if self.ocspMustStaple:
if (
@@ -606,7 +654,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else:
return tlsfeature_ext is None
def _check_nameConstraints(extensions):
def _check_nameConstraints(extensions: cryptography.x509.Extensions) -> bool:
current_nc_ext = _find_extension(
extensions, cryptography.x509.NameConstraints
)
@@ -638,12 +686,14 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
current_nc_excl
):
return False
if nc_perm or nc_excl:
if (nc_perm or nc_excl) and current_nc_ext:
if current_nc_ext.critical != self.name_constraints_critical:
return False
return True
def _check_subject_key_identifier(extensions):
def _check_subject_key_identifier(
extensions: cryptography.x509.Extensions,
) -> bool:
ext = _find_extension(extensions, cryptography.x509.SubjectKeyIdentifier)
if (
self.create_subject_key_identifier
@@ -652,6 +702,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
if not ext or ext.critical:
return False
if self.create_subject_key_identifier:
assert self.privatekey is not None
digest = cryptography.x509.SubjectKeyIdentifier.from_public_key(
self.privatekey.public_key()
).digest
@@ -661,7 +712,9 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else:
return ext is None
def _check_authority_key_identifier(extensions):
def _check_authority_key_identifier(
extensions: cryptography.x509.Extensions,
) -> bool:
ext = _find_extension(extensions, cryptography.x509.AuthorityKeyIdentifier)
if (
self.authority_key_identifier is not None
@@ -688,7 +741,9 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else:
return ext is None
def _check_crl_distribution_points(extensions):
def _check_crl_distribution_points(
extensions: cryptography.x509.Extensions,
) -> bool:
ext = _find_extension(extensions, cryptography.x509.CRLDistributionPoints)
if self.crl_distribution_points is None:
return ext is None
@@ -696,7 +751,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
return False
return list(ext.value) == self.crl_distribution_points
def _check_extensions(csr):
def _check_extensions(csr: cryptography.x509.CertificateSigningRequest) -> bool:
extensions = csr.extensions
return (
_check_subjectAltName(extensions)
@@ -710,7 +765,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
and _check_crl_distribution_points(extensions)
)
def _check_signature(csr):
def _check_signature(csr: cryptography.x509.CertificateSigningRequest) -> bool:
if not csr.is_signature_valid:
return False
# To check whether public key of CSR belongs to private key,
@@ -719,6 +774,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
)
assert self.privatekey is not None
key_b = self.privatekey.public_key().public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
@@ -732,14 +788,16 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
)
def select_backend(module):
def select_backend(
module: AnsibleModule,
) -> CertificateSigningRequestCryptographyBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateSigningRequestCryptographyBackend(module)
def get_csr_argument_spec():
def get_csr_argument_spec() -> ArgumentSpec:
return ArgumentSpec(
argument_spec=dict(
digest=dict(type="str", default="sha256"),

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import abc
import binascii
import typing as t
from ansible.module_utils.common.text.converters import to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
@@ -27,6 +28,19 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificatePublicKeyTypes,
PrivateKeyTypes,
)
from ....plugin_utils.action_module import AnsibleActionModule
from ....plugin_utils.filter_module import FilterModuleMock
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
@@ -41,66 +55,69 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CSRInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(self, module, content, validate_signature):
# content must be a bytes string
def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
self.module = module
self.content = content
self.validate_signature = validate_signature
@abc.abstractmethod
def _get_subject_ordered(self):
def _get_subject_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_key_usage(self):
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_extended_key_usage(self):
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_basic_constraints(self):
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_ocsp_must_staple(self):
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
pass
@abc.abstractmethod
def _get_subject_alt_name(self):
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_name_constraints(self):
def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_public_key_pem(self):
def _get_public_key_pem(self) -> bytes:
pass
@abc.abstractmethod
def _get_public_key_object(self):
def _get_public_key_object(self) -> CertificatePublicKeyTypes:
pass
@abc.abstractmethod
def _get_subject_key_identifier(self):
def _get_subject_key_identifier(self) -> bytes | None:
pass
@abc.abstractmethod
def _get_authority_key_identifier(self):
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
pass
@abc.abstractmethod
def _get_all_extensions(self):
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
pass
@abc.abstractmethod
def _is_signature_valid(self):
def _is_signature_valid(self) -> bool:
pass
def get_info(self, prefer_one_fingerprint=False):
result = dict()
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.csr = load_certificate_request(
None,
content=self.content,
@@ -145,15 +162,17 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
}
)
ski = self._get_subject_key_identifier()
if ski is not None:
ski = binascii.hexlify(ski).decode("ascii")
ski_bytes = self._get_subject_key_identifier()
ski = None
if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
result["subject_key_identifier"] = ski
aki, aci, acsn = self._get_authority_key_identifier()
if aki is not None:
aki = binascii.hexlify(aki).decode("ascii")
aki_bytes, aci, acsn = self._get_authority_key_identifier()
aki = None
if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
@@ -170,19 +189,25 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
"""Validate the supplied CSR, using the cryptography backend"""
def __init__(self, module, content, validate_signature):
def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
super(CSRInfoRetrievalCryptography, self).__init__(
module, content, validate_signature
)
self.name_encoding = module.params.get("name_encoding", "ignore")
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get(
"name_encoding", "ignore"
)
def _get_subject_ordered(self):
result = []
def _get_subject_ordered(self) -> list[list[str]]:
result: list[list[str]] = []
for attribute in self.csr.subject:
result.append([cryptography_oid_to_name(attribute.oid), attribute.value])
result.append(
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
)
return result
def _get_key_usage(self):
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
try:
current_key_ext = self.csr.extensions.get_extension_for_class(x509.KeyUsage)
current_key_usage = current_key_ext.value
@@ -229,7 +254,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_extended_key_usage(self):
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
try:
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
x509.ExtendedKeyUsage
@@ -243,7 +268,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_basic_constraints(self):
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
try:
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
x509.BasicConstraints
@@ -255,7 +280,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_ocsp_must_staple(self):
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
try:
# This only works with cryptography >= 2.1
tlsfeature_ext = self.csr.extensions.get_extension_for_class(
@@ -268,7 +293,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_subject_alt_name(self):
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
try:
san_ext = self.csr.extensions.get_extension_for_class(
x509.SubjectAlternativeName
@@ -281,7 +306,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_name_constraints(self):
def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
try:
nc_ext = self.csr.extensions.get_extension_for_class(x509.NameConstraints)
permitted = [
@@ -296,23 +321,25 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, None, False
def _get_public_key_pem(self):
def _get_public_key_pem(self) -> bytes:
return self.csr.public_key().public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_public_key_object(self):
def _get_public_key_object(self) -> CertificatePublicKeyTypes:
return self.csr.public_key()
def _get_subject_key_identifier(self):
def _get_subject_key_identifier(self) -> bytes | None:
try:
ext = self.csr.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
return ext.value.digest
except cryptography.x509.ExtensionNotFound:
return None
def _get_authority_key_identifier(self):
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
try:
ext = self.csr.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier
@@ -331,23 +358,28 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound:
return None, None, None
def _get_all_extensions(self):
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
return cryptography_get_extensions_from_csr(self.csr)
def _is_signature_valid(self):
def _is_signature_valid(self) -> bool:
return self.csr.is_signature_valid
def get_csr_info(
module, content, validate_signature=True, prefer_one_fingerprint=False
):
module: GeneralAnsibleModule,
content: bytes,
validate_signature: bool = True,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(module, content, validate_signature=True):
def select_backend(
module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
) -> CSRInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import abc
import base64
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
@@ -33,6 +34,17 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
from ....plugin_utils.action_module import AnsibleActionModule
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
@@ -64,29 +76,37 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyBackend(metaclass=abc.ABCMeta):
def __init__(self, module):
def __init__(self, module: GeneralAnsibleModule) -> None:
self.module = module
self.type = module.params["type"]
self.size = module.params["size"]
self.curve = module.params["curve"]
self.passphrase = module.params["passphrase"]
self.cipher = module.params["cipher"]
self.format = module.params["format"]
self.format_mismatch = module.params.get("format_mismatch", "regenerate")
self.regenerate = module.params.get("regenerate", "full_idempotence")
self.type: t.Literal[
"DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"
] = module.params["type"]
self.size: int = module.params["size"]
self.curve: str | None = module.params["curve"]
self.passphrase: str | None = module.params["passphrase"]
self.cipher: str = module.params["cipher"]
self.format: t.Literal["pkcs1", "pkcs8", "raw", "auto", "auto_ignore"] = (
module.params["format"]
)
self.format_mismatch: t.Literal["regenerate", "convert"] = module.params.get(
"format_mismatch", "regenerate"
)
self.regenerate: t.Literal[
"never", "fail", "partial_idempotence", "full_idempotence", "always"
] = module.params.get("regenerate", "full_idempotence")
self.private_key = None
self.private_key: PrivateKeyTypes | None = None
self.existing_private_key = None
self.existing_private_key_bytes = None
self.existing_private_key: PrivateKeyTypes | None = None
self.existing_private_key_bytes: bytes | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
def _get_info(self, data):
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return dict()
result = dict(can_parse_key=False)
return {}
result: dict[str, t.Any] = {"can_parse_key": False}
try:
result.update(
get_privatekey_info(
@@ -106,11 +126,11 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
return result
@abc.abstractmethod
def generate_private_key(self):
def generate_private_key(self) -> None:
"""(Re-)Generate private key."""
pass
def convert_private_key(self):
def convert_private_key(self) -> None:
"""Convert existing private key (self.existing_private_key) to new private key (self.private_key).
This is effectively a copy without active conversion. The conversion is done
@@ -121,42 +141,37 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
self.private_key = self.existing_private_key
@abc.abstractmethod
def get_private_key_data(self):
def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key."""
pass
def set_existing(self, privatekey_bytes):
def set_existing(self, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.existing_private_key_bytes = privatekey_bytes
self.diff_after = self.diff_before = self._get_info(
self.existing_private_key_bytes
)
def has_existing(self):
def has_existing(self) -> bool:
"""Query whether an existing private key is/has been there."""
return self.existing_private_key_bytes is not None
@abc.abstractmethod
def _check_passphrase(self):
def _check_passphrase(self) -> bool:
"""Check whether provided passphrase matches, assuming self.existing_private_key_bytes has been populated."""
pass
@abc.abstractmethod
def _ensure_existing_private_key_loaded(self):
def _ensure_existing_private_key_loaded(self) -> None:
"""Make sure that self.existing_private_key is populated from self.existing_private_key_bytes."""
pass
@abc.abstractmethod
def _check_size_and_type(self):
def _check_size_and_type(self) -> bool:
"""Check whether provided size and type matches, assuming self.existing_private_key has been populated."""
pass
@abc.abstractmethod
def _check_format(self):
def _check_format(self) -> bool:
"""Check whether the key file format, assuming self.existing_private_key and self.existing_private_key_bytes has been populated."""
pass
def needs_regeneration(self):
def needs_regeneration(self) -> bool:
"""Check whether a regeneration is necessary."""
if self.regenerate == "always":
return True
@@ -194,7 +209,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
)
return False
def needs_conversion(self):
def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
# During conversion step, convert if format does not match and format_mismatch == 'convert'
self._ensure_existing_private_key_loaded()
@@ -204,7 +219,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
and not self._check_format()
)
def _get_fingerprint(self):
def _get_fingerprint(self) -> dict[str, str] | None:
if self.private_key:
return get_fingerprint_of_privatekey(self.private_key)
try:
@@ -214,8 +229,9 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
pass
if self.existing_private_key:
return get_fingerprint_of_privatekey(self.existing_private_key)
return None
def dump(self, include_key):
def dump(self, include_key: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
if not self.private_key:
@@ -224,7 +240,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
except Exception:
# Ignore errors
pass
result = {
result: dict[str, t.Any] = {
"type": self.type,
"size": self.size,
"fingerprint": self._get_fingerprint(),
@@ -253,38 +269,57 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
return result
# Implementation with using cryptography
class PrivateKeyCryptographyBackend(PrivateKeyBackend):
class _Curve:
def __init__(
self,
name: str,
ectype: str,
deprecated: bool,
) -> None:
self.name = name
self.ectype = ectype
self.deprecated = deprecated
def _get_ec_class(self, ectype):
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(ectype)
def _get_ec_class(
self, module: GeneralAnsibleModule
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore
if ecclass is None:
self.module.fail_json(
msg=f"Your cryptography version does not support {ectype}"
module.fail_json(
msg=f"Your cryptography version does not support {self.ectype}"
)
return ecclass
def _add_curve(self, name, ectype, deprecated=False):
def create(size):
ecclass = self._get_ec_class(ectype)
return ecclass()
def create(
self, size: int, module: GeneralAnsibleModule
) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
ecclass = self._get_ec_class(module)
return ecclass()
def verify(privatekey):
ecclass = self._get_ec_class(ectype)
return isinstance(
privatekey.private_numbers().public_numbers.curve, ecclass
)
def verify(
self,
privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
module: GeneralAnsibleModule,
) -> bool:
ecclass = self._get_ec_class(module)
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
self.curves[name] = {
"create": create,
"verify": verify,
"deprecated": deprecated,
}
def __init__(self, module):
# Implementation with using cryptography
class PrivateKeyCryptographyBackend(PrivateKeyBackend):
def _add_curve(
self,
name: str,
ectype: str,
deprecated: bool = False,
) -> None:
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
def __init__(self, module: GeneralAnsibleModule) -> None:
super(PrivateKeyCryptographyBackend, self).__init__(module=module)
self.curves = dict()
self.curves: dict[str, _Curve] = {}
self._add_curve("secp224r1", "SECP224R1")
self._add_curve("secp256k1", "SECP256K1")
self._add_curve("secp256r1", "SECP256R1")
@@ -305,15 +340,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
self._add_curve("brainpoolP384r1", "BrainpoolP384R1", deprecated=True)
self._add_curve("brainpoolP512r1", "BrainpoolP512R1", deprecated=True)
def _get_wanted_format(self):
def _get_wanted_format(self) -> t.Literal["pkcs1", "pkcs8", "raw"]:
if self.format not in ("auto", "auto_ignore"):
return self.format
return self.format # type: ignore
if self.type in ("X25519", "X448", "Ed25519", "Ed448"):
return "pkcs8"
else:
return "pkcs1"
def generate_private_key(self):
def generate_private_key(self) -> None:
"""(Re-)Generate private key."""
try:
if self.type == "RSA":
@@ -346,13 +381,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate()
)
if self.type == "ECC" and self.curve in self.curves:
if self.curves[self.curve]["deprecated"]:
if self.curves[self.curve].deprecated:
self.module.warn(
f"Elliptic curves of type {self.curve} should not be used for new keys!"
)
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
curve=self.curves[self.curve]["create"](self.size),
curve=self.curves[self.curve].create(
size=self.size, module=self.module
),
)
)
except cryptography.exceptions.UnsupportedAlgorithm:
@@ -360,22 +397,24 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
msg=f"Cryptography backend does not support the algorithm required for {self.type}"
)
def get_private_key_data(self):
def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key"""
if self.private_key is None:
raise AssertionError("private_key not set")
# Select export format and encoding
try:
export_format = self._get_wanted_format()
export_format_txt = self._get_wanted_format()
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
if export_format == "pkcs1":
if export_format_txt == "pkcs1":
# "TraditionalOpenSSL" format is PKCS1
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
)
elif export_format == "pkcs8":
elif export_format_txt == "pkcs8":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
)
elif export_format == "raw":
elif export_format_txt == "raw":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
)
@@ -388,9 +427,9 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
)
# Select key encryption
encryption_algorithm = (
cryptography.hazmat.primitives.serialization.NoEncryption()
)
encryption_algorithm: (
cryptography.hazmat.primitives.serialization.KeySerializationEncryption
) = cryptography.hazmat.primitives.serialization.NoEncryption()
if self.cipher and self.passphrase:
if self.cipher == "auto":
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
@@ -418,8 +457,10 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
exception=traceback.format_exc(),
)
def _load_privatekey(self):
def _load_privatekey(self) -> PrivateKeyTypes:
data = self.existing_private_key_bytes
if data is None:
raise AssertionError("existing_private_key_bytes not set")
try:
# Interpret bytes depending on format.
format = identify_private_key_format(data)
@@ -460,11 +501,13 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
except Exception as e:
raise PrivateKeyError(e)
def _ensure_existing_private_key_loaded(self):
def _ensure_existing_private_key_loaded(self) -> None:
if self.existing_private_key is None and self.has_existing():
self.existing_private_key = self._load_privatekey()
def _check_passphrase(self):
def _check_passphrase(self) -> bool:
if self.existing_private_key_bytes is None:
raise AssertionError("existing_private_key_bytes not set")
try:
format = identify_private_key_format(self.existing_private_key_bytes)
if format == "raw":
@@ -475,7 +518,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
# provided.
return self.passphrase is None
else:
return (
return bool(
cryptography.hazmat.primitives.serialization.load_pem_private_key(
self.existing_private_key_bytes,
None if self.passphrase is None else to_bytes(self.passphrase),
@@ -484,7 +527,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
except Exception:
return False
def _check_size_and_type(self):
def _check_size_and_type(self) -> bool:
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey,
@@ -527,11 +570,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
return False
if self.curve not in self.curves:
return False
return self.curves[self.curve]["verify"](self.existing_private_key)
return self.curves[self.curve].verify(
self.existing_private_key, module=self.module
)
return False
def _check_format(self):
def _check_format(self) -> bool:
if self.existing_private_key_bytes is None:
raise AssertionError("existing_private_key_bytes not set")
if self.format == "auto_ignore":
return True
try:
@@ -541,14 +588,14 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
return False
def select_backend(module):
def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyCryptographyBackend(module)
def get_privatekey_argument_spec():
def get_privatekey_argument_spec() -> ArgumentSpec:
return ArgumentSpec(
argument_spec=dict(
size=dict(type="int", default=4096),
@@ -607,6 +654,6 @@ def get_privatekey_argument_spec():
),
),
required_if=[
["type", "ECC", ["curve"]],
("type", "ECC", ["curve"]),
],
)

View File

@@ -6,6 +6,7 @@ from __future__ import annotations
import abc
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
@@ -27,6 +28,13 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
from ansible_collections.community.crypto.plugins.module_utils.io import load_file
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
@@ -58,42 +66,48 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.src_path = module.params["src_path"]
self.src_content = module.params["src_content"]
self.src_passphrase = module.params["src_passphrase"]
self.format = module.params["format"]
self.dest_passphrase = module.params["dest_passphrase"]
self.src_path: str | None = module.params["src_path"]
self.src_content: str | None = module.params["src_content"]
self.src_passphrase: str | None = module.params["src_passphrase"]
self.format: t.Literal["pkcs1", "pkcs8", "raw"] = module.params["format"]
self.dest_passphrase: str | None = module.params["dest_passphrase"]
self.src_private_key = None
self.src_private_key: PrivateKeyTypes | None = None
if self.src_path is not None:
self.src_private_key_bytes = load_file(self.src_path, module)
else:
if self.src_content is None:
raise AssertionError("src_content is None")
self.src_private_key_bytes = self.src_content.encode("utf-8")
self.dest_private_key = None
self.dest_private_key_bytes = None
self.dest_private_key: PrivateKeyTypes | None = None
self.dest_private_key_bytes: bytes | None = None
@abc.abstractmethod
def get_private_key_data(self):
def get_private_key_data(self) -> bytes:
"""Return bytes for self.src_private_key in output format."""
pass
def set_existing_destination(self, privatekey_bytes):
def set_existing_destination(self, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.dest_private_key_bytes = privatekey_bytes
def has_existing_destination(self):
def has_existing_destination(self) -> bool:
"""Query whether an existing private key is/has been there."""
return self.dest_private_key_bytes is not None
@abc.abstractmethod
def _load_private_key(self, data, passphrase, current_hint=None):
def _load_private_key(
self,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
) -> tuple[str, PrivateKeyTypes]:
"""Check whether data can be loaded as a private key with the provided passphrase. Return tuple (type, private_key)."""
pass
def needs_conversion(self):
def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
dummy, self.src_private_key = self._load_private_key(
self.src_private_key_bytes, self.src_passphrase
@@ -101,6 +115,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
if not self.has_existing_destination():
return True
assert self.dest_private_key_bytes is not None
try:
format, self.dest_private_key = self._load_private_key(
@@ -115,18 +130,20 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
self.dest_private_key, self.src_private_key
)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
return {}
# Implementation with using cryptography
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module)
def get_private_key_data(self):
def get_private_key_data(self) -> bytes:
"""Return bytes for self.src_private_key in output format"""
if self.src_private_key is None:
raise AssertionError("src_private_key not set")
# Select export format and encoding
try:
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
@@ -152,9 +169,9 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
)
# Select key encryption
encryption_algorithm = (
cryptography.hazmat.primitives.serialization.NoEncryption()
)
encryption_algorithm: (
cryptography.hazmat.primitives.serialization.KeySerializationEncryption
) = cryptography.hazmat.primitives.serialization.NoEncryption()
if self.dest_passphrase:
encryption_algorithm = (
cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
@@ -179,7 +196,12 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
exception=traceback.format_exc(),
)
def _load_private_key(self, data, passphrase, current_hint=None):
def _load_private_key(
self,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
) -> tuple[str, PrivateKeyTypes]:
try:
# Interpret bytes depending on format.
format = identify_private_key_format(data)
@@ -247,14 +269,14 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
raise PrivateKeyError(e)
def select_backend(module):
def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyConvertCryptographyBackend(module)
def get_privatekey_argument_spec():
def get_privatekey_argument_spec() -> ArgumentSpec:
return ArgumentSpec(
argument_spec=dict(
src_path=dict(type="path"),

View File

@@ -7,6 +7,7 @@
from __future__ import annotations
import abc
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -29,6 +30,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
from ....plugin_utils.action_module import AnsibleActionModule
from ....plugin_utils.filter_module import FilterModuleMock
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
@@ -40,38 +53,49 @@ except ImportError:
SIGNATURE_TEST_DATA = b"1234"
def _get_cryptography_private_key_info(key, need_private_key_data=False):
def _get_cryptography_private_key_info(
key: PrivateKeyTypes, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
key_type, key_public_data = _get_cryptography_public_key_info(key.public_key())
key_private_data = dict()
key_private_data: dict[str, t.Any] = {}
if need_private_key_data:
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
private_numbers = key.private_numbers()
key_private_data["p"] = private_numbers.p
key_private_data["q"] = private_numbers.q
key_private_data["exponent"] = private_numbers.d
rsa_private_numbers = key.private_numbers()
key_private_data["p"] = rsa_private_numbers.p
key_private_data["q"] = rsa_private_numbers.q
key_private_data["exponent"] = rsa_private_numbers.d
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey
):
private_numbers = key.private_numbers()
key_private_data["x"] = private_numbers.x
dsa_private_numbers = key.private_numbers()
key_private_data["x"] = dsa_private_numbers.x
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
):
private_numbers = key.private_numbers()
key_private_data["multiplier"] = private_numbers.private_value
ecc_private_numbers = key.private_numbers()
key_private_data["multiplier"] = ecc_private_numbers.private_value
return key_type, key_public_data, key_private_data
def _check_dsa_consistency(key_public_data, key_private_data):
def _check_dsa_consistency(
key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
# Get parameters
p = key_public_data.get("p")
q = key_public_data.get("q")
g = key_public_data.get("g")
y = key_public_data.get("y")
x = key_private_data.get("x")
for v in (p, q, g, y, x):
if v is None:
return None
p: int | None = key_public_data.get("p")
if p is None:
return None
q: int | None = key_public_data.get("q")
if q is None:
return None
g: int | None = key_public_data.get("g")
if g is None:
return None
y: int | None = key_public_data.get("y")
if y is None:
return None
x: int | None = key_private_data.get("x")
if x is None:
return None
# Make sure that g is not 0, 1 or -1 in Z/pZ
if g < 2 or g >= p - 1:
return False
@@ -94,13 +118,16 @@ def _check_dsa_consistency(key_public_data, key_private_data):
def _is_cryptography_key_consistent(
key, key_public_data, key_private_data, warn_func=None
):
key: PrivateKeyTypes,
key_public_data: dict[str, t.Any],
key_private_data: dict[str, t.Any],
warn_func: t.Callable[[str], None] | None = None,
) -> bool | None:
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
# key._backend was removed in cryptography 42.0.0
backend = getattr(key, "_backend", None)
if backend is not None:
return bool(backend._lib.RSA_check_key(key._rsa_cdata))
return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
result = _check_dsa_consistency(key_public_data, key_private_data)
if result is not None:
@@ -145,9 +172,9 @@ def _is_cryptography_key_consistent(
if isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey):
has_simple_sign_function = True
if has_simple_sign_function:
signature = key.sign(SIGNATURE_TEST_DATA)
signature = key.sign(SIGNATURE_TEST_DATA) # type: ignore
try:
key.public_key().verify(signature, SIGNATURE_TEST_DATA)
key.public_key().verify(signature, SIGNATURE_TEST_DATA) # type: ignore
return True
except cryptography.exceptions.InvalidSignature:
return False
@@ -158,14 +185,14 @@ def _is_cryptography_key_consistent(
class PrivateKeyConsistencyError(OpenSSLObjectError):
def __init__(self, msg, result):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
super(PrivateKeyConsistencyError, self).__init__(msg)
self.error_message = msg
self.result = result
class PrivateKeyParseError(OpenSSLObjectError):
def __init__(self, msg, result):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
super(PrivateKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
@@ -174,13 +201,12 @@ class PrivateKeyParseError(OpenSSLObjectError):
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self,
module,
content,
passphrase=None,
return_private_key_data=False,
check_consistency=False,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
return_private_key_data: bool = False,
check_consistency: bool = False,
):
# content must be a bytes string
self.module = module
self.content = content
self.passphrase = passphrase
@@ -188,22 +214,26 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
self.check_consistency = check_consistency
@abc.abstractmethod
def _get_public_key(self, binary):
def _get_public_key(self, binary: bool) -> bytes:
pass
@abc.abstractmethod
def _get_key_info(self, need_private_key_data=False):
def _get_key_info(
self, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
pass
@abc.abstractmethod
def _is_key_consistent(self, key_public_data, key_private_data):
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
pass
def get_info(self, prefer_one_fingerprint=False):
result = dict(
can_parse_key=False,
key_is_consistent=None,
)
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {
"can_parse_key": False,
"key_is_consistent": None,
}
priv_key_detail = self.content
try:
self.key = load_privatekey(
@@ -252,35 +282,39 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
"""Validate the supplied private key, using the cryptography backend"""
def __init__(self, module, content, **kwargs):
def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None:
super(PrivateKeyInfoRetrievalCryptography, self).__init__(
module, content, **kwargs
)
def _get_public_key(self, binary):
def _get_public_key(self, binary: bool) -> bytes:
return self.key.public_key().public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(self, need_private_key_data=False):
def _get_key_info(
self, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return _get_cryptography_private_key_info(
self.key, need_private_key_data=need_private_key_data
)
def _is_key_consistent(self, key_public_data, key_private_data):
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
return _is_cryptography_key_consistent(
self.key, key_public_data, key_private_data, warn_func=self.module.warn
)
def get_privatekey_info(
module,
content,
passphrase=None,
return_private_key_data=False,
prefer_one_fingerprint=False,
):
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
return_private_key_data: bool = False,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PrivateKeyInfoRetrievalCryptography(
module,
content,
@@ -291,12 +325,12 @@ def get_privatekey_info(
def select_backend(
module,
content,
passphrase=None,
return_private_key_data=False,
check_consistency=False,
):
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
return_private_key_data: bool = False,
check_consistency: bool = False,
) -> PrivateKeyInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import abc
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
@@ -19,6 +20,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PublicKeyTypes,
)
from ....plugin_utils.action_module import AnsibleActionModule
from ....plugin_utils.filter_module import FilterModuleMock
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
@@ -32,23 +45,25 @@ except ImportError:
pass
def _get_cryptography_public_key_info(key):
key_public_data = dict()
def _get_cryptography_public_key_info(
key: PublicKeyTypes,
) -> tuple[str, dict[str, t.Any]]:
key_public_data: dict[str, t.Any] = {}
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey):
key_type = "RSA"
public_numbers = key.public_numbers()
rsa_public_numbers = key.public_numbers()
key_public_data["size"] = key.key_size
key_public_data["modulus"] = public_numbers.n
key_public_data["exponent"] = public_numbers.e
key_public_data["modulus"] = rsa_public_numbers.n
key_public_data["exponent"] = rsa_public_numbers.e
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey):
key_type = "DSA"
parameter_numbers = key.parameters().parameter_numbers()
public_numbers = key.public_numbers()
dsa_parameter_numbers = key.parameters().parameter_numbers()
dsa_public_numbers = key.public_numbers()
key_public_data["size"] = key.key_size
key_public_data["p"] = parameter_numbers.p
key_public_data["q"] = parameter_numbers.q
key_public_data["g"] = parameter_numbers.g
key_public_data["y"] = public_numbers.y
key_public_data["p"] = dsa_parameter_numbers.p
key_public_data["q"] = dsa_parameter_numbers.q
key_public_data["g"] = dsa_parameter_numbers.g
key_public_data["y"] = dsa_public_numbers.y
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey
):
@@ -67,10 +82,10 @@ def _get_cryptography_public_key_info(key):
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey
):
key_type = "ECC"
public_numbers = key.public_numbers()
ecc_public_numbers = key.public_numbers()
key_public_data["curve"] = key.curve.name
key_public_data["x"] = public_numbers.x
key_public_data["y"] = public_numbers.y
key_public_data["x"] = ecc_public_numbers.x
key_public_data["y"] = ecc_public_numbers.y
key_public_data["exponent_size"] = key.curve.key_size
else:
key_type = f"unknown ({type(key)})"
@@ -78,29 +93,34 @@ def _get_cryptography_public_key_info(key):
class PublicKeyParseError(OpenSSLObjectError):
def __init__(self, msg, result):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
super(PublicKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(self, module, content=None, key=None):
def __init__(
self,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> None:
# content must be a bytes string
self.module = module
self.content = content
self.key = key
@abc.abstractmethod
def _get_public_key(self, binary):
def _get_public_key(self, binary: bool) -> bytes:
pass
@abc.abstractmethod
def _get_key_info(self):
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
pass
def get_info(self, prefer_one_fingerprint=False):
result = dict()
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
if self.key is None:
try:
self.key = load_publickey(content=self.content)
@@ -123,27 +143,45 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
"""Validate the supplied public key, using the cryptography backend"""
def __init__(self, module, content=None, key=None):
def __init__(
self,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> None:
super(PublicKeyInfoRetrievalCryptography, self).__init__(
module, content=content, key=key
)
def _get_public_key(self, binary):
def _get_public_key(self, binary: bool) -> bytes:
if self.key is None:
raise AssertionError("key must be set")
return self.key.public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(self):
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
if self.key is None:
raise AssertionError("key must be set")
return _get_cryptography_public_key_info(self.key)
def get_publickey_info(module, content=None, key=None, prefer_one_fingerprint=False):
def get_publickey_info(
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(module, content=None, key=None):
def select_backend(
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> PublicKeyInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)

View File

@@ -8,3 +8,6 @@ from __future__ import annotations
from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import ( # noqa: F401, pylint: disable=unused-import
parse_openssh_version,
)
# TODO: delete!

View File

@@ -4,6 +4,8 @@
from __future__ import annotations
import typing as t
PEM_START = "-----BEGIN "
PEM_END_START = "-----END "
@@ -12,7 +14,7 @@ PKCS8_PRIVATEKEY_NAMES = ("PRIVATE KEY", "ENCRYPTED PRIVATE KEY")
PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY"
def identify_pem_format(content, encoding="utf-8"):
def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
"""Given the contents of a binary file, tests whether this could be a PEM file."""
try:
first_pem = extract_first_pem(content.decode(encoding))
@@ -30,7 +32,9 @@ def identify_pem_format(content, encoding="utf-8"):
return False
def identify_private_key_format(content, encoding="utf-8"):
def identify_private_key_format(
content: bytes, encoding: str = "utf-8"
) -> t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]:
"""Given the contents of a private key file, identifies its format."""
# See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85
# (PEM_read_bio_PrivateKey)
@@ -59,12 +63,12 @@ def identify_private_key_format(content, encoding="utf-8"):
return "raw"
def split_pem_list(text, keep_inbetween=False):
def split_pem_list(text: str, keep_inbetween: bool = False) -> list[str]:
"""
Split concatenated PEM objects into a list of strings, where each is one PEM object.
"""
result = []
current = [] if keep_inbetween else None
current: list[str] | None = [] if keep_inbetween else None
for line in text.splitlines(True):
if line.strip():
if not keep_inbetween and line.startswith("-----BEGIN "):
@@ -77,7 +81,7 @@ def split_pem_list(text, keep_inbetween=False):
return result
def extract_first_pem(text):
def extract_first_pem(text: str) -> str | None:
"""
Given one PEM or multiple concatenated PEM objects, return only the first one, or None if there is none.
"""
@@ -87,7 +91,7 @@ def extract_first_pem(text):
return all_pems[0]
def _extract_type(line, start=PEM_START):
def _extract_type(line: str, start: str = PEM_START) -> str | None:
if not line.startswith(start):
return None
if not line.endswith(PEM_END):
@@ -95,7 +99,7 @@ def _extract_type(line, start=PEM_START):
return line[len(start) : -len(PEM_END)]
def extract_pem(content, strict=False):
def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
lines = content.splitlines()
if len(lines) < 3:
raise ValueError(f"PEM must have at least 3 lines, have only {len(lines)}")
@@ -117,5 +121,4 @@ def extract_pem(content, strict=False):
raise ValueError(
f"Last line has length {len(lines[-2])}, should be in (0, 64]"
)
content = lines[1:-1]
return header_type, "".join(content)
return header_type, "".join(lines[1:-1])

View File

@@ -8,8 +8,13 @@ import abc
import errno
import hashlib
import os
import typing as t
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
is_potential_certificate_issuer_private_key,
is_potential_certificate_private_key,
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import (
identify_pem_format,
)
@@ -34,6 +39,17 @@ except ImportError:
from .basic import OpenSSLBadPassphraseError, OpenSSLObjectError
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
PublicKeyTypes,
)
from .cryptography_support import CertificatePrivateKeyTypes
# This list of preferred fingerprints is used when prefer_one=True is supplied to the
# fingerprinting methods.
PREFERRED_FINGERPRINTS = (
@@ -48,18 +64,12 @@ PREFERRED_FINGERPRINTS = (
)
def get_fingerprint_of_bytes(source, prefer_one=False):
def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[str, str]:
"""Generate the fingerprint of the given bytes."""
fingerprint = {}
try:
algorithms = hashlib.algorithms
except AttributeError:
try:
algorithms = hashlib.algorithms_guaranteed
except AttributeError:
return None
algorithms: t.Iterable[str] = hashlib.algorithms_guaranteed
if prefer_one:
# Sort algorithms to have the ones in PREFERRED_FINGERPRINTS at the beginning
@@ -97,7 +107,9 @@ def get_fingerprint_of_bytes(source, prefer_one=False):
return fingerprint
def get_fingerprint_of_privatekey(privatekey, prefer_one=False):
def get_fingerprint_of_privatekey(
privatekey: PrivateKeyTypes, prefer_one: bool = False
) -> dict[str, str]:
"""Generate the fingerprint of the public key."""
publickey = privatekey.public_key().public_bytes(
@@ -107,11 +119,16 @@ def get_fingerprint_of_privatekey(privatekey, prefer_one=False):
return get_fingerprint_of_bytes(publickey, prefer_one=prefer_one)
def get_fingerprint(path, passphrase=None, content=None, prefer_one=False):
def get_fingerprint(
path: os.PathLike | str | None = None,
passphrase: str | bytes | None = None,
content: bytes | None = None,
prefer_one: bool = False,
) -> dict[str, str]:
"""Generate the fingerprint of the public key."""
privatekey = load_privatekey(
path,
path=path,
passphrase=passphrase,
content=content,
check_passphrase=False,
@@ -121,11 +138,11 @@ def get_fingerprint(path, passphrase=None, content=None, prefer_one=False):
def load_privatekey(
path,
passphrase=None,
check_passphrase=True,
content=None,
):
path: os.PathLike | str | None = None,
passphrase: str | bytes | None = None,
check_passphrase: bool = True,
content: bytes | None = None,
) -> PrivateKeyTypes:
"""Load the specified OpenSSL private key.
The content can also be specified via content; in that case,
@@ -134,6 +151,8 @@ def load_privatekey(
try:
if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
with open(path, "rb") as b_priv_key_fh:
priv_key_detail = b_priv_key_fh.read()
else:
@@ -154,7 +173,55 @@ def load_privatekey(
raise OpenSSLBadPassphraseError("Wrong passphrase provided for private key")
def load_publickey(path=None, content=None):
def load_certificate_privatekey(
*,
path: os.PathLike | str | None = None,
content: bytes | None = None,
passphrase: str | bytes | None = None,
check_passphrase: bool = True,
) -> CertificatePrivateKeyTypes:
"""
Load the specified OpenSSL private key that can be used as a private key for certificates.
"""
private_key = load_privatekey(
path=path,
passphrase=passphrase,
check_passphrase=check_passphrase,
content=content,
)
if not is_potential_certificate_private_key(private_key):
raise OpenSSLObjectError(
f"Key of type {type(private_key)} not supported for certificates"
)
return private_key
def load_certificate_issuer_privatekey(
*,
path: os.PathLike | str | None = None,
content: bytes | None = None,
passphrase: str | bytes | None = None,
check_passphrase: bool = True,
) -> CertificateIssuerPrivateKeyTypes:
"""
Load the specified OpenSSL private key that can be used for issuing certificates.
"""
private_key = load_privatekey(
path=path,
passphrase=passphrase,
check_passphrase=check_passphrase,
content=content,
)
if not is_potential_certificate_issuer_private_key(private_key):
raise OpenSSLObjectError(
f"Key of type {type(private_key)} not supported for issuing certificates"
)
return private_key
def load_publickey(
path: os.PathLike | str | None = None, content: bytes | None = None
) -> PublicKeyTypes:
if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
@@ -170,11 +237,17 @@ def load_publickey(path=None, content=None):
raise OpenSSLObjectError(f"Error while deserializing key: {e}")
def load_certificate(path, content=None, der_support_enabled=False):
def load_certificate(
path: os.PathLike | str | None = None,
content: bytes | None = None,
der_support_enabled: bool = False,
) -> x509.Certificate:
"""Load the specified certificate."""
try:
if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
with open(path, "rb") as cert_fh:
cert_content = cert_fh.read()
else:
@@ -193,10 +266,14 @@ def load_certificate(path, content=None, der_support_enabled=False):
raise OpenSSLObjectError(f"Cannot parse DER certificate: {exc}")
def load_certificate_request(path, content=None):
def load_certificate_request(
path: os.PathLike | str | None = None, content: bytes | None = None
) -> x509.CertificateSigningRequest:
"""Load the specified certificate signing request."""
try:
if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
with open(path, "rb") as csr_fh:
csr_content = csr_fh.read()
else:
@@ -209,45 +286,44 @@ def load_certificate_request(path, content=None):
raise OpenSSLObjectError(exc)
def parse_name_field(input_dict, name_field_name=None):
def parse_name_field(
input_dict: dict[str, list[str | bytes] | str | bytes],
name_field_name: str | None = None,
) -> list[tuple[str, str | bytes]]:
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
error_str = "{key}" if name_field_name is None else "{key} in {name}"
def error_str(key: str) -> str:
if name_field_name is None:
return f"{key}"
return f"{key} in {name_field_name}"
result = []
for key, value in input_dict.items():
if isinstance(value, list):
for entry in value:
if not isinstance(entry, (str, bytes)):
raise TypeError(
f"Values {error_str} must be strings".format(
key=key, name=name_field_name
)
)
raise TypeError(f"Values {error_str(key)} must be strings")
if not entry:
raise ValueError(
f"Values for {error_str} must not be empty strings".format(
key=key, name=name_field_name
)
f"Values for {error_str(key)} must not be empty strings"
)
result.append((key, entry))
elif isinstance(value, (str, bytes)):
if not value:
raise ValueError(
f"Value for {error_str} must not be an empty string".format(
key=key, name=name_field_name
)
f"Value for {error_str(key)} must not be an empty string"
)
result.append((key, value))
else:
raise TypeError(
(
f"Value for {error_str} must be either a string or a list of strings"
).format(key=key, name=name_field_name)
f"Value for {error_str(key)} must be either a string or a list of strings"
)
return result
def parse_ordered_name_field(input_list, name_field_name):
def parse_ordered_name_field(
input_list: list[dict[str, list[str | bytes] | str | bytes]], name_field_name: str
) -> list[tuple[str, str | bytes]]:
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
result = []
@@ -265,24 +341,39 @@ def parse_ordered_name_field(input_list, name_field_name):
return result
def select_message_digest(digest_string):
digest = None
@t.overload
def select_message_digest(
digest_string: t.Literal["sha256", "sha384", "sha512", "sha1", "md5"],
) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5: ...
@t.overload
def select_message_digest(
digest_string: str,
) -> (
hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None
): ...
def select_message_digest(
digest_string: str,
) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None:
if digest_string == "sha256":
digest = hashes.SHA256()
elif digest_string == "sha384":
digest = hashes.SHA384()
elif digest_string == "sha512":
digest = hashes.SHA512()
elif digest_string == "sha1":
digest = hashes.SHA1()
elif digest_string == "md5":
digest = hashes.MD5()
return digest
return hashes.SHA256()
if digest_string == "sha384":
return hashes.SHA384()
if digest_string == "sha512":
return hashes.SHA512()
if digest_string == "sha1":
return hashes.SHA1()
if digest_string == "md5":
return hashes.MD5()
return None
class OpenSSLObject(metaclass=abc.ABCMeta):
def __init__(self, path, state, force, check_mode):
def __init__(self, path: str, state: str, force: bool, check_mode: bool) -> None:
self.path = path
self.state = state
self.force = force
@@ -290,13 +381,13 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
self.changed = False
self.check_mode = check_mode
def check(self, module, perms_required=True):
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
def _check_state():
def _check_state() -> bool:
return os.path.exists(self.path)
def _check_perms(module):
def _check_perms(module: AnsibleModule) -> bool:
file_args = module.load_file_common_arguments(module.params)
if module.check_file_absent_if_check_mode(file_args["path"]):
return False
@@ -308,18 +399,14 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
return _check_state() and _check_perms(module)
@abc.abstractmethod
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
pass
@abc.abstractmethod
def generate(self):
def generate(self, module: AnsibleModule) -> None:
"""Generate the resource."""
pass
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
"""Remove the resource from the filesystem."""
if self.check_mode:
if os.path.exists(self.path):

View File

@@ -11,6 +11,7 @@ Must be kept in sync with plugins/doc_fragments/cryptography_dep.py.
from __future__ import annotations
import traceback
import typing as t
from ansible.module_utils.basic import missing_required_lib
from ansible_collections.community.crypto.plugins.module_utils.version import (
@@ -18,19 +19,29 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
_CRYPTOGRAPHY_IMP_ERR = None
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ..plugin_utils.action_module import AnsibleActionModule
from ..plugin_utils.filter_module import FilterModuleMock
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
_CRYPTOGRAPHY_IMP_ERR: str | None = None
_CRYPTOGRAPHY_FILE: str | None = None
try:
import cryptography
from cryptography import x509 # noqa: F401, pylint: disable=unused-import
_CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
_CRYPTOGRAPHY_FILE = cryptography.__file__
except ImportError:
_CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
_CRYPTOGRAPHY_FOUND = False
_CRYPTOGRAPHY_FILE = None
CRYPTOGRAPHY_FOUND = False
CRYPTOGRAPHY_VERSION = LooseVersion("0.0")
else:
_CRYPTOGRAPHY_FOUND = True
CRYPTOGRAPHY_FOUND = True
# Corresponds to the community.crypto.cryptography_dep.minimum doc fragment
@@ -38,25 +49,27 @@ COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION = "3.3"
def assert_required_cryptography_version(
module,
module: GeneralAnsibleModule,
*,
minimum_cryptography_version: str = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
) -> None:
if not _CRYPTOGRAPHY_FOUND:
if not CRYPTOGRAPHY_FOUND:
module.fail_json(
msg=missing_required_lib(f"cryptography >= {minimum_cryptography_version}"),
exception=_CRYPTOGRAPHY_IMP_ERR,
)
if _CRYPTOGRAPHY_VERSION < LooseVersion(minimum_cryptography_version):
if CRYPTOGRAPHY_VERSION < LooseVersion(minimum_cryptography_version):
module.fail_json(
msg=(
f"Cannot detect the required Python library cryptography (>= {minimum_cryptography_version})."
f" Only found a too old version ({_CRYPTOGRAPHY_VERSION}) at {_CRYPTOGRAPHY_FILE}."
f" Only found a too old version ({CRYPTOGRAPHY_VERSION}) at {_CRYPTOGRAPHY_FILE}."
),
)
__all__ = (
"COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION",
"CRYPTOGRAPHY_FOUND",
"CRYPTOGRAPHY_VERSION",
"assert_required_cryptography_version",
)

View File

@@ -14,6 +14,7 @@ import json
import os
import re
import traceback
import typing as t
from urllib.error import HTTPError
from urllib.parse import urlencode
@@ -34,7 +35,7 @@ else:
valid_file_format = re.compile(r".*(\.)(yml|yaml|json)$")
def ecs_client_argument_spec():
def ecs_client_argument_spec() -> dict[str, t.Any]:
return dict(
entrust_api_user=dict(type="str", required=True),
entrust_api_key=dict(type="str", required=True, no_log=True),
@@ -50,19 +51,17 @@ def ecs_client_argument_spec():
class SessionConfigurationException(Exception):
"""Raised if we cannot configure a session with the API"""
pass
class RestOperationException(Exception):
"""Encapsulate a REST API error"""
def __init__(self, error):
def __init__(self, error: dict[str, t.Any]) -> None:
self.status = to_native(error.get("status", None))
self.errors = [to_native(err.get("message")) for err in error.get("errors", {})]
self.message = " ".join(self.errors)
def generate_docstring(operation_spec):
def generate_docstring(operation_spec: dict[str, t.Any]) -> str:
"""Generate a docstring for an operation defined in operation_spec (swagger)"""
# Description of the operation
docs = operation_spec.get("description", "No Description")

View File

@@ -14,7 +14,9 @@ class GPGError(Exception):
class GPGRunner(metaclass=abc.ABCMeta):
@abc.abstractmethod
def run_command(self, command, check_rc=True, data=None):
def run_command(
self, command: list[str], check_rc: bool = True, data: bytes | None = None
) -> tuple[int, str, str]:
"""
Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``.
@@ -29,7 +31,7 @@ class GPGRunner(metaclass=abc.ABCMeta):
pass
def get_fingerprint_from_stdout(stdout):
def get_fingerprint_from_stdout(stdout: str) -> str:
lines = stdout.splitlines(False)
for line in lines:
if line.startswith("fpr:"):
@@ -42,7 +44,7 @@ def get_fingerprint_from_stdout(stdout):
raise GPGError(f'Cannot extract fingerprint from stdout "{stdout}"')
def get_fingerprint_from_file(gpg_runner, path):
def get_fingerprint_from_file(gpg_runner: GPGRunner, path: str) -> str:
if not os.path.exists(path):
raise GPGError(f"{path} does not exist")
stdout = gpg_runner.run_command(
@@ -59,7 +61,7 @@ def get_fingerprint_from_file(gpg_runner, path):
return get_fingerprint_from_stdout(stdout)
def get_fingerprint_from_bytes(gpg_runner, content):
def get_fingerprint_from_bytes(gpg_runner: GPGRunner, content: bytes) -> str:
stdout = gpg_runner.run_command(
[
"--no-keyring",

View File

@@ -7,9 +7,14 @@ from __future__ import annotations
import errno
import os
import tempfile
import typing as t
def load_file(path, module=None):
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def load_file(path: str | os.PathLike, module: AnsibleModule | None = None) -> bytes:
"""
Load the file as a bytes string.
"""
@@ -22,7 +27,11 @@ def load_file(path, module=None):
module.fail_json(f"Error while loading {path} - {exc}")
def load_file_if_exists(path, module=None, ignore_errors=False):
def load_file_if_exists(
path: str | os.PathLike,
module: AnsibleModule | None = None,
ignore_errors: bool = False,
) -> bytes | None:
"""
Load the file as a bytes string. If the file does not exist, ``None`` is returned.
@@ -49,7 +58,12 @@ def load_file_if_exists(path, module=None, ignore_errors=False):
module.fail_json(f"Error while loading {path} - {exc}")
def write_file(module, content, default_mode=None, path=None):
def write_file(
module: AnsibleModule,
content: bytes,
default_mode: str | int | None = None,
path: str | os.PathLike | None = None,
) -> None:
"""
Writes content into destination file as securely as possible.
Uses file arguments from module.

View File

@@ -8,14 +8,31 @@ import abc
import os
import stat
import traceback
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import (
parse_openssh_version,
)
def restore_on_failure(f):
def backup_and_restore(module, path, *args, **kwargs):
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)
from ..certificate import OpensshCertificateTimeParameters
Param = t.ParamSpec("Param")
def restore_on_failure(
f: t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None],
) -> t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None]:
def backup_and_restore(
module: AnsibleModule, path: str | os.PathLike, *args, **kwargs
) -> None:
backup_file = module.backup_local(path) if os.path.exists(path) else None
try:
@@ -31,12 +48,31 @@ def restore_on_failure(f):
@restore_on_failure
def safe_atomic_move(module, path, destination):
def safe_atomic_move(
module: AnsibleModule, path: str | os.PathLike, destination: str | os.PathLike
) -> None:
module.atomic_move(os.path.abspath(path), os.path.abspath(destination))
def _restore_all_on_failure(f):
def backup_and_restore(self, sources_and_destinations, *args, **kwargs):
def _restore_all_on_failure(
f: t.Callable[
t.Concatenate[
OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param
],
None,
],
) -> t.Callable[
t.Concatenate[
OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param
],
None,
]:
def backup_and_restore(
self: OpensshModule,
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
*args,
**kwargs,
) -> None:
backups = [
(d, self.module.backup_local(d))
for s, d in sources_and_destinations
@@ -59,13 +95,13 @@ def _restore_all_on_failure(f):
class OpensshModule(metaclass=abc.ABCMeta):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.changed = False
self.check_mode = self.module.check_mode
self.changed: bool = False
self.check_mode: bool = self.module.check_mode
def execute(self):
def execute(self) -> t.NoReturn:
try:
self._execute()
except Exception as e:
@@ -77,11 +113,11 @@ class OpensshModule(metaclass=abc.ABCMeta):
self.module.exit_json(**self.result)
@abc.abstractmethod
def _execute(self):
def _execute(self) -> None:
pass
@property
def result(self):
def result(self) -> dict[str, t.Any]:
result = self._result
result["changed"] = self.changed
@@ -93,31 +129,31 @@ class OpensshModule(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def _result(self):
def _result(self) -> dict[str, t.Any]:
pass
@property
@abc.abstractmethod
def diff(self):
def diff(self) -> dict[str, t.Any]:
pass
@staticmethod
def skip_if_check_mode(f):
def wrapper(self, *args, **kwargs):
def skip_if_check_mode(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
def wrapper(self, *args, **kwargs) -> None:
if not self.check_mode:
f(self, *args, **kwargs)
return wrapper
return wrapper # type: ignore
@staticmethod
def trigger_change(f):
def wrapper(self, *args, **kwargs):
def trigger_change(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
def wrapper(self, *args, **kwargs) -> None:
f(self, *args, **kwargs)
self.changed = True
return wrapper
return wrapper # type: ignore
def _check_if_base_dir(self, path):
def _check_if_base_dir(self, path: str | os.PathLike) -> None:
base_dir = os.path.dirname(path) or "."
if not os.path.isdir(base_dir):
self.module.fail_json(
@@ -125,16 +161,19 @@ class OpensshModule(metaclass=abc.ABCMeta):
msg=f"The directory {base_dir} does not exist or the file is not a directory",
)
def _get_ssh_version(self):
def _get_ssh_version(self) -> str | None:
ssh_bin = self.module.get_bin_path("ssh")
if not ssh_bin:
return ""
return None
return parse_openssh_version(
self.module.run_command([ssh_bin, "-V", "-q"], check_rc=True)[2].strip()
)
@_restore_all_on_failure
def _safe_secure_move(self, sources_and_destinations):
def _safe_secure_move(
self,
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
) -> None:
"""Moves a list of files from 'source' to 'destination' and restores 'destination' from backup upon failure.
If 'destination' does not already exist, then 'source' permissions are preserved to prevent
exposing protected data ('atomic_move' uses the 'destination' base directory mask for
@@ -148,7 +187,7 @@ class OpensshModule(metaclass=abc.ABCMeta):
else:
self.module.preserved_copy(source, destination)
def _update_permissions(self, path):
def _update_permissions(self, path: str | os.PathLike) -> None:
file_args = self.module.load_file_common_arguments(self.module.params)
file_args["path"] = path
@@ -161,25 +200,25 @@ class OpensshModule(metaclass=abc.ABCMeta):
class KeygenCommand:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self._bin_path = module.get_bin_path("ssh-keygen", True)
self._run_command = module.run_command
def generate_certificate(
self,
certificate_path,
identifier,
options,
pkcs11_provider,
principals,
serial_number,
signature_algorithm,
signing_key_path,
type,
time_parameters,
use_agent,
certificate_path: str,
identifier: str,
options: list[str] | None,
pkcs11_provider: str | None,
principals: list[str] | None,
serial_number: int | None,
signature_algorithm: str | None,
signing_key_path: str,
type: t.Literal["host", "user"] | None,
time_parameters: OpensshCertificateTimeParameters,
use_agent: bool,
**kwargs,
):
) -> tuple[int, str, str]:
args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier]
if options:
@@ -203,7 +242,9 @@ class KeygenCommand:
return self._run_command(args, **kwargs)
def generate_keypair(self, private_key_path, size, type, comment, **kwargs):
def generate_keypair(
self, private_key_path: str, size: int, type: str, comment: str | None, **kwargs
) -> tuple[int, str, str]:
args = [
self._bin_path,
"-q",
@@ -224,32 +265,40 @@ class KeygenCommand:
return self._run_command(args, data=data, **kwargs)
def get_certificate_info(self, certificate_path, **kwargs):
def get_certificate_info(
self, certificate_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-L", "-f", certificate_path], **kwargs
)
def get_matching_public_key(self, private_key_path, **kwargs):
def get_matching_public_key(
self, private_key_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
)
def get_private_key(self, private_key_path, **kwargs):
def get_private_key(self, private_key_path: str, **kwargs) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-l", "-f", private_key_path], **kwargs
)
def update_comment(
self, private_key_path, comment, force_new_format=True, **kwargs
):
self,
private_key_path: str,
comment: str,
force_new_format: bool = True,
**kwargs,
) -> tuple[int, str, str]:
if os.path.exists(private_key_path) and not os.access(
private_key_path, os.W_OK
):
try:
os.chmod(private_key_path, stat.S_IWUSR + stat.S_IRUSR)
except (IOError, OSError) as e:
raise e(
f"The private key at {private_key_path} is not writeable preventing a comment update"
raise ValueError(
f"The private key at {private_key_path} is not writeable preventing a comment update ({e})"
)
command = [self._bin_path, "-q"]
@@ -259,31 +308,36 @@ class KeygenCommand:
return self._run_command(command, **kwargs)
_PrivateKey = t.TypeVar("_PrivateKey", bound="PrivateKey")
class PrivateKey:
def __init__(self, size, key_type, fingerprint, format=""):
def __init__(
self, size: int, key_type: str, fingerprint: str, format: str = ""
) -> None:
self._size = size
self._type = key_type
self._fingerprint = fingerprint
self._format = format
@property
def size(self):
def size(self) -> int:
return self._size
@property
def type(self):
def type(self) -> str:
return self._type
@property
def fingerprint(self):
def fingerprint(self) -> str:
return self._fingerprint
@property
def format(self):
def format(self) -> str:
return self._format
@classmethod
def from_string(cls, string):
def from_string(cls: t.Type[_PrivateKey], string: str) -> _PrivateKey:
properties = string.split()
return cls(
@@ -292,7 +346,7 @@ class PrivateKey:
fingerprint=properties[1],
)
def to_dict(self):
def to_dict(self) -> dict[str, t.Any]:
return {
"size": self._size,
"type": self._type,
@@ -301,13 +355,16 @@ class PrivateKey:
}
_PublicKey = t.TypeVar("_PublicKey", bound="PublicKey")
class PublicKey:
def __init__(self, type_string, data, comment):
def __init__(self, type_string: str, data: str, comment: str | None) -> None:
self._type_string = type_string
self._data = data
self._comment = comment
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
@@ -323,30 +380,30 @@ class PublicKey:
]
)
def __ne__(self, other):
def __ne__(self, other: object) -> bool:
return not self == other
def __str__(self):
def __str__(self) -> str:
return f"{self._type_string} {self._data}"
@property
def comment(self):
def comment(self) -> str | None:
return self._comment
@comment.setter
def comment(self, value):
def comment(self, value: str | None) -> None:
self._comment = value
@property
def data(self):
def data(self) -> str:
return self._data
@property
def type_string(self):
def type_string(self) -> str:
return self._type_string
@classmethod
def from_string(cls, string):
def from_string(cls: t.Type[_PublicKey], string: str) -> _PublicKey:
properties = string.strip("\n").split(" ", 2)
return cls(
@@ -356,7 +413,7 @@ class PublicKey:
)
@classmethod
def load(cls, path):
def load(cls: t.Type[_PublicKey], path: str | os.PathLike) -> _PublicKey | None:
try:
with open(path, "r") as f:
properties = f.read().strip(" \n").split(" ", 2)
@@ -372,14 +429,16 @@ class PublicKey:
comment="" if len(properties) <= 2 else properties[2],
)
def to_dict(self):
def to_dict(self) -> dict[str, t.Any]:
return {
"comment": self._comment,
"public_key": self._data,
}
def parse_private_key_format(path):
def parse_private_key_format(
path: str | os.PathLike,
) -> t.Literal["SSH", "PKCS8", "PKCS1", ""]:
with open(path, "r") as file:
header = file.readline().strip()

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import abc
import os
import typing as t
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.common.text.converters import to_bytes, to_text
@@ -39,31 +40,43 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)
class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackend, self).__init__(module)
self.comment = self.module.params["comment"]
self.private_key_path = self.module.params["path"]
self.comment: str | None = self.module.params["comment"]
self.private_key_path: str = self.module.params["path"]
self.public_key_path = self.private_key_path + ".pub"
self.regenerate = (
self.regenerate: t.Literal[
"never", "fail", "partial_idempotence", "full_idempotence", "always"
] = (
self.module.params["regenerate"]
if not self.module.params["force"]
else "always"
)
self.state = self.module.params["state"]
self.type = self.module.params["type"]
self.state: t.Literal["present", "absent"] = self.module.params["state"]
self.type: t.Literal["rsa", "dsa", "rsa1", "ecdsa", "ed25519"] = (
self.module.params["type"]
)
self.size = self._get_size(self.module.params["size"])
self.size: int = self._get_size(self.module.params["size"])
self._validate_path()
self.original_private_key = None
self.original_public_key = None
self.private_key = None
self.public_key = None
self.original_private_key: PrivateKey | None = None
self.original_public_key: PublicKey | None = None
self.private_key: PrivateKey | None = None
self.public_key: PublicKey | None = None
def _get_size(self, size):
def _get_size(self, size: int | None) -> int:
if self.type in ("rsa", "rsa1"):
result = 4096 if size is None else size
if result < 1024:
@@ -96,7 +109,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
return result
def _validate_path(self):
def _validate_path(self) -> None:
self._check_if_base_dir(self.private_key_path)
if os.path.isdir(self.private_key_path):
@@ -104,7 +117,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
msg=f"{self.private_key_path} is a directory. Please specify a path to a file."
)
def _execute(self):
def _execute(self) -> None:
self.original_private_key = self._load_private_key()
self.original_public_key = self._load_public_key()
@@ -125,7 +138,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
if self._should_remove():
self._remove()
def _load_private_key(self):
def _load_private_key(self) -> PrivateKey | None:
result = None
if self._private_key_exists():
try:
@@ -135,14 +148,14 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
return result
def _private_key_exists(self):
def _private_key_exists(self) -> bool:
return os.path.exists(self.private_key_path)
@abc.abstractmethod
def _get_private_key(self):
def _get_private_key(self) -> PrivateKey:
pass
def _load_public_key(self):
def _load_public_key(self) -> PublicKey | None:
result = None
if self._public_key_exists():
try:
@@ -151,10 +164,10 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
pass
return result
def _public_key_exists(self):
def _public_key_exists(self) -> bool:
return os.path.exists(self.public_key_path)
def _validate_key_load(self):
def _validate_key_load(self) -> None:
if (
self._private_key_exists()
and self.regenerate in ("never", "fail", "partial_idempotence")
@@ -167,10 +180,10 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
)
@abc.abstractmethod
def _private_key_readable(self):
def _private_key_readable(self) -> bool:
pass
def _should_generate(self):
def _should_generate(self) -> bool:
if self.original_private_key is None:
return True
elif self.regenerate == "never":
@@ -188,7 +201,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
else:
return True
def _private_key_valid(self):
def _private_key_valid(self) -> bool:
if self.original_private_key is None:
return False
@@ -196,17 +209,17 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
[
self.size == self.original_private_key.size,
self.type == self.original_private_key.type,
self._private_key_valid_backend(),
self._private_key_valid_backend(self.original_private_key),
]
)
@abc.abstractmethod
def _private_key_valid_backend(self):
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
pass
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _generate(self):
def _generate(self) -> None:
temp_private_key, temp_public_key = self._generate_temp_keypair()
try:
@@ -219,7 +232,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
except OSError as e:
self.module.fail_json(msg=str(e))
def _generate_temp_keypair(self):
def _generate_temp_keypair(self) -> tuple[str, str]:
temp_private_key = os.path.join(
self.module.tmpdir, os.path.basename(self.private_key_path)
)
@@ -236,25 +249,26 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
return temp_private_key, temp_public_key
@abc.abstractmethod
def _generate_keypair(self, private_key_path):
def _generate_keypair(self, private_key_path: str) -> None:
pass
def _public_key_valid(self):
def _public_key_valid(self) -> bool:
if self.original_public_key is None:
return False
valid_public_key = self._get_public_key()
valid_public_key.comment = self.comment
if valid_public_key:
valid_public_key.comment = self.comment
return self.original_public_key == valid_public_key
@abc.abstractmethod
def _get_public_key(self):
def _get_public_key(self) -> PublicKey | t.Literal[""]:
pass
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _restore_public_key(self):
def _restore_public_key(self) -> None:
try:
temp_public_key = self._create_temp_public_key(
str(self._get_public_key()) + "\n"
@@ -269,7 +283,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
if self.comment:
self._update_comment()
def _create_temp_public_key(self, content):
def _create_temp_public_key(self, content: str | bytes) -> str:
temp_public_key = os.path.join(
self.module.tmpdir, os.path.basename(self.public_key_path)
)
@@ -290,15 +304,15 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
return temp_public_key
@abc.abstractmethod
def _update_comment(self):
def _update_comment(self) -> None:
pass
def _should_remove(self):
def _should_remove(self) -> bool:
return self._private_key_exists() or self._public_key_exists()
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _remove(self):
def _remove(self) -> None:
try:
if self._private_key_exists():
os.remove(self.private_key_path)
@@ -308,7 +322,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
self.module.fail_json(msg=str(e))
@property
def _result(self):
def _result(self) -> dict[str, t.Any]:
private_key = self.private_key or self.original_private_key
public_key = self.public_key or self.original_public_key
@@ -322,7 +336,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
}
@property
def diff(self):
def diff(self) -> dict[str, t.Any]:
before = (
self.original_private_key.to_dict() if self.original_private_key else {}
)
@@ -340,7 +354,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
class KeypairBackendOpensshBin(KeypairBackend):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackendOpensshBin, self).__init__(module)
if self.module.params["private_key_format"] != "auto":
@@ -350,12 +364,12 @@ class KeypairBackendOpensshBin(KeypairBackend):
self.ssh_keygen = KeygenCommand(self.module)
def _generate_keypair(self, private_key_path):
def _generate_keypair(self, private_key_path: str) -> None:
self.ssh_keygen.generate_keypair(
private_key_path, self.size, self.type, self.comment, check_rc=True
)
def _get_private_key(self):
def _get_private_key(self) -> PrivateKey:
rc, private_key_content, err = self.ssh_keygen.get_private_key(
self.private_key_path, check_rc=False
)
@@ -363,13 +377,13 @@ class KeypairBackendOpensshBin(KeypairBackend):
raise ValueError(err)
return PrivateKey.from_string(private_key_content)
def _get_public_key(self):
def _get_public_key(self) -> PublicKey | t.Literal[""]:
public_key_content = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=True
)[1]
return PublicKey.from_string(public_key_content)
def _private_key_readable(self):
def _private_key_readable(self) -> bool:
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=False
)
@@ -383,7 +397,7 @@ class KeypairBackendOpensshBin(KeypairBackend):
)
)
def _update_comment(self):
def _update_comment(self) -> None:
try:
ssh_version = self._get_ssh_version() or "7.8"
force_new_format = (
@@ -391,19 +405,19 @@ class KeypairBackendOpensshBin(KeypairBackend):
)
self.ssh_keygen.update_comment(
self.private_key_path,
self.comment,
self.comment or "",
force_new_format=force_new_format,
check_rc=True,
)
except (IOError, OSError) as e:
self.module.fail_json(msg=str(e))
def _private_key_valid_backend(self):
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
return True
class KeypairBackendCryptography(KeypairBackend):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackendCryptography, self).__init__(module)
if self.type == "rsa1":
@@ -416,12 +430,15 @@ class KeypairBackendCryptography(KeypairBackend):
if module.params["passphrase"]
else None
)
self.private_key_format = self._get_key_format(
module.params["private_key_format"]
)
key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"] = module.params[
"private_key_format"
]
self.private_key_format = self._get_key_format(key_format)
def _get_key_format(self, key_format):
result = "SSH"
def _get_key_format(
self, key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"]
) -> t.Literal["SSH", "PKCS1", "PKCS8"]:
result: t.Literal["SSH", "PKCS1", "PKCS8"] = "SSH"
if key_format == "auto":
# Default to OpenSSH 7.8 compatibility when OpenSSH is not installed
@@ -435,11 +452,12 @@ class KeypairBackendCryptography(KeypairBackend):
# but still defaulted to PKCS1 format with the exception of ed25519 keys
result = "PKCS1"
else:
result = key_format.upper()
result = key_format.upper() # type: ignore
return result
def _generate_keypair(self, private_key_path):
def _generate_keypair(self, private_key_path: str) -> None:
assert self.type != "rsa1"
keypair = OpensshKeypair.generate(
keytype=self.type,
size=self.size,
@@ -455,7 +473,7 @@ class KeypairBackendCryptography(KeypairBackend):
public_key_path = private_key_path + ".pub"
secure_write(public_key_path, 0o644, keypair.public_key)
def _get_private_key(self):
def _get_private_key(self) -> PrivateKey:
keypair = OpensshKeypair.load(
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
)
@@ -467,7 +485,7 @@ class KeypairBackendCryptography(KeypairBackend):
format=parse_private_key_format(self.private_key_path),
)
def _get_public_key(self):
def _get_public_key(self) -> PublicKey | t.Literal[""]:
try:
keypair = OpensshKeypair.load(
path=self.private_key_path,
@@ -480,7 +498,7 @@ class KeypairBackendCryptography(KeypairBackend):
return PublicKey.from_string(to_text(keypair.public_key))
def _private_key_readable(self):
def _private_key_readable(self) -> bool:
try:
OpensshKeypair.load(
path=self.private_key_path,
@@ -504,7 +522,7 @@ class KeypairBackendCryptography(KeypairBackend):
return True
def _update_comment(self):
def _update_comment(self) -> None:
keypair = OpensshKeypair.load(
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
)
@@ -519,16 +537,18 @@ class KeypairBackendCryptography(KeypairBackend):
except (IOError, OSError) as e:
self.module.fail_json(msg=str(e))
def _private_key_valid_backend(self):
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
# avoids breaking behavior and prevents
# automatic conversions with OpenSSH upgrades
if self.module.params["private_key_format"] == "auto":
return True
return self.private_key_format == self.original_private_key.format
return self.private_key_format == original_private_key.format
def select_backend(module, backend):
def select_backend(
module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
) -> KeypairBackend:
can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion(
CRYPTOGRAPHY_VERSION
) >= LooseVersion(COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION)

View File

@@ -8,6 +8,7 @@ import abc
import binascii
import datetime as _datetime
import os
import typing as t
from base64 import b64encode
from datetime import datetime
from hashlib import sha256
@@ -26,6 +27,16 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
)
if t.TYPE_CHECKING:
from .cryptography import KeyType
DateFormat = t.Literal["human_readable", "openssh", "timestamp"]
DateFormatStr = t.Literal["human_readable", "openssh"]
DateFormatInt = t.Literal["timestamp"]
else:
KeyType = None
# Protocol References
# -------------------
# https://datatracker.ietf.org/doc/html/rfc4251
@@ -44,7 +55,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
_USER_TYPE = 1
_HOST_TYPE = 2
_SSH_TYPE_STRINGS = {
_SSH_TYPE_STRINGS: dict[KeyType | str, bytes] = {
"rsa": b"ssh-rsa",
"dsa": b"ssh-dss",
"ecdsa-nistp256": b"ecdsa-sha2-nistp256",
@@ -94,16 +105,18 @@ _EXTENSIONS = (
class OpensshCertificateTimeParameters:
def __init__(self, valid_from, valid_to):
def __init__(
self, valid_from: str | bytes | int, valid_to: str | bytes | int
) -> None:
self._valid_from = self.to_datetime(valid_from)
self._valid_to = self.to_datetime(valid_to)
if self._valid_from > self._valid_to:
raise ValueError(
f"Valid from: {valid_from} must not be greater than Valid to: {valid_to}"
f"Valid from: {valid_from!r} must not be greater than Valid to: {valid_to!r}"
)
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
else:
@@ -112,55 +125,83 @@ class OpensshCertificateTimeParameters:
and self._valid_to == other._valid_to
)
def __ne__(self, other):
def __ne__(self, other: object) -> bool:
return not self == other
@property
def validity_string(self):
def validity_string(self) -> str:
if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER):
return f"{self.valid_from(date_format='openssh')}:{self.valid_to(date_format='openssh')}"
return ""
def valid_from(self, date_format):
@t.overload
def valid_from(self, date_format: DateFormatStr) -> str: ...
@t.overload
def valid_from(self, date_format: DateFormatInt) -> int: ...
@t.overload
def valid_from(self, date_format: DateFormat) -> str | int: ...
def valid_from(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_from, date_format)
def valid_to(self, date_format):
@t.overload
def valid_to(self, date_format: DateFormatStr) -> str: ...
@t.overload
def valid_to(self, date_format: DateFormatInt) -> int: ...
@t.overload
def valid_to(self, date_format: DateFormat) -> str | int: ...
def valid_to(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_to, date_format)
def within_range(self, valid_at):
def within_range(self, valid_at: str | bytes | int | None) -> bool:
if valid_at is not None:
valid_at_datetime = self.to_datetime(valid_at)
return self._valid_from <= valid_at_datetime <= self._valid_to
return True
@t.overload
@staticmethod
def format_datetime(dt, date_format):
def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ...
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int:
if date_format in ("human_readable", "openssh"):
if dt == _ALWAYS:
result = "always"
elif dt == _FOREVER:
result = "forever"
return "always"
if dt == _FOREVER:
return "forever"
else:
result = (
return (
dt.isoformat().replace("+00:00", "")
if date_format == "human_readable"
else dt.strftime("%Y%m%d%H%M%S")
)
elif date_format == "timestamp":
if date_format == "timestamp":
td = dt - _ALWAYS
result = int(
return int(
(td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6
)
else:
raise ValueError(f"{date_format} is not a valid format")
return result
raise ValueError(f"{date_format} is not a valid format")
@staticmethod
def to_datetime(time_string_or_timestamp):
def to_datetime(time_string_or_timestamp: str | bytes | int) -> datetime:
try:
if isinstance(time_string_or_timestamp, (str, bytes)):
result = OpensshCertificateTimeParameters._time_string_to_datetime(
time_string_or_timestamp.strip()
to_text(time_string_or_timestamp.strip())
)
elif isinstance(time_string_or_timestamp, int):
result = OpensshCertificateTimeParameters._timestamp_to_datetime(
@@ -175,43 +216,53 @@ class OpensshCertificateTimeParameters:
return result
@staticmethod
def _timestamp_to_datetime(timestamp):
def _timestamp_to_datetime(timestamp: int) -> datetime:
if timestamp == 0x0:
result = _ALWAYS
elif timestamp == 0xFFFFFFFFFFFFFFFF:
result = _FOREVER
else:
try:
result = datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
except OverflowError:
raise ValueError
return result
return _ALWAYS
if timestamp == 0xFFFFFFFFFFFFFFFF:
return _FOREVER
try:
return datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
except OverflowError:
raise ValueError
@staticmethod
def _time_string_to_datetime(time_string):
result = None
def _time_string_to_datetime(time_string: str) -> datetime:
if time_string == "always":
result = _ALWAYS
elif time_string == "forever":
result = _FOREVER
elif is_relative_time_string(time_string):
return _ALWAYS
if time_string == "forever":
return _FOREVER
if is_relative_time_string(time_string):
result = convert_relative_to_datetime(time_string, with_timezone=True)
else:
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
try:
result = _add_or_remove_timezone(
datetime.strptime(time_string, time_format),
with_timezone=True,
)
except ValueError:
pass
if result is None:
raise ValueError
return result
result = None
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
try:
result = _add_or_remove_timezone(
datetime.strptime(time_string, time_format),
with_timezone=True,
)
except ValueError:
pass
if result is None:
raise ValueError
return result
_OpensshCertificateOption = t.TypeVar(
"_OpensshCertificateOption", bound="OpensshCertificateOption"
)
class OpensshCertificateOption:
def __init__(self, option_type, name, data):
def __init__(
self,
option_type: t.Literal["critical", "extension"],
name: str | bytes,
data: str | bytes,
):
if option_type not in ("critical", "extension"):
raise ValueError("type must be either 'critical' or 'extension'")
@@ -225,7 +276,7 @@ class OpensshCertificateOption:
self._name = name.lower()
self._data = data
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
@@ -237,32 +288,34 @@ class OpensshCertificateOption:
]
)
def __hash__(self):
def __hash__(self) -> int:
return hash((self._option_type, self._name, self._data))
def __ne__(self, other):
def __ne__(self, other: object) -> bool:
return not self == other
def __str__(self):
def __str__(self) -> str:
if self._data:
return f"{self._name}={self._data}"
return self._name
return f"{self._name!r}={self._data!r}"
return f"{self._name!r}"
@property
def data(self):
def data(self) -> str | bytes:
return self._data
@property
def name(self):
def name(self) -> str | bytes:
return self._name
@property
def type(self):
def type(self) -> t.Literal["critical", "extension"]:
return self._option_type
@classmethod
def from_string(cls, option_string):
if not isinstance(option_string, (str, bytes)):
def from_string(
cls: t.Type[_OpensshCertificateOption], option_string: str
) -> _OpensshCertificateOption:
if not isinstance(option_string, str):
raise ValueError(
f"option_string must be a string not {type(option_string)}"
)
@@ -280,7 +333,8 @@ class OpensshCertificateOption:
name, data = option_string.strip(), ""
return cls(
option_type=option_type or get_option_type(name.lower()),
# We have str, but we're expecting a specific literal:
option_type=option_type or get_option_type(name.lower()), # type: ignore
name=name,
data=data,
)
@@ -291,21 +345,21 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
def __init__(
self,
nonce=None,
serial=None,
cert_type=None,
key_id=None,
principals=None,
valid_after=None,
valid_before=None,
critical_options=None,
extensions=None,
reserved=None,
signing_key=None,
nonce: bytes | None = None,
serial: int | None = None,
cert_type: int | None = None,
key_id: bytes | None = None,
principals: list[bytes] | None = None,
valid_after: int | None = None,
valid_before: int | None = None,
critical_options: list[tuple[bytes, bytes]] | None = None,
extensions: list[tuple[bytes, bytes]] | None = None,
reserved: bytes | None = None,
signing_key: bytes | None = None,
):
self.nonce = nonce
self.serial = serial
self._cert_type = cert_type
self._cert_type: int | None = cert_type
self.key_id = key_id
self.principals = principals
self.valid_after = valid_after
@@ -315,10 +369,10 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
self.reserved = reserved
self.signing_key = signing_key
self.type_string = None
self.type_string: bytes | None = None
@property
def cert_type(self):
def cert_type(self) -> t.Literal["user", "host", ""]:
if self._cert_type == _USER_TYPE:
return "user"
elif self._cert_type == _HOST_TYPE:
@@ -327,7 +381,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
return ""
@cert_type.setter
def cert_type(self, cert_type):
def cert_type(self, cert_type: t.Literal["user", "host"] | int) -> None:
if cert_type == "user" or cert_type == _USER_TYPE:
self._cert_type = _USER_TYPE
elif cert_type == "host" or cert_type == _HOST_TYPE:
@@ -335,28 +389,30 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
else:
raise ValueError(f"{cert_type} is not a valid certificate type")
def signing_key_fingerprint(self):
def signing_key_fingerprint(self) -> bytes:
if self.signing_key is None:
raise ValueError("signing_key not present")
return fingerprint(self.signing_key)
@abc.abstractmethod
def public_key_fingerprint(self):
def public_key_fingerprint(self) -> bytes:
pass
@abc.abstractmethod
def parse_public_numbers(self, parser):
def parse_public_numbers(self, parser: OpensshParser) -> None:
pass
class OpensshRSACertificateInfo(OpensshCertificateInfo):
def __init__(self, e=None, n=None, **kwargs):
def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None:
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
self.e = e
self.n = n
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.e is None, self.n is None]):
def public_key_fingerprint(self) -> bytes:
if self.e is None or self.n is None:
return b""
writer = _OpensshWriter()
@@ -366,13 +422,20 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
return fingerprint(writer.bytes())
def parse_public_numbers(self, parser):
def parse_public_numbers(self, parser: OpensshParser) -> None:
self.e = parser.mpint()
self.n = parser.mpint()
class OpensshDSACertificateInfo(OpensshCertificateInfo):
def __init__(self, p=None, q=None, g=None, y=None, **kwargs):
def __init__(
self,
p: int | None = None,
q: int | None = None,
g: int | None = None,
y: int | None = None,
**kwargs,
) -> None:
super(OpensshDSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01
self.p = p
@@ -381,8 +444,8 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
self.y = y
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.p is None, self.q is None, self.g is None, self.y is None]):
def public_key_fingerprint(self) -> bytes:
if self.p is None or self.q is None or self.g is None or self.y is None:
return b""
writer = _OpensshWriter()
@@ -394,7 +457,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
return fingerprint(writer.bytes())
def parse_public_numbers(self, parser):
def parse_public_numbers(self, parser: OpensshParser) -> None:
self.p = parser.mpint()
self.q = parser.mpint()
self.g = parser.mpint()
@@ -402,7 +465,9 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
class OpensshECDSACertificateInfo(OpensshCertificateInfo):
def __init__(self, curve=None, public_key=None, **kwargs):
def __init__(
self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
):
super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
self._curve = None
if curve is not None:
@@ -411,11 +476,11 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
self.public_key = public_key
@property
def curve(self):
def curve(self) -> bytes | None:
return self._curve
@curve.setter
def curve(self, curve):
def curve(self, curve: bytes) -> None:
if curve in _ECDSA_CURVE_IDENTIFIERS.values():
self._curve = curve
self.type_string = (
@@ -428,8 +493,8 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
)
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.curve is None, self.public_key is None]):
def public_key_fingerprint(self) -> bytes:
if self.curve is None or self.public_key is None:
return b""
writer = _OpensshWriter()
@@ -439,18 +504,18 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
return fingerprint(writer.bytes())
def parse_public_numbers(self, parser):
def parse_public_numbers(self, parser: OpensshParser) -> None:
self.curve = parser.string()
self.public_key = parser.string()
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
def __init__(self, pk=None, **kwargs):
def __init__(self, pk: bytes | None = None, **kwargs) -> None:
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
self.pk = pk
def public_key_fingerprint(self):
def public_key_fingerprint(self) -> bytes:
if self.pk is None:
return b""
@@ -460,21 +525,26 @@ class OpensshED25519CertificateInfo(OpensshCertificateInfo):
return fingerprint(writer.bytes())
def parse_public_numbers(self, parser):
def parse_public_numbers(self, parser: OpensshParser) -> None:
self.pk = parser.string()
_OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate")
# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
class OpensshCertificate:
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
def __init__(self, cert_info, signature):
def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes):
self._cert_info = cert_info
self.signature = signature
@classmethod
def load(cls, path):
def load(
cls: t.Type[_OpensshCertificate], path: str | os.PathLike
) -> _OpensshCertificate:
if not os.path.exists(path):
raise ValueError(f"{path} is not a valid path.")
@@ -492,11 +562,11 @@ class OpensshCertificate:
for key_type, string in _SSH_TYPE_STRINGS.items():
if format_identifier == string + _CERT_SUFFIX_V01:
pub_key_type = key_type
pub_key_type = t.cast(KeyType, key_type)
break
else:
raise ValueError(
f"Invalid certificate format identifier: {format_identifier}"
f"Invalid certificate format identifier: {format_identifier!r}"
)
parser = OpensshParser(cert)
@@ -521,75 +591,97 @@ class OpensshCertificate:
)
@property
def type_string(self):
def type_string(self) -> str:
return to_text(self._cert_info.type_string)
@property
def nonce(self):
def nonce(self) -> bytes:
if self._cert_info.nonce is None:
raise ValueError
return self._cert_info.nonce
@property
def public_key(self):
def public_key(self) -> str:
return to_text(self._cert_info.public_key_fingerprint())
@property
def serial(self):
def serial(self) -> int:
if self._cert_info.serial is None:
raise ValueError
return self._cert_info.serial
@property
def type(self):
return self._cert_info.cert_type
def type(self) -> t.Literal["user", "host"]:
result = self._cert_info.cert_type
if result == "":
raise ValueError
return result
@property
def key_id(self):
def key_id(self) -> str:
return to_text(self._cert_info.key_id)
@property
def principals(self):
def principals(self) -> list[str]:
if self._cert_info.principals is None:
raise ValueError
return [to_text(p) for p in self._cert_info.principals]
@property
def valid_after(self):
def valid_after(self) -> int:
if self._cert_info.valid_after is None:
raise ValueError
return self._cert_info.valid_after
@property
def valid_before(self):
def valid_before(self) -> int:
if self._cert_info.valid_before is None:
raise ValueError
return self._cert_info.valid_before
@property
def critical_options(self):
def critical_options(self) -> list[OpensshCertificateOption]:
if self._cert_info.critical_options is None:
raise ValueError
return [
OpensshCertificateOption("critical", to_text(n), to_text(d))
for n, d in self._cert_info.critical_options
]
@property
def extensions(self):
def extensions(self) -> list[OpensshCertificateOption]:
if self._cert_info.extensions is None:
raise ValueError
return [
OpensshCertificateOption("extension", to_text(n), to_text(d))
for n, d in self._cert_info.extensions
]
@property
def reserved(self):
def reserved(self) -> bytes:
if self._cert_info.reserved is None:
raise ValueError
return self._cert_info.reserved
@property
def signing_key(self):
def signing_key(self) -> str:
return to_text(self._cert_info.signing_key_fingerprint())
@property
def signature_type(self):
def signature_type(self) -> str:
signature_data = OpensshParser.signature_data(self.signature)
return to_text(signature_data["signature_type"])
@staticmethod
def _parse_cert_info(pub_key_type, parser):
def _parse_cert_info(
pub_key_type: KeyType, parser: OpensshParser
) -> OpensshCertificateInfo:
cert_info = get_cert_info_object(pub_key_type)
cert_info.nonce = parser.string()
cert_info.parse_public_numbers(parser)
cert_info.serial = parser.uint64()
cert_info.cert_type = parser.uint32()
# mypy doesn't understand that the setter accepts other types than the getter:
cert_info.cert_type = parser.uint32() # type: ignore
cert_info.key_id = parser.string()
cert_info.principals = parser.string_list()
cert_info.valid_after = parser.uint64()
@@ -601,7 +693,7 @@ class OpensshCertificate:
return cert_info
def to_dict(self):
def to_dict(self) -> dict[str, t.Any]:
time_parameters = OpensshCertificateTimeParameters(
valid_from=self.valid_after, valid_to=self.valid_before
)
@@ -624,7 +716,7 @@ class OpensshCertificate:
}
def apply_directives(directives):
def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOption]:
if any(d not in _DIRECTIVES for d in directives):
raise ValueError(f"directives must be one of {', '.join(_DIRECTIVES)}")
@@ -650,50 +742,47 @@ def apply_directives(directives):
)
def default_options():
def default_options() -> list[OpensshCertificateOption]:
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
def fingerprint(public_key):
def fingerprint(public_key: bytes) -> bytes:
"""Generates a SHA256 hash and formats output to resemble ``ssh-keygen``"""
h = sha256()
h.update(public_key)
return b"SHA256:" + b64encode(h.digest()).rstrip(b"=")
def get_cert_info_object(key_type):
def get_cert_info_object(key_type: KeyType) -> OpensshCertificateInfo:
if key_type == "rsa":
cert_info = OpensshRSACertificateInfo()
elif key_type == "dsa":
cert_info = OpensshDSACertificateInfo()
elif key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
cert_info = OpensshECDSACertificateInfo()
elif key_type == "ed25519":
cert_info = OpensshED25519CertificateInfo()
else:
raise ValueError(f"{key_type} is not a valid key type")
return cert_info
return OpensshRSACertificateInfo()
if key_type == "dsa":
return OpensshDSACertificateInfo()
if key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
return OpensshECDSACertificateInfo()
if key_type == "ed25519":
return OpensshED25519CertificateInfo()
raise ValueError(f"{key_type} is not a valid key type")
def get_option_type(name):
def get_option_type(name: str) -> t.Literal["critical", "extension"]:
if name in _CRITICAL_OPTIONS:
result = "critical"
elif name in _EXTENSIONS:
result = "extension"
else:
raise ValueError(
f"{name} is not a valid option. "
"Custom options must start with 'critical:' or 'extension:' to indicate type"
)
return result
return "critical"
if name in _EXTENSIONS:
return "extension"
raise ValueError(
f"{name} is not a valid option. "
"Custom options must start with 'critical:' or 'extension:' to indicate type"
)
def is_relative_time_string(time_string):
def is_relative_time_string(time_string: str) -> bool:
return time_string.startswith("+") or time_string.startswith("-")
def parse_option_list(option_list):
def parse_option_list(
option_list: t.Iterable[str],
) -> tuple[list[OpensshCertificateOption], list[OpensshCertificateOption]]:
critical_options = []
directives = []
extensions = []

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import os
import typing as t
from base64 import b64decode, b64encode
from getpass import getuser
from socket import gethostname
@@ -64,6 +65,27 @@ except ImportError:
CRYPTOGRAPHY_VERSION = "0.0"
_ALGORITHM_PARAMETERS = {}
if t.TYPE_CHECKING:
KeyFormat = t.Literal["SSH", "PKCS8", "PKCS1"]
KeySerializationFormat = t.Literal["PEM", "DER", "SSH"]
KeyType = t.Literal["rsa", "dsa", "ed25519", "ecdsa"]
PrivateKeyTypes = t.Union[
rsa.RSAPrivateKey,
dsa.DSAPrivateKey,
ec.EllipticCurvePrivateKey,
Ed25519PrivateKey,
]
PublicKeyTypes = t.Union[
rsa.RSAPublicKey, dsa.DSAPublicKey, ec.EllipticCurvePublicKey, Ed25519PublicKey
]
from cryptography.hazmat.primitives.asymmetric.types import (
PublicKeyTypes as AllPublicKeyTypes,
)
_TEXT_ENCODING = "UTF-8"
@@ -111,11 +133,19 @@ class InvalidSignatureError(OpenSSHError):
pass
_AsymmetricKeypair = t.TypeVar("_AsymmetricKeypair", bound="AsymmetricKeypair")
class AsymmetricKeypair:
"""Container for newly generated asymmetric key pairs or those loaded from existing files"""
@classmethod
def generate(cls, keytype="rsa", size=None, passphrase=None):
def generate(
cls: t.Type[_AsymmetricKeypair],
keytype: KeyType = "rsa",
size: int | None = None,
passphrase: bytes | None = None,
) -> _AsymmetricKeypair:
"""Returns an Asymmetric_Keypair object generated with the supplied parameters
or defaults to an unencrypted RSA-2048 key
@@ -124,19 +154,21 @@ class AsymmetricKeypair:
:passphrase: Secret of type Bytes used to encrypt the private key being generated
"""
if keytype not in _ALGORITHM_PARAMETERS.keys():
if keytype not in _ALGORITHM_PARAMETERS:
raise InvalidKeyTypeError(
f"{keytype} is not a valid keytype. Valid keytypes are {', '.join(_ALGORITHM_PARAMETERS)}"
)
if not size:
size = _ALGORITHM_PARAMETERS[keytype]["default_size"]
size = _ALGORITHM_PARAMETERS[keytype]["default_size"] # type: ignore
else:
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]:
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]: # type: ignore
raise InvalidKeySizeError(
f"{size} is not a valid key size for {keytype} keys"
)
size = t.cast(int, size)
privatekey: PrivateKeyTypes
if passphrase:
encryption_algorithm = get_encryption_algorithm(passphrase)
else:
@@ -157,7 +189,7 @@ class AsymmetricKeypair:
privatekey = Ed25519PrivateKey.generate()
elif keytype == "ecdsa":
privatekey = ec.generate_private_key(
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size],
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size], # type: ignore
)
publickey = privatekey.public_key()
@@ -172,13 +204,13 @@ class AsymmetricKeypair:
@classmethod
def load(
cls,
path,
passphrase=None,
private_key_format="PEM",
public_key_format="PEM",
no_public_key=False,
):
cls: t.Type[_AsymmetricKeypair],
path: str | os.PathLike,
passphrase: bytes | None = None,
private_key_format: KeySerializationFormat = "PEM",
public_key_format: KeySerializationFormat = "PEM",
no_public_key: bool = False,
) -> _AsymmetricKeypair:
"""Returns an Asymmetric_Keypair object loaded from the supplied file path
:path: A path to an existing private key to be loaded
@@ -197,14 +229,17 @@ class AsymmetricKeypair:
if no_public_key:
publickey = privatekey.public_key()
else:
publickey = load_publickey(path + ".pub", public_key_format)
# TODO: BUG: load_publickey() can return unsupported key types
# (Also we should check whether the public key fits the private key...)
publickey = load_publickey(path + ".pub", public_key_format) # type: ignore
# Ed25519 keys are always of size 256 and do not have a key_size attribute
if isinstance(privatekey, Ed25519PrivateKey):
size = _ALGORITHM_PARAMETERS["ed25519"]["default_size"]
size: int = _ALGORITHM_PARAMETERS["ed25519"]["default_size"] # type: ignore
else:
size = privatekey.key_size
keytype: KeyType
if isinstance(privatekey, rsa.RSAPrivateKey):
keytype = "rsa"
elif isinstance(privatekey, dsa.DSAPrivateKey):
@@ -224,7 +259,14 @@ class AsymmetricKeypair:
encryption_algorithm=encryption_algorithm,
)
def __init__(self, keytype, size, privatekey, publickey, encryption_algorithm):
def __init__(
self,
keytype: KeyType,
size: int,
privatekey: PrivateKeyTypes,
publickey: PublicKeyTypes,
encryption_algorithm: serialization.KeySerializationEncryption,
) -> None:
"""
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for the private key of this key pair
@@ -246,7 +288,7 @@ class AsymmetricKeypair:
"The private key and public key of this keypair do not match"
)
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, AsymmetricKeypair):
return NotImplemented
@@ -256,55 +298,53 @@ class AsymmetricKeypair:
self.encryption_algorithm, other.encryption_algorithm
)
def __ne__(self, other):
def __ne__(self, other: object) -> bool:
return not self == other
@property
def private_key(self):
def private_key(self) -> PrivateKeyTypes:
"""Returns the private key of this key pair"""
return self.__privatekey
@property
def public_key(self):
def public_key(self) -> PublicKeyTypes:
"""Returns the public key of this key pair"""
return self.__publickey
@property
def size(self):
def size(self) -> int:
"""Returns the size of the private key of this key pair"""
return self.__size
@property
def key_type(self):
def key_type(self) -> KeyType:
"""Returns the key type of this key pair"""
return self.__keytype
@property
def encryption_algorithm(self):
def encryption_algorithm(self) -> serialization.KeySerializationEncryption:
"""Returns the key encryption algorithm of this key pair"""
return self.__encryption_algorithm
def sign(self, data):
def sign(self, data: bytes) -> bytes:
"""Returns signature of data signed with the private key of this key pair
:data: byteslike data to sign
"""
try:
signature = self.__privatekey.sign(
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"]
return self.__privatekey.sign(
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"] # type: ignore
)
except TypeError as e:
raise InvalidDataError(e)
return signature
def verify(self, signature, data):
def verify(self, signature: bytes, data: bytes) -> None:
"""Verifies that the signature associated with the provided data was signed
by the private key of this key pair.
@@ -312,15 +352,15 @@ class AsymmetricKeypair:
:data: byteslike data signed by the provided signature
"""
try:
return self.__publickey.verify(
self.__publickey.verify(
signature,
data,
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"],
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"], # type: ignore
)
except InvalidSignature:
raise InvalidSignatureError
def update_passphrase(self, passphrase=None):
def update_passphrase(self, passphrase: bytes | None = None) -> None:
"""Updates the encryption algorithm of this key pair
:passphrase: Byte secret used to encrypt this key pair
@@ -332,11 +372,20 @@ class AsymmetricKeypair:
self.__encryption_algorithm = serialization.NoEncryption()
_OpensshKeypair = t.TypeVar("_OpensshKeypair", bound="OpensshKeypair")
class OpensshKeypair:
"""Container for OpenSSH encoded asymmetric key pairs"""
@classmethod
def generate(cls, keytype="rsa", size=None, passphrase=None, comment=None):
def generate(
cls: t.Type[_OpensshKeypair],
keytype: KeyType = "rsa",
size: int | None = None,
passphrase: bytes | None = None,
comment: str | None = None,
) -> _OpensshKeypair:
"""Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key
:keytype: One of rsa, dsa, ecdsa, ed25519
@@ -362,7 +411,12 @@ class OpensshKeypair:
)
@classmethod
def load(cls, path, passphrase=None, no_public_key=False):
def load(
cls: t.Type[_OpensshKeypair],
path: str | os.PathLike,
passphrase: bytes | None = None,
no_public_key: bool = False,
) -> _OpensshKeypair:
"""Returns an Openssh_Keypair object loaded from the supplied file path
:path: A path to an existing private key to be loaded
@@ -373,7 +427,7 @@ class OpensshKeypair:
if no_public_key:
comment = ""
else:
comment = extract_comment(path + ".pub")
comment = extract_comment(str(path) + ".pub")
asym_keypair = AsymmetricKeypair.load(
path, passphrase, "SSH", "SSH", no_public_key
@@ -391,7 +445,9 @@ class OpensshKeypair:
)
@staticmethod
def encode_openssh_privatekey(asym_keypair, key_format):
def encode_openssh_privatekey(
asym_keypair: AsymmetricKeypair, key_format: KeyFormat
) -> bytes:
"""Returns an OpenSSH encoded private key for a given keypair
:asym_keypair: Asymmetric_Keypair from the private key is extracted
@@ -422,7 +478,9 @@ class OpensshKeypair:
return encoded_privatekey
@staticmethod
def encode_openssh_publickey(asym_keypair, comment):
def encode_openssh_publickey(
asym_keypair: AsymmetricKeypair, comment: str
) -> bytes:
"""Returns an OpenSSH encoded public key for a given keypair
:asym_keypair: Asymmetric_Keypair from the public key is extracted
@@ -436,14 +494,19 @@ class OpensshKeypair:
validate_comment(comment)
encoded_publickey += (
f" {comment}".encode(encoding=_TEXT_ENCODING) if comment else b""
(b" " + comment.encode(encoding=_TEXT_ENCODING)) if comment else b""
)
return encoded_publickey
def __init__(
self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment
):
self,
asym_keypair: AsymmetricKeypair,
openssh_privatekey: bytes,
openssh_publickey: bytes,
fingerprint: str,
comment: str | None,
) -> None:
"""
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
:openssh_privatekey: An OpenSSH encoded private key
@@ -458,7 +521,7 @@ class OpensshKeypair:
self.__fingerprint = fingerprint
self.__comment = comment
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, OpensshKeypair):
return NotImplemented
@@ -468,49 +531,49 @@ class OpensshKeypair:
)
@property
def asymmetric_keypair(self):
def asymmetric_keypair(self) -> AsymmetricKeypair:
"""Returns the underlying asymmetric key pair of this OpenSSH encoded key pair"""
return self.__asym_keypair
@property
def private_key(self):
def private_key(self) -> bytes:
"""Returns the OpenSSH formatted private key of this key pair"""
return self.__openssh_privatekey
@property
def public_key(self):
def public_key(self) -> bytes:
"""Returns the OpenSSH formatted public key of this key pair"""
return self.__openssh_publickey
@property
def size(self):
def size(self) -> int:
"""Returns the size of the private key of this key pair"""
return self.__asym_keypair.size
@property
def key_type(self):
def key_type(self) -> KeyType:
"""Returns the key type of this key pair"""
return self.__asym_keypair.key_type
@property
def fingerprint(self):
def fingerprint(self) -> str:
"""Returns the fingerprint (SHA256 Hash) of the public key of this key pair"""
return self.__fingerprint
@property
def comment(self):
def comment(self) -> str | None:
"""Returns the comment applied to the OpenSSH formatted public key of this key pair"""
return self.__comment
@comment.setter
def comment(self, comment):
def comment(self, comment: str) -> bytes:
"""Updates the comment applied to the OpenSSH formatted public key of this key pair
:comment: Text to update the OpenSSH public key comment
@@ -529,7 +592,7 @@ class OpensshKeypair:
)
return self.__openssh_publickey
def update_passphrase(self, passphrase):
def update_passphrase(self, passphrase: bytes | None) -> None:
"""Updates the passphrase used to encrypt the private key of this keypair
:passphrase: Text secret used for encryption
@@ -541,18 +604,17 @@ class OpensshKeypair:
)
def load_privatekey(path, passphrase, key_format):
def load_privatekey(
path: str | os.PathLike,
passphrase: bytes | None,
key_format: KeySerializationFormat,
) -> PrivateKeyTypes:
privatekey_loaders = {
"PEM": serialization.load_pem_private_key,
"DER": serialization.load_der_private_key,
"SSH": serialization.load_ssh_private_key,
}
# OpenSSH formatted private keys are not available in Cryptography <3.0
if hasattr(serialization, "load_ssh_private_key"):
privatekey_loaders["SSH"] = serialization.load_ssh_private_key
else:
privatekey_loaders["SSH"] = serialization.load_pem_private_key
try:
privatekey_loader = privatekey_loaders[key_format]
except KeyError:
@@ -567,16 +629,16 @@ def load_privatekey(path, passphrase, key_format):
with open(path, "rb") as f:
content = f.read()
privatekey = privatekey_loader(
privatekey = privatekey_loader( # type: ignore
data=content,
password=passphrase,
)
except ValueError as e:
except ValueError as exc:
# Revert to PEM if key could not be loaded in SSH format
if key_format == "SSH":
try:
privatekey = privatekey_loaders["PEM"](
privatekey = privatekey_loaders["PEM"]( # type: ignore
data=content,
password=passphrase,
)
@@ -587,7 +649,7 @@ def load_privatekey(path, passphrase, key_format):
except UnsupportedAlgorithm as e:
raise InvalidAlgorithmError(e)
else:
raise InvalidPrivateKeyFileError(e)
raise InvalidPrivateKeyFileError(exc)
except TypeError as e:
raise InvalidPassphraseError(e)
except UnsupportedAlgorithm as e:
@@ -596,7 +658,9 @@ def load_privatekey(path, passphrase, key_format):
return privatekey
def load_publickey(path, key_format):
def load_publickey(
path: str | os.PathLike, key_format: KeySerializationFormat
) -> AllPublicKeyTypes:
publickey_loaders = {
"PEM": serialization.load_pem_public_key,
"DER": serialization.load_der_public_key,
@@ -628,20 +692,27 @@ def load_publickey(path, key_format):
return publickey
def compare_publickeys(pk1, pk2):
def compare_publickeys(pk1: PublicKeyTypes, pk2: PublicKeyTypes) -> bool:
a = isinstance(pk1, Ed25519PublicKey)
b = isinstance(pk2, Ed25519PublicKey)
if a or b:
if not a or not b:
return False
a = pk1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
b = pk2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
return a == b
a_bytes = pk1.public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)
b_bytes = pk2.public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)
return a_bytes == b_bytes
else:
return pk1.public_numbers() == pk2.public_numbers()
return pk1.public_numbers() == pk2.public_numbers() # type: ignore
def compare_encryption_algorithms(ea1, ea2):
def compare_encryption_algorithms(
ea1: serialization.KeySerializationEncryption,
ea2: serialization.KeySerializationEncryption,
) -> bool:
if isinstance(ea1, serialization.NoEncryption) and isinstance(
ea2, serialization.NoEncryption
):
@@ -654,19 +725,21 @@ def compare_encryption_algorithms(ea1, ea2):
return False
def get_encryption_algorithm(passphrase):
def get_encryption_algorithm(
passphrase: bytes,
) -> serialization.KeySerializationEncryption:
try:
return serialization.BestAvailableEncryption(passphrase)
except ValueError as e:
raise InvalidPassphraseError(e)
def validate_comment(comment):
def validate_comment(comment: str) -> None:
if not hasattr(comment, "encode"):
raise InvalidCommentError(f"{comment} cannot be encoded to text")
def extract_comment(path):
def extract_comment(path: str | os.PathLike) -> str:
if not os.path.exists(path):
raise InvalidPublicKeyFileError(f"No file was found at {path}")
@@ -684,7 +757,7 @@ def extract_comment(path):
return comment
def calculate_fingerprint(openssh_publickey):
def calculate_fingerprint(openssh_publickey: bytes) -> str:
digest = hashes.Hash(hashes.SHA256())
decoded_pubkey = b64decode(openssh_publickey.split(b" ")[1])
digest.update(decoded_pubkey)

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import os
import re
import typing as t
from contextlib import contextmanager
from struct import Struct
@@ -38,17 +39,20 @@ _UINT64 = Struct(b"!Q")
_UINT64_MAX = 0xFFFFFFFFFFFFFFFF
def any_in(sequence, *elements):
_T = t.TypeVar("_T")
def any_in(sequence: t.Iterable[_T], *elements: _T) -> bool:
return any(e in sequence for e in elements)
def file_mode(path):
def file_mode(path: str | os.PathLike) -> int:
if not os.path.exists(path):
return 0o000
return os.stat(path).st_mode & 0o777
def parse_openssh_version(version_string):
def parse_openssh_version(version_string: str) -> str | None:
"""Parse the version output of ssh -V and return version numbers that can be compared"""
parsed_result = re.match(
@@ -63,7 +67,7 @@ def parse_openssh_version(version_string):
@contextmanager
def secure_open(path, mode):
def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode)
try:
yield fd
@@ -71,7 +75,7 @@ def secure_open(path, mode):
os.close(fd)
def secure_write(path, mode, content):
def secure_write(path: str | os.PathLike, mode: int, content: bytes) -> None:
with secure_open(path, mode) as fd:
os.write(fd, content)
@@ -84,35 +88,35 @@ class OpensshParser:
UINT32_OFFSET = 4
UINT64_OFFSET = 8
def __init__(self, data):
def __init__(self, data: bytes | bytearray) -> None:
if not isinstance(data, (bytes, bytearray)):
raise TypeError(f"Data must be bytes-like not {type(data)}")
self._data = memoryview(data)
self._pos = 0
def boolean(self):
def boolean(self) -> bool:
next_pos = self._check_position(self.BOOLEAN_OFFSET)
value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def uint32(self):
def uint32(self) -> int:
next_pos = self._check_position(self.UINT32_OFFSET)
value = _UINT32.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def uint64(self):
def uint64(self) -> int:
next_pos = self._check_position(self.UINT64_OFFSET)
value = _UINT64.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def string(self):
def string(self) -> bytes:
length = self.uint32()
next_pos = self._check_position(length)
@@ -122,15 +126,15 @@ class OpensshParser:
# Cast to bytes is required as a memoryview slice is itself a memoryview
return bytes(value)
def mpint(self):
def mpint(self) -> int:
return self._big_int(self.string(), "big", signed=True)
def name_list(self):
def name_list(self) -> list[str]:
raw_string = self.string()
return raw_string.decode("ASCII").split(",")
# Convenience function, but not an official data type from SSH
def string_list(self):
def string_list(self) -> list[bytes]:
result = []
raw_string = self.string()
@@ -142,7 +146,7 @@ class OpensshParser:
return result
# Convenience function, but not an official data type from SSH
def option_list(self):
def option_list(self) -> list[tuple[bytes, bytes]]:
result = []
raw_string = self.string()
@@ -159,15 +163,15 @@ class OpensshParser:
return result
def seek(self, offset):
def seek(self, offset: int) -> int:
self._pos = self._check_position(offset)
return self._pos
def remaining_bytes(self):
def remaining_bytes(self) -> int:
return len(self._data) - self._pos
def _check_position(self, offset):
def _check_position(self, offset: int) -> int:
if self._pos + offset > len(self._data):
raise ValueError(f"Insufficient data remaining at position: {self._pos}")
elif self._pos + offset < 0:
@@ -176,8 +180,8 @@ class OpensshParser:
return self._pos + offset
@classmethod
def signature_data(cls, signature_string):
signature_data = {}
def signature_data(cls, signature_string: bytes) -> dict[str, bytes | int]:
signature_data: dict[str, bytes | int] = {}
parser = cls(signature_string)
signature_type = parser.string()
@@ -205,14 +209,19 @@ class OpensshParser:
signature_data["R"] = cls._big_int(signature_blob[:32], "little")
signature_data["S"] = cls._big_int(signature_blob[32:], "little")
else:
raise ValueError(f"{signature_type} is not a valid signature type")
raise ValueError(f"{signature_type!r} is not a valid signature type")
signature_data["signature_type"] = signature_type
return signature_data
@classmethod
def _big_int(cls, raw_string, byte_order, signed=False):
def _big_int(
cls,
raw_string: bytes,
byte_order: t.Literal["big", "little"],
signed: bool = False,
) -> int:
if byte_order not in ("big", "little"):
raise ValueError(
f"Byte_order must be one of (big, little) not {byte_order}"
@@ -230,18 +239,16 @@ class _OpensshWriter:
in validating parsed material.
"""
def __init__(self, buffer=None):
def __init__(self, buffer: bytearray | None = None):
if buffer is not None:
if not isinstance(buffer, (bytes, bytearray)):
raise TypeError(
f"Buffer must be a bytes-like object not {type(buffer)}"
)
if not isinstance(buffer, bytearray):
raise TypeError(f"Buffer must be a bytearray, not {type(buffer)}")
else:
buffer = bytearray()
self._buff = buffer
self._buff: bytearray = buffer
def boolean(self, value):
def boolean(self, value: bool) -> t.Self:
if not isinstance(value, bool):
raise TypeError(f"Value must be of type bool not {type(value)}")
@@ -249,7 +256,7 @@ class _OpensshWriter:
return self
def uint32(self, value):
def uint32(self, value: int) -> t.Self:
if not isinstance(value, int):
raise TypeError(f"Value must be of type int not {type(value)}")
if value < 0 or value > _UINT32_MAX:
@@ -261,7 +268,7 @@ class _OpensshWriter:
return self
def uint64(self, value):
def uint64(self, value: int) -> t.Self:
if not isinstance(value, int):
raise TypeError(f"Value must be of type int not {type(value)}")
if value < 0 or value > _UINT64_MAX:
@@ -273,7 +280,7 @@ class _OpensshWriter:
return self
def string(self, value):
def string(self, value: bytes | bytearray) -> t.Self:
if not isinstance(value, (bytes, bytearray)):
raise TypeError(f"Value must be bytes-like not {type(value)}")
self.uint32(len(value))
@@ -281,7 +288,7 @@ class _OpensshWriter:
return self
def mpint(self, value):
def mpint(self, value: int) -> t.Self:
if not isinstance(value, int):
raise TypeError(f"Value must be of type int not {type(value)}")
@@ -289,7 +296,7 @@ class _OpensshWriter:
return self
def name_list(self, value):
def name_list(self, value: list[str]) -> t.Self:
if not isinstance(value, list):
raise TypeError(f"Value must be a list of byte strings not {type(value)}")
@@ -300,7 +307,7 @@ class _OpensshWriter:
return self
def string_list(self, value):
def string_list(self, value: list[bytes]) -> t.Self:
if not isinstance(value, list):
raise TypeError(f"Value must be a list of byte string not {type(value)}")
@@ -312,7 +319,7 @@ class _OpensshWriter:
return self
def option_list(self, value):
def option_list(self, value: list[tuple[bytes, bytes]]) -> t.Self:
if not isinstance(value, list) or (value and not isinstance(value[0], tuple)):
raise TypeError("Value must be a list of tuples")
@@ -327,7 +334,7 @@ class _OpensshWriter:
return self
@staticmethod
def _int_to_mpint(num):
def _int_to_mpint(num: int) -> bytes:
byte_length = (num.bit_length() + 7) // 8
try:
return num.to_bytes(byte_length, "big", signed=True)
@@ -335,5 +342,5 @@ class _OpensshWriter:
except OverflowError:
return num.to_bytes(byte_length + 1, "big", signed=True)
def bytes(self):
def bytes(self) -> bytes:
return bytes(self._buff)

View File

@@ -10,7 +10,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.math impor
)
def th(number):
def th(number: int) -> str:
abs_number = abs(number)
mod_10 = abs_number % 10
mod_100 = abs_number % 100
@@ -24,13 +24,13 @@ def th(number):
return "th"
def parse_serial(value):
def parse_serial(value: str | bytes) -> int:
"""
Given a colon-separated string of hexadecimal byte values, converts it to an integer.
"""
value = to_native(value)
value_str = to_native(value)
result = 0
for i, part in enumerate(value.split(":")):
for i, part in enumerate(value_str.split(":")):
try:
part_value = int(part, 16)
if part_value < 0 or part_value > 255:
@@ -43,11 +43,11 @@ def parse_serial(value):
return result
def to_serial(value):
def to_serial(value: int) -> str:
"""
Given an integer, converts its absolute value to a colon-separated string of hexadecimal byte values.
"""
value = convert_int_to_hex(value).upper()
if len(value) % 2 != 0:
value = "0" + value
return ":".join(value[i : i + 2] for i in range(0, len(value), 2))
value_str = convert_int_to_hex(value).upper()
if len(value_str) % 2 != 0:
value_str = f"0{value_str}"
return ":".join(value_str[i : i + 2] for i in range(0, len(value_str), 2))

View File

@@ -13,37 +13,16 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.basic impo
)
try:
UTC = datetime.timezone.utc
except AttributeError:
_DURATION_ZERO = datetime.timedelta(0)
class _UTCClass(datetime.tzinfo):
def utcoffset(self, dt):
return _DURATION_ZERO
def dst(self, dt):
return _DURATION_ZERO
def tzname(self, dt):
return "UTC"
def fromutc(self, dt):
return dt
def __repr__(self):
return "UTC"
UTC = _UTCClass()
UTC = datetime.timezone.utc
def get_now_datetime(with_timezone):
def get_now_datetime(with_timezone: bool) -> datetime.datetime:
if with_timezone:
return datetime.datetime.now(tz=UTC)
return datetime.datetime.utcnow()
def ensure_utc_timezone(timestamp):
def ensure_utc_timezone(timestamp: datetime.datetime) -> datetime.datetime:
if timestamp.tzinfo is UTC:
return timestamp
if timestamp.tzinfo is None:
@@ -52,7 +31,7 @@ def ensure_utc_timezone(timestamp):
return timestamp.astimezone(UTC)
def remove_timezone(timestamp):
def remove_timezone(timestamp: datetime.datetime) -> datetime.datetime:
# Convert to native datetime object
if timestamp.tzinfo is None:
return timestamp
@@ -61,26 +40,34 @@ def remove_timezone(timestamp):
return timestamp.replace(tzinfo=None)
def add_or_remove_timezone(timestamp, with_timezone):
def add_or_remove_timezone(
timestamp: datetime.datetime, with_timezone: bool
) -> datetime.datetime:
return (
ensure_utc_timezone(timestamp) if with_timezone else remove_timezone(timestamp)
)
def get_epoch_seconds(timestamp):
def get_epoch_seconds(timestamp: datetime.datetime) -> float:
if timestamp.tzinfo is None:
# timestamp.timestamp() is offset by the local timezone if timestamp has no timezone
timestamp = ensure_utc_timezone(timestamp)
return timestamp.timestamp()
def from_epoch_seconds(timestamp, with_timezone):
def from_epoch_seconds(
timestamp: int | float, with_timezone: bool
) -> datetime.datetime:
if with_timezone:
return datetime.datetime.fromtimestamp(timestamp, UTC)
return datetime.datetime.utcfromtimestamp(timestamp)
def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=None):
def convert_relative_to_datetime(
relative_time_string: str,
with_timezone: bool = False,
now: datetime.datetime | None = None,
) -> datetime.datetime | None:
"""Get a datetime.datetime or None from a string in the time format described in sshd_config(5)"""
parsed_result = re.match(
@@ -115,7 +102,12 @@ def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=
return now - offset
def get_relative_time_option(input_string, input_name, with_timezone=False, now=None):
def get_relative_time_option(
input_string: str,
input_name: str,
with_timezone: bool = False,
now: datetime.datetime | None = None,
) -> datetime.datetime:
"""
Return an absolute timespec if a relative timespec or an ASN1 formatted
string is provided.
@@ -129,9 +121,12 @@ def get_relative_time_option(input_string, input_name, with_timezone=False, now=
)
# Relative time
if result.startswith("+") or result.startswith("-"):
return convert_relative_to_datetime(
result, with_timezone=with_timezone, now=now
)
res = convert_relative_to_datetime(result, with_timezone=with_timezone, now=now)
if res is None:
raise OpenSSLObjectError(
f'The timespec "{input_string}" for {input_name} is invalid'
)
return res
# Absolute time
for date_fmt, length in [
(

View File

@@ -165,6 +165,7 @@ account_uri:
"""
import base64
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
@@ -180,7 +181,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec()
argument_spec.update_argspec(
terms_agreed=dict(type="bool", default=False),
@@ -204,24 +205,24 @@ def main():
),
)
argument_spec.update(
mutually_exclusive=(["new_account_key_src", "new_account_key_content"],),
required_if=(
mutually_exclusive=[("new_account_key_src", "new_account_key_content")],
required_if=[
# Make sure that for state == changed_key, one of
# new_account_key_src and new_account_key_content are specified
[
(
"state",
"changed_key",
["new_account_key_src", "new_account_key_content"],
True,
],
),
),
],
)
module = argument_spec.create_ansible_module(supports_check_mode=True)
backend = create_backend(module, True)
if module.params["external_account_binding"]:
# Make sure padding is there
key = module.params["external_account_binding"]["key"]
key: str = module.params["external_account_binding"]["key"]
if len(key) % 4 != 0:
key = key + ("=" * (4 - (len(key) % 4)))
# Make sure key is Base64 encoded
@@ -237,24 +238,25 @@ def main():
client = ACMEClient(module, backend)
account = ACMEAccount(client)
changed = False
state = module.params.get("state")
diff_before = {}
diff_after = {}
state: t.Literal["present", "absent", "changed_key"] = module.params["state"]
diff_before: dict[str, t.Any] = {}
diff_after: dict[str, t.Any] = {}
if state == "absent":
created, account_data = account.setup_account(allow_creation=False)
if account_data:
diff_before = dict(account_data)
diff_before["public_account_key"] = client.account_key_data["jwk"]
if client.account_key_data:
diff_before["public_account_key"] = client.account_key_data["jwk"]
if created:
raise AssertionError("Unwanted account creation")
if account_data is not None:
# Account is not yet deactivated
if not module.check_mode:
# Deactivate it
payload = {"status": "deactivated"}
deactivate_payload = {"status": "deactivated"}
result, info = client.send_signed_request(
client.account_uri,
payload,
t.cast(str, client.account_uri),
deactivate_payload,
error_msg="Failed to deactivate account",
expected_status_codes=[200],
)
@@ -278,13 +280,15 @@ def main():
diff_before = {}
else:
diff_before = dict(account_data)
diff_before["public_account_key"] = client.account_key_data["jwk"]
if client.account_key_data:
diff_before["public_account_key"] = client.account_key_data["jwk"]
updated = False
if not created:
updated, account_data = account.update_account(account_data, contact)
changed = created or updated
diff_after = dict(account_data)
diff_after["public_account_key"] = client.account_key_data["jwk"]
if client.account_key_data:
diff_after["public_account_key"] = client.account_key_data["jwk"]
elif state == "changed_key":
# Parse new account key
try:
@@ -306,7 +310,8 @@ def main():
msg="Account does not exist or is deactivated."
)
diff_before = dict(account_data)
diff_before["public_account_key"] = client.account_key_data["jwk"]
if client.account_key_data:
diff_before["public_account_key"] = client.account_key_data["jwk"]
# Now we can start the account key rollover
if not module.check_mode:
# Compose inner signed message
@@ -317,12 +322,12 @@ def main():
"jwk": new_key_data["jwk"],
"url": url,
}
payload = {
change_key_payload = {
"account": client.account_uri,
"newKey": new_key_data["jwk"], # specified in draft 12 and older
"oldKey": client.account_jwk, # specified in draft 13 and newer
}
data = client.sign_request(protected, payload, new_key_data)
data = client.sign_request(protected, change_key_payload, new_key_data)
# Send request and verify result
result, info = client.send_signed_request(
url,
@@ -332,8 +337,9 @@ def main():
)
if module._diff:
client.account_key_data = new_key_data
client.account_jws_header["alg"] = new_key_data["alg"]
diff_after = account.get_account_data()
if client.account_jws_header:
client.account_jws_header["alg"] = new_key_data["alg"]
diff_after = account.get_account_data() or {}
elif module._diff:
# Kind of fake diff_after
diff_after = dict(diff_before)

View File

@@ -204,6 +204,8 @@ order_uris:
version_added: 1.5.0
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
)
@@ -220,48 +222,55 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
)
def get_orders_list(module, client, orders_url):
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def get_orders_list(
module: AnsibleModule, client: ACMEClient, orders_url: str
) -> list[str]:
"""
Retrieves orders list (handles pagination).
Retrieves order URL list (handles pagination).
"""
orders = []
while orders_url:
orders: list[str] = []
next_orders_url: str | None = orders_url
while next_orders_url:
# Get part of orders list
res, info = client.get_request(
orders_url, parse_json_result=True, fail_on_error=True
next_orders_url, parse_json_result=True, fail_on_error=True
)
if not res.get("orders"):
if orders:
module.warn(
f"When retrieving orders list part {orders_url}, got empty result list"
f"When retrieving orders list part {next_orders_url}, got empty result list"
)
break
# Add order URLs to result list
orders.extend(res["orders"])
# Extract URL of next part of results list
new_orders_url = []
new_orders_url: list[str | None] = []
def f(link, relation):
def f(link: str, relation: str) -> None:
if relation == "next":
new_orders_url.append(link)
process_links(info, f)
new_orders_url.append(None)
previous_orders_url, orders_url = orders_url, new_orders_url.pop(0)
if orders_url == previous_orders_url:
previous_orders_url, next_orders_url = next_orders_url, new_orders_url.pop(0)
if next_orders_url == previous_orders_url:
# Prevent infinite loop
orders_url = None
next_orders_url = None
return orders
def get_order(client, order_url):
def get_order(client: ACMEClient, order_url: str) -> dict[str, t.Any]:
"""
Retrieve order data.
"""
return client.get_request(order_url, parse_json_result=True, fail_on_error=True)[0]
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec()
argument_spec.update_argspec(
retrieve_orders=dict(
@@ -282,16 +291,19 @@ def main():
)
if created:
raise AssertionError("Unwanted account creation")
result = {
result: dict[str, t.Any] = {
"changed": False,
"exists": client.account_uri is not None,
"account_uri": client.account_uri,
"exists": False,
"account_uri": None,
}
if client.account_uri is not None:
if client.account_uri is not None and account_data:
result["account_uri"] = client.account_uri
result["exists"] = True
# Make sure promised data is there
if "contact" not in account_data:
account_data["contact"] = []
account_data["public_account_key"] = client.account_key_data["jwk"]
if client.account_key_data:
account_data["public_account_key"] = client.account_key_data["jwk"]
result["account"] = account_data
# Retrieve orders list
if (

View File

@@ -94,6 +94,8 @@ renewal_info:
sample: '2024-04-29T01:17:10.236921+00:00'
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
ACMEClient,
create_backend,
@@ -104,15 +106,15 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_account=False)
argument_spec.update_argspec(
certificate_path=dict(type="path"),
certificate_content=dict(type="str"),
)
argument_spec.update(
required_one_of=(["certificate_path", "certificate_content"],),
mutually_exclusive=(["certificate_path", "certificate_content"],),
required_one_of=[("certificate_path", "certificate_content")],
mutually_exclusive=[("certificate_path", "certificate_content")],
)
module = argument_spec.create_ansible_module(supports_check_mode=True)
backend = create_backend(module, True)

View File

@@ -562,6 +562,7 @@ all_chains:
"""
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
@@ -592,6 +593,17 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
CertificateInformation,
CryptoBackend,
)
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
Authorization,
)
NO_CHALLENGE = "no challenge"
@@ -602,7 +614,7 @@ class ACMECertificateClient:
certificates.
"""
def __init__(self, module, backend):
def __init__(self, module: AnsibleModule, backend: CryptoBackend):
self.module = module
self.version = module.params["acme_version"]
self.challenge = module.params["challenge"]
@@ -618,9 +630,9 @@ class ACMECertificateClient:
self.account = ACMEAccount(self.client)
self.directory = self.client.directory
self.data = module.params["data"]
self.authorizations = None
self.authorizations: dict[str, Authorization] | None = None
self.cert_days = -1
self.order = None
self.order: Order | None = None
self.order_uri = self.data.get("order_uri") if self.data else None
self.all_chains = None
self.select_chain_matcher = []
@@ -662,7 +674,6 @@ class ACMECertificateClient:
contact.append("mailto:" + module.params["account_email"])
created, account_data = self.account.setup_account(
contact,
agreement=module.params.get("agreement"),
terms_agreed=module.params.get("terms_agreed"),
allow_creation=modify_account,
)
@@ -681,7 +692,7 @@ class ACMECertificateClient:
csr_filename=self.csr, csr_content=self.csr_content
)
def is_first_step(self):
def is_first_step(self) -> bool:
"""
Return True if this is the first execution of this module, i.e. if a
sufficient data object from a first run has not been provided.
@@ -692,7 +703,7 @@ class ACMECertificateClient:
# stored in self.order_uri by the constructor).
return self.order_uri is None
def _get_cert_info_or_none(self):
def _get_cert_info_or_none(self) -> CertificateInformation | None:
if self.module.params.get("dest"):
filename = self.module.params["dest"]
else:
@@ -701,7 +712,7 @@ class ACMECertificateClient:
return None
return self.client.backend.get_cert_information(cert_filename=filename)
def start_challenges(self):
def start_challenges(self) -> None:
"""
Create new authorizations for all identifiers of the CSR,
respectively start a new order for ACME v2.
@@ -733,13 +744,16 @@ class ACMECertificateClient:
self.authorizations.update(self.order.authorizations)
self.changed = True
def get_challenges_data(self, first_step):
def get_challenges_data(
self, first_step: bool
) -> tuple[dict[str, t.Any], dict[str, list[str]]]:
"""
Get challenge details for the chosen challenge type.
Return a tuple of generic challenge details, and specialized DNS challenge details.
"""
# Get general challenge data
data = {}
assert self.authorizations is not None
data: dict[str, t.Any] = {}
data_dns: dict[str, list[str]] = {}
for type_identifier, authz in self.authorizations.items():
identifier_type, identifier = split_identifier(type_identifier)
# Skip valid authentications: their challenges are already valid
@@ -747,7 +761,9 @@ class ACMECertificateClient:
if authz.status == "valid":
continue
# We drop the type from the key to preserve backwards compatibility
data[authz.identifier] = authz.get_challenge_data(self.client)
challenges = authz.get_challenge_data(self.client)
assert authz.identifier is not None
data[authz.identifier] = challenges
if (
first_step
and self.challenge is not None
@@ -756,10 +772,7 @@ class ACMECertificateClient:
raise ModuleFailException(
f"Found no challenge of type '{self.challenge}' for identifier {type_identifier}!"
)
# Get DNS challenge data
data_dns = {}
if self.challenge == "dns-01":
for identifier, challenges in data.items():
if self.challenge == "dns-01":
if self.challenge in challenges:
values = data_dns.get(challenges[self.challenge]["record"])
if values is None:
@@ -768,7 +781,7 @@ class ACMECertificateClient:
values.append(challenges[self.challenge]["resource_value"])
return data, data_dns
def finish_challenges(self):
def finish_challenges(self) -> None:
"""
Verify challenges for all identifiers of the CSR.
"""
@@ -777,6 +790,7 @@ class ACMECertificateClient:
# Step 1: obtain challenge information
# For ACME v2, we obtain the order object by fetching the
# order URI, and extract the information from there.
assert self.order_uri is not None
self.order = Order.from_url(self.client, self.order_uri)
self.order.load_authorizations(self.client)
self.authorizations.update(self.order.authorizations)
@@ -799,7 +813,9 @@ class ACMECertificateClient:
# Step 3: wait for authzs to validate
wait_for_validation(authzs_to_wait_for, self.client)
def download_alternate_chains(self, cert):
def download_alternate_chains(
self, cert: CertificateChain
) -> list[CertificateChain]:
alternate_chains = []
for alternate in cert.alternates:
try:
@@ -812,7 +828,9 @@ class ACMECertificateClient:
alternate_chains.append(alt_cert)
return alternate_chains
def find_matching_chain(self, chains):
def find_matching_chain(
self, chains: t.Iterable[CertificateChain]
) -> CertificateChain | None:
for criterium_idx, matcher in enumerate(self.select_chain_matcher):
for chain in chains:
if matcher.match(chain):
@@ -822,12 +840,13 @@ class ACMECertificateClient:
return chain
return None
def get_certificate(self):
def get_certificate(self) -> None:
"""
Request a new certificate and write it to the destination file.
First verifies whether all authorizations are valid; if not, aborts
with an error.
"""
assert self.authorizations is not None
for identifier_type, identifier in self.identifiers:
authz = self.authorizations.get(
normalize_combined_identifier(
@@ -844,7 +863,9 @@ class ACMECertificateClient:
module=self.module,
)
assert self.order is not None
self.order.finalize(self.client, pem_to_der(self.csr, self.csr_content))
assert self.order.certificate_uri is not None
cert = CertificateChain.download(self.client, self.order.certificate_uri)
if self.module.params["retrieve_all_alternates"] or self.select_chain_matcher:
# Retrieve alternate chains
@@ -887,12 +908,13 @@ class ACMECertificateClient:
):
self.changed = True
def deactivate_authzs(self):
def deactivate_authzs(self) -> None:
"""
Deactivates all valid authz's. Does not raise exceptions.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
https://tools.ietf.org/html/rfc8555#section-7.5.2
"""
assert self.authorizations is not None
for authz in self.authorizations.values():
try:
authz.deactivate(self.client)
@@ -905,7 +927,7 @@ class ACMECertificateClient:
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=True)
argument_spec.argument_spec["csr"]["aliases"] = ["src"]
argument_spec.update_argspec(
@@ -981,7 +1003,7 @@ def main():
else:
client = ACMECertificateClient(module, backend)
client.cert_days = cert_days
other = dict()
other: dict[str, t.Any] = {}
is_first_step = client.is_first_step()
if is_first_step:
# First run: start challenges / start new order
@@ -998,6 +1020,7 @@ def main():
client.deactivate_authzs()
data, data_dns = client.get_challenges_data(first_step=is_first_step)
auths = dict()
assert client.authorizations is not None
for k, v in client.authorizations.items():
# Remove "type:" from key
auths[v.identifier] = v.to_json()

View File

@@ -51,6 +51,8 @@ EXAMPLES = r"""
RETURN = """#"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
)
@@ -65,7 +67,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
from ansible_collections.community.crypto.plugins.module_utils.acme.orders import Order
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec()
argument_spec.update_argspec(
order_uri=dict(type="str", required=True),

View File

@@ -371,6 +371,8 @@ account_uri:
type: str
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend,
create_default_argspec,
@@ -383,7 +385,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=True)
argument_spec.update_argspec(
deactivate_authzs=dict(type="bool", default=True),

View File

@@ -317,6 +317,8 @@ selected_chain:
returned: always
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend,
create_default_argspec,
@@ -329,7 +331,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.acme.certificates import (
CertificateChain,
)
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=True)
argument_spec.update_argspec(
order_uri=dict(type="str", required=True),
@@ -375,6 +383,7 @@ def main():
or module.params["retrieve_all_alternates"]
)
changed = False
alternate_chains: list[CertificateChain] | None
if order.status == "valid":
# Step 2 and 3: download certificate(s) and chain(s)
cert, alternate_chains = client.download_certificate(

View File

@@ -357,6 +357,8 @@ authorizations_by_status:
returned: always
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend,
create_default_argspec,
@@ -369,7 +371,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=False)
argument_spec.update_argspec(
order_uri=dict(type="str", required=True),
@@ -381,8 +383,8 @@ def main():
try:
client = ACMECertificateClient(module, backend)
order = client.load_order()
authorizations_by_identifier = dict()
authorizations_by_status = {
authorizations_by_identifier: dict[str, dict[str, t.Any]] = {}
authorizations_by_status: dict[str, list[str]] = {
"pending": [],
"invalid": [],
"valid": [],
@@ -392,7 +394,8 @@ def main():
}
for identifier, authz in order.authorizations.items():
authorizations_by_identifier[identifier] = authz.to_json()
authorizations_by_status[authz.status].append(identifier)
if authz.status is not None:
authorizations_by_status[authz.status].append(identifier)
module.exit_json(
changed=False,
account_uri=client.client.account_uri,

View File

@@ -229,6 +229,8 @@ validating_challenges:
returned: always
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend,
create_default_argspec,
@@ -241,7 +243,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
Authorization,
)
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=False)
argument_spec.update_argspec(
order_uri=dict(type="str", required=True),
@@ -271,10 +279,12 @@ def main():
missing_challenge_authzs = [k for k, v in challenges.items() if v is None]
if missing_challenge_authzs:
missing_challenge_authzs = ", ".join(sorted(missing_challenge_authzs))
missing_challenge_authzs_str = ", ".join(
sorted(missing_challenge_authzs)
)
raise ModuleFailException(
"The challenge parameter must be supplied if there are pending authorizations."
f" The following authorizations are pending: {missing_challenge_authzs}"
f" The following authorizations are pending: {missing_challenge_authzs_str}"
)
bad_challenge_authzs = [
@@ -293,11 +303,13 @@ def main():
f"The following authorizations do not support the selected challenges: {authz_challenges_pairs}"
)
def is_pending(authz: Authorization) -> bool:
challenge_name = challenges[authz.combined_identifier]
challenge_obj = authz.find_challenge(challenge_name)
return challenge_obj is not None and challenge_obj.status == "pending"
really_pending_authzs = [
authz
for authz in pending_authzs
if authz.find_challenge(challenges[authz.combined_identifier]).status
== "pending"
authz for authz in pending_authzs if is_pending(authz)
]
# Step 4: validate pending authorizations
@@ -320,7 +332,7 @@ def main():
identifier_type=authz.identifier_type,
authz_url=authz.url,
challenge_type=challenge_type,
challenge_url=challenge.url,
challenge_url=challenge.url if challenge else None,
)
for authz, challenge_type, challenge in authzs_with_challenges_to_wait_for
],

View File

@@ -160,6 +160,7 @@ cert_id:
import os
import random
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
ACMEClient,
@@ -175,7 +176,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_account=False)
argument_spec.update_argspec(
certificate_path=dict(type="path"),
@@ -190,7 +191,7 @@ def main():
treat_parsing_error_as_non_existing=dict(type="bool", default=False),
)
argument_spec.update(
mutually_exclusive=(["certificate_path", "certificate_content"],),
mutually_exclusive=[("certificate_path", "certificate_content")],
)
module = argument_spec.create_ansible_module(supports_check_mode=True)
backend = create_backend(module, True)
@@ -203,7 +204,7 @@ def main():
supports_ari=False,
)
def complete(should_renew, **kwargs):
def complete(should_renew: bool, **kwargs) -> t.NoReturn:
result["should_renew"] = should_renew
result.update(kwargs)
module.exit_json(**result)

View File

@@ -110,6 +110,8 @@ EXAMPLES = r"""
RETURN = """#"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount,
)
@@ -129,7 +131,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(require_account_key=False)
argument_spec.update_argspec(
private_key_src=dict(type="path"),
@@ -139,22 +141,22 @@ def main():
revoke_reason=dict(type="int"),
)
argument_spec.update(
required_one_of=(
[
required_one_of=[
(
"account_key_src",
"account_key_content",
"private_key_src",
"private_key_content",
],
),
mutually_exclusive=(
[
),
],
mutually_exclusive=[
(
"account_key_src",
"account_key_content",
"private_key_src",
"private_key_content",
],
),
),
],
)
module = argument_spec.create_ansible_module()
backend = create_backend(module, False)
@@ -164,9 +166,9 @@ def main():
account = ACMEAccount(client)
# Load certificate
certificate = pem_to_der(module.params.get("certificate"))
certificate = nopad_b64(certificate)
certificate_b64 = nopad_b64(certificate)
# Construct payload
payload = {"certificate": certificate}
payload = {"certificate": certificate_b64}
if module.params.get("revoke_reason") is not None:
payload["reason"] = module.params.get("revoke_reason")
endpoint = client.directory["revokeCert"]

View File

@@ -149,6 +149,7 @@ regular_certificate:
import base64
import datetime
import ipaddress
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_text
@@ -173,10 +174,13 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
try:
import cryptography
import cryptography.hazmat.backends
import cryptography.hazmat.primitives.asymmetric.dh
import cryptography.hazmat.primitives.asymmetric.ec
import cryptography.hazmat.primitives.asymmetric.padding
import cryptography.hazmat.primitives.asymmetric.rsa
import cryptography.hazmat.primitives.asymmetric.utils
import cryptography.hazmat.primitives.asymmetric.x448
import cryptography.hazmat.primitives.asymmetric.x25519
import cryptography.hazmat.primitives.hashes
import cryptography.hazmat.primitives.serialization
import cryptography.x509
@@ -186,7 +190,7 @@ except ImportError:
# Convert byte string to ASN1 encoded octet string
def encode_octet_string(octet_string):
def encode_octet_string(octet_string: bytes) -> bytes:
if len(octet_string) >= 128:
raise ModuleFailException(
"Cannot handle octet strings with more than 128 bytes"
@@ -194,7 +198,7 @@ def encode_octet_string(octet_string):
return bytes([0x4, len(octet_string)]) + octet_string
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
challenge=dict(type="str", required=True, choices=["tls-alpn-01"]),
@@ -213,16 +217,16 @@ def main():
try:
# Get parameters
challenge = module.params["challenge"]
challenge_data = module.params["challenge_data"]
challenge: t.Literal["tls-alpn-01"] = module.params["challenge"]
challenge_data: dict[str, t.Any] = module.params["challenge_data"]
# Get hold of private key
private_key_content = module.params.get("private_key_content")
private_key_passphrase = module.params.get("private_key_passphrase")
if private_key_content is None:
private_key_content_str: str | None = module.params["private_key_content"]
private_key_passphrase: str | None = module.params["private_key_passphrase"]
if private_key_content_str is None:
private_key_content = read_file(module.params["private_key_src"])
else:
private_key_content = to_bytes(private_key_content)
private_key_content = to_bytes(private_key_content_str)
try:
private_key = (
cryptography.hazmat.primitives.serialization.load_pem_private_key(
@@ -236,6 +240,17 @@ def main():
)
except Exception as e:
raise ModuleFailException(f"Error while loading private key: {e}")
if isinstance(
private_key,
(
cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
),
):
raise ModuleFailException(
f"Cannot use private key type {type(private_key)}"
)
# Some common attributes
domain = to_text(challenge_data["resource"])
@@ -246,6 +261,7 @@ def main():
now = get_now_datetime(with_timezone=CRYPTOGRAPHY_TIMEZONE)
not_valid_before = now
not_valid_after = now + datetime.timedelta(days=10)
san: cryptography.x509.GeneralName
if identifier_type == "dns":
san = cryptography.x509.DNSName(identifier)
elif identifier_type == "ip":

View File

@@ -223,6 +223,8 @@ output_json:
- '...'
"""
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
ACMEClient,
@@ -235,7 +237,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def main():
def main() -> t.NoReturn:
argument_spec = create_default_argspec(require_account_key=False)
argument_spec.update_argspec(
url=dict(type="str"),
@@ -246,17 +248,17 @@ def main():
fail_on_acme_error=dict(type="bool", default=True),
)
argument_spec.update(
required_if=(
["method", "get", ["url"]],
["method", "post", ["url", "content"]],
["method", "get", ["account_key_src", "account_key_content"], True],
["method", "post", ["account_key_src", "account_key_content"], True],
),
required_if=[
("method", "get", ["url"]),
("method", "post", ["url", "content"]),
("method", "get", ["account_key_src", "account_key_content"], True),
("method", "post", ["account_key_src", "account_key_content"], True),
],
)
module = argument_spec.create_ansible_module()
backend = create_backend(module, False)
result = dict()
result: dict[str, t.Any] = {}
changed = False
try:
# Get hold of ACMEClient and ACMEAccount objects (includes directory)

View File

@@ -121,6 +121,7 @@ complete_chain:
"""
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes
@@ -153,14 +154,18 @@ class Certificate:
Stores PEM with parsed certificate.
"""
def __init__(self, pem, cert):
def __init__(self, pem: str, cert: cryptography.x509.Certificate) -> None:
if not (pem.endswith("\n") or pem.endswith("\r")):
pem = pem + "\n"
self.pem = pem
self.cert = cert
def is_parent(module, cert, potential_parent):
def is_parent(
module: AnsibleModule,
cert: Certificate,
potential_parent: Certificate,
) -> bool:
"""
Tests whether the given certificate has been issued by the potential parent certificate.
"""
@@ -173,6 +178,10 @@ def is_parent(module, cert, potential_parent):
if isinstance(
public_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey
):
if cert.cert.signature_hash_algorithm is None:
raise AssertionError(
"signature_hash_algorithm should be present for RSA certificates"
)
public_key.verify(
cert.cert.signature,
cert.cert.tbs_certificate_bytes,
@@ -183,6 +192,10 @@ def is_parent(module, cert, potential_parent):
public_key,
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey,
):
if cert.cert.signature_hash_algorithm is None:
raise AssertionError(
"signature_hash_algorithm should be present for EC certificates"
)
public_key.verify(
cert.cert.signature,
cert.cert.tbs_certificate_bytes,
@@ -213,11 +226,16 @@ def is_parent(module, cert, potential_parent):
module.fail_json(msg=f"Unknown error on signature validation: {e}")
def parse_PEM_list(module, text, source, fail_on_error=True):
def parse_PEM_list(
module: AnsibleModule,
text: str,
source: str | os.PathLike,
fail_on_error: bool = True,
) -> list[Certificate]:
"""
Parse concatenated PEM certificates. Return list of ``Certificate`` objects.
"""
result = []
result: list[Certificate] = []
for cert_pem in split_pem_list(text):
# Try to load PEM certificate
try:
@@ -232,7 +250,9 @@ def parse_PEM_list(module, text, source, fail_on_error=True):
return result
def load_PEM_list(module, path, fail_on_error=True):
def load_PEM_list(
module: AnsibleModule, path: str | os.PathLike, fail_on_error: bool = True
) -> list[Certificate]:
"""
Load concatenated PEM certificates from file. Return list of ``Certificate`` objects.
"""
@@ -258,13 +278,15 @@ class CertificateSet:
Stores a set of certificates. Allows to search for parent (issuer of a certificate).
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.certificates = set()
self.certificates_by_issuer = dict()
self.certificate_by_cert = dict()
self.certificates: set[Certificate] = set()
self.certificates_by_issuer: dict[cryptography.x509.Name, list[Certificate]] = (
{}
)
self.certificate_by_cert: dict[cryptography.x509.Certificate, Certificate] = {}
def _load_file(self, path):
def _load_file(self, path: str | os.PathLike) -> None:
certs = load_PEM_list(self.module, path, fail_on_error=False)
for cert in certs:
self.certificates.add(cert)
@@ -273,7 +295,7 @@ class CertificateSet:
self.certificates_by_issuer[cert.cert.subject].append(cert)
self.certificate_by_cert[cert.cert] = cert
def load(self, path):
def load(self, path: str | os.PathLike) -> None:
"""
Load lists of PEM certificates from a file or a directory.
"""
@@ -285,7 +307,7 @@ class CertificateSet:
else:
self._load_file(b_path)
def find_parent(self, cert):
def find_parent(self, cert: Certificate) -> Certificate | None:
"""
Search for the parent (issuer) of a certificate. Return ``None`` if none was found.
"""
@@ -296,14 +318,18 @@ class CertificateSet:
return None
def format_cert(cert):
def format_cert(cert: Certificate) -> str:
"""
Return human readable representation of certificate for error messages.
"""
return str(cert.cert)
def check_cycle(module, occured_certificates, next):
def check_cycle(
module: AnsibleModule,
occured_certificates: set[cryptography.x509.Certificate],
next: Certificate,
) -> None:
"""
Make sure that next is not in occured_certificates so far, and add it.
"""
@@ -313,7 +339,7 @@ def check_cycle(module, occured_certificates, next):
occured_certificates.add(next_cert)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
input_chain=dict(type="str", required=True),
@@ -354,10 +380,10 @@ def main():
roots.load(path)
# Try to complete chain
current = chain[-1]
current: Certificate | None = chain[-1]
completed = []
occured_certificates = set([cert.cert for cert in chain])
if current.cert in roots.certificate_by_cert:
if current and current.cert in roots.certificate_by_cert:
# Do not try to complete the chain when it is already ending with a root certificate
current = None
while current:

View File

@@ -152,10 +152,13 @@ openssl:
"""
import traceback
import typing as t
from ansible.module_utils.basic import AnsibleModule
CRYPTOGRAPHY_VERSION: str | None
CRYPTOGRAPHY_IMP_ERR: str | None
try:
import cryptography
from cryptography.exceptions import UnsupportedAlgorithm
@@ -165,10 +168,10 @@ try:
# only got added in 0.2, so let's guard the import
from cryptography.exceptions import InternalError as CryptographyInternalError
except ImportError:
CryptographyInternalError = Exception
CryptographyInternalError = Exception # type: ignore
except ImportError:
UnsupportedAlgorithm = Exception
CryptographyInternalError = Exception
UnsupportedAlgorithm = Exception # type: ignore
CryptographyInternalError = Exception # type: ignore
HAS_CRYPTOGRAPHY = False
CRYPTOGRAPHY_VERSION = None
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -201,8 +204,8 @@ CURVES = (
)
def add_crypto_information(module):
result = {}
def add_crypto_information(module: AnsibleModule) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
result["python_cryptography_installed"] = HAS_CRYPTOGRAPHY
if not HAS_CRYPTOGRAPHY:
result["python_cryptography_import_error"] = CRYPTOGRAPHY_IMP_ERR
@@ -397,9 +400,9 @@ def add_crypto_information(module):
return result
def add_openssl_information(module):
def add_openssl_information(module: AnsibleModule) -> dict[str, t.Any]:
openssl_binary = module.get_bin_path("openssl")
result = {
result: dict[str, t.Any] = {
"openssl_present": openssl_binary is not None,
}
if openssl_binary is None:
@@ -426,9 +429,9 @@ INFO_FUNCTIONS = (
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(argument_spec={}, supports_check_mode=True)
result = {}
result: dict[str, t.Any] = {}
for fn in INFO_FUNCTIONS:
result.update(fn(module))
module.exit_json(**result)

View File

@@ -550,6 +550,7 @@ import datetime
import os
import re
import time
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes
@@ -572,7 +573,7 @@ from ansible_collections.community.crypto.plugins.module_utils.io import write_f
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
def validate_cert_expiry(cert_expiry):
def validate_cert_expiry(cert_expiry: str) -> bool:
search_string_partial = re.compile(
r"^([0-9]+)-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01])\Z"
)
@@ -587,7 +588,7 @@ def validate_cert_expiry(cert_expiry):
return False
def calculate_cert_days(expires_after):
def calculate_cert_days(expires_after: str | None) -> int:
cert_days = 0
if expires_after:
expires_after_datetime = datetime.datetime.strptime(
@@ -600,7 +601,9 @@ def calculate_cert_days(expires_after):
# Populate the value of body[dict_param_name] with the JSON equivalent of
# module parameter of param_name if that parameter is present, otherwise leave field
# out of resulting dict
def convert_module_param_to_json_bool(module, dict_param_name, param_name):
def convert_module_param_to_json_bool(
module: AnsibleModule, dict_param_name: str, param_name: str
) -> dict[str, str]:
body = {}
if module.params[param_name] is not None:
if module.params[param_name]:
@@ -886,7 +889,7 @@ class EcsCertificate:
return result
def custom_fields_spec():
def custom_fields_spec() -> dict[str, dict[str, str]]:
return dict(
text1=dict(type="str"),
text2=dict(type="str"),
@@ -926,7 +929,7 @@ def custom_fields_spec():
)
def ecs_certificate_argument_spec():
def ecs_certificate_argument_spec() -> dict[str, dict[str, t.Any]]:
return dict(
backup=dict(type="bool", default=False),
force=dict(type="bool", default=False),
@@ -979,7 +982,7 @@ def ecs_certificate_argument_spec():
)
def main():
def main() -> t.NoReturn:
ecs_argument_spec = ecs_client_argument_spec()
ecs_argument_spec.update(ecs_certificate_argument_spec())
module = AnsibleModule(

View File

@@ -218,6 +218,7 @@ ev_days_remaining:
import datetime
import time
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.ecs.api import (
@@ -228,7 +229,7 @@ from ansible_collections.community.crypto.plugins.module_utils.ecs.api import (
)
def calculate_days_remaining(expiry_date):
def calculate_days_remaining(expiry_date: str | None) -> int | None:
days_remaining = None
if expiry_date:
expiry_datetime = datetime.datetime.strptime(expiry_date, "%Y-%m-%dT%H:%M:%SZ")
@@ -403,8 +404,8 @@ class EcsDomain:
msg=f"Failed to request domain validation from Entrust (ECS) {e.message}"
)
def dump(self):
result = {
def dump(self) -> dict[str, t.Any]:
result: dict[str, t.Any] = {
"changed": self.changed,
"client_id": self.client_id,
"domain_status": self.domain_status,
@@ -436,7 +437,7 @@ class EcsDomain:
return result
def ecs_domain_argument_spec():
def ecs_domain_argument_spec() -> dict[str, dict[str, t.Any]]:
return dict(
client_id=dict(type="int", default=1),
domain_name=dict(type="str", required=True),
@@ -447,7 +448,7 @@ def ecs_domain_argument_spec():
)
def main():
def main() -> t.NoReturn:
ecs_argument_spec = ecs_client_argument_spec()
ecs_argument_spec.update(ecs_domain_argument_spec())
module = AnsibleModule(

View File

@@ -268,6 +268,7 @@ import atexit
import base64
import ssl
import sys
import typing as t
from os.path import isfile
from socket import create_connection, setdefaulttimeout, socket
from ssl import (
@@ -305,7 +306,7 @@ except ImportError:
pass
def send_starttls_packet(sock, server_type):
def send_starttls_packet(sock: socket, server_type: t.Literal["mysql"]) -> None:
if server_type == "mysql":
ssl_request_packet = (
b"\x20\x00\x00\x01\x85\xae\x7f\x00"
@@ -321,7 +322,7 @@ def send_starttls_packet(sock, server_type):
sock.send(ssl_request_packet)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
ca_cert=dict(type="path"),
@@ -342,18 +343,18 @@ def main():
),
)
ca_cert = module.params.get("ca_cert")
host = module.params.get("host")
port = module.params.get("port")
proxy_host = module.params.get("proxy_host")
proxy_port = module.params.get("proxy_port")
timeout = module.params.get("timeout")
server_name = module.params.get("server_name")
start_tls_server_type = module.params.get("starttls")
ciphers = module.params.get("ciphers")
asn1_base64 = module.params["asn1_base64"]
tls_ctx_options = module.params["tls_ctx_options"]
get_certificate_chain = module.params["get_certificate_chain"]
ca_cert: str | None = module.params.get("ca_cert")
host: str = module.params.get("host")
port: int = module.params.get("port")
proxy_host: str | None = module.params.get("proxy_host")
proxy_port: int | None = module.params.get("proxy_port")
timeout: int = module.params.get("timeout")
server_name: str | None = module.params.get("server_name")
start_tls_server_type: t.Literal["mysql"] | None = module.params.get("starttls")
ciphers: list[str] | None = module.params.get("ciphers")
asn1_base64: bool = module.params["asn1_base64"]
tls_ctx_options: list[str | bytes | int] | None = module.params["tls_ctx_options"]
get_certificate_chain: bool = module.params["get_certificate_chain"]
if get_certificate_chain and sys.version_info < (3, 10):
module.fail_json(
@@ -365,9 +366,9 @@ def main():
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
result = dict(
changed=False,
)
result: dict[str, t.Any] = {
"changed": False,
}
if timeout:
setdefaulttimeout(timeout)
@@ -409,7 +410,7 @@ def main():
if tls_ctx_options is not None:
# Clear default ctx options
ctx.options = 0
ctx.options = 0 # type: ignore
# For each item in the tls_ctx_options list
for tls_ctx_option in tls_ctx_options:
@@ -450,8 +451,10 @@ def main():
)
tls_sock = ctx.wrap_socket(sock, server_hostname=server_name or host)
cert = tls_sock.getpeercert(True)
cert = DER_cert_to_PEM_cert(cert)
cert_der = tls_sock.getpeercert(True)
if cert_der is None:
raise Exception("Unexpected error: no peer certificate has been returned")
cert: str = DER_cert_to_PEM_cert(cert_der)
if get_certificate_chain:
if sys.version_info < (3, 13):
@@ -474,7 +477,7 @@ def main():
# Python 3.13 do not return lists of byte strings, but lists of _ssl.Certificate objects. This is going to
# be fixed by https://github.com/python/cpython/pull/118669. For now we convert the certificates ourselves
# if they are not byte strings to work around this.
def _convert_chain(chain):
def _convert_chain(chain: list[bytes]) -> list[bytes]:
return [
(
c
@@ -514,13 +517,13 @@ def main():
result["extensions"] = []
for dotted_number, entry in cryptography_get_extensions_from_cert(x509).items():
oid = cryptography.x509.oid.ObjectIdentifier(dotted_number)
ext = {
ext: dict[str, t.Any] = {
"critical": entry["critical"],
"asn1_data": entry["value"],
"name": cryptography_oid_to_name(oid, short=True),
}
if not asn1_base64:
ext["asn1_data"] = base64.b64decode(ext["asn1_data"])
ext["asn1_data"] = base64.b64decode(entry["value"]) # type: ignore
result["extensions"].append(ext)
result["issuer"] = {}

View File

@@ -420,16 +420,13 @@ name:
import os
import re
import stat
import typing as t
from base64 import b64decode
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_native
RETURN_CODE = 0
STDOUT = 1
STDERR = 2
# used to get <luks-name> out of lsblk output in format 'crypt <luks-name>'
# regex takes care of any possible blank characters
LUKS_NAME_REGEX = re.compile(r"^crypt\s+([^\s]*)\s*$")
@@ -456,7 +453,7 @@ LUKS2_HEADER_OFFSETS = [
LUKS2_HEADER2 = b"SKUL\xba\xbe"
def wipe_luks_headers(device):
def wipe_luks_headers(device: str) -> None:
wipe_offsets = []
with open(device, "rb") as f:
# f.seek(0)
@@ -478,12 +475,12 @@ def wipe_luks_headers(device):
class Handler:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self._module = module
self._lsblk_bin = self._module.get_bin_path("lsblk", True)
self._passphrase_encoding = module.params["passphrase_encoding"]
def get_passphrase_from_module_params(self, parameter_name):
def get_passphrase_from_module_params(self, parameter_name: str) -> bytes | None:
passphrase = self._module.params[parameter_name]
if passphrase is None:
return None
@@ -496,88 +493,91 @@ class Handler:
f"Error while base64-decoding '{parameter_name}': {exc}"
)
def _run_command(self, command, data=None):
def _run_command(
self, command: list[str], data: bytes | None = None
) -> tuple[int, str, str]:
return self._module.run_command(command, data=data, binary_data=True)
def get_device_by_uuid(self, uuid):
def get_device_by_uuid(self, uuid: str | None) -> str | None:
"""Returns the device that holds UUID passed by user"""
self._blkid_bin = self._module.get_bin_path("blkid", True)
uuid = self._module.params["uuid"]
if uuid is None:
return None
result = self._run_command([self._blkid_bin, "--uuid", uuid])
if result[RETURN_CODE] != 0:
rc, stdout, dummy = self._run_command([self._blkid_bin, "--uuid", uuid])
if rc != 0:
return None
return result[STDOUT].strip()
return stdout.strip()
def get_device_by_label(self, label):
def get_device_by_label(self, label: str) -> str | None:
"""Returns the device that holds label passed by user"""
self._blkid_bin = self._module.get_bin_path("blkid", True)
label = self._module.params["label"]
if label is None:
return None
result = self._run_command([self._blkid_bin, "--label", label])
if result[RETURN_CODE] != 0:
rc, stdout, dummy = self._run_command([self._blkid_bin, "--label", label])
if rc != 0:
return None
return result[STDOUT].strip()
return stdout.strip()
def generate_luks_name(self, device):
def generate_luks_name(self, device: str) -> str:
"""Generate name for luks based on device UUID ('luks-<UUID>').
Raises ValueError when obtaining of UUID fails.
"""
result = self._run_command([self._lsblk_bin, "-n", device, "-o", "UUID"])
rc, stdout, stderr = self._run_command(
[self._lsblk_bin, "-n", device, "-o", "UUID"]
)
if result[RETURN_CODE] != 0:
raise ValueError(
f"Error while generating LUKS name for {device}: {result[STDERR]}"
)
dev_uuid = result[STDOUT].strip()
if rc != 0:
raise ValueError(f"Error while generating LUKS name for {device}: {stderr}")
dev_uuid = stdout.strip()
return f"luks-{dev_uuid}"
class CryptHandler(Handler):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(CryptHandler, self).__init__(module)
self._cryptsetup_bin = self._module.get_bin_path("cryptsetup", True)
def get_container_name_by_device(self, device):
def get_container_name_by_device(self, device: str) -> str | None:
"""obtain LUKS container name based on the device where it is located
return None if not found
raise ValueError if lsblk command fails
"""
result = self._run_command([self._lsblk_bin, device, "-nlo", "type,name"])
if result[RETURN_CODE] != 0:
raise ValueError(
f"Error while obtaining LUKS name for {device}: {result[STDERR]}"
)
rc, stdout, stderr = self._run_command(
[self._lsblk_bin, device, "-nlo", "type,name"]
)
if rc != 0:
raise ValueError(f"Error while obtaining LUKS name for {device}: {stderr}")
for line in result[STDOUT].splitlines(False):
for line in stdout.splitlines(False):
m = LUKS_NAME_REGEX.match(line)
if m:
return m.group(1)
return None
def get_container_device_by_name(self, name):
def get_container_device_by_name(self, name: str) -> str | None:
"""obtain device name based on the LUKS container name
return None if not found
raise ValueError if lsblk command fails
"""
# apparently each device can have only one LUKS container on it
result = self._run_command([self._cryptsetup_bin, "status", name])
if result[RETURN_CODE] != 0:
rc, stdout, dummy = self._run_command([self._cryptsetup_bin, "status", name])
if rc != 0:
return None
m = LUKS_DEVICE_REGEX.search(result[STDOUT])
m = LUKS_DEVICE_REGEX.search(stdout)
if not m:
return None
device = m.group(1)
return device
def is_luks(self, device):
def is_luks(self, device: str) -> bool:
"""check if the LUKS container does exist"""
result = self._run_command([self._cryptsetup_bin, "isLuks", device])
return result[RETURN_CODE] == 0
rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "isLuks", device])
return rc == 0
def get_luks_type(self, device):
def get_luks_type(self, device: str) -> t.Literal["luks1", "luks2"] | None:
"""get the luks type of a device"""
if self.is_luks(device):
with open(device, "rb") as f:
@@ -589,16 +589,18 @@ class CryptHandler(Handler):
return "luks1"
return None
def is_luks_slot_set(self, device, keyslot):
def is_luks_slot_set(self, device: str, keyslot: int) -> bool:
"""check if a keyslot is set"""
result = self._run_command([self._cryptsetup_bin, "luksDump", device])
if result[RETURN_CODE] != 0:
rc, stdout, dummy = self._run_command(
[self._cryptsetup_bin, "luksDump", device]
)
if rc != 0:
raise ValueError(f"Error while dumping LUKS header from {device}")
result_luks1 = f"Key Slot {keyslot}: ENABLED" in result[STDOUT]
result_luks2 = f" {keyslot}: luks2" in result[STDOUT]
result_luks1 = f"Key Slot {keyslot}: ENABLED" in stdout
result_luks2 = f" {keyslot}: luks2" in stdout
return result_luks1 or result_luks2
def _add_pbkdf_options(self, options, pbkdf):
def _add_pbkdf_options(self, options: list[str], pbkdf: dict[str, t.Any]) -> None:
if pbkdf["iteration_time"] is not None:
options.extend(["--iter-time", str(int(pbkdf["iteration_time"] * 1000))])
if pbkdf["iteration_count"] is not None:
@@ -612,16 +614,16 @@ class CryptHandler(Handler):
def run_luks_create(
self,
device,
keyfile,
passphrase,
keyslot,
keysize,
cipher,
hash_,
sector_size,
pbkdf,
):
device: str,
keyfile: str | None,
passphrase: bytes | None,
keyslot: int | None,
keysize: int | None,
cipher: str | None,
hash_: str | None,
sector_size: str | None,
pbkdf: dict[str, t.Any] | None,
) -> None:
# create a new luks container; use batch mode to auto confirm
luks_type = self._module.params["type"]
label = self._module.params["label"]
@@ -653,23 +655,23 @@ class CryptHandler(Handler):
else:
args.append("-")
result = self._run_command(args, data=passphrase)
if result[RETURN_CODE] != 0:
raise ValueError(f"Error while creating LUKS on {device}: {result[STDERR]}")
rc, dummy, stderr = self._run_command(args, data=passphrase)
if rc != 0:
raise ValueError(f"Error while creating LUKS on {device}: {stderr}")
def run_luks_open(
self,
device,
keyfile,
passphrase,
perf_same_cpu_crypt,
perf_submit_from_crypt_cpus,
perf_no_read_workqueue,
perf_no_write_workqueue,
persistent,
allow_discards,
name,
):
device: str,
keyfile: str | None,
passphrase: bytes | None,
perf_same_cpu_crypt: bool,
perf_submit_from_crypt_cpus: bool,
perf_no_read_workqueue: bool,
perf_no_write_workqueue: bool,
persistent: bool,
allow_discards: bool,
name: str,
) -> None:
args = [self._cryptsetup_bin]
if keyfile:
args.extend(["--key-file", keyfile])
@@ -689,27 +691,27 @@ class CryptHandler(Handler):
args.extend(["--allow-discards"])
args.extend(["open", "--type", "luks", device, name])
result = self._run_command(args, data=passphrase)
if result[RETURN_CODE] != 0:
rc, dummy, stderr = self._run_command(args, data=passphrase)
if rc != 0:
raise ValueError(
f"Error while opening LUKS container on {device}: {result[STDERR]}"
f"Error while opening LUKS container on {device}: {stderr}"
)
def run_luks_close(self, name):
result = self._run_command([self._cryptsetup_bin, "close", name])
if result[RETURN_CODE] != 0:
def run_luks_close(self, name: str) -> None:
rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "close", name])
if rc != 0:
raise ValueError(f"Error while closing LUKS container {name}")
def run_luks_remove(self, device):
def run_luks_remove(self, device: str) -> None:
wipefs_bin = self._module.get_bin_path("wipefs", True)
name = self.get_container_name_by_device(device)
if name is not None:
self.run_luks_close(name)
result = self._run_command([wipefs_bin, "--all", device])
if result[RETURN_CODE] != 0:
rc, dummy, stderr = self._run_command([wipefs_bin, "--all", device])
if rc != 0:
raise ValueError(
f"Error while wiping LUKS container signatures for {device}: {result[STDERR]}"
f"Error while wiping LUKS container signatures for {device}: {stderr}"
)
# For LUKS2, sometimes both `cryptsetup erase` and `wipefs` do **not**
@@ -724,14 +726,14 @@ class CryptHandler(Handler):
def run_luks_add_key(
self,
device,
keyfile,
passphrase,
new_keyfile,
new_passphrase,
new_keyslot,
pbkdf,
):
device: str,
keyfile: str | None,
passphrase: bytes | None,
new_keyfile: str | None,
new_passphrase: bytes | None,
new_keyslot: int | None,
pbkdf: dict[str, t.Any] | None,
) -> None:
"""Add new key from a keyfile or passphrase to given 'device';
authentication done using 'keyfile' or 'passphrase'.
Raises ValueError when command fails.
@@ -746,36 +748,47 @@ class CryptHandler(Handler):
if keyfile:
args.extend(["--key-file", keyfile])
else:
elif passphrase is not None:
args.extend(["--key-file", "-", "--keyfile-size", str(len(passphrase))])
data.append(passphrase)
else:
raise ValueError("Need passphrase or keyfile")
if new_keyfile:
args.append(new_keyfile)
else:
elif new_passphrase is not None:
args.append("-")
data.append(new_passphrase)
else:
raise ValueError("Need new passphrase or new keyfile")
result = self._run_command(args, data=b"".join(data) or None)
if result[RETURN_CODE] != 0:
rc, dummy, stderr = self._run_command(args, data=b"".join(data) or None)
if rc != 0:
raise ValueError(
f"Error while adding new LUKS keyslot to {device}: {result[STDERR]}"
f"Error while adding new LUKS keyslot to {device}: {stderr}"
)
def run_luks_remove_key(
self, device, keyfile, passphrase, keyslot, force_remove_last_key=False
):
self,
device: str,
keyfile: str | None,
passphrase: bytes | None,
keyslot: int | None,
force_remove_last_key: bool = False,
) -> None:
"""Remove key from given device
Raises ValueError when command fails
"""
if not force_remove_last_key:
result = self._run_command([self._cryptsetup_bin, "luksDump", device])
if result[RETURN_CODE] != 0:
rc, stdout, dummy = self._run_command(
[self._cryptsetup_bin, "luksDump", device]
)
if rc != 0:
raise ValueError(f"Error while dumping LUKS header from {device}")
keyslot_count = 0
keyslot_area = False
keyslot_re = re.compile(r"^Key Slot [0-9]+: ENABLED")
for line in result[STDOUT].splitlines():
for line in stdout.splitlines():
if line.startswith("Keyslots:"):
keyslot_area = True
elif line.startswith(" "):
@@ -808,13 +821,17 @@ class CryptHandler(Handler):
# Since we supply -q no passphrase is needed
args = [self._cryptsetup_bin, "luksKillSlot", device, "-q", str(keyslot)]
passphrase = None
result = self._run_command(args, data=passphrase)
if result[RETURN_CODE] != 0:
raise ValueError(
f"Error while removing LUKS key from {device}: {result[STDERR]}"
)
rc, dummy, stderr = self._run_command(args, data=passphrase)
if rc != 0:
raise ValueError(f"Error while removing LUKS key from {device}: {stderr}")
def luks_test_key(self, device, keyfile, passphrase, keyslot=None):
def luks_test_key(
self,
device: str,
keyfile: str | None,
passphrase: bytes | None,
keyslot: int | None = None,
) -> bool:
"""Check whether the keyfile or passphrase works.
Raises ValueError when command fails.
"""
@@ -830,42 +847,37 @@ class CryptHandler(Handler):
if keyslot is not None:
args.extend(["--key-slot", str(keyslot)])
result = self._run_command(args, data=data)
if result[RETURN_CODE] == 0:
rc, stdout, stderr = self._run_command(args, data=data)
if rc == 0:
return True
for output in (STDOUT, STDERR):
if "No key available with this passphrase" in result[output]:
for output in (stdout, stderr):
if "No key available with this passphrase" in output:
return False
if "No usable keyslot is available." in result[output]:
if "No usable keyslot is available." in output:
return False
# This check is necessary due to cryptsetup in version 2.0.3 not printing 'No usable keyslot is available'
# when using the --key-slot parameter in combination with --test-passphrase
if (
result[RETURN_CODE] == 1
and keyslot is not None
and result[STDOUT] == ""
and result[STDERR] == ""
):
if rc == 1 and keyslot is not None and stdout == "" and stderr == "":
return False
raise ValueError(
f"Error while testing whether keyslot exists on {device}: {result[STDERR]}"
f"Error while testing whether keyslot exists on {device}: {stderr}"
)
class ConditionsHandler(Handler):
def __init__(self, module, crypthandler):
def __init__(self, module: AnsibleModule, crypthandler: CryptHandler) -> None:
super(ConditionsHandler, self).__init__(module)
self._crypthandler = crypthandler
self.device = self.get_device_name()
def get_device_name(self):
device = self._module.params.get("device")
label = self._module.params.get("label")
uuid = self._module.params.get("uuid")
name = self._module.params.get("name")
def get_device_name(self) -> str | None:
device: str | None = self._module.params.get("device")
label: str | None = self._module.params.get("label")
uuid: str | None = self._module.params.get("uuid")
name: str | None = self._module.params.get("name")
if device is None and label is not None:
device = self.get_device_by_label(label)
@@ -876,7 +888,7 @@ class ConditionsHandler(Handler):
return device
def luks_create(self):
def luks_create(self) -> bool:
return (
self.device is not None
and (
@@ -887,7 +899,7 @@ class ConditionsHandler(Handler):
and not self._crypthandler.is_luks(self.device)
)
def opened_luks_name(self):
def opened_luks_name(self, device: str) -> str | None:
"""If luks is already opened, return its name.
If 'name' parameter is specified and differs
from obtained value, fail.
@@ -897,7 +909,7 @@ class ConditionsHandler(Handler):
return None
# try to obtain luks name - it may be already opened
name = self._crypthandler.get_container_name_by_device(self.device)
name = self._crypthandler.get_container_name_by_device(device)
if name is None:
# container is not open
@@ -917,7 +929,7 @@ class ConditionsHandler(Handler):
# container is opened and the names match
return name
def luks_open(self):
def luks_open(self) -> bool:
if (
(
self._module.params["keyfile"] is None
@@ -929,13 +941,13 @@ class ConditionsHandler(Handler):
# conditions for open not fulfilled
return False
name = self.opened_luks_name()
name = self.opened_luks_name(self.device)
if name is None:
return True
return False
def luks_close(self):
def luks_close(self) -> bool:
if (
self._module.params["name"] is None and self.device is None
) or self._module.params["state"] != "closed":
@@ -948,15 +960,17 @@ class ConditionsHandler(Handler):
luks_is_open = name is not None
if self._module.params["name"] is not None:
self.device = self._crypthandler.get_container_device_by_name(
device = self._crypthandler.get_container_device_by_name(
self._module.params["name"]
)
# successfully getting device based on name means that luks is open
luks_is_open = self.device is not None
luks_is_open = device is not None
if device is not None:
self.device = device
return luks_is_open
def luks_add_key(self):
def luks_add_key(self) -> bool:
if (
self.device is None
or (
@@ -995,7 +1009,7 @@ class ConditionsHandler(Handler):
return not key_present
def luks_remove_key(self):
def luks_remove_key(self) -> bool:
if self.device is None or (
self._module.params["remove_keyfile"] is None
and self._module.params["remove_passphrase"] is None
@@ -1037,14 +1051,16 @@ class ConditionsHandler(Handler):
self.get_passphrase_from_module_params("remove_passphrase"),
)
def luks_remove(self):
def luks_remove(self) -> bool:
return (
self.device is not None
and self._module.params["state"] == "absent"
and self._crypthandler.is_luks(self.device)
)
def validate_keyslot(self, param, luks_type):
def validate_keyslot(
self, param: str, luks_type: t.Literal["luks1", "luks2"] | None
) -> None:
if self._module.params[param] is not None:
if luks_type is None and param == "keyslot":
if 8 <= self._module.params[param] <= 31:
@@ -1066,7 +1082,7 @@ class ConditionsHandler(Handler):
)
def run_module():
def run_module() -> t.NoReturn:
# available arguments/parameters that a user can pass
module_args = dict(
state=dict(
@@ -1122,7 +1138,7 @@ def run_module():
]
# seed the result dict in the object
result = dict(changed=False, name=None)
result: dict[str, t.Any] = {"changed": False, "name": None}
module = AnsibleModule(
argument_spec=module_args,
@@ -1142,19 +1158,26 @@ def run_module():
except Exception as e:
module.fail_json(msg=str(e))
crypt = CryptHandler(module)
conditions = ConditionsHandler(module, crypt)
# conditions not allowed to run
if module.params["label"] is not None and module.params["type"] == "luks1":
module.fail_json(msg="You cannot combine type luks1 with the label option.")
crypt = CryptHandler(module)
try:
conditions = ConditionsHandler(module, crypt)
except ValueError as exc:
module.fail_json(msg=str(exc))
if (
module.params["keyslot"] is not None
or module.params["new_keyslot"] is not None
or module.params["remove_keyslot"] is not None
):
luks_type = crypt.get_luks_type(conditions.get_device_name())
luks_type = (
crypt.get_luks_type(conditions.device)
if conditions.device is not None
else None
)
if luks_type is None and module.params["type"] is not None:
luks_type = module.params["type"]
for param in ["keyslot", "new_keyslot", "remove_keyslot"]:
@@ -1175,6 +1198,7 @@ def run_module():
# luks create
if conditions.luks_create():
assert conditions.device # ensured in conditions.luks_create()
if not module.check_mode:
try:
crypt.run_luks_create(
@@ -1196,11 +1220,13 @@ def run_module():
# luks open
name = conditions.opened_luks_name()
if name is not None:
result["name"] = name
if conditions.device is not None:
name = conditions.opened_luks_name(conditions.device)
if name is not None:
result["name"] = name
if conditions.luks_open():
assert conditions.device # ensured in conditions.luks_open()
name = module.params["name"]
if name is None:
try:
@@ -1237,6 +1263,8 @@ def run_module():
module.fail_json(msg=f"luks_device error: {e}")
else:
name = module.params["name"]
if name is None:
module.fail_json(msg="Cannot determine name to close device")
if not module.check_mode:
try:
crypt.run_luks_close(name)
@@ -1249,6 +1277,7 @@ def run_module():
# luks add key
if conditions.luks_add_key():
assert conditions.device # ensured in conditions.luks_add_key()
if not module.check_mode:
try:
crypt.run_luks_add_key(
@@ -1268,6 +1297,7 @@ def run_module():
# luks remove key
if conditions.luks_remove_key():
assert conditions.device # ensured in conditions.luks_remove_key()
if not module.check_mode:
try:
last_key = module.params["force_remove_last_key"]
@@ -1286,6 +1316,7 @@ def run_module():
# luks remove
if conditions.luks_remove():
assert conditions.device # ensured in conditions.luks_remove()
if not module.check_mode:
try:
crypt.run_luks_remove(conditions.device)
@@ -1299,7 +1330,7 @@ def run_module():
module.exit_json(**result)
def main():
def main() -> t.NoReturn:
run_module()

View File

@@ -284,6 +284,7 @@ info:
"""
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.common import (
@@ -302,47 +303,58 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
class Certificate(OpensshModule):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(Certificate, self).__init__(module)
self.ssh_keygen = KeygenCommand(self.module)
self.identifier = self.module.params["identifier"] or ""
self.options = self.module.params["options"] or []
self.path = self.module.params["path"]
self.pkcs11_provider = self.module.params["pkcs11_provider"]
self.principals = self.module.params["principals"] or []
self.public_key = self.module.params["public_key"]
self.regenerate = (
self.identifier: str = self.module.params["identifier"] or ""
self.options: list[str] = self.module.params["options"] or []
self.path: str = self.module.params["path"]
self.pkcs11_provider: str | None = self.module.params["pkcs11_provider"]
self.principals: list[str] = self.module.params["principals"] or []
self.public_key: str | None = self.module.params["public_key"]
self.regenerate: t.Literal[
"never",
"fail",
"partial_idempotence",
"full_idempotence",
"always",
] = (
self.module.params["regenerate"]
if not self.module.params["force"]
else "always"
)
self.serial_number = self.module.params["serial_number"]
self.signature_algorithm = self.module.params["signature_algorithm"]
self.signing_key = self.module.params["signing_key"]
self.state = self.module.params["state"]
self.type = self.module.params["type"]
self.use_agent = self.module.params["use_agent"]
self.valid_at = self.module.params["valid_at"]
self.ignore_timestamps = self.module.params["ignore_timestamps"]
self.serial_number: int | None = self.module.params["serial_number"]
self.signature_algorithm: (
t.Literal["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"] | None
) = self.module.params["signature_algorithm"]
self.signing_key: str | None = self.module.params["signing_key"]
self.state: t.Literal["absent", "present"] = self.module.params["state"]
self.type: t.Literal["host", "user"] | None = self.module.params["type"]
self.use_agent: bool = self.module.params["use_agent"]
self.valid_at: str | None = self.module.params["valid_at"]
self.ignore_timestamps: bool = self.module.params["ignore_timestamps"]
self._check_if_base_dir(self.path)
if self.state == "present":
self._validate_parameters()
self.data = None
self.original_data = None
self.data: OpensshCertificate | None = None
self.original_data: OpensshCertificate | None = None
if self._exists():
self._load_certificate()
self.time_parameters = None
self.time_parameters: OpensshCertificateTimeParameters | None = None
if self.state == "present":
self._set_time_parameters()
def _validate_parameters(self):
def _validate_parameters(self) -> None:
for path in (self.public_key, self.signing_key):
self._check_if_base_dir(path)
if (
path is not None
): # should never be None, but the type checker doesn't know
self._check_if_base_dir(path)
if self.options and self.type == "host":
self.module.fail_json(
@@ -352,7 +364,7 @@ class Certificate(OpensshModule):
if self.use_agent:
self._use_agent_available()
def _use_agent_available(self):
def _use_agent_available(self) -> None:
ssh_version = self._get_ssh_version()
if not ssh_version:
self.module.fail_json(msg="Failed to determine ssh version")
@@ -362,10 +374,10 @@ class Certificate(OpensshModule):
+ f" Your version is: {ssh_version}"
)
def _exists(self):
def _exists(self) -> bool:
return os.path.exists(self.path)
def _load_certificate(self):
def _load_certificate(self) -> None:
try:
self.original_data = OpensshCertificate.load(self.path)
except (TypeError, ValueError) as e:
@@ -373,7 +385,7 @@ class Certificate(OpensshModule):
self.module.fail_json(msg=f"Unable to read existing certificate: {e}")
self.module.warn(f"Unable to read existing certificate: {e}")
def _set_time_parameters(self):
def _set_time_parameters(self) -> None:
try:
self.time_parameters = OpensshCertificateTimeParameters(
valid_from=self.module.params["valid_from"],
@@ -382,7 +394,7 @@ class Certificate(OpensshModule):
except ValueError as e:
self.module.fail_json(msg=str(e))
def _execute(self):
def _execute(self) -> None:
if self.state == "present":
if self._should_generate():
self._generate()
@@ -391,7 +403,7 @@ class Certificate(OpensshModule):
if self._exists():
self._remove()
def _should_generate(self):
def _should_generate(self) -> bool:
if self.regenerate == "never":
return self.original_data is None
elif self.regenerate == "fail":
@@ -408,7 +420,13 @@ class Certificate(OpensshModule):
else:
return True
def _is_fully_valid(self):
def _is_fully_valid(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
if self.public_key is None:
raise AssertionError("Contract violation public_key not provided")
if self.signing_key is None:
raise AssertionError("Contract violation signing_key not provided")
return self._is_partially_valid() and all(
[
self._compare_options() if self.original_data.type == "user" else True,
@@ -420,7 +438,9 @@ class Certificate(OpensshModule):
]
)
def _is_partially_valid(self):
def _is_partially_valid(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
return all(
[
set(self.original_data.principals) == set(self.principals),
@@ -439,7 +459,9 @@ class Certificate(OpensshModule):
]
)
def _compare_time_parameters(self):
def _compare_time_parameters(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
try:
original_time_parameters = OpensshCertificateTimeParameters(
valid_from=self.original_data.valid_after,
@@ -458,7 +480,9 @@ class Certificate(OpensshModule):
]
)
def _compare_options(self):
def _compare_options(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
try:
critical_options, extensions = parse_option_list(self.options)
except ValueError as e:
@@ -471,13 +495,13 @@ class Certificate(OpensshModule):
]
)
def _get_key_fingerprint(self, path):
def _get_key_fingerprint(self, path: str) -> str:
private_key_content = self.ssh_keygen.get_private_key(path, check_rc=True)[1]
return PrivateKey.from_string(private_key_content).fingerprint
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _generate(self):
def _generate(self) -> None:
try:
temp_certificate = self._generate_temp_certificate()
self._safe_secure_move([(temp_certificate, self.path)])
@@ -491,7 +515,14 @@ class Certificate(OpensshModule):
except (TypeError, ValueError) as e:
self.module.fail_json(msg=f"Unable to read new certificate: {e}")
def _generate_temp_certificate(self):
def _generate_temp_certificate(self) -> str:
if self.public_key is None:
raise AssertionError("Contract violation public_key not provided")
if self.signing_key is None:
raise AssertionError("Contract violation signing_key not provided")
if self.time_parameters is None:
raise AssertionError("Contract violation time_parameters not provided")
key_copy = os.path.join(self.module.tmpdir, os.path.basename(self.public_key))
try:
@@ -523,14 +554,14 @@ class Certificate(OpensshModule):
@OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode
def _remove(self):
def _remove(self) -> None:
try:
os.remove(self.path)
except OSError as e:
self.module.fail_json(msg=f"Unable to remove existing certificate: {e}")
@property
def _result(self):
def _result(self) -> dict[str, t.Any]:
if self.state != "present":
return {}
@@ -546,14 +577,14 @@ class Certificate(OpensshModule):
}
@property
def diff(self):
def diff(self) -> dict[str, t.Any]:
return {
"before": get_cert_dict(self.original_data),
"after": get_cert_dict(self.data),
}
def format_cert_info(cert_info):
def format_cert_info(cert_info: str) -> list[str]:
result = []
string = ""
@@ -579,7 +610,7 @@ def format_cert_info(cert_info):
return result
def get_cert_dict(data):
def get_cert_dict(data: OpensshCertificate | None) -> dict[str, t.Any]:
if data is None:
return {}
@@ -590,7 +621,7 @@ def get_cert_dict(data):
return result
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
force=dict(type="bool", default=False),

View File

@@ -198,13 +198,15 @@ comment:
sample: test@comment
"""
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.keypair_backend import (
select_backend,
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(

View File

@@ -239,6 +239,7 @@ csr:
"""
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
@@ -256,9 +257,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import (
CertificateSigningRequestBackend,
)
class CertificateSigningRequestModule(OpenSSLObject):
def __init__(self, module, module_backend):
def __init__(
self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend
) -> None:
super(CertificateSigningRequestModule, self).__init__(
module.params["path"],
module.params["state"],
@@ -269,11 +279,11 @@ class CertificateSigningRequestModule(OpenSSLObject):
self.return_content = module.params["return_content"]
self.backup = module.params["backup"]
self.backup_file = None
self.backup_file: str | None = None
self.module_backend.set_existing(load_file_if_exists(self.path, module))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Generate the certificate signing request."""
if self.force or self.module_backend.needs_regeneration():
if not self.check_mode:
@@ -292,13 +302,13 @@ class CertificateSigningRequestModule(OpenSSLObject):
file_args, self.changed
)
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
self.module_backend.set_existing(None)
if self.backup and not self.check_mode:
self.backup_file = module.backup_local(self.path)
super(CertificateSigningRequestModule, self).remove(module)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = self.module_backend.dump(include_csr=self.return_content)
result.update(
@@ -312,7 +322,7 @@ class CertificateSigningRequestModule(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_csr_argument_spec()
argument_spec.argument_spec.update(
dict(

View File

@@ -308,6 +308,7 @@ authority_cert_serial_number:
sample: 12345
"""
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -318,7 +319,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path"),
@@ -335,11 +336,15 @@ def main():
supports_check_mode=True,
)
if module.params["content"] is not None:
data = module.params["content"].encode("utf-8")
content: str | None = module.params["content"]
path: str | None = module.params["path"]
if content is not None:
data = content.encode("utf-8")
else:
if path is None:
module.fail_json(msg="One of content and path must be provided")
try:
with open(module.params["path"], "rb") as f:
with open(path, "rb") as f:
data = f.read()
except (IOError, OSError) as e:
module.fail_json(msg=f"Error while reading CSR file from disk: {e}")

View File

@@ -127,6 +127,8 @@ csr:
type: str
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
)
@@ -136,8 +138,17 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import (
CertificateSigningRequestBackend,
)
class CertificateSigningRequestModule:
def __init__(self, module, module_backend):
def __init__(
self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend
) -> None:
self.check_mode = module.check_mode
self.module = module
self.module_backend = module_backend
@@ -145,13 +156,13 @@ class CertificateSigningRequestModule:
if module.params["content"] is not None:
self.module_backend.set_existing(module.params["content"].encode("utf-8"))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Generate the certificate signing request."""
if self.module_backend.needs_regeneration():
self.module_backend.generate_csr()
self.changed = True
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = self.module_backend.dump(include_csr=True)
result.update(
@@ -162,7 +173,7 @@ class CertificateSigningRequestModule:
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_csr_argument_spec()
argument_spec.argument_spec.update(
dict(

View File

@@ -132,6 +132,7 @@ import abc
import os
import re
import tempfile
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_native
@@ -173,23 +174,23 @@ class DHParameterError(Exception):
class DHParameterBase:
def __init__(self, module):
self.state = module.params["state"]
self.path = module.params["path"]
self.size = module.params["size"]
self.force = module.params["force"]
def __init__(self, module: AnsibleModule) -> None:
self.state: t.Literal["absent", "present"] = module.params["state"]
self.path: str = module.params["path"]
self.size: int = module.params["size"]
self.force: bool = module.params["force"]
self.changed = False
self.return_content = module.params["return_content"]
self.return_content: bool = module.params["return_content"]
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
@abc.abstractmethod
def _do_generate(self, module):
def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params."""
pass
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Generate DH params."""
changed = False
@@ -206,7 +207,7 @@ class DHParameterBase:
self.changed = changed
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
if self.backup:
self.backup_file = module.backup_local(self.path)
try:
@@ -215,28 +216,27 @@ class DHParameterBase:
except OSError as exc:
module.fail_json(msg=str(exc))
def check(self, module):
def check(self, module: AnsibleModule) -> bool:
"""Ensure the resource is in its desired state."""
if self.force:
return False
return self._check_params_valid(module) and self._check_fs_attributes(module)
@abc.abstractmethod
def _check_params_valid(self, module):
def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state"""
pass
def _check_fs_attributes(self, module):
def _check_fs_attributes(self, module: AnsibleModule) -> bool:
"""Checks (and changes if not in check mode!) fs attributes"""
file_args = module.load_file_common_arguments(module.params)
if module.check_file_absent_if_check_mode(file_args["path"]):
return False
return not module.set_fs_attributes_if_different(file_args, False)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = {
result: dict[str, t.Any] = {
"size": self.size,
"filename": self.path,
"changed": self.changed,
@@ -252,25 +252,24 @@ class DHParameterBase:
class DHParameterAbsent(DHParameterBase):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(DHParameterAbsent, self).__init__(module)
def _do_generate(self, module):
def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params."""
pass
def _check_params_valid(self, module):
def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state"""
pass
return False
class DHParameterOpenSSL(DHParameterBase):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(DHParameterOpenSSL, self).__init__(module)
self.openssl_bin = module.get_bin_path("openssl", True)
def _do_generate(self, module):
def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params."""
# create a tempfile
fd, tmpsrc = tempfile.mkstemp()
@@ -288,7 +287,7 @@ class DHParameterOpenSSL(DHParameterBase):
except Exception as e:
module.fail_json(msg=f"Failed to write to file {self.path}: {str(e)}")
def _check_params_valid(self, module):
def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state"""
command = [
self.openssl_bin,
@@ -321,10 +320,10 @@ class DHParameterOpenSSL(DHParameterBase):
class DHParameterCryptography(DHParameterBase):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(DHParameterCryptography, self).__init__(module)
def _do_generate(self, module):
def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params."""
# Generate parameters
params = cryptography.hazmat.primitives.asymmetric.dh.generate_parameters(
@@ -341,7 +340,7 @@ class DHParameterCryptography(DHParameterBase):
self.backup_file = module.backup_local(self.path)
write_file(module, result)
def _check_params_valid(self, module):
def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state"""
# Load parameters
try:
@@ -357,7 +356,7 @@ class DHParameterCryptography(DHParameterBase):
return bits == self.size
def main():
def main() -> t.NoReturn:
"""Main function"""
module = AnsibleModule(
@@ -383,6 +382,7 @@ def main():
msg=f"The directory '{base_dir}' does not exist or the file is not a directory",
)
dhparam: DHParameterOpenSSL | DHParameterCryptography | DHParameterAbsent
if module.params["state"] == "present":
backend = module.params["select_crypto_backend"]
if backend == "auto":

View File

@@ -280,6 +280,7 @@ import itertools
import os
import stat
import traceback
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_native
@@ -296,7 +297,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
OpenSSLObject,
load_certificate,
load_privatekey,
load_certificate_issuer_privatekey,
)
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -320,6 +321,7 @@ except ImportError:
CRYPTOGRAPHY_COMPATIBILITY2022_ERR = None
try:
import cryptography.x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization.pkcs12 import PBES
@@ -333,8 +335,22 @@ except Exception:
else:
CRYPTOGRAPHY_HAS_COMPATIBILITY2022 = True
if t.TYPE_CHECKING:
from ..module_utils.crypto.cryptography_support import (
CertificateIssuerPrivateKeyTypes,
)
def load_certificate_set(filename):
PKCS12 = tuple[
t.Union[CertificateIssuerPrivateKeyTypes, None],
t.Union[cryptography.x509.Certificate, None],
list[cryptography.x509.Certificate],
t.Union[bytes, None],
]
def load_certificate_set(
filename: str | os.PathLike,
) -> list[cryptography.x509.Certificate]:
"""
Load list of concatenated PEM files, and return a list of parsed certificates.
"""
@@ -351,70 +367,80 @@ class PkcsError(OpenSSLObjectError):
class Pkcs(OpenSSLObject):
def __init__(self, module, iter_size_default=2048):
path: str
def __init__(self, module: AnsibleModule, iter_size_default: int = 2048) -> None:
super(Pkcs, self).__init__(
module.params["path"],
module.params["state"],
module.params["force"],
module.check_mode,
)
self.action = module.params["action"]
self.other_certificates = module.params["other_certificates"]
self.other_certificates_parse_all = module.params[
self.action: t.Literal["export", "parse"] = module.params["action"]
self.other_certificates: list[cryptography.x509.Certificate] = []
self.other_certificates_str: list[str] | None = module.params[
"other_certificates"
]
self.other_certificates_parse_all: bool = module.params[
"other_certificates_parse_all"
]
self.other_certificates_content = module.params["other_certificates_content"]
self.certificate_path = module.params["certificate_path"]
self.certificate_content = module.params["certificate_content"]
self.friendly_name = module.params["friendly_name"]
self.iter_size = module.params["iter_size"] or iter_size_default
self.maciter_size = module.params["maciter_size"] or 1
self.encryption_level = module.params["encryption_level"]
self.other_certificates_content: list[str] | None = module.params[
"other_certificates_content"
]
self.certificate_path: str | None = module.params["certificate_path"]
certificate_content: str | None = module.params["certificate_content"]
self.friendly_name: str | None = module.params["friendly_name"]
self.iter_size: int = module.params["iter_size"] or iter_size_default
self.maciter_size: int = module.params["maciter_size"] or 1
self.encryption_level: t.Literal["auto", "compatibility2022"] = module.params[
"encryption_level"
]
self.passphrase = module.params["passphrase"]
self.pkcs12 = None
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
self.pkcs12_bytes = None
self.return_content = module.params["return_content"]
self.src = module.params["src"]
self.pkcs12: PKCS12 | None = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
self.pkcs12_bytes: bytes | None = None
self.return_content: bool = module.params["return_content"]
self.src: str | None = module.params["src"]
if module.params["mode"] is None:
module.params["mode"] = "0400"
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
self.certificate_content: bytes | None = None
if self.certificate_path is not None:
try:
with open(self.certificate_path, "rb") as fh:
self.certificate_content = fh.read()
except (IOError, OSError) as exc:
raise PkcsError(exc)
elif self.certificate_content is not None:
self.certificate_content = to_bytes(self.certificate_content)
elif certificate_content is not None:
self.certificate_content = to_bytes(certificate_content)
self.privatekey_content: bytes | None = None
if self.privatekey_path is not None:
try:
with open(self.privatekey_path, "rb") as fh:
self.privatekey_content = fh.read()
except (IOError, OSError) as exc:
raise PkcsError(exc)
elif self.privatekey_content is not None:
self.privatekey_content = to_bytes(self.privatekey_content)
elif privatekey_content is not None:
self.privatekey_content = to_bytes(privatekey_content)
if self.other_certificates:
if self.other_certificates_str:
if self.other_certificates_parse_all:
filenames = list(self.other_certificates)
self.other_certificates = []
for other_cert_bundle in filenames:
for other_cert_bundle in self.other_certificates_str:
self.other_certificates.extend(
load_certificate_set(other_cert_bundle)
)
else:
self.other_certificates = [
load_certificate(other_cert)
for other_cert in self.other_certificates
for other_cert in self.other_certificates_str
]
elif self.other_certificates_content:
certs = self.other_certificates_content
@@ -430,40 +456,42 @@ class Pkcs(OpenSSLObject):
]
@abc.abstractmethod
def generate_bytes(self, module):
def generate_bytes(self, module: AnsibleModule) -> bytes:
"""Generate PKCS#12 file archive."""
@abc.abstractmethod
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
pass
@abc.abstractmethod
def parse_bytes(self, pkcs12_content):
def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
pass
@abc.abstractmethod
def _dump_privatekey(self, pkcs12):
def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
pass
@abc.abstractmethod
def _dump_certificate(self, pkcs12):
def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
pass
@abc.abstractmethod
def _dump_other_certificates(self, pkcs12):
def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
pass
@abc.abstractmethod
def _get_friendly_name(self, pkcs12):
pass
def check(self, module, perms_required=True):
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
state_and_perms = super(Pkcs, self).check(module, perms_required)
def _check_pkey_passphrase():
def _check_pkey_passphrase() -> bool:
if self.privatekey_passphrase:
try:
load_privatekey(
None,
load_certificate_issuer_privatekey(
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
)
@@ -476,6 +504,7 @@ class Pkcs(OpenSSLObject):
if os.path.exists(self.path) and module.params["action"] == "export":
self.generate_bytes(module) # ignore result
assert self.pkcs12 is not None
self.src = self.path
try:
(
@@ -524,7 +553,7 @@ class Pkcs(OpenSSLObject):
return False
elif (
module.params["action"] == "parse"
and os.path.exists(self.src)
and os.path.exists(self.src or "")
and os.path.exists(self.path)
):
try:
@@ -548,10 +577,10 @@ class Pkcs(OpenSSLObject):
return _check_pkey_passphrase()
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = {
result: dict[str, t.Any] = {
"filename": self.path,
}
if self.privatekey_path:
@@ -567,13 +596,20 @@ class Pkcs(OpenSSLObject):
return result
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
if self.backup:
self.backup_file = module.backup_local(self.path)
super(Pkcs, self).remove(module)
def parse(self):
def parse(self) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
"""Read PKCS#12 file."""
if self.src is None:
raise AssertionError("Contract violation: src is None")
try:
with open(self.src, "rb") as pkcs12_fh:
@@ -582,10 +618,13 @@ class Pkcs(OpenSSLObject):
except IOError as exc:
raise PkcsError(exc)
def generate(self):
def generate(self, module: AnsibleModule) -> None:
# Empty method because OpenSSLObject wants this
pass
def write(self, module, content, mode=None):
def write(
self, module: AnsibleModule, content: bytes, mode: int | str | None = None
) -> None:
"""Write the PKCS#12 file."""
if self.backup:
self.backup_file = module.backup_local(self.path)
@@ -595,7 +634,7 @@ class Pkcs(OpenSSLObject):
class PkcsCryptography(Pkcs):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(PkcsCryptography, self).__init__(module, iter_size_default=50000)
if (
self.encryption_level == "compatibility2022"
@@ -607,13 +646,12 @@ class PkcsCryptography(Pkcs):
exception=CRYPTOGRAPHY_COMPATIBILITY2022_ERR,
)
def generate_bytes(self, module):
def generate_bytes(self, module: AnsibleModule) -> bytes:
"""Generate PKCS#12 file archive."""
pkey = None
if self.privatekey_content:
try:
pkey = load_privatekey(
None,
pkey = load_certificate_issuer_privatekey(
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
)
@@ -631,6 +669,7 @@ class PkcsCryptography(Pkcs):
# Store fake object which can be used to retrieve the components back
self.pkcs12 = (pkey, cert, self.other_certificates, friendly_name)
encryption: serialization.KeySerializationEncryption
if not self.passphrase:
encryption = serialization.NoEncryption()
elif self.encryption_level == "compatibility2022":
@@ -654,7 +693,12 @@ class PkcsCryptography(Pkcs):
encryption,
)
def parse_bytes(self, pkcs12_content):
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
try:
private_key, certificate, additional_certificates, friendly_name = (
parse_pkcs12(pkcs12_content, self.passphrase)
@@ -683,11 +727,7 @@ class PkcsCryptography(Pkcs):
except ValueError as exc:
raise PkcsError(exc)
# The following methods will get self.pkcs12 passed, which is computed as:
#
# self.pkcs12 = (pkey, cert, self.other_certificates, self.friendly_name)
def _dump_privatekey(self, pkcs12):
def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
return (
pkcs12[0].private_bytes(
encoding=serialization.Encoding.PEM,
@@ -698,27 +738,27 @@ class PkcsCryptography(Pkcs):
else None
)
def _dump_certificate(self, pkcs12):
def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
return pkcs12[1].public_bytes(serialization.Encoding.PEM) if pkcs12[1] else None
def _dump_other_certificates(self, pkcs12):
def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
return [
other_cert.public_bytes(serialization.Encoding.PEM)
for other_cert in pkcs12[2]
]
def _get_friendly_name(self, pkcs12):
def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
return pkcs12[3]
def select_backend(module):
def select_backend(module: AnsibleModule) -> Pkcs:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PkcsCryptography(module)
def main():
def main() -> t.NoReturn:
argument_spec = dict(
action=dict(type="str", default="export", choices=["export", "parse"]),
other_certificates=dict(

View File

@@ -155,6 +155,7 @@ privatekey:
"""
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
@@ -172,9 +173,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey import (
PrivateKeyBackend,
)
class PrivateKeyModule(OpenSSLObject):
def __init__(self, module, module_backend):
def __init__(
self, module: AnsibleModule, module_backend: PrivateKeyBackend
) -> None:
super(PrivateKeyModule, self).__init__(
module.params["path"],
module.params["state"],
@@ -182,19 +192,19 @@ class PrivateKeyModule(OpenSSLObject):
module.check_mode,
)
self.module_backend = module_backend
self.return_content = module.params["return_content"]
self.return_content: bool = module.params["return_content"]
if self.force:
module_backend.regenerate = "always"
self.backup = module.params["backup"]
self.backup_file = None
self.backup: str | None = module.params["backup"]
self.backup_file: str | None = None
if module.params["mode"] is None:
module.params["mode"] = "0600"
module_backend.set_existing(load_file_if_exists(self.path, module))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Generate a keypair."""
if self.module_backend.needs_regeneration():
@@ -228,13 +238,13 @@ class PrivateKeyModule(OpenSSLObject):
file_args, self.changed
)
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
self.module_backend.set_existing(None)
if self.backup and not self.check_mode:
self.backup_file = module.backup_local(self.path)
super(PrivateKeyModule, self).remove(module)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = self.module_backend.dump(include_key=self.return_content)
@@ -246,7 +256,7 @@ class PrivateKeyModule(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_privatekey_argument_spec()
argument_spec.argument_spec.update(

View File

@@ -60,6 +60,7 @@ backup_file:
"""
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
@@ -77,8 +78,17 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey_convert import (
PrivateKeyConvertBackend,
)
class PrivateKeyConvertModule(OpenSSLObject):
def __init__(self, module, module_backend):
def __init__(
self, module: AnsibleModule, module_backend: PrivateKeyConvertBackend
) -> None:
super(PrivateKeyConvertModule, self).__init__(
module.params["dest_path"],
"present",
@@ -87,8 +97,8 @@ class PrivateKeyConvertModule(OpenSSLObject):
)
self.module_backend = module_backend
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
module.params["path"] = module.params["dest_path"]
if module.params["mode"] is None:
@@ -96,12 +106,14 @@ class PrivateKeyConvertModule(OpenSSLObject):
module_backend.set_existing_destination(load_file_if_exists(self.path, module))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Do conversion."""
if self.module_backend.needs_conversion():
# Convert
privatekey_data = self.module_backend.get_private_key_data()
if privatekey_data is None:
raise AssertionError("Contract violation: privatekey_data is None")
if not self.check_mode:
if self.backup:
self.backup_file = module.backup_local(self.path)
@@ -116,7 +128,7 @@ class PrivateKeyConvertModule(OpenSSLObject):
file_args, self.changed
)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = self.module_backend.dump()
@@ -127,7 +139,7 @@ class PrivateKeyConvertModule(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_privatekey_argument_spec()
argument_spec.argument_spec.update(

View File

@@ -200,6 +200,7 @@ private_data:
type: dict
"""
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -212,7 +213,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path"),
@@ -243,7 +244,7 @@ def main():
data = f.read()
except (IOError, OSError) as e:
module.fail_json(
msg=f"Error while reading private key file from disk: {e}", **result
msg=f"Error while reading private key file from disk: {e}", **result # type: ignore
)
result["can_load_key"] = True
@@ -261,10 +262,10 @@ def main():
module.exit_json(**result)
except PrivateKeyParseError as exc:
result.update(exc.result)
module.fail_json(msg=exc.error_message, **result)
module.fail_json(msg=exc.error_message, **result) # type: ignore
except PrivateKeyConsistencyError as exc:
result.update(exc.result)
module.fail_json(msg=exc.error_message, **result)
module.fail_json(msg=exc.error_message, **result) # type: ignore
except OpenSSLObjectError as exc:
module.fail_json(msg=str(exc))

View File

@@ -186,6 +186,7 @@ publickey:
"""
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -218,6 +219,12 @@ try:
except ImportError:
pass
if t.TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
PublicKeyTypes,
)
class PublicKeyError(OpenSSLObjectError):
pass
@@ -225,7 +232,7 @@ class PublicKeyError(OpenSSLObjectError):
class PublicKey(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(PublicKey, self).__init__(
module.params["path"],
module.params["state"],
@@ -233,27 +240,29 @@ class PublicKey(OpenSSLObject):
module.check_mode,
)
self.module = module
self.format = module.params["format"]
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
if self.privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.privatekey = None
self.publickey_bytes = None
self.return_content = module.params["return_content"]
self.fingerprint = {}
self.format: t.Literal["OpenSSH", "PEM"] = module.params["format"]
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
if privatekey_content is not None:
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.privatekey: PrivateKeyTypes | None = None
self.publickey_bytes: bytes | None = None
self.return_content: bool = module.params["return_content"]
self.fingerprint: dict[str, str] = {}
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
def _get_info(self, data):
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return dict()
result = dict(can_parse_key=False)
return {}
result = {"can_parse_key": False}
try:
result.update(
get_publickey_info(
@@ -267,7 +276,7 @@ class PublicKey(OpenSSLObject):
pass
return result
def _create_publickey(self, module):
def _create_publickey(self, module: AnsibleModule) -> bytes:
self.privatekey = load_privatekey(
path=self.privatekey_path,
content=self.privatekey_content,
@@ -284,10 +293,12 @@ class PublicKey(OpenSSLObject):
crypto_serialization.PublicFormat.SubjectPublicKeyInfo,
)
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Generate the public key."""
if self.privatekey_content is None and not os.path.exists(self.privatekey_path):
if self.privatekey_path is not None and not os.path.exists(
self.privatekey_path
):
raise PublicKeyError(
f"The private key {self.privatekey_path} does not exist"
)
@@ -320,17 +331,18 @@ class PublicKey(OpenSSLObject):
elif module.set_fs_attributes_if_different(file_args, False):
self.changed = True
def check(self, module, perms_required=True):
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
state_and_perms = super(PublicKey, self).check(module, perms_required)
def _check_privatekey():
if self.privatekey_content is None and not os.path.exists(
def _check_privatekey() -> bool:
if self.privatekey_path is not None and not os.path.exists(
self.privatekey_path
):
return False
current_publickey: PublicKeyTypes
try:
with open(self.path, "rb") as public_key_fh:
publickey_content = public_key_fh.read()
@@ -369,15 +381,15 @@ class PublicKey(OpenSSLObject):
return _check_privatekey()
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
if self.backup:
self.backup_file = module.backup_local(self.path)
super(PublicKey, self).remove(module)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = {
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
"filename": self.path,
"format": self.format,
@@ -403,7 +415,7 @@ class PublicKey(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(

View File

@@ -152,6 +152,7 @@ public_data:
returned: When RV(type=DSA) or RV(type=ECC)
"""
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -163,7 +164,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path"),
@@ -191,7 +192,7 @@ def main():
data = f.read()
except (IOError, OSError) as e:
module.fail_json(
msg=f"Error while reading public key file from disk: {e}", **result
msg=f"Error while reading public key file from disk: {e}", **result # type: ignore
)
module_backend = select_backend(module, data)
@@ -201,7 +202,7 @@ def main():
module.exit_json(**result)
except PublicKeyParseError as exc:
result.update(exc.result)
module.fail_json(msg=exc.error_message, **result)
module.fail_json(msg=exc.error_message, **result) # type: ignore
except OpenSSLObjectError as exc:
module.fail_json(msg=str(exc))

View File

@@ -99,6 +99,7 @@ signature:
import base64
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -132,7 +133,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im
class SignatureBase(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(SignatureBase, self).__init__(
path=module.params["path"],
state="present",
@@ -140,32 +141,35 @@ class SignatureBase(OpenSSLObject):
check_mode=module.check_mode,
)
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
if self.privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.module = module
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
if privatekey_content is not None:
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
def generate(self):
def generate(self, module: AnsibleModule) -> None:
# Empty method because OpenSSLObject wants this
pass
def dump(self):
def dump(self) -> dict[str, t.Any]:
# Empty method because OpenSSLObject wants this
pass
return {}
# Implementation with using cryptography
class SignatureCryptography(SignatureBase):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(SignatureCryptography, self).__init__(module)
def run(self):
def run(self) -> dict[str, t.Any]:
_padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
_hash = cryptography.hazmat.primitives.hashes.SHA256()
result = dict()
result: dict[str, t.Any] = {}
try:
with open(self.path, "rb") as f:
@@ -223,7 +227,7 @@ class SignatureCryptography(SignatureBase):
raise OpenSSLObjectError(e)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
privatekey_path=dict(type="path"),

View File

@@ -88,6 +88,7 @@ valid:
import base64
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -121,7 +122,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im
class SignatureInfoBase(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(SignatureInfoBase, self).__init__(
path=module.params["path"],
state="present",
@@ -129,32 +130,35 @@ class SignatureInfoBase(OpenSSLObject):
check_mode=module.check_mode,
)
self.signature = module.params["signature"]
self.certificate_path = module.params["certificate_path"]
self.certificate_content = module.params["certificate_content"]
if self.certificate_content is not None:
self.certificate_content = self.certificate_content.encode("utf-8")
self.module = module
self.signature: str = module.params["signature"]
self.certificate_path: str | None = module.params["certificate_path"]
certificate_content: str | None = module.params["certificate_content"]
if certificate_content is not None:
self.certificate_content: bytes | None = certificate_content.encode("utf-8")
else:
self.certificate_content = None
def generate(self):
def generate(self, module: AnsibleModule) -> None:
# Empty method because OpenSSLObject wants this
pass
def dump(self):
def dump(self) -> dict[str, t.Any]:
# Empty method because OpenSSLObject wants this
pass
return {}
# Implementation with using cryptography
class SignatureInfoCryptography(SignatureInfoBase):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(SignatureInfoCryptography, self).__init__(module)
def run(self):
def run(self) -> dict[str, t.Any]:
_padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
_hash = cryptography.hazmat.primitives.hashes.SHA256()
result = dict()
result: dict[str, t.Any] = {}
try:
with open(self.path, "rb") as f:
@@ -228,7 +232,7 @@ class SignatureInfoCryptography(SignatureInfoBase):
raise OpenSSLObjectError(e)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
certificate_path=dict(type="path"),

View File

@@ -224,6 +224,7 @@ certificate:
import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
@@ -257,8 +258,15 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
CertificateBackend,
)
class CertificateAbsent(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(CertificateAbsent, self).__init__(
module.params["path"],
module.params["state"],
@@ -266,19 +274,19 @@ class CertificateAbsent(OpenSSLObject):
module.check_mode,
)
self.module = module
self.return_content = module.params["return_content"]
self.backup = module.params["backup"]
self.backup_file = None
self.return_content: bool = module.params["return_content"]
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
pass
def remove(self, module):
def remove(self, module: AnsibleModule) -> None:
if self.backup:
self.backup_file = module.backup_local(self.path)
super(CertificateAbsent, self).remove(module)
def dump(self, check_mode=False):
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = {
"changed": self.changed,
"filename": self.path,
@@ -296,7 +304,7 @@ class CertificateAbsent(OpenSSLObject):
class GenericCertificate(OpenSSLObject):
"""Retrieve a certificate using the given module backend."""
def __init__(self, module, module_backend):
def __init__(self, module: AnsibleModule, module_backend: CertificateBackend):
super(GenericCertificate, self).__init__(
module.params["path"],
module.params["state"],
@@ -311,7 +319,7 @@ class GenericCertificate(OpenSSLObject):
self.module_backend = module_backend
self.module_backend.set_existing(load_file_if_exists(self.path, module))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
if self.module_backend.needs_regeneration():
if not self.check_mode:
self.module_backend.generate_certificate()
@@ -329,14 +337,14 @@ class GenericCertificate(OpenSSLObject):
file_args, self.changed
)
def check(self, module, perms_required=True):
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
return (
super(GenericCertificate, self).check(module, perms_required)
and not self.module_backend.needs_regeneration()
)
def dump(self, check_mode=False):
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = self.module_backend.dump(include_certificate=self.return_content)
result.update(
{
@@ -349,7 +357,7 @@ class GenericCertificate(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_certificate_argument_spec()
add_acme_provider_to_argument_spec(argument_spec)
add_entrust_provider_to_argument_spec(argument_spec)
@@ -363,13 +371,14 @@ def main():
return_content=dict(type="bool", default=False),
)
)
argument_spec.required_if.append(["state", "present", ["provider"]])
argument_spec.required_if.append(("state", "present", ["provider"]))
module = argument_spec.create_ansible_module(
add_file_common_args=True,
supports_check_mode=True,
)
try:
certificate: GenericCertificate | CertificateAbsent
if module.params["state"] == "absent":
certificate = CertificateAbsent(module)
@@ -389,7 +398,13 @@ def main():
)
provider = module.params["provider"]
provider_map = {
provider_map: dict[
str,
type[AcmeCertificateProvider]
| type[EntrustCertificateProvider]
| type[OwnCACertificateProvider]
| type[SelfSignedCertificateProvider],
] = {
"acme": AcmeCertificateProvider,
"entrust": EntrustCertificateProvider,
"ownca": OwnCACertificateProvider,

View File

@@ -106,6 +106,7 @@ backup_file:
import base64
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_text
@@ -142,8 +143,12 @@ except ImportError:
pass
def parse_certificate(input, strict=False):
input_format = "pem" if identify_pem_format(input) else "der"
def parse_certificate(
input: bytes, strict: bool = False
) -> tuple[bytes, t.Literal["pem", "der"], str | None]:
input_format: t.Literal["pem", "der"] = (
"pem" if identify_pem_format(input) else "der"
)
if input_format == "pem":
pems = split_pem_list(to_text(input))
if len(pems) > 1 and strict:
@@ -162,7 +167,7 @@ def parse_certificate(input, strict=False):
class X509CertificateConvertModule(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(X509CertificateConvertModule, self).__init__(
module.params["dest_path"],
"present",
@@ -170,9 +175,9 @@ class X509CertificateConvertModule(OpenSSLObject):
module.check_mode,
)
self.src_path = module.params["src_path"]
self.src_content = module.params["src_content"]
self.src_content_base64 = module.params["src_content_base64"]
self.src_path: str | None = module.params["src_path"]
self.src_content: str | None = module.params["src_content"]
self.src_content_base64: bool = module.params["src_content_base64"]
if self.src_content is not None:
self.input = to_bytes(self.src_content)
if self.src_content_base64:
@@ -181,6 +186,8 @@ class X509CertificateConvertModule(OpenSSLObject):
except Exception as exc:
module.fail_json(msg=f"Cannot Base64 decode src_content: {exc}")
else:
if self.src_path is None:
module.fail_json(msg="One of src_path and src_content must be provided")
try:
with open(self.src_path, "rb") as f:
self.input = f.read()
@@ -189,8 +196,8 @@ class X509CertificateConvertModule(OpenSSLObject):
msg=f"Failure while reading file {self.src_path}: {exc}"
)
self.format = module.params["format"]
self.strict = module.params["strict"]
self.format: t.Literal["pem", "der"] = module.params["format"]
self.strict: bool = module.params["strict"]
self.wanted_pem_type = "CERTIFICATE"
try:
@@ -203,8 +210,8 @@ class X509CertificateConvertModule(OpenSSLObject):
if module.params["verify_cert_parsable"]:
self.verify_cert_parsable(module)
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
module.params["path"] = self.path
@@ -221,7 +228,7 @@ class X509CertificateConvertModule(OpenSSLObject):
except Exception:
pass
def verify_cert_parsable(self, module):
def verify_cert_parsable(self, module: AnsibleModule) -> None:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
@@ -230,7 +237,7 @@ class X509CertificateConvertModule(OpenSSLObject):
except Exception as exc:
module.fail_json(msg=f"Error while parsing certificate: {exc}")
def needs_conversion(self):
def needs_conversion(self) -> bool:
if self.dest_content is None or self.dest_content_format is None:
return True
if self.dest_content_format != self.format:
@@ -241,7 +248,7 @@ class X509CertificateConvertModule(OpenSSLObject):
return True
return False
def get_dest_certificate(self):
def get_dest_certificate(self) -> bytes:
if self.format == "der":
return self.input
data = to_bytes(base64.b64encode(self.input))
@@ -250,7 +257,7 @@ class X509CertificateConvertModule(OpenSSLObject):
lines.append(to_bytes(f"{PEM_END_START}{self.wanted_pem_type}{PEM_END}\n"))
return b"\n".join(lines)
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
"""Do conversion."""
if self.needs_conversion():
# Convert
@@ -269,18 +276,18 @@ class X509CertificateConvertModule(OpenSSLObject):
file_args, self.changed
)
def dump(self):
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result = dict(
changed=self.changed,
)
result: dict[str, t.Any] = {
"changed": self.changed,
}
if self.backup_file:
result["backup_file"] = self.backup_file
return result
def main():
def main() -> t.NoReturn:
argument_spec = dict(
src_path=dict(type="path"),
src_content=dict(type="str"),

View File

@@ -390,8 +390,10 @@ issuer_uri:
version_added: 2.9.0
"""
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
)
@@ -406,7 +408,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path"),
@@ -424,18 +426,22 @@ def main():
supports_check_mode=True,
)
if module.params["content"] is not None:
data = module.params["content"].encode("utf-8")
content: str | None = module.params["content"]
path: str | None = module.params["path"]
if content is not None:
data = content.encode("utf-8")
else:
if path is None:
module.fail_json(msg="One of path and content must be provided")
try:
with open(module.params["path"], "rb") as f:
with open(path, "rb") as f:
data = f.read()
except (IOError, OSError) as e:
module.fail_json(msg=f"Error while reading certificate file from disk: {e}")
module_backend = select_backend(module, data)
valid_at = module.params["valid_at"]
valid_at: dict[str, t.Any] = module.params["valid_at"]
if valid_at:
for k, v in valid_at.items():
if not isinstance(v, (str, bytes)):
@@ -443,13 +449,11 @@ def main():
msg=f"The value for valid_at.{k} must be of type string (got {type(v)})"
)
valid_at[k] = get_relative_time_option(
v, f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE
to_text(v), f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE
)
try:
result = module_backend.get_info(
der_support_enabled=module.params["content"] is None
)
result = module_backend.get_info(der_support_enabled=content is None)
not_before = module_backend.get_not_before()
not_after = module_backend.get_not_after()

View File

@@ -118,6 +118,8 @@ certificate:
type: str
"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError,
)
@@ -139,23 +141,31 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
CertificateBackend,
)
class GenericCertificate:
"""Retrieve a certificate using the given module backend."""
def __init__(self, module, module_backend):
def __init__(self, module: AnsibleModule, module_backend: CertificateBackend):
self.check_mode = module.check_mode
self.module = module
self.module_backend = module_backend
self.changed = False
if module.params["content"] is not None:
self.module_backend.set_existing(module.params["content"].encode("utf-8"))
content: str | None = module.params["content"]
if content is not None:
self.module_backend.set_existing(content.encode("utf-8"))
def generate(self, module):
def generate(self, module: AnsibleModule) -> None:
if self.module_backend.needs_regeneration():
self.module_backend.generate_certificate()
self.changed = True
def dump(self, check_mode=False):
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = self.module_backend.dump(include_certificate=True)
result.update(
{
@@ -165,7 +175,7 @@ class GenericCertificate:
return result
def main():
def main() -> t.NoReturn:
argument_spec = get_certificate_argument_spec()
argument_spec.argument_spec["provider"]["required"] = True
add_entrust_provider_to_argument_spec(argument_spec)
@@ -182,7 +192,12 @@ def main():
try:
provider = module.params["provider"]
provider_map = {
provider_map: dict[
str,
type[EntrustCertificateProvider]
| type[OwnCACertificateProvider]
| type[SelfSignedCertificateProvider],
] = {
"entrust": EntrustCertificateProvider,
"ownca": OwnCACertificateProvider,
"selfsigned": SelfSignedCertificateProvider,

View File

@@ -425,6 +425,7 @@ crl:
import base64
import os
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_text
@@ -463,7 +464,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
OpenSSLObject,
load_certificate,
load_privatekey,
load_certificate_issuer_privatekey,
parse_name_field,
parse_ordered_name_field,
select_message_digest,
@@ -495,6 +496,9 @@ try:
except ImportError:
pass
if t.TYPE_CHECKING:
import datetime
class CRLError(OpenSSLObjectError):
pass
@@ -502,7 +506,7 @@ class CRLError(OpenSSLObjectError):
class CRL(OpenSSLObject):
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
super(CRL, self).__init__(
module.params["path"],
module.params["state"],
@@ -510,53 +514,69 @@ class CRL(OpenSSLObject):
module.check_mode,
)
self.format = module.params["format"]
self.format: t.Literal["pem", "der"] = module.params["format"]
self.update = module.params["crl_mode"] == "update"
self.ignore_timestamps = module.params["ignore_timestamps"]
self.return_content = module.params["return_content"]
self.name_encoding = module.params["name_encoding"]
self.serial_numbers_format = module.params["serial_numbers"]
self.crl_content = None
self.update: bool = module.params["crl_mode"] == "update"
self.ignore_timestamps: bool = module.params["ignore_timestamps"]
self.return_content: bool = module.params["return_content"]
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params[
"name_encoding"
]
self.serial_numbers_format: t.Literal["integer", "hex-octets"] = module.params[
"serial_numbers"
]
self.crl_content: bytes | None = None
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
if self.privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
if privatekey_content is not None:
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
try:
if module.params["issuer_ordered"]:
issuer_ordered: list[dict[str, t.Any]] | None = module.params[
"issuer_ordered"
]
issuer: dict[str, list[str | bytes] | bytes | str] | None = module.params[
"issuer"
]
if issuer_ordered:
self.issuer_ordered = True
self.issuer = parse_ordered_name_field(
module.params["issuer_ordered"], "issuer_ordered"
)
self.issuer = parse_ordered_name_field(issuer_ordered, "issuer_ordered")
else:
self.issuer_ordered = False
self.issuer = parse_name_field(module.params["issuer"], "issuer")
self.issuer = (
parse_name_field(issuer, "issuer") if issuer is not None else []
)
except (TypeError, ValueError) as exc:
module.fail_json(msg=str(exc))
self.last_update = get_relative_time_option(
self.last_update: datetime.datetime = get_relative_time_option(
module.params["last_update"],
"last_update",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.next_update = get_relative_time_option(
self.next_update: datetime.datetime | None = get_relative_time_option(
module.params["next_update"],
"next_update",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["digest"])
if self.digest is None:
digest = select_message_digest(module.params["digest"])
if digest is None:
raise CRLError(f'The digest "{module.params["digest"]}" is not supported')
self.digest = digest
self.module = module
self.revoked_certificates = []
for i, rc in enumerate(module.params["revoked_certificates"]):
result = {
revoked_certificates: list[dict[str, t.Any]] = module.params[
"revoked_certificates"
]
for i, rc in enumerate(revoked_certificates):
result: dict[str, t.Any] = {
"serial_number": None,
"revocation_date": None,
"issuer": None,
@@ -567,21 +587,25 @@ class CRL(OpenSSLObject):
"invalidity_date_critical": False,
}
path_prefix = f"revoked_certificates[{i}]."
if rc["path"] is not None or rc["content"] is not None:
path: str | None = rc["path"]
content_str: str | None = rc["content"]
if path is not None or content_str is not None:
# Load certificate from file or content
try:
if rc["content"] is not None:
rc["content"] = rc["content"].encode("utf-8")
cert = load_certificate(rc["path"], content=rc["content"])
content: bytes | None = None
if content_str is not None:
content = content_str.encode("utf-8")
rc["content"] = content
cert = load_certificate(path, content=content)
result["serial_number"] = cert.serial_number
except OpenSSLObjectError as e:
if rc["content"] is not None:
if content_str is not None:
module.fail_json(
msg=f"Cannot parse certificate from {path_prefix}content: {e}"
)
else:
module.fail_json(
msg=f'Cannot read certificate "{rc["path"]}" from {path_prefix}path: {e}'
msg=f'Cannot read certificate "{path}" from {path_prefix}path: {e}'
)
else:
# Specify serial_number (and potentially issuer) directly
@@ -611,11 +635,11 @@ class CRL(OpenSSLObject):
result["invalidity_date_critical"] = rc["invalidity_date_critical"]
self.revoked_certificates.append(result)
self.backup = module.params["backup"]
self.backup_file = None
self.backup: bool = module.params["backup"]
self.backup_file: str | None = None
try:
self.privatekey = load_privatekey(
self.privatekey = load_certificate_issuer_privatekey(
path=self.privatekey_path,
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
@@ -643,7 +667,7 @@ class CRL(OpenSSLObject):
self.diff_after = self.diff_before = self._get_info(data)
def _parse_serial_number(self, value, index):
def _parse_serial_number(self, value: t.Any, index: int) -> int:
if self.serial_numbers_format == "integer":
try:
return check_type_int(value)
@@ -662,22 +686,42 @@ class CRL(OpenSSLObject):
f"Unexpected value {self.serial_numbers_format} of serial_numbers"
)
def _get_info(self, data):
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return dict()
return {}
try:
result = get_crl_info(self.module, data)
result["can_parse_crl"] = True
return result
except Exception:
return dict(can_parse_crl=False)
return {"can_parse_crl": False}
def remove(self):
def remove(self, module: AnsibleModule) -> None:
if self.backup:
self.backup_file = self.module.backup_local(self.path)
super(CRL, self).remove(self.module)
def _compress_entry(self, entry):
def _compress_entry(self, entry: dict[str, t.Any]) -> (
tuple[
int | None,
tuple[str, ...] | None,
bool,
int | None,
bool,
datetime.datetime | None,
bool,
]
| tuple[
int | None,
datetime.datetime | None,
tuple[str, ...] | None,
bool,
int | None,
bool,
datetime.datetime | None,
bool,
]
):
issuer = None
if entry["issuer"] is not None:
# Normalize to IDNA. If this is used-provided, it was already converted to
@@ -713,7 +757,12 @@ class CRL(OpenSSLObject):
entry["invalidity_date_critical"],
)
def check(self, module, perms_required=True, ignore_conversion=True):
def check(
self,
module: AnsibleModule,
perms_required: bool = True,
ignore_conversion: bool = True,
) -> bool:
"""Ensure the resource is in its desired state."""
state_and_perms = super(CRL, self).check(self.module, perms_required)
@@ -743,10 +792,13 @@ class CRL(OpenSSLObject):
]
is_issuer = [(sub.oid, sub.value) for sub in self.crl.issuer]
if not self.issuer_ordered:
want_issuer = set(want_issuer)
is_issuer = set(is_issuer)
if want_issuer != is_issuer:
return False
want_issuer_set = set(want_issuer)
is_issuer_set = set(is_issuer)
if want_issuer_set != is_issuer_set:
return False
else:
if want_issuer != is_issuer:
return False
old_entries = [
self._compress_entry(cryptography_decode_revoked_certificate(cert))
@@ -769,7 +821,7 @@ class CRL(OpenSSLObject):
return True
def _generate_crl(self):
def _generate_crl(self) -> bytes:
crl = CertificateRevocationListBuilder()
try:
@@ -787,7 +839,8 @@ class CRL(OpenSSLObject):
raise CRLError(e)
crl = set_last_update(crl, self.last_update)
crl = set_next_update(crl, self.next_update)
if self.next_update is not None:
crl = set_next_update(crl, self.next_update)
if self.update and self.crl:
new_entries = set(
@@ -799,22 +852,26 @@ class CRL(OpenSSLObject):
)
if decoded_entry not in new_entries:
crl = crl.add_revoked_certificate(entry)
for entry in self.revoked_certificates:
for revoked_entry in self.revoked_certificates:
revoked_cert = RevokedCertificateBuilder()
revoked_cert = revoked_cert.serial_number(entry["serial_number"])
revoked_cert = set_revocation_date(revoked_cert, entry["revocation_date"])
if entry["issuer"] is not None:
revoked_cert = revoked_cert.serial_number(revoked_entry["serial_number"])
revoked_cert = set_revocation_date(
revoked_cert, revoked_entry["revocation_date"]
)
if revoked_entry["issuer"] is not None:
revoked_cert = revoked_cert.add_extension(
x509.CertificateIssuer(entry["issuer"]), entry["issuer_critical"]
x509.CertificateIssuer(revoked_entry["issuer"]),
revoked_entry["issuer_critical"],
)
if entry["reason"] is not None:
if revoked_entry["reason"] is not None:
revoked_cert = revoked_cert.add_extension(
x509.CRLReason(entry["reason"]), entry["reason_critical"]
x509.CRLReason(revoked_entry["reason"]),
revoked_entry["reason_critical"],
)
if entry["invalidity_date"] is not None:
if revoked_entry["invalidity_date"] is not None:
revoked_cert = revoked_cert.add_extension(
x509.InvalidityDate(entry["invalidity_date"]),
entry["invalidity_date_critical"],
x509.InvalidityDate(revoked_entry["invalidity_date"]),
revoked_entry["invalidity_date_critical"],
)
crl = crl.add_revoked_certificate(revoked_cert.build())
@@ -827,7 +884,7 @@ class CRL(OpenSSLObject):
else:
return self.crl.public_bytes(Encoding.DER)
def generate(self):
def generate(self, module: AnsibleModule) -> None:
result = None
if (
not self.check(self.module, perms_required=False, ignore_conversion=True)
@@ -861,7 +918,7 @@ class CRL(OpenSSLObject):
elif self.module.set_fs_attributes_if_different(file_args, False):
self.changed = True
def dump(self, check_mode=False):
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = {
"changed": self.changed,
"filename": self.path,
@@ -879,37 +936,53 @@ class CRL(OpenSSLObject):
if check_mode:
result["last_update"] = self.last_update.strftime(TIMESTAMP_FORMAT)
result["next_update"] = self.next_update.strftime(TIMESTAMP_FORMAT)
result["next_update"] = (
self.next_update.strftime(TIMESTAMP_FORMAT)
if self.next_update is not None
else None
)
# result['digest'] = cryptography_oid_to_name(self.crl.signature_algorithm_oid)
result["digest"] = self.module.params["digest"]
result["issuer_ordered"] = self.issuer
result["issuer"] = {}
issuer: dict[str, str | bytes] = {}
result["issuer"] = issuer
for k, v in self.issuer:
result["issuer"][k] = v
result["revoked_certificates"] = []
issuer[k] = v
revoked_certificates: list[dict[str, t.Any]] = []
result["revoked_certificates"] = revoked_certificates
for entry in self.revoked_certificates:
result["revoked_certificates"].append(
revoked_certificates.append(
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
)
elif self.crl:
result["last_update"] = get_last_update(self.crl).strftime(TIMESTAMP_FORMAT)
result["next_update"] = get_next_update(self.crl).strftime(TIMESTAMP_FORMAT)
next_update = get_next_update(self.crl)
result["next_update"] = (
next_update.strftime(TIMESTAMP_FORMAT)
if next_update is not None
else None
)
result["digest"] = cryptography_oid_to_name(
cryptography_get_signature_algorithm_oid_from_crl(self.crl)
)
issuer = []
issuer_list: list[list[str]] = []
for attribute in self.crl.issuer:
issuer.append(
[cryptography_oid_to_name(attribute.oid), attribute.value]
issuer_list.append(
[
cryptography_oid_to_name(attribute.oid),
to_text(attribute.value),
]
)
result["issuer_ordered"] = issuer
result["issuer"] = {}
for k, v in issuer:
result["issuer"][k] = v
result["revoked_certificates"] = []
result["issuer_ordered"] = issuer_list
issuer = {}
result["issuer"] = issuer
for k, v in issuer_list:
issuer[k] = v
revoked_certificates = []
result["revoked_certificates"] = revoked_certificates
for cert in self.crl:
entry = cryptography_decode_revoked_certificate(cert)
result["revoked_certificates"].append(
revoked_certificates.append(
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
)
@@ -923,7 +996,7 @@ class CRL(OpenSSLObject):
return result
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
state=dict(type="str", default="present", choices=["present", "absent"]),
@@ -1015,14 +1088,14 @@ def main():
)
module.exit_json(**result)
crl.generate()
crl.generate(module)
else:
if module.check_mode:
result = crl.dump(check_mode=True)
result["changed"] = os.path.exists(module.params["path"])
module.exit_json(**result)
crl.remove()
crl.remove(module)
result = crl.dump()
module.exit_json(**result)

View File

@@ -100,7 +100,9 @@ last_update:
type: str
sample: '20190413202428Z'
next_update:
description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
description:
- The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
- Will be C(none) if no such timestamp is present.
returned: success
type: str
sample: '20190413202428Z'
@@ -172,6 +174,7 @@ revoked_certificates:
import base64
import binascii
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -185,7 +188,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
)
def main():
def main() -> t.NoReturn:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path"),
@@ -200,25 +203,30 @@ def main():
supports_check_mode=True,
)
if module.params["content"] is None:
content: str | None = module.params["content"]
path: str | None = module.params["path"]
if content is None:
if path is None:
module.fail_json(msg="One of content and path must be provided")
try:
with open(module.params["path"], "rb") as f:
with open(path, "rb") as f:
data = f.read()
except (IOError, OSError) as e:
module.fail_json(msg=f"Error while reading CRL file from disk: {e}")
else:
data = module.params["content"].encode("utf-8")
data = content.encode("utf-8")
if not identify_pem_format(data):
try:
data = base64.b64decode(module.params["content"])
data = base64.b64decode(content)
except (binascii.Error, TypeError) as e:
module.fail_json(msg=f"Error while Base64 decoding content: {e}")
list_revoked_certificates: bool = module.params["list_revoked_certificates"]
try:
result = get_crl_info(
module,
data,
list_revoked_certificates=module.params["list_revoked_certificates"],
list_revoked_certificates=list_revoked_certificates,
)
module.exit_json(**result)
except OpenSSLObjectError as e:

View File

@@ -15,20 +15,24 @@ from __future__ import annotations
import abc
import copy
import traceback
import typing as t
from ansible.errors import AnsibleError
from ansible.module_utils.basic import SEQUENCETYPE, remove_values
from ansible.module_utils.common._collections_compat import Mapping
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
from ansible.module_utils.common.validation import (
safe_eval,
)
from ansible.module_utils.errors import UnsupportedError
from ansible.plugins.action import ActionBase
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
ArgumentSpec,
)
class _ModuleExitException(Exception):
def __init__(self, result):
def __init__(self, result: dict[str, t.Any]) -> None:
super(_ModuleExitException, self).__init__()
self.result = result
@@ -36,20 +40,21 @@ class _ModuleExitException(Exception):
class AnsibleActionModule:
def __init__(
self,
action_plugin,
argument_spec,
bypass_checks=False,
mutually_exclusive=None,
required_together=None,
required_one_of=None,
supports_check_mode=False,
required_if=None,
required_by=None,
):
action_plugin: ActionModuleBase,
argument_spec: dict[str, t.Any],
*,
bypass_checks: bool = False,
supports_check_mode: bool = False,
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
required_together: list[list[str] | tuple[str, ...]] | None = None,
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
required_if: list[tuple[str, t.Any, list[str] | tuple[str, ...]]] | None = None,
required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
) -> None:
# Internal data
self.__action_plugin = action_plugin
self.__warnings = []
self.__deprecations = []
self.__warnings: list[str] = []
self.__deprecations: list[dict[str, str | None]] = []
# AnsibleModule data
self._name = self.__action_plugin._task.action
@@ -67,10 +72,6 @@ class AnsibleActionModule:
self._diff = self.__action_plugin._play_context.diff
self._verbosity = self.__action_plugin._display.verbosity
self.aliases = {}
self._legal_inputs = []
self._options_context = list()
self.params = copy.deepcopy(self.__action_plugin._task.args)
self.no_log_values = set()
self._validator = ArgumentSpecValidator(
@@ -122,38 +123,41 @@ class AnsibleActionModule:
self.fail_json(msg=msg)
def safe_eval(self, value, locals=None, include_exceptions=False):
return safe_eval(value, locals, include_exceptions)
def warn(self, warning):
def warn(self, warning: str) -> None:
# Copied from ansible.module_utils.common.warnings:
if isinstance(warning, (str, bytes)):
if isinstance(warning, str):
self.__warnings.append(warning)
else:
raise TypeError(f"warn requires a string not a {type(warning)}")
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
if version is not None and date is not None:
raise AssertionError(
"implementation error -- version and date must not both be set"
)
# Copied from ansible.module_utils.common.warnings:
if isinstance(msg, (str, bytes)):
# For compatibility, we accept that neither version nor date is set,
# and treat that the same as if version would haven been set
if date is not None:
self.__deprecations.append(
{"msg": msg, "date": date, "collection_name": collection_name}
)
else:
self.__deprecations.append(
{"msg": msg, "version": version, "collection_name": collection_name}
)
else:
if not isinstance(msg, str):
raise TypeError(f"deprecate requires a string not a {type(msg)}")
def _return_formatted(self, kwargs):
# For compatibility, we accept that neither version nor date is set,
# and treat that the same as if version would haven been set
if date is not None:
self.__deprecations.append(
{"msg": msg, "date": date, "collection_name": collection_name}
)
else:
self.__deprecations.append(
{"msg": msg, "version": version, "collection_name": collection_name}
)
def _return_formatted(self, kwargs: dict[str, t.Any]) -> t.NoReturn:
if "invocation" not in kwargs:
kwargs["invocation"] = {"module_args": self.params}
@@ -194,13 +198,13 @@ class AnsibleActionModule:
kwargs = remove_values(kwargs, self.no_log_values)
raise _ModuleExitException(kwargs)
def exit_json(self, **kwargs):
def exit_json(self, **kwargs) -> t.NoReturn:
result = dict(kwargs)
if "failed" not in result:
result["failed"] = False
self._return_formatted(result)
def fail_json(self, msg, **kwargs):
def fail_json(self, msg: str, **kwargs) -> t.NoReturn:
result = dict(kwargs)
result["failed"] = True
result["msg"] = msg
@@ -209,16 +213,15 @@ class AnsibleActionModule:
class ActionModuleBase(ActionBase, metaclass=abc.ABCMeta):
@abc.abstractmethod
def setup_module(self):
def setup_module(self) -> tuple[ArgumentSpec, dict[str, t.Any]]:
"""Return pair (ArgumentSpec, kwargs)."""
pass
@abc.abstractmethod
def run_module(self, module):
def run_module(self, module: AnsibleActionModule) -> None:
"""Run module code"""
module.fail_json(msg="Not implemented.")
def run(self, tmp=None, task_vars=None):
def run(self, tmp=None, task_vars=None) -> dict[str, t.Any]:
if task_vars is None:
task_vars = dict()

View File

@@ -6,14 +6,23 @@
from __future__ import annotations
import typing as t
from ansible.errors import AnsibleFilterError
from ansible.utils.display import Display
_display = Display()
class FilterModuleMock:
def __init__(self, params):
def __init__(self, params: dict[str, t.Any]) -> None:
self.check_mode = True
self.params = params
self._diff = False
def fail_json(self, msg, **kwargs):
def fail_json(self, msg: str, **kwargs) -> t.NoReturn:
raise AnsibleFilterError(msg)
def warn(self, warning: str) -> None:
_display.warning(warning)

View File

@@ -4,6 +4,7 @@
from __future__ import annotations
import typing as t
from subprocess import PIPE, Popen
from ansible.module_utils.common.process import get_bin_path
@@ -15,7 +16,7 @@ from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import
class PluginGPGRunner(GPGRunner):
def __init__(self, executable=None, cwd=None):
def __init__(self, executable: str | None = None, cwd: str | None = None) -> None:
if executable is None:
try:
executable = get_bin_path("gpg")
@@ -24,7 +25,9 @@ class PluginGPGRunner(GPGRunner):
self.executable = executable
self.cwd = cwd
def run_command(self, command, check_rc=True, data=None):
def run_command(
self, command: list[str], check_rc: bool = True, data: bytes | None = None
) -> tuple[int, str, str]:
"""
Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``.
@@ -41,12 +44,10 @@ class PluginGPGRunner(GPGRunner):
command, shell=False, cwd=self.cwd, stdin=PIPE, stdout=PIPE, stderr=PIPE
)
stdout, stderr = p.communicate(input=data)
stdout = to_native(stdout, errors="surrogate_or_replace")
stderr = to_native(stderr, errors="surrogate_or_replace")
stdout_n = to_native(stdout, errors="surrogate_or_replace")
stderr_n = to_native(stderr, errors="surrogate_or_replace")
if check_rc and p.returncode != 0:
stdout_n = (to_native(stdout, errors="surrogate_or_replace"),)
stderr_n = (to_native(stderr, errors="surrogate_or_replace"),)
raise GPGError(
f'Running {" ".join(command)} yielded return code {p.returncode} with stdout: "{stdout_n}" and stderr: "{stderr_n}")'
)
return p.returncode, stdout, stderr
return t.cast(int, p.returncode), stdout_n, stderr_n