Add type hints and type checking (#885)

* Enable basic type checking.

* Fix first errors.

* Add changelog fragment.

* Add types to module_utils and plugin_utils (without module backends).

* Add typing hints for acme_* modules.

* Add typing to X.509 certificate modules, and add more helpers.

* Add typing to remaining module backends.

* Add typing for action, filter, and lookup plugins.

* Bump ansible-core 2.19 beta requirement for typing.

* Add more typing definitions.

* Add typing to some unit tests.
This commit is contained in:
Felix Fontein
2025-05-11 18:00:11 +02:00
committed by GitHub
parent 82f0176773
commit f758d94fba
124 changed files with 4986 additions and 2662 deletions

View File

@@ -18,7 +18,14 @@ run_yamllint = true
yamllint_config = ".yamllint" yamllint_config = ".yamllint"
yamllint_config_plugins = ".yamllint-docs" yamllint_config_plugins = ".yamllint-docs"
yamllint_config_plugins_examples = ".yamllint-examples" yamllint_config_plugins_examples = ".yamllint-examples"
run_mypy = false run_mypy = true
mypy_ansible_core_package = "ansible-core>=2.19.0b3"
mypy_config = "tests/nox-config-mypy.ini"
mypy_extra_deps = [
"cryptography",
"types-mock",
"types-PyYAML",
]
[sessions.docs_check] [sessions.docs_check]
validate_collection_refs="all" validate_collection_refs="all"

View File

@@ -5,3 +5,4 @@ minor_changes:
- "Python code modernization: remove Python 3 specific code (https://github.com/ansible-collections/community.crypto/pull/877)." - "Python code modernization: remove Python 3 specific code (https://github.com/ansible-collections/community.crypto/pull/877)."
- "Python code modernization: avoid unnecessary string conversion (https://github.com/ansible-collections/community.crypto/pull/880)." - "Python code modernization: avoid unnecessary string conversion (https://github.com/ansible-collections/community.crypto/pull/880)."
- "Python code modernization: avoid using ``six`` (https://github.com/ansible-collections/community.crypto/pull/884)." - "Python code modernization: avoid using ``six`` (https://github.com/ansible-collections/community.crypto/pull/884)."
- "Python code modernization: add type hints and type checking (https://github.com/ansible-collections/community.crypto/pull/885)."

View File

@@ -0,0 +1,5 @@
breaking_changes:
- "The validation for relative timestamps is now more strict. A string starting with ``+`` or ``-`` must be valid,
otherwise validation will fail. In the past such strings were often silently ignored, and in many cases the code
which triggered the validation was not able to handle no result
(https://github.com/ansible-collections/community.crypto/pull/885)."

View File

@@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import base64 import base64
import typing as t
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -19,25 +20,41 @@ from ansible_collections.community.crypto.plugins.plugin_utils.action_module imp
) )
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
ArgumentSpec,
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey import (
PrivateKeyBackend,
)
from ansible_collections.community.crypto.plugins.plugin_utils.action_module import (
AnsibleActionModule,
)
class PrivateKeyModule: class PrivateKeyModule:
def __init__(self, module, module_backend): def __init__(
self, module: AnsibleActionModule, module_backend: PrivateKeyBackend
) -> None:
self.module = module self.module = module
self.module_backend = module_backend self.module_backend = module_backend
self.check_mode = module.check_mode self.check_mode = module.check_mode
self.changed = False self.changed = False
self.return_current_key = module.params["return_current_key"] self.return_current_key: bool = module.params["return_current_key"]
if module.params["content"] is not None: content: str | None = module.params["content"]
if module.params["content_base64"]: content_base64: bool = module.params["content_base64"]
if content is not None:
if content_base64:
try: try:
data = base64.b64decode(module.params["content"]) data = base64.b64decode(content)
except Exception as e: except Exception as e:
module.fail_json(msg=f"Cannot decode Base64 encoded data: {e}") module.fail_json(msg=f"Cannot decode Base64 encoded data: {e}")
else: else:
data = to_bytes(module.params["content"]) data = to_bytes(content)
module_backend.set_existing(data) module_backend.set_existing(data)
def generate(self, module): def generate(self, module: AnsibleActionModule) -> None:
"""Generate a keypair.""" """Generate a keypair."""
if self.module_backend.needs_regeneration(): if self.module_backend.needs_regeneration():
@@ -53,7 +70,7 @@ class PrivateKeyModule:
self.privatekey_bytes = privatekey_data self.privatekey_bytes = privatekey_data
self.changed = True self.changed = True
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = self.module_backend.dump( result = self.module_backend.dump(
include_key=self.changed or self.return_current_key include_key=self.changed or self.return_current_key
@@ -64,7 +81,7 @@ class PrivateKeyModule:
class ActionModule(ActionModuleBase): class ActionModule(ActionModuleBase):
@staticmethod @staticmethod
def setup_module(): def setup_module() -> tuple[ArgumentSpec, dict[str, t.Any]]:
argument_spec = get_privatekey_argument_spec() argument_spec = get_privatekey_argument_spec()
argument_spec.argument_spec.update( argument_spec.argument_spec.update(
dict( dict(
@@ -78,7 +95,7 @@ class ActionModule(ActionModuleBase):
) )
@staticmethod @staticmethod
def run_module(module): def run_module(module: AnsibleActionModule) -> None:
module_backend = select_backend(module=module) module_backend = select_backend(module=module)
try: try:

View File

@@ -39,6 +39,8 @@ _value:
type: string type: string
""" """
import typing as t
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import ( from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import (
@@ -50,7 +52,7 @@ from ansible_collections.community.crypto.plugins.plugin_utils.gnupg import (
) )
def gpg_fingerprint(input): def gpg_fingerprint(input: str | bytes) -> str:
if not isinstance(input, (str, bytes)): if not isinstance(input, (str, bytes)):
raise AnsibleFilterError( raise AnsibleFilterError(
f"The input for the community.crypto.gpg_fingerprint filter must be a string; got {type(input)} instead" f"The input for the community.crypto.gpg_fingerprint filter must be a string; got {type(input)} instead"
@@ -65,7 +67,7 @@ def gpg_fingerprint(input):
class FilterModule: class FilterModule:
"""Ansible jinja2 filters""" """Ansible jinja2 filters"""
def filters(self): def filters(self) -> dict[str, t.Callable]:
return { return {
"gpg_fingerprint": gpg_fingerprint, "gpg_fingerprint": gpg_fingerprint,
} }

View File

@@ -274,6 +274,8 @@ _value:
sample: 12345 sample: 12345
""" """
import typing as t
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -287,7 +289,9 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
) )
def openssl_csr_info_filter(data, name_encoding="ignore"): def openssl_csr_info_filter(
data: str | bytes, name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore"
) -> dict[str, t.Any]:
"""Extract information from X.509 PEM certificate.""" """Extract information from X.509 PEM certificate."""
if not isinstance(data, (str, bytes)): if not isinstance(data, (str, bytes)):
raise AnsibleFilterError( raise AnsibleFilterError(
@@ -313,7 +317,7 @@ def openssl_csr_info_filter(data, name_encoding="ignore"):
class FilterModule: class FilterModule:
"""Ansible jinja2 filters""" """Ansible jinja2 filters"""
def filters(self): def filters(self) -> dict[str, t.Callable]:
return { return {
"openssl_csr_info": openssl_csr_info_filter, "openssl_csr_info": openssl_csr_info_filter,
} }

View File

@@ -146,8 +146,10 @@ _value:
type: dict type: dict
""" """
import typing as t
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError, OpenSSLObjectError,
) )
@@ -161,8 +163,10 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
def openssl_privatekey_info_filter( def openssl_privatekey_info_filter(
data, passphrase=None, return_private_key_data=False data: str | bytes,
): passphrase: str | bytes | None = None,
return_private_key_data: bool = False,
) -> dict[str, t.Any]:
"""Extract information from X.509 PEM certificate.""" """Extract information from X.509 PEM certificate."""
if not isinstance(data, (str, bytes)): if not isinstance(data, (str, bytes)):
raise AnsibleFilterError( raise AnsibleFilterError(
@@ -182,7 +186,7 @@ def openssl_privatekey_info_filter(
result = get_privatekey_info( result = get_privatekey_info(
module, module,
content=to_bytes(data), content=to_bytes(data),
passphrase=passphrase, passphrase=to_text(passphrase) if passphrase is not None else None,
return_private_key_data=return_private_key_data, return_private_key_data=return_private_key_data,
) )
result.pop("can_parse_key", None) result.pop("can_parse_key", None)
@@ -197,7 +201,7 @@ def openssl_privatekey_info_filter(
class FilterModule: class FilterModule:
"""Ansible jinja2 filters""" """Ansible jinja2 filters"""
def filters(self): def filters(self) -> dict[str, t.Callable]:
return { return {
"openssl_privatekey_info": openssl_privatekey_info_filter, "openssl_privatekey_info": openssl_privatekey_info_filter,
} }

View File

@@ -123,6 +123,8 @@ _value:
returned: When RV(_value.type=DSA) or RV(_value.type=ECC) returned: When RV(_value.type=DSA) or RV(_value.type=ECC)
""" """
import typing as t
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -137,7 +139,7 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
) )
def openssl_publickey_info_filter(data): def openssl_publickey_info_filter(data: str | bytes) -> dict[str, t.Any]:
"""Extract information from OpenSSL PEM public key.""" """Extract information from OpenSSL PEM public key."""
if not isinstance(data, (str, bytes)): if not isinstance(data, (str, bytes)):
raise AnsibleFilterError( raise AnsibleFilterError(
@@ -156,7 +158,7 @@ def openssl_publickey_info_filter(data):
class FilterModule: class FilterModule:
"""Ansible jinja2 filters""" """Ansible jinja2 filters"""
def filters(self): def filters(self) -> dict[str, t.Callable]:
return { return {
"openssl_publickey_info": openssl_publickey_info_filter, "openssl_publickey_info": openssl_publickey_info_filter,
} }

View File

@@ -39,6 +39,8 @@ _value:
type: int type: int
""" """
import typing as t
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
from ansible_collections.community.crypto.plugins.module_utils.serial import ( from ansible_collections.community.crypto.plugins.module_utils.serial import (
@@ -46,7 +48,7 @@ from ansible_collections.community.crypto.plugins.module_utils.serial import (
) )
def parse_serial_filter(input): def parse_serial_filter(input: str | bytes) -> int:
if not isinstance(input, (str, bytes)): if not isinstance(input, (str, bytes)):
raise AnsibleFilterError( raise AnsibleFilterError(
f"The input for the community.crypto.parse_serial filter must be a string; got {type(input)} instead" f"The input for the community.crypto.parse_serial filter must be a string; got {type(input)} instead"
@@ -60,7 +62,7 @@ def parse_serial_filter(input):
class FilterModule: class FilterModule:
"""Ansible jinja2 filters""" """Ansible jinja2 filters"""
def filters(self): def filters(self) -> dict[str, t.Callable]:
return { return {
"parse_serial": parse_serial_filter, "parse_serial": parse_serial_filter,
} }

View File

@@ -38,6 +38,8 @@ _value:
elements: string elements: string
""" """
import typing as t
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import (
@@ -45,21 +47,20 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
) )
def split_pem_filter(data): def split_pem_filter(data: str | bytes) -> list[str]:
"""Split PEM file.""" """Split PEM file."""
if not isinstance(data, (str, bytes)): if not isinstance(data, (str, bytes)):
raise AnsibleFilterError( raise AnsibleFilterError(
f"The community.crypto.split_pem input must be a text type, not {type(data)}" f"The community.crypto.split_pem input must be a text type, not {type(data)}"
) )
data = to_text(data) return split_pem_list(to_text(data))
return split_pem_list(data)
class FilterModule: class FilterModule:
"""Ansible jinja2 filters""" """Ansible jinja2 filters"""
def filters(self): def filters(self) -> dict[str, t.Callable]:
return { return {
"split_pem": split_pem_filter, "split_pem": split_pem_filter,
} }

View File

@@ -39,11 +39,13 @@ _value:
type: string type: string
""" """
import typing as t
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible_collections.community.crypto.plugins.module_utils.serial import to_serial from ansible_collections.community.crypto.plugins.module_utils.serial import to_serial
def to_serial_filter(input): def to_serial_filter(input: int) -> str:
if not isinstance(input, int): if not isinstance(input, int):
raise AnsibleFilterError( raise AnsibleFilterError(
f"The input for the community.crypto.to_serial filter must be an integer; got {type(input)} instead" f"The input for the community.crypto.to_serial filter must be an integer; got {type(input)} instead"
@@ -61,7 +63,7 @@ def to_serial_filter(input):
class FilterModule: class FilterModule:
"""Ansible jinja2 filters""" """Ansible jinja2 filters"""
def filters(self): def filters(self) -> dict[str, t.Callable]:
return { return {
"to_serial": to_serial_filter, "to_serial": to_serial_filter,
} }

View File

@@ -308,6 +308,8 @@ _value:
type: str type: str
""" """
import typing as t
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -321,7 +323,9 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
) )
def x509_certificate_info_filter(data, name_encoding="ignore"): def x509_certificate_info_filter(
data: str | bytes, name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore"
) -> dict[str, t.Any]:
"""Extract information from X.509 PEM certificate.""" """Extract information from X.509 PEM certificate."""
if not isinstance(data, (str, bytes)): if not isinstance(data, (str, bytes)):
raise AnsibleFilterError( raise AnsibleFilterError(
@@ -347,7 +351,7 @@ def x509_certificate_info_filter(data, name_encoding="ignore"):
class FilterModule: class FilterModule:
"""Ansible jinja2 filters""" """Ansible jinja2 filters"""
def filters(self): def filters(self) -> dict[str, t.Callable]:
return { return {
"x509_certificate_info": x509_certificate_info_filter, "x509_certificate_info": x509_certificate_info_filter,
} }

View File

@@ -155,6 +155,7 @@ _value:
import base64 import base64
import binascii import binascii
import typing as t
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible.module_utils.common.text.converters import to_bytes, to_native
@@ -172,7 +173,11 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
) )
def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates=True): def x509_crl_info_filter(
data: str | bytes,
name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore",
list_revoked_certificates: bool = True,
) -> dict[str, t.Any]:
"""Extract information from X.509 PEM certificate.""" """Extract information from X.509 PEM certificate."""
if not isinstance(data, (str, bytes)): if not isinstance(data, (str, bytes)):
raise AnsibleFilterError( raise AnsibleFilterError(
@@ -192,17 +197,19 @@ def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates
f'The name_encoding option must be one of the values "ignore", "idna", or "unicode", not "{name_encoding}"' f'The name_encoding option must be one of the values "ignore", "idna", or "unicode", not "{name_encoding}"'
) )
data = to_bytes(data) data_bytes = to_bytes(data)
if not identify_pem_format(data): if not identify_pem_format(data_bytes):
try: try:
data = base64.b64decode(to_native(data)) data_bytes = base64.b64decode(to_native(data_bytes))
except (binascii.Error, TypeError, ValueError, UnicodeEncodeError): except (binascii.Error, TypeError, ValueError, UnicodeEncodeError):
pass pass
module = FilterModuleMock({"name_encoding": name_encoding}) module = FilterModuleMock({"name_encoding": name_encoding})
try: try:
return get_crl_info( return get_crl_info(
module, content=data, list_revoked_certificates=list_revoked_certificates module,
content=data_bytes,
list_revoked_certificates=list_revoked_certificates,
) )
except OpenSSLObjectError as exc: except OpenSSLObjectError as exc:
raise AnsibleFilterError(str(exc)) raise AnsibleFilterError(str(exc))
@@ -211,7 +218,7 @@ def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates
class FilterModule: class FilterModule:
"""Ansible jinja2 filters""" """Ansible jinja2 filters"""
def filters(self): def filters(self) -> dict[str, t.Callable]:
return { return {
"x509_crl_info": x509_crl_info_filter, "x509_crl_info": x509_crl_info_filter,
} }

View File

@@ -42,7 +42,11 @@ _value:
elements: string elements: string
""" """
import os
import typing as t
from ansible.errors import AnsibleLookupError from ansible.errors import AnsibleLookupError
from ansible.module_utils.common.text.converters import to_native
from ansible.plugins.lookup import LookupBase from ansible.plugins.lookup import LookupBase
from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import ( from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import (
GPGError, GPGError,
@@ -54,14 +58,20 @@ from ansible_collections.community.crypto.plugins.plugin_utils.gnupg import (
class LookupModule(LookupBase): class LookupModule(LookupBase):
def run(self, terms, variables=None, **kwargs): def run(self, terms: list[t.Any], variables=None, **kwargs) -> list[str]:
self.set_options(direct=kwargs) self.set_options(direct=kwargs)
if self._loader is None:
raise AssertionError("Contract violation: self._loader is None")
try: try:
gpg = PluginGPGRunner(cwd=self._loader.get_basedir()) gpg = PluginGPGRunner(cwd=self._loader.get_basedir())
result = [] result = []
for path in terms: for i, path in enumerate(terms):
result.append(get_fingerprint_from_file(gpg, path)) if not isinstance(path, (str, bytes, os.PathLike)):
raise AnsibleLookupError(
f"Lookup parameter #{i} should be string or a path object, but got {type(path)}"
)
result.append(get_fingerprint_from_file(gpg, to_native(path)))
return result return result
except GPGError as exc: except GPGError as exc:
raise AnsibleLookupError(str(exc)) raise AnsibleLookupError(str(exc))

View File

@@ -5,6 +5,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible.module_utils.common._collections_compat import Mapping from ansible.module_utils.common._collections_compat import Mapping
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
ACMEProtocolException, ACMEProtocolException,
@@ -12,26 +14,29 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
) )
if t.TYPE_CHECKING:
from .acme import ACMEClient
class ACMEAccount: class ACMEAccount:
""" """
ACME account object. Allows to create new accounts, check for existence of accounts, ACME account object. Allows to create new accounts, check for existence of accounts,
retrieve account data. retrieve account data.
""" """
def __init__(self, client): def __init__(self, client: ACMEClient) -> None:
# Set to true to enable logging of all signed requests # Set to true to enable logging of all signed requests
self._debug = False self._debug: bool = False
self.client = client self.client = client
def _new_reg( def _new_reg(
self, self,
contact=None, contact: list[str] | None = None,
agreement=None, terms_agreed: bool = False,
terms_agreed=False, allow_creation: bool = True,
allow_creation=True, external_account_binding: dict[str, t.Any] | None = None,
external_account_binding=None, ) -> tuple[bool, dict[str, t.Any] | None]:
):
""" """
Registers a new ACME account. Returns a pair ``(created, data)``. Registers a new ACME account. Returns a pair ``(created, data)``.
Here, ``created`` is ``True`` if the account was created and Here, ``created`` is ``True`` if the account was created and
@@ -63,7 +68,7 @@ class ACMEAccount:
return created, data return created, data
# An account does not yet exist. Try to create one next. # An account does not yet exist. Try to create one next.
new_reg = {"contact": contact} new_reg: dict[str, t.Any] = {"contact": contact}
if not allow_creation: if not allow_creation:
# https://tools.ietf.org/html/rfc8555#section-7.3.1 # https://tools.ietf.org/html/rfc8555#section-7.3.1
new_reg["onlyReturnExisting"] = True new_reg["onlyReturnExisting"] = True
@@ -99,7 +104,7 @@ class ACMEAccount:
self.client.module, self.client.module,
msg="Invalid account creation reply from ACME server", msg="Invalid account creation reply from ACME server",
info=info, info=info,
content=result, content_json=result,
) )
if info["status"] == 201: if info["status"] == 201:
@@ -152,7 +157,7 @@ class ACMEAccount:
content_json=result, content_json=result,
) )
def get_account_data(self): def get_account_data(self) -> dict[str, t.Any] | None:
""" """
Retrieve account information. Can only be called when the account Retrieve account information. Can only be called when the account
URI is already known (such as after calling setup_account). URI is already known (such as after calling setup_account).
@@ -161,7 +166,7 @@ class ACMEAccount:
if self.client.account_uri is None: if self.client.account_uri is None:
raise ModuleFailException("Account URI unknown") raise ModuleFailException("Account URI unknown")
# try POST-as-GET first (draft-15 or newer) # try POST-as-GET first (draft-15 or newer)
data = None data: dict[str, t.Any] | None = None
result, info = self.client.send_signed_request( result, info = self.client.send_signed_request(
self.client.account_uri, data, fail_on_error=False self.client.account_uri, data, fail_on_error=False
) )
@@ -180,7 +185,7 @@ class ACMEAccount:
self.client.module, self.client.module,
msg="Invalid account data retrieved from ACME server", msg="Invalid account data retrieved from ACME server",
info=info, info=info,
content=result, content_json=result,
) )
if ( if (
info["status"] in (400, 403) info["status"] in (400, 403)
@@ -203,15 +208,34 @@ class ACMEAccount:
) )
return result return result
@t.overload
def setup_account( def setup_account(
self, self,
contact=None, contact: list[str] | None = None,
agreement=None, terms_agreed: bool = False,
terms_agreed=False, allow_creation: t.Literal[True] = True,
allow_creation=True, remove_account_uri_if_not_exists: bool = False,
remove_account_uri_if_not_exists=False, external_account_binding: dict[str, t.Any] | None = None,
external_account_binding=None, ) -> tuple[bool, dict[str, t.Any]]: ...
):
@t.overload
def setup_account(
self,
contact: list[str] | None = None,
terms_agreed: bool = False,
allow_creation: bool = True,
remove_account_uri_if_not_exists: bool = False,
external_account_binding: dict[str, t.Any] | None = None,
) -> tuple[bool, dict[str, t.Any] | None]: ...
def setup_account(
self,
contact: list[str] | None = None,
terms_agreed: bool = False,
allow_creation: bool = True,
remove_account_uri_if_not_exists: bool = False,
external_account_binding: dict[str, t.Any] | None = None,
) -> tuple[bool, dict[str, t.Any] | None]:
""" """
Detect or create an account on the ACME server. For ACME v1, Detect or create an account on the ACME server. For ACME v1,
as the only way (without knowing an account URI) to test if an as the only way (without knowing an account URI) to test if an
@@ -253,7 +277,6 @@ class ACMEAccount:
else: else:
created, account_data = self._new_reg( created, account_data = self._new_reg(
contact, contact,
agreement=agreement,
terms_agreed=terms_agreed, terms_agreed=terms_agreed,
allow_creation=allow_creation and not self.client.module.check_mode, allow_creation=allow_creation and not self.client.module.check_mode,
external_account_binding=external_account_binding, external_account_binding=external_account_binding,
@@ -267,7 +290,9 @@ class ACMEAccount:
account_data = {"contact": contact or []} account_data = {"contact": contact or []}
return created, account_data return created, account_data
def update_account(self, account_data, contact=None): def update_account(
self, account_data: dict[str, t.Any], contact: list[str] | None = None
) -> tuple[bool, dict[str, t.Any]]:
""" """
Update an account on the ACME server. Check mode is fully respected. Update an account on the ACME server. Check mode is fully respected.
@@ -280,8 +305,11 @@ class ACMEAccount:
https://tools.ietf.org/html/rfc8555#section-7.3.2 https://tools.ietf.org/html/rfc8555#section-7.3.2
""" """
if self.client.account_uri is None:
raise ModuleFailException("Cannot update account without account URI")
# Create request # Create request
update_request = {} update_request: dict[str, t.Any] = {}
if contact is not None and account_data.get("contact", []) != contact: if contact is not None and account_data.get("contact", []) != contact:
update_request["contact"] = list(contact) update_request["contact"] = list(contact)
@@ -302,7 +330,7 @@ class ACMEAccount:
self.client.module, self.client.module,
msg="Invalid account updating reply from ACME server", msg="Invalid account updating reply from ACME server",
info=info, info=info,
content=account_data, content_json=account_data,
) )
return True, account_data return True, account_data

View File

@@ -10,6 +10,7 @@ import datetime
import json import json
import locale import locale
import time import time
import typing as t
from ansible.module_utils.basic import missing_required_lib from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
@@ -41,13 +42,24 @@ from ansible_collections.community.crypto.plugins.module_utils.argspec import (
) )
if t.TYPE_CHECKING:
import os
from ansible.module_utils.basic import AnsibleModule
from .account import ACMEAccount
from .backends import CertificateInformation, CryptoBackend
# -1 usually means connection problems # -1 usually means connection problems
RETRY_STATUS_CODES = (-1, 408, 429, 503) RETRY_STATUS_CODES = (-1, 408, 429, 503)
RETRY_COUNT = 10 RETRY_COUNT = 10
def _decode_retry(module, response, info, retry_count): def _decode_retry(
module: AnsibleModule, response: t.Any, info: dict[str, t.Any], retry_count: int
) -> bool:
if info["status"] not in RETRY_STATUS_CODES: if info["status"] not in RETRY_STATUS_CODES:
return False return False
@@ -61,7 +73,8 @@ def _decode_retry(module, response, info, retry_count):
# 429 and 503 should have a Retry-After header (https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After) # 429 and 503 should have a Retry-After header (https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
try: try:
retry_after = min(max(1, int(info.get("retry-after"))), 60) # TODO: use utils.parse_retry_after()
retry_after = min(max(1, int(info.get("retry-after", "10"))), 60)
except (TypeError, ValueError): except (TypeError, ValueError):
retry_after = 10 retry_after = 10
module.log( module.log(
@@ -73,13 +86,13 @@ def _decode_retry(module, response, info, retry_count):
def _assert_fetch_url_success( def _assert_fetch_url_success(
module, module: AnsibleModule,
response, response: t.Any,
info, info: dict[str, t.Any],
allow_redirect=False, allow_redirect: bool = False,
allow_client_error=True, allow_client_error: bool = True,
allow_server_error=True, allow_server_error: bool = True,
): ) -> None:
if info["status"] < 0: if info["status"] < 0:
raise NetworkException(msg=f"Failure downloading {info['url']}, {info['msg']}") raise NetworkException(msg=f"Failure downloading {info['url']}, {info['msg']}")
@@ -91,7 +104,9 @@ def _assert_fetch_url_success(
raise ACMEProtocolException(module, info=info, response=response) raise ACMEProtocolException(module, info=info, response=response)
def _is_failed(info, expected_status_codes=None): def _is_failed(
info: dict[str, t.Any], expected_status_codes: t.Iterable[int] | None = None
) -> bool:
if info["status"] < 200 or info["status"] >= 400: if info["status"] < 200 or info["status"] >= 400:
return True return True
if ( if (
@@ -111,12 +126,12 @@ class ACMEDirectory:
https://tools.ietf.org/html/rfc8555#section-7.1.1 https://tools.ietf.org/html/rfc8555#section-7.1.1
""" """
def __init__(self, module, account): def __init__(self, module: AnsibleModule, client: ACMEClient) -> None:
self.module = module self.module = module
self.directory_root = module.params["acme_directory"] self.directory_root = module.params["acme_directory"]
self.version = module.params["acme_version"] self.version = module.params["acme_version"]
self.directory, dummy = account.get_request(self.directory_root, get_only=True) self.directory, dummy = client.get_request(self.directory_root, get_only=True)
self.request_timeout = module.params["request_timeout"] self.request_timeout = module.params["request_timeout"]
@@ -131,16 +146,16 @@ class ACMEDirectory:
if "meta" not in self.directory: if "meta" not in self.directory:
self.directory["meta"] = {} self.directory["meta"] = {}
def __getitem__(self, key): def __getitem__(self, key: str) -> t.Any:
return self.directory[key] return self.directory[key]
def __contains__(self, key): def __contains__(self, key: str) -> bool:
return key in self.directory return key in self.directory
def get(self, key, default_value=None): def get(self, key: str, default_value: t.Any = None) -> t.Any:
return self.directory.get(key, default_value) return self.directory.get(key, default_value)
def get_nonce(self, resource=None): def get_nonce(self, resource: str | None = None) -> str:
url = self.directory["newNonce"] url = self.directory["newNonce"]
if resource is not None: if resource is not None:
url = resource url = resource
@@ -170,7 +185,7 @@ class ACMEDirectory:
) )
retry_count += 1 retry_count += 1
def has_renewal_info_endpoint(self): def has_renewal_info_endpoint(self) -> bool:
return "renewalInfo" in self.directory return "renewalInfo" in self.directory
@@ -180,7 +195,7 @@ class ACMEClient:
ACME server. ACME server.
""" """
def __init__(self, module, backend): def __init__(self, module: AnsibleModule, backend: CryptoBackend) -> None:
# Set to true to enable logging of all signed requests # Set to true to enable logging of all signed requests
self._debug = False self._debug = False
@@ -221,16 +236,22 @@ class ACMEClient:
self.directory = ACMEDirectory(module, self) self.directory = ACMEDirectory(module, self)
def set_account_uri(self, uri): def set_account_uri(self, uri: str) -> None:
""" """
Set account URI. For ACME v2, it needs to be used to sending signed Set account URI. For ACME v2, it needs to be used to sending signed
requests. requests.
""" """
self.account_uri = uri self.account_uri = uri
self.account_jws_header.pop("jwk") if self.account_jws_header:
self.account_jws_header["kid"] = self.account_uri self.account_jws_header.pop("jwk", None)
self.account_jws_header["kid"] = self.account_uri
def parse_key(self, key_file=None, key_content=None, passphrase=None): def parse_key(
self,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
) -> dict[str, t.Any]:
""" """
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data. Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
In case of an error, raises KeyParsingError. In case of an error, raises KeyParsingError.
@@ -239,7 +260,13 @@ class ACMEClient:
raise AssertionError("One of key_file and key_content must be specified!") raise AssertionError("One of key_file and key_content must be specified!")
return self.backend.parse_key(key_file, key_content, passphrase=passphrase) return self.backend.parse_key(key_file, key_content, passphrase=passphrase)
def sign_request(self, protected, payload, key_data, encode_payload=True): def sign_request(
self,
protected: dict[str, t.Any],
payload: str | dict[str, t.Any] | None,
key_data: dict[str, t.Any],
encode_payload: bool = True,
) -> dict[str, t.Any]:
""" """
Signs an ACME request. Signs an ACME request.
""" """
@@ -260,7 +287,7 @@ class ACMEClient:
return self.backend.sign(payload64, protected64, key_data) return self.backend.sign(payload64, protected64, key_data)
def _log(self, msg, data=None): def _log(self, msg: str, data: t.Any = None) -> None:
""" """
Write arguments to acme.log when logging is enabled. Write arguments to acme.log when logging is enabled.
""" """
@@ -275,18 +302,49 @@ class ACMEClient:
) )
) )
@t.overload
def send_signed_request( def send_signed_request(
self, self,
url, url: str,
payload, payload: str | dict[str, t.Any] | None,
key_data=None, *,
jws_header=None, key_data: dict[str, t.Any] | None = None,
parse_json_result=True, jws_header: dict[str, t.Any] | None = None,
encode_payload=True, parse_json_result: t.Literal[True] = True,
fail_on_error=True, encode_payload: bool = True,
error_msg=None, fail_on_error: bool = True,
expected_status_codes=None, error_msg: str | None = None,
): expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ...
@t.overload
def send_signed_request(
self,
url: str,
payload: str | dict[str, t.Any] | None,
*,
key_data: dict[str, t.Any] | None = None,
jws_header: dict[str, t.Any] | None = None,
parse_json_result: t.Literal[False],
encode_payload: bool = True,
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[bytes, dict[str, t.Any]]: ...
def send_signed_request(
self,
url: str,
payload: str | dict[str, t.Any] | None,
*,
key_data: dict[str, t.Any] | None = None,
jws_header: dict[str, t.Any] | None = None,
parse_json_result: bool = True,
encode_payload: bool = True,
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]:
""" """
Sends a JWS signed HTTP POST request to the ACME server and returns Sends a JWS signed HTTP POST request to the ACME server and returns
the response as dictionary (if parse_json_result is True) or in raw form the response as dictionary (if parse_json_result is True) or in raw form
@@ -297,7 +355,11 @@ class ACMEClient:
(https://tools.ietf.org/html/rfc8555#section-6.3) (https://tools.ietf.org/html/rfc8555#section-6.3)
""" """
key_data = key_data or self.account_key_data key_data = key_data or self.account_key_data
if key_data is None:
raise ModuleFailException("Missing key data")
jws_header = jws_header or self.account_jws_header jws_header = jws_header or self.account_jws_header
if jws_header is None:
raise ModuleFailException("Missing JWS header")
failed_tries = 0 failed_tries = 0
while True: while True:
protected = copy.deepcopy(jws_header) protected = copy.deepcopy(jws_header)
@@ -382,16 +444,43 @@ class ACMEClient:
) )
return result, info return result, info
@t.overload
def get_request( def get_request(
self, self,
uri, uri: str,
parse_json_result=True, *,
headers=None, parse_json_result: t.Literal[True] = True,
get_only=False, headers: dict[str, str] | None = None,
fail_on_error=True, get_only: bool = False,
error_msg=None, fail_on_error: bool = True,
expected_status_codes=None, error_msg: str | None = None,
): expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ...
@t.overload
def get_request(
self,
uri: str,
*,
parse_json_result: t.Literal[False],
headers: dict[str, str] | None = None,
get_only: bool = False,
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[bytes, dict[str, t.Any]]: ...
def get_request(
self,
uri: str,
*,
parse_json_result: bool = True,
headers: dict[str, str] | None = None,
get_only: bool = False,
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]:
""" """
Perform a GET-like request. Will try POST-as-GET for ACMEv2, with fallback Perform a GET-like request. Will try POST-as-GET for ACMEv2, with fallback
to GET if server replies with a status code of 405. to GET if server replies with a status code of 405.
@@ -436,6 +525,7 @@ class ACMEClient:
# Process result # Process result
parsed_json_result = False parsed_json_result = False
result: dict[str, t.Any] | bytes
if parse_json_result: if parse_json_result:
result = {} result = {}
if content: if content:
@@ -445,7 +535,7 @@ class ACMEClient:
parsed_json_result = True parsed_json_result = True
except ValueError: except ValueError:
raise NetworkException( raise NetworkException(
f"Failed to parse the ACME response: {uri} {content}" f"Failed to parse the ACME response: {uri} {content!r}"
) )
else: else:
result = content result = content
@@ -460,19 +550,21 @@ class ACMEClient:
msg=error_msg, msg=error_msg,
info=info, info=info,
content=content, content=content,
content_json=result if parsed_json_result else None, content_json=(
t.cast(dict[str, t.Any], result) if parsed_json_result else None
),
) )
return result, info return result, info
def get_renewal_info( def get_renewal_info(
self, self,
cert_id=None, cert_id: str | None = None,
cert_info=None, cert_info: CertificateInformation | None = None,
cert_filename=None, cert_filename: str | os.PathLike | None = None,
cert_content=None, cert_content: str | bytes | None = None,
include_retry_after=False, include_retry_after: bool = False,
retry_after_relative_with_timezone=True, retry_after_relative_with_timezone: bool = True,
): ) -> dict[str, t.Any]:
if not self.directory.has_renewal_info_endpoint(): if not self.directory.has_renewal_info_endpoint():
raise ModuleFailException( raise ModuleFailException(
"The ACME endpoint does not support ACME Renewal Information retrieval" "The ACME endpoint does not support ACME Renewal Information retrieval"
@@ -504,10 +596,10 @@ class ACMEClient:
def create_default_argspec( def create_default_argspec(
with_account=True, with_account: bool = True,
require_account_key=True, require_account_key: bool = True,
with_certificate=False, with_certificate: bool = False,
): ) -> ArgumentSpec:
""" """
Provides default argument spec for the options documented in the acme doc fragment. Provides default argument spec for the options documented in the acme doc fragment.
""" """
@@ -544,7 +636,7 @@ def create_default_argspec(
return result return result
def create_backend(module, needs_acme_v2=True): def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoBackend:
backend = module.params["select_crypto_backend"] backend = module.params["select_crypto_backend"]
# Backend autodetect # Backend autodetect
@@ -552,6 +644,7 @@ def create_backend(module, needs_acme_v2=True):
backend = "cryptography" if HAS_CURRENT_CRYPTOGRAPHY else "openssl" backend = "cryptography" if HAS_CURRENT_CRYPTOGRAPHY else "openssl"
# Create backend object # Create backend object
module_backend: CryptoBackend
if backend == "cryptography": if backend == "cryptography":
if CRYPTOGRAPHY_ERROR is not None: if CRYPTOGRAPHY_ERROR is not None:
# Either we could not import cryptography at all, or there was an unexpected error # Either we could not import cryptography at all, or there was an unexpected error

View File

@@ -9,6 +9,7 @@ import base64
import binascii import binascii
import os import os
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import ( from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
@@ -75,10 +76,19 @@ else:
CRYPTOGRAPHY_MINIMAL_VERSION CRYPTOGRAPHY_MINIMAL_VERSION
) )
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from .certificates import CertificateChain, Criterium
class CryptographyChainMatcher(ChainMatcher): class CryptographyChainMatcher(ChainMatcher):
@staticmethod @staticmethod
def _parse_key_identifier(key_identifier, name, criterium_idx, module): def _parse_key_identifier(
key_identifier: str | None, name: str, criterium_idx: int, module: AnsibleModule
) -> bytes | None:
if key_identifier: if key_identifier:
try: try:
return binascii.unhexlify(key_identifier.replace(":", "")) return binascii.unhexlify(key_identifier.replace(":", ""))
@@ -94,11 +104,11 @@ class CryptographyChainMatcher(ChainMatcher):
) )
return None return None
def __init__(self, criterium, module): def __init__(self, criterium: Criterium, module: AnsibleModule) -> None:
self.criterium = criterium self.criterium = criterium
self.test_certificates = criterium.test_certificates self.test_certificates = criterium.test_certificates
self.subject = [] self.subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = []
self.issuer = [] self.issuer: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = []
if criterium.subject: if criterium.subject:
self.subject = [ self.subject = [
(cryptography_name_to_oid(k), to_native(v)) (cryptography_name_to_oid(k), to_native(v))
@@ -121,8 +131,13 @@ class CryptographyChainMatcher(ChainMatcher):
criterium.index, criterium.index,
module, module,
) )
self.module = module
def _match_subject(self, x509_subject, match_subject): def _match_subject(
self,
x509_subject: cryptography.x509.Name,
match_subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]],
) -> bool:
for oid, value in match_subject: for oid, value in match_subject:
found = False found = False
for attribute in x509_subject: for attribute in x509_subject:
@@ -133,7 +148,7 @@ class CryptographyChainMatcher(ChainMatcher):
return False return False
return True return True
def match(self, certificate): def match(self, certificate: CertificateChain) -> bool:
""" """
Check whether an alternate chain matches the specified criterium. Check whether an alternate chain matches the specified criterium.
""" """
@@ -152,19 +167,22 @@ class CryptographyChainMatcher(ChainMatcher):
matches = False matches = False
if self.subject_key_identifier: if self.subject_key_identifier:
try: try:
ext = x509.extensions.get_extension_for_class( ext_ski = x509.extensions.get_extension_for_class(
cryptography.x509.SubjectKeyIdentifier cryptography.x509.SubjectKeyIdentifier
) )
if self.subject_key_identifier != ext.value.digest: if self.subject_key_identifier != ext_ski.value.digest:
matches = False matches = False
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
matches = False matches = False
if self.authority_key_identifier: if self.authority_key_identifier:
try: try:
ext = x509.extensions.get_extension_for_class( ext_aki = x509.extensions.get_extension_for_class(
cryptography.x509.AuthorityKeyIdentifier cryptography.x509.AuthorityKeyIdentifier
) )
if self.authority_key_identifier != ext.value.key_identifier: if (
self.authority_key_identifier
!= ext_aki.value.key_identifier
):
matches = False matches = False
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
matches = False matches = False
@@ -176,59 +194,68 @@ class CryptographyChainMatcher(ChainMatcher):
class CryptographyBackend(CryptoBackend): class CryptographyBackend(CryptoBackend):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(CryptographyBackend, self).__init__( super(CryptographyBackend, self).__init__(
module, with_timezone=CRYPTOGRAPHY_TIMEZONE module, with_timezone=CRYPTOGRAPHY_TIMEZONE
) )
def parse_key(self, key_file=None, key_content=None, passphrase=None): def parse_key(
self,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
) -> dict[str, t.Any]:
""" """
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data. Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
Raises KeyParsingError in case of errors. Raises KeyParsingError in case of errors.
""" """
# If key_content is not given, read key_file # If key_content is not given, read key_file
if key_content is None: if key_content is None:
key_content = read_file(key_file) if key_file is None:
raise KeyParsingError(
"one of key_file and key_content must be specified"
)
b_key_content = read_file(key_file)
else: else:
key_content = to_bytes(key_content) b_key_content = to_bytes(key_content)
# Parse key # Parse key
try: try:
key = cryptography.hazmat.primitives.serialization.load_pem_private_key( key = cryptography.hazmat.primitives.serialization.load_pem_private_key(
key_content, b_key_content,
password=to_bytes(passphrase) if passphrase is not None else None, password=to_bytes(passphrase) if passphrase is not None else None,
) )
except Exception as e: except Exception as e:
raise KeyParsingError(f"error while loading key: {e}") raise KeyParsingError(f"error while loading key: {e}")
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey): if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
pk = key.public_key().public_numbers() rsa_pk = key.public_key().public_numbers()
return { return {
"key_obj": key, "key_obj": key,
"type": "rsa", "type": "rsa",
"alg": "RS256", "alg": "RS256",
"jwk": { "jwk": {
"kty": "RSA", "kty": "RSA",
"e": nopad_b64(convert_int_to_bytes(pk.e)), "e": nopad_b64(convert_int_to_bytes(rsa_pk.e)),
"n": nopad_b64(convert_int_to_bytes(pk.n)), "n": nopad_b64(convert_int_to_bytes(rsa_pk.n)),
}, },
"hash": "sha256", "hash": "sha256",
} }
elif isinstance( elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
): ):
pk = key.public_key().public_numbers() ec_pk = key.public_key().public_numbers()
if pk.curve.name == "secp256r1": if ec_pk.curve.name == "secp256r1":
bits = 256 bits = 256
alg = "ES256" alg = "ES256"
hashalg = "sha256" hashalg = "sha256"
point_size = 32 point_size = 32
curve = "P-256" curve = "P-256"
elif pk.curve.name == "secp384r1": elif ec_pk.curve.name == "secp384r1":
bits = 384 bits = 384
alg = "ES384" alg = "ES384"
hashalg = "sha384" hashalg = "sha384"
point_size = 48 point_size = 48
curve = "P-384" curve = "P-384"
elif pk.curve.name == "secp521r1": elif ec_pk.curve.name == "secp521r1":
# Not yet supported on Let's Encrypt side, see # Not yet supported on Let's Encrypt side, see
# https://github.com/letsencrypt/boulder/issues/2217 # https://github.com/letsencrypt/boulder/issues/2217
bits = 521 bits = 521
@@ -237,7 +264,7 @@ class CryptographyBackend(CryptoBackend):
point_size = 66 point_size = 66
curve = "P-521" curve = "P-521"
else: else:
raise KeyParsingError(f"unknown elliptic curve: {pk.curve.name}") raise KeyParsingError(f"unknown elliptic curve: {ec_pk.curve.name}")
num_bytes = (bits + 7) // 8 num_bytes = (bits + 7) // 8
return { return {
"key_obj": key, "key_obj": key,
@@ -246,8 +273,8 @@ class CryptographyBackend(CryptoBackend):
"jwk": { "jwk": {
"kty": "EC", "kty": "EC",
"crv": curve, "crv": curve,
"x": nopad_b64(convert_int_to_bytes(pk.x, count=num_bytes)), "x": nopad_b64(convert_int_to_bytes(ec_pk.x, count=num_bytes)),
"y": nopad_b64(convert_int_to_bytes(pk.y, count=num_bytes)), "y": nopad_b64(convert_int_to_bytes(ec_pk.y, count=num_bytes)),
}, },
"hash": hashalg, "hash": hashalg,
"point_size": point_size, "point_size": point_size,
@@ -255,8 +282,11 @@ class CryptographyBackend(CryptoBackend):
else: else:
raise KeyParsingError(f'unknown key type "{type(key)}"') raise KeyParsingError(f'unknown key type "{type(key)}"')
def sign(self, payload64, protected64, key_data): def sign(
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
) -> dict[str, t.Any]:
sign_payload = f"{protected64}.{payload64}".encode("utf8") sign_payload = f"{protected64}.{payload64}".encode("utf8")
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
if "mac_obj" in key_data: if "mac_obj" in key_data:
mac = key_data["mac_obj"]() mac = key_data["mac_obj"]()
mac.update(sign_payload) mac.update(sign_payload)
@@ -292,8 +322,9 @@ class CryptographyBackend(CryptoBackend):
"signature": nopad_b64(signature), "signature": nopad_b64(signature),
} }
def create_mac_key(self, alg, key): def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
"""Create a MAC key.""" """Create a MAC key."""
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
if alg == "HS256": if alg == "HS256":
hashalg = cryptography.hazmat.primitives.hashes.SHA256 hashalg = cryptography.hazmat.primitives.hashes.SHA256
hashbytes = 32 hashbytes = 32
@@ -324,7 +355,11 @@ class CryptographyBackend(CryptoBackend):
}, },
} }
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None): def get_ordered_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> list[tuple[str, str]]:
""" """
Return a list of requested identifiers (CN and SANs) for the CSR. Return a list of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either Each identifier is a pair (type, identifier), where type is either
@@ -334,15 +369,19 @@ class CryptographyBackend(CryptoBackend):
as the first element in the result. as the first element in the result.
""" """
if csr_content is None: if csr_content is None:
csr_content = read_file(csr_filename) if csr_filename is None:
raise BackendException(
"One of csr_content and csr_filename has to be provided"
)
b_csr_content = read_file(csr_filename)
else: else:
csr_content = to_bytes(csr_content) b_csr_content = to_bytes(csr_content)
csr = cryptography.x509.load_pem_x509_csr(csr_content) csr = cryptography.x509.load_pem_x509_csr(b_csr_content)
identifiers = set() identifiers = set()
result = [] result = []
def add_identifier(identifier): def add_identifier(identifier: tuple[str, str]) -> None:
if identifier in identifiers: if identifier in identifiers:
return return
identifiers.add(identifier) identifiers.add(identifier)
@@ -350,7 +389,7 @@ class CryptographyBackend(CryptoBackend):
for sub in csr.subject: for sub in csr.subject:
if sub.oid == cryptography.x509.oid.NameOID.COMMON_NAME: if sub.oid == cryptography.x509.oid.NameOID.COMMON_NAME:
add_identifier(("dns", sub.value)) add_identifier(("dns", t.cast(str, sub.value)))
for extension in csr.extensions: for extension in csr.extensions:
if ( if (
extension.oid extension.oid
@@ -367,7 +406,11 @@ class CryptographyBackend(CryptoBackend):
) )
return result return result
def get_csr_identifiers(self, csr_filename=None, csr_content=None): def get_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | bytes | None = None,
) -> set[tuple[str, str]]:
""" """
Return a set of requested identifiers (CN and SANs) for the CSR. Return a set of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either Each identifier is a pair (type, identifier), where type is either
@@ -379,7 +422,12 @@ class CryptographyBackend(CryptoBackend):
) )
) )
def get_cert_days(self, cert_filename=None, cert_content=None, now=None): def get_cert_days(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
now: datetime.datetime | None = None,
) -> int:
""" """
Return the days the certificate in cert_filename remains valid and -1 Return the days the certificate in cert_filename remains valid and -1
if the file was not found. If cert_filename contains more than one if the file was not found. If cert_filename contains more than one
@@ -398,10 +446,10 @@ class CryptographyBackend(CryptoBackend):
return -1 return -1
# Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf. # Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf.
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "") b_cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
try: try:
cert = cryptography.x509.load_pem_x509_certificate(cert_content) cert = cryptography.x509.load_pem_x509_certificate(b_cert_content)
except Exception as e: except Exception as e:
if cert_filename is None: if cert_filename is None:
raise BackendException(f"Cannot parse certificate: {e}") raise BackendException(f"Cannot parse certificate: {e}")
@@ -413,13 +461,17 @@ class CryptographyBackend(CryptoBackend):
now = add_or_remove_timezone(now, with_timezone=CRYPTOGRAPHY_TIMEZONE) now = add_or_remove_timezone(now, with_timezone=CRYPTOGRAPHY_TIMEZONE)
return (get_not_valid_after(cert) - now).days return (get_not_valid_after(cert) - now).days
def create_chain_matcher(self, criterium): def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
""" """
Given a Criterium object, creates a ChainMatcher object. Given a Criterium object, creates a ChainMatcher object.
""" """
return CryptographyChainMatcher(criterium, self.module) return CryptographyChainMatcher(criterium, self.module)
def get_cert_information(self, cert_filename=None, cert_content=None): def get_cert_information(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
) -> CertificateInformation:
""" """
Return some information on a X.509 certificate as a CertificateInformation object. Return some information on a X.509 certificate as a CertificateInformation object.
""" """
@@ -429,10 +481,10 @@ class CryptographyBackend(CryptoBackend):
cert_content = to_bytes(cert_content) cert_content = to_bytes(cert_content)
# Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf. # Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf.
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "") b_cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
try: try:
cert = cryptography.x509.load_pem_x509_certificate(cert_content) cert = cryptography.x509.load_pem_x509_certificate(b_cert_content)
except Exception as e: except Exception as e:
if cert_filename is None: if cert_filename is None:
raise BackendException(f"Cannot parse certificate: {e}") raise BackendException(f"Cannot parse certificate: {e}")
@@ -440,19 +492,19 @@ class CryptographyBackend(CryptoBackend):
ski = None ski = None
try: try:
ext = cert.extensions.get_extension_for_class( ext_ski = cert.extensions.get_extension_for_class(
cryptography.x509.SubjectKeyIdentifier cryptography.x509.SubjectKeyIdentifier
) )
ski = ext.value.digest ski = ext_ski.value.digest
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
pass pass
aki = None aki = None
try: try:
ext = cert.extensions.get_extension_for_class( ext_aki = cert.extensions.get_extension_for_class(
cryptography.x509.AuthorityKeyIdentifier cryptography.x509.AuthorityKeyIdentifier
) )
aki = ext.value.key_identifier aki = ext_aki.value.key_identifier
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
pass pass

View File

@@ -13,6 +13,7 @@ import os
import re import re
import tempfile import tempfile
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import ( from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
@@ -34,12 +35,23 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from .certificates import Criterium
_OPENSSL_ENVIRONMENT_UPDATE = dict(LANG="C", LC_ALL="C", LC_MESSAGES="C", LC_CTYPE="C") _OPENSSL_ENVIRONMENT_UPDATE = dict(LANG="C", LC_ALL="C", LC_MESSAGES="C", LC_CTYPE="C")
def _extract_date(out_text, name, cert_filename_suffix=""): def _extract_date(
out_text: str, name: str, cert_filename_suffix: str = ""
) -> datetime.datetime:
matcher = re.search(rf"\s+{name}\s*:\s+(.*)", out_text)
if matcher is None:
raise BackendException(f"No '{name}' date found{cert_filename_suffix}")
date_str = matcher.group(1)
try: try:
date_str = re.search(rf"\s+{name}\s*:\s+(.*)", out_text).group(1)
# For some reason Python's strptime() does not return any timezone information, # For some reason Python's strptime() does not return any timezone information,
# even though the information is there and a supported timezone for all supported # even though the information is there and a supported timezone for all supported
# Python implementations (GMT). So we have to modify the datetime object by # Python implementations (GMT). So we have to modify the datetime object by
@@ -47,19 +59,40 @@ def _extract_date(out_text, name, cert_filename_suffix=""):
return ensure_utc_timezone( return ensure_utc_timezone(
datetime.datetime.strptime(date_str, "%b %d %H:%M:%S %Y %Z") datetime.datetime.strptime(date_str, "%b %d %H:%M:%S %Y %Z")
) )
except AttributeError:
raise BackendException(f"No '{name}' date found{cert_filename_suffix}")
except ValueError as exc: except ValueError as exc:
raise BackendException( raise BackendException(
f"Failed to parse '{name}' date{cert_filename_suffix}: {exc}" f"Failed to parse '{name}' date{cert_filename_suffix}: {exc}"
) )
def _decode_octets(octets_text): def _decode_octets(octets_text: str) -> bytes:
return binascii.unhexlify(re.sub(r"(\s|:)", "", octets_text).encode("utf-8")) return binascii.unhexlify(re.sub(r"(\s|:)", "", octets_text).encode("utf-8"))
def _extract_octets(out_text, name, required=True, potential_prefixes=None): @t.overload
def _extract_octets(
out_text: str,
name: str,
required: t.Literal[False],
potential_prefixes: t.Iterable[str] | None = None,
) -> bytes | None: ...
@t.overload
def _extract_octets(
out_text: str,
name: str,
required: t.Literal[True],
potential_prefixes: t.Iterable[str] | None = None,
) -> bytes: ...
def _extract_octets(
out_text: str,
name: str,
required: bool = True,
potential_prefixes: t.Iterable[str] | None = None,
) -> bytes | None:
part = ( part = (
f"(?:{'|'.join(re.escape(pp) for pp in potential_prefixes)})" f"(?:{'|'.join(re.escape(pp) for pp in potential_prefixes)})"
if potential_prefixes if potential_prefixes
@@ -75,13 +108,20 @@ def _extract_octets(out_text, name, required=True, potential_prefixes=None):
class OpenSSLCLIBackend(CryptoBackend): class OpenSSLCLIBackend(CryptoBackend):
def __init__(self, module, openssl_binary=None): def __init__(
self, module: AnsibleModule, openssl_binary: str | None = None
) -> None:
super(OpenSSLCLIBackend, self).__init__(module, with_timezone=True) super(OpenSSLCLIBackend, self).__init__(module, with_timezone=True)
if openssl_binary is None: if openssl_binary is None:
openssl_binary = module.get_bin_path("openssl", True) openssl_binary = module.get_bin_path("openssl", True)
self.openssl_binary = openssl_binary self.openssl_binary = openssl_binary
def parse_key(self, key_file=None, key_content=None, passphrase=None): def parse_key(
self,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
) -> dict[str, t.Any]:
""" """
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data. Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
Raises KeyParsingError in case of errors. Raises KeyParsingError in case of errors.
@@ -90,6 +130,10 @@ class OpenSSLCLIBackend(CryptoBackend):
raise KeyParsingError("openssl backend does not support key passphrases") raise KeyParsingError("openssl backend does not support key passphrases")
# If key_file is not given, but key_content, write that to a temporary file # If key_file is not given, but key_content, write that to a temporary file
if key_file is None: if key_file is None:
if key_content is None:
raise KeyParsingError(
"one of key_file and key_content must be specified"
)
fd, tmpsrc = tempfile.mkstemp() fd, tmpsrc = tempfile.mkstemp()
self.module.add_cleanup_file(tmpsrc) # Ansible will delete the file on exit self.module.add_cleanup_file(tmpsrc) # Ansible will delete the file on exit
f = os.fdopen(fd, "wb") f = os.fdopen(fd, "wb")
@@ -108,8 +152,8 @@ class OpenSSLCLIBackend(CryptoBackend):
f.close() f.close()
# Parse key # Parse key
account_key_type = None account_key_type = None
with open(key_file, "rt") as f: with open(key_file, "rt") as fi:
for line in f: for line in fi:
m = re.match( m = re.match(
r"^\s*-{5,}BEGIN\s+(EC|RSA)\s+PRIVATE\s+KEY-{5,}\s*$", line r"^\s*-{5,}BEGIN\s+(EC|RSA)\s+PRIVATE\s+KEY-{5,}\s*$", line
) )
@@ -129,38 +173,44 @@ class OpenSSLCLIBackend(CryptoBackend):
self.openssl_binary, self.openssl_binary,
account_key_type, account_key_type,
"-in", "-in",
key_file, str(key_file),
"-noout", "-noout",
"-text", "-text",
] ]
rc, out, err = self.module.run_command( rc, out, stderr = self.module.run_command(
openssl_keydump_cmd, openssl_keydump_cmd,
check_rc=False, check_rc=False,
environ_update=_OPENSSL_ENVIRONMENT_UPDATE, environ_update=_OPENSSL_ENVIRONMENT_UPDATE,
) )
if rc != 0: if rc != 0:
raise BackendException( raise BackendException(
f"Error while running {' '.join(openssl_keydump_cmd)}: {err}" f"Error while running {' '.join(openssl_keydump_cmd)}: {stderr}"
) )
out_text = to_text(out, errors="surrogate_or_strict") out_text = to_text(out, errors="surrogate_or_strict")
if account_key_type == "rsa": if account_key_type == "rsa":
pub_hex = re.search( matcher = re.search(
r"modulus:\n\s+00:([a-f0-9\:\s]+?)\npublicExponent", r"modulus:\n\s+00:([a-f0-9\:\s]+?)\npublicExponent",
out_text, out_text,
re.MULTILINE | re.DOTALL, re.MULTILINE | re.DOTALL,
).group(1) )
if matcher is None:
raise KeyParsingError("cannot parse RSA key: modulus not found")
pub_hex = matcher.group(1)
pub_exp = re.search( matcher = re.search(
r"\npublicExponent: ([0-9]+)", out_text, re.MULTILINE | re.DOTALL r"\npublicExponent: ([0-9]+)", out_text, re.MULTILINE | re.DOTALL
).group(1) )
if matcher is None:
raise KeyParsingError("cannot parse RSA key: public exponent not found")
pub_exp = matcher.group(1)
pub_exp = f"{int(pub_exp):x}" pub_exp = f"{int(pub_exp):x}"
if len(pub_exp) % 2: if len(pub_exp) % 2:
pub_exp = f"0{pub_exp}" pub_exp = f"0{pub_exp}"
return { return {
"key_file": key_file, "key_file": str(key_file),
"type": "rsa", "type": "rsa",
"alg": "RS256", "alg": "RS256",
"jwk": { "jwk": {
@@ -223,8 +273,13 @@ class OpenSSLCLIBackend(CryptoBackend):
"hash": hashalg, "hash": hashalg,
"point_size": point_size, "point_size": point_size,
} }
raise KeyParsingError(
f"Internal error: unexpected account_key_type = {account_key_type!r}"
)
def sign(self, payload64, protected64, key_data): def sign(
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
) -> dict[str, t.Any]:
sign_payload = f"{protected64}.{payload64}".encode("utf8") sign_payload = f"{protected64}.{payload64}".encode("utf8")
if key_data["type"] == "hmac": if key_data["type"] == "hmac":
hex_key = ( hex_key = (
@@ -284,7 +339,7 @@ class OpenSSLCLIBackend(CryptoBackend):
"signature": nopad_b64(to_bytes(out)), "signature": nopad_b64(to_bytes(out)),
} }
def create_mac_key(self, alg, key): def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
"""Create a MAC key.""" """Create a MAC key."""
if alg == "HS256": if alg == "HS256":
hashalg = "sha256" hashalg = "sha256"
@@ -315,14 +370,18 @@ class OpenSSLCLIBackend(CryptoBackend):
} }
@staticmethod @staticmethod
def _normalize_ip(ip): def _normalize_ip(ip: str) -> str:
try: try:
return ipaddress.ip_address(to_text(ip)).compressed return ipaddress.ip_address(ip).compressed
except ValueError: except ValueError:
# We do not want to error out on something IPAddress() cannot parse # We do not want to error out on something IPAddress() cannot parse
return ip return ip
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None): def get_ordered_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> list[tuple[str, str]]:
""" """
Return a list of requested identifiers (CN and SANs) for the CSR. Return a list of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either Each identifier is a pair (type, identifier), where type is either
@@ -335,13 +394,13 @@ class OpenSSLCLIBackend(CryptoBackend):
data = None data = None
if csr_content is not None: if csr_content is not None:
filename = "/dev/stdin" filename = "/dev/stdin"
data = csr_content.encode("utf-8") data = to_bytes(csr_content)
openssl_csr_cmd = [ openssl_csr_cmd = [
self.openssl_binary, self.openssl_binary,
"req", "req",
"-in", "-in",
filename, str(filename),
"-noout", "-noout",
"-text", "-text",
] ]
@@ -360,7 +419,7 @@ class OpenSSLCLIBackend(CryptoBackend):
identifiers = set() identifiers = set()
result = [] result = []
def add_identifier(identifier): def add_identifier(identifier: tuple[str, str]) -> None:
if identifier in identifiers: if identifier in identifiers:
return return
identifiers.add(identifier) identifiers.add(identifier)
@@ -389,7 +448,11 @@ class OpenSSLCLIBackend(CryptoBackend):
raise BackendException(f'Found unsupported SAN identifier "{san}"') raise BackendException(f'Found unsupported SAN identifier "{san}"')
return result return result
def get_csr_identifiers(self, csr_filename=None, csr_content=None): def get_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> set[tuple[str, str]]:
""" """
Return a set of requested identifiers (CN and SANs) for the CSR. Return a set of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either Each identifier is a pair (type, identifier), where type is either
@@ -401,7 +464,12 @@ class OpenSSLCLIBackend(CryptoBackend):
) )
) )
def get_cert_days(self, cert_filename=None, cert_content=None, now=None): def get_cert_days(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
now: datetime.datetime | None = None,
) -> int:
""" """
Return the days the certificate in cert_filename remains valid and -1 Return the days the certificate in cert_filename remains valid and -1
if the file was not found. If cert_filename contains more than one if the file was not found. If cert_filename contains more than one
@@ -413,7 +481,7 @@ class OpenSSLCLIBackend(CryptoBackend):
data = None data = None
if cert_content is not None: if cert_content is not None:
filename = "/dev/stdin" filename = "/dev/stdin"
data = cert_content.encode("utf-8") data = to_bytes(cert_content)
cert_filename_suffix = "" cert_filename_suffix = ""
elif cert_filename is not None: elif cert_filename is not None:
if not os.path.exists(cert_filename): if not os.path.exists(cert_filename):
@@ -426,7 +494,7 @@ class OpenSSLCLIBackend(CryptoBackend):
self.openssl_binary, self.openssl_binary,
"x509", "x509",
"-in", "-in",
filename, str(filename),
"-noout", "-noout",
"-text", "-text",
] ]
@@ -452,7 +520,7 @@ class OpenSSLCLIBackend(CryptoBackend):
now = ensure_utc_timezone(now) now = ensure_utc_timezone(now)
return (not_after - now).days return (not_after - now).days
def create_chain_matcher(self, criterium): def create_chain_matcher(self, criterium: Criterium) -> t.NoReturn:
""" """
Given a Criterium object, creates a ChainMatcher object. Given a Criterium object, creates a ChainMatcher object.
""" """
@@ -460,7 +528,11 @@ class OpenSSLCLIBackend(CryptoBackend):
'Alternate chain matching can only be used with the "cryptography" backend.' 'Alternate chain matching can only be used with the "cryptography" backend.'
) )
def get_cert_information(self, cert_filename=None, cert_content=None): def get_cert_information(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
) -> CertificateInformation:
""" """
Return some information on a X.509 certificate as a CertificateInformation object. Return some information on a X.509 certificate as a CertificateInformation object.
""" """
@@ -477,7 +549,7 @@ class OpenSSLCLIBackend(CryptoBackend):
self.openssl_binary, self.openssl_binary,
"x509", "x509",
"-in", "-in",
filename, str(filename),
"-noout", "-noout",
"-text", "-text",
] ]

View File

@@ -8,7 +8,7 @@ from __future__ import annotations
import abc import abc
import datetime import datetime
import re import re
from collections import namedtuple import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
BackendException, BackendException,
@@ -27,16 +27,20 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
) )
CertificateInformation = namedtuple( if t.TYPE_CHECKING:
"CertificateInformation", import os
(
"not_valid_after", from ansible.module_utils.basic import AnsibleModule
"not_valid_before",
"serial_number", from .certificates import ChainMatcher, Criterium
"subject_key_identifier",
"authority_key_identifier",
), class CertificateInformation(t.NamedTuple):
) not_valid_after: datetime.datetime
not_valid_before: datetime.datetime
serial_number: int
subject_key_identifier: bytes | None
authority_key_identifier: bytes | None
_FRACTIONAL_MATCHER = re.compile( _FRACTIONAL_MATCHER = re.compile(
@@ -44,7 +48,7 @@ _FRACTIONAL_MATCHER = re.compile(
) )
def _reduce_fractional_digits(timestamp_str): def _reduce_fractional_digits(timestamp_str: str) -> str:
""" """
Given a RFC 3339 timestamp that includes too many digits for the fractional seconds part, reduces these to at most 6. Given a RFC 3339 timestamp that includes too many digits for the fractional seconds part, reduces these to at most 6.
""" """
@@ -60,7 +64,7 @@ def _reduce_fractional_digits(timestamp_str):
return f"{timestamp}{fractional}{timezone}" return f"{timestamp}{fractional}{timezone}"
def _parse_acme_timestamp(timestamp_str, with_timezone): def _parse_acme_timestamp(timestamp_str: str, with_timezone: bool) -> datetime.datetime:
""" """
Parses a RFC 3339 timestamp. Parses a RFC 3339 timestamp.
""" """
@@ -86,34 +90,42 @@ def _parse_acme_timestamp(timestamp_str, with_timezone):
class CryptoBackend(metaclass=abc.ABCMeta): class CryptoBackend(metaclass=abc.ABCMeta):
def __init__(self, module, with_timezone=False): def __init__(self, module: AnsibleModule, with_timezone: bool = False) -> None:
self.module = module self.module = module
self._with_timezone = with_timezone self._with_timezone = with_timezone
def get_now(self): def get_now(self) -> datetime.datetime:
return get_now_datetime(with_timezone=self._with_timezone) return get_now_datetime(with_timezone=self._with_timezone)
def parse_acme_timestamp(self, timestamp_str): def parse_acme_timestamp(self, timestamp_str: str) -> datetime.datetime:
# RFC 3339 (https://www.rfc-editor.org/info/rfc3339) # RFC 3339 (https://www.rfc-editor.org/info/rfc3339)
return _parse_acme_timestamp(timestamp_str, with_timezone=self._with_timezone) return _parse_acme_timestamp(timestamp_str, with_timezone=self._with_timezone)
def parse_module_parameter(self, value, name): def parse_module_parameter(self, value: str, name: str) -> datetime.datetime:
try: try:
return get_relative_time_option( result = get_relative_time_option(
value, name, with_timezone=self._with_timezone value, name, with_timezone=self._with_timezone
) )
if result is None:
raise BackendException(f"Invalid value for {name}: {value!r}")
return result
except OpenSSLObjectError as exc: except OpenSSLObjectError as exc:
raise BackendException(str(exc)) raise BackendException(str(exc))
def interpolate_timestamp(self, timestamp_start, timestamp_end, percentage): def interpolate_timestamp(
self,
timestamp_start: datetime.datetime,
timestamp_end: datetime.datetime,
percentage: float,
) -> datetime.datetime:
start = get_epoch_seconds(timestamp_start) start = get_epoch_seconds(timestamp_start)
end = get_epoch_seconds(timestamp_end) end = get_epoch_seconds(timestamp_end)
return from_epoch_seconds( return from_epoch_seconds(
start + percentage * (end - start), with_timezone=self._with_timezone start + percentage * (end - start), with_timezone=self._with_timezone
) )
def get_utc_datetime(self, *args, **kwargs): def get_utc_datetime(self, *args, **kwargs) -> datetime.datetime:
kwargs_ext = dict(kwargs) kwargs_ext: dict[str, t.Any] = dict(kwargs)
if self._with_timezone and ("tzinfo" not in kwargs_ext and len(args) < 8): if self._with_timezone and ("tzinfo" not in kwargs_ext and len(args) < 8):
kwargs_ext["tzinfo"] = UTC kwargs_ext["tzinfo"] = UTC
result = datetime.datetime(*args, **kwargs_ext) result = datetime.datetime(*args, **kwargs_ext)
@@ -122,22 +134,33 @@ class CryptoBackend(metaclass=abc.ABCMeta):
return result return result
@abc.abstractmethod @abc.abstractmethod
def parse_key(self, key_file=None, key_content=None, passphrase=None): def parse_key(
self,
key_file: str | os.PathLike | None = None,
key_content: str | None = None,
passphrase: str | None = None,
) -> dict[str, t.Any]:
""" """
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data. Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
Raises KeyParsingError in case of errors. Raises KeyParsingError in case of errors.
""" """
@abc.abstractmethod @abc.abstractmethod
def sign(self, payload64, protected64, key_data): def sign(
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
) -> dict[str, t.Any]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def create_mac_key(self, alg, key): def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
"""Create a MAC key.""" """Create a MAC key."""
@abc.abstractmethod @abc.abstractmethod
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None): def get_ordered_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> list[tuple[str, str]]:
""" """
Return a list of requested identifiers (CN and SANs) for the CSR. Return a list of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either Each identifier is a pair (type, identifier), where type is either
@@ -148,7 +171,11 @@ class CryptoBackend(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def get_csr_identifiers(self, csr_filename=None, csr_content=None): def get_csr_identifiers(
self,
csr_filename: str | os.PathLike | None = None,
csr_content: str | bytes | None = None,
) -> set[tuple[str, str]]:
""" """
Return a set of requested identifiers (CN and SANs) for the CSR. Return a set of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either Each identifier is a pair (type, identifier), where type is either
@@ -156,7 +183,12 @@ class CryptoBackend(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def get_cert_days(self, cert_filename=None, cert_content=None, now=None): def get_cert_days(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
now: datetime.datetime | None = None,
) -> int:
""" """
Return the days the certificate in cert_filename remains valid and -1 Return the days the certificate in cert_filename remains valid and -1
if the file was not found. If cert_filename contains more than one if the file was not found. If cert_filename contains more than one
@@ -166,13 +198,17 @@ class CryptoBackend(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def create_chain_matcher(self, criterium): def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
""" """
Given a Criterium object, creates a ChainMatcher object. Given a Criterium object, creates a ChainMatcher object.
""" """
@abc.abstractmethod @abc.abstractmethod
def get_cert_information(self, cert_filename=None, cert_content=None): def get_cert_information(
self,
cert_filename: str | os.PathLike | None = None,
cert_content: str | bytes | None = None,
) -> CertificateInformation:
""" """
Return some information on a X.509 certificate as a CertificateInformation object. Return some information on a X.509 certificate as a CertificateInformation object.
""" """

View File

@@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount, ACMEAccount,
@@ -30,6 +31,14 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from .backends import CryptoBackend
from .certificates import ChainMatcher
from .challenges import Challenge
class ACMECertificateClient: class ACMECertificateClient:
""" """
ACME v2 client class. Uses an ACME account object and a CSR to ACME v2 client class. Uses an ACME account object and a CSR to
@@ -37,7 +46,13 @@ class ACMECertificateClient:
certificates. certificates.
""" """
def __init__(self, module, backend, client=None, account=None): def __init__(
self,
module: AnsibleModule,
backend: CryptoBackend,
client: ACMEClient | None = None,
account: ACMEAccount | None = None,
) -> None:
self.module = module self.module = module
self.version = module.params["acme_version"] self.version = module.params["acme_version"]
self.csr = module.params.get("csr") self.csr = module.params.get("csr")
@@ -66,13 +81,17 @@ class ACMECertificateClient:
# Extract list of identifiers from CSR # Extract list of identifiers from CSR
if self.csr is not None or self.csr_content is not None: if self.csr is not None or self.csr_content is not None:
self.identifiers = self.client.backend.get_ordered_csr_identifiers( self.identifiers: list[tuple[str, str]] | None = (
csr_filename=self.csr, csr_content=self.csr_content self.client.backend.get_ordered_csr_identifiers(
csr_filename=self.csr, csr_content=self.csr_content
)
) )
else: else:
self.identifiers = None self.identifiers = None
def parse_select_chain(self, select_chain): def parse_select_chain(
self, select_chain: list[dict[str, t.Any]] | None
) -> list[ChainMatcher]:
select_chain_matcher = [] select_chain_matcher = []
if select_chain: if select_chain:
for criterium_idx, criterium in enumerate(select_chain): for criterium_idx, criterium in enumerate(select_chain):
@@ -88,14 +107,16 @@ class ACMECertificateClient:
) )
return select_chain_matcher return select_chain_matcher
def load_order(self): def load_order(self) -> Order:
if not self.order_uri: if not self.order_uri:
raise ModuleFailException("The order URI has not been provided") raise ModuleFailException("The order URI has not been provided")
order = Order.from_url(self.client, self.order_uri) order = Order.from_url(self.client, self.order_uri)
order.load_authorizations(self.client) order.load_authorizations(self.client)
return order return order
def create_order(self, replaces_cert_id=None, profile=None): def create_order(
self, replaces_cert_id: str | None = None, profile: str | None = None
) -> Order:
""" """
Create a new order. Create a new order.
""" """
@@ -114,31 +135,31 @@ class ACMECertificateClient:
order.load_authorizations(self.client) order.load_authorizations(self.client)
return order return order
def get_challenges_data(self, order): def get_challenges_data(
self, order: Order
) -> tuple[list[dict[str, t.Any]], dict[str, list[str]]]:
""" """
Get challenge details. Get challenge details.
Return a tuple of generic challenge details, and specialized DNS challenge details. Return a tuple of generic challenge details, and specialized DNS challenge details.
""" """
# Get general challenge data data: list[dict[str, t.Any]] = []
data = [] data_dns: dict[str, list[str]] = {}
dns_challenge_type = "dns-01"
for authz in order.authorizations.values(): for authz in order.authorizations.values():
# Skip valid authentications: their challenges are already valid # Skip valid authentications: their challenges are already valid
# and do not need to be returned # and do not need to be returned
if authz.status == "valid": if authz.status == "valid":
continue continue
challenge_data = authz.get_challenge_data(self.client)
data.append( data.append(
dict( dict(
identifier=authz.identifier, identifier=authz.identifier,
identifier_type=authz.identifier_type, identifier_type=authz.identifier_type,
challenges=authz.get_challenge_data(self.client), challenges=challenge_data,
) )
) )
# Get DNS challenge data dns_challenge = challenge_data.get(dns_challenge_type)
data_dns = {}
dns_challenge_type = "dns-01"
for entry in data:
dns_challenge = entry["challenges"].get(dns_challenge_type)
if dns_challenge: if dns_challenge:
values = data_dns.get(dns_challenge["record"]) values = data_dns.get(dns_challenge["record"])
if values is None: if values is None:
@@ -147,7 +168,7 @@ class ACMECertificateClient:
values.append(dns_challenge["resource_value"]) values.append(dns_challenge["resource_value"])
return data, data_dns return data, data_dns
def check_that_authorizations_can_be_used(self, order): def check_that_authorizations_can_be_used(self, order: Order) -> None:
bad_authzs = [] bad_authzs = []
for authz in order.authorizations.values(): for authz in order.authorizations.values():
if authz.status not in ("valid", "pending"): if authz.status not in ("valid", "pending"):
@@ -155,27 +176,32 @@ class ACMECertificateClient:
f"{authz.combined_identifier} (status={authz.status!r})" f"{authz.combined_identifier} (status={authz.status!r})"
) )
if bad_authzs: if bad_authzs:
bad_authzs = ", ".join(sorted(bad_authzs)) bad_authzs_str = ", ".join(sorted(bad_authzs))
raise ModuleFailException( raise ModuleFailException(
"Some of the authorizations for the order are in a bad state, so the order" "Some of the authorizations for the order are in a bad state, so the order"
f" can no longer be satisfied: {bad_authzs}", f" can no longer be satisfied: {bad_authzs_str}",
) )
def collect_invalid_authzs(self, order): def collect_invalid_authzs(self, order: Order) -> list[Authorization]:
return [ return [
authz authz
for authz in order.authorizations.values() for authz in order.authorizations.values()
if authz.status == "invalid" if authz.status == "invalid"
] ]
def collect_pending_authzs(self, order): def collect_pending_authzs(self, order: Order) -> list[Authorization]:
return [ return [
authz authz
for authz in order.authorizations.values() for authz in order.authorizations.values()
if authz.status == "pending" if authz.status == "pending"
] ]
def call_validate(self, pending_authzs, get_challenge, wait=True): def call_validate(
self,
pending_authzs: list[Authorization],
get_challenge: t.Callable[[Authorization], str],
wait: bool = True,
) -> list[tuple[Authorization, str, Challenge | None]]:
authzs_with_challenges_to_wait_for = [] authzs_with_challenges_to_wait_for = []
for authz in pending_authzs: for authz in pending_authzs:
challenge_type = get_challenge(authz) challenge_type = get_challenge(authz)
@@ -185,10 +211,12 @@ class ACMECertificateClient:
) )
return authzs_with_challenges_to_wait_for return authzs_with_challenges_to_wait_for
def wait_for_validation(self, authzs_to_wait_for): def wait_for_validation(self, authzs_to_wait_for: list[Authorization]) -> None:
wait_for_validation(authzs_to_wait_for, self.client) wait_for_validation(authzs_to_wait_for, self.client)
def _download_alternate_chains(self, cert): def _download_alternate_chains(
self, cert: CertificateChain
) -> list[CertificateChain]:
alternate_chains = [] alternate_chains = []
for alternate in cert.alternates: for alternate in cert.alternates:
try: try:
@@ -206,13 +234,30 @@ class ACMECertificateClient:
) )
return alternate_chains return alternate_chains
def download_certificate(self, order, download_all_chains=True): @t.overload
def download_certificate(
self, order: Order, *, download_all_chains: t.Literal[True] = True
) -> tuple[CertificateChain, list[CertificateChain]]: ...
@t.overload
def download_certificate(
self, order: Order, *, download_all_chains: t.Literal[False]
) -> tuple[CertificateChain, None]: ...
@t.overload
def download_certificate(
self, order: Order, *, download_all_chains: bool = True
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
def download_certificate(
self, order: Order, *, download_all_chains: bool = True
) -> tuple[CertificateChain, list[CertificateChain] | None]:
""" """
Download certificate from a valid oder. Download certificate from a valid oder.
""" """
if order.status != "valid": if order.status != "valid":
raise ModuleFailException( raise ModuleFailException(
f"The order must be valid, but has state {order.state!r}!" f"The order must be valid, but has state {order.status!r}!"
) )
if not order.certificate_uri: if not order.certificate_uri:
@@ -232,7 +277,24 @@ class ACMECertificateClient:
return cert, alternate_chains return cert, alternate_chains
def get_certificate(self, order, download_all_chains=True): @t.overload
def get_certificate(
self, order: Order, *, download_all_chains: t.Literal[True] = True
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
@t.overload
def get_certificate(
self, order: Order, *, download_all_chains: t.Literal[False]
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
@t.overload
def get_certificate(
self, order: Order, *, download_all_chains: bool = True
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
def get_certificate(
self, order: Order, *, download_all_chains: bool = True
) -> tuple[CertificateChain, list[CertificateChain] | None]:
""" """
Request a new certificate and downloads it, and optionally all certificate chains. Request a new certificate and downloads it, and optionally all certificate chains.
First verifies whether all authorizations are valid; if not, aborts with an error. First verifies whether all authorizations are valid; if not, aborts with an error.
@@ -250,7 +312,11 @@ class ACMECertificateClient:
return self.download_certificate(order, download_all_chains=download_all_chains) return self.download_certificate(order, download_all_chains=download_all_chains)
def find_matching_chain(self, chains, select_chain_matcher): def find_matching_chain(
self,
chains: list[CertificateChain],
select_chain_matcher: t.Iterable[ChainMatcher],
) -> CertificateChain | None:
for criterium_idx, matcher in enumerate(select_chain_matcher): for criterium_idx, matcher in enumerate(select_chain_matcher):
for chain in chains: for chain in chains:
if matcher.match(chain): if matcher.match(chain):
@@ -261,9 +327,15 @@ class ACMECertificateClient:
return None return None
def write_cert_chain( def write_cert_chain(
self, cert, cert_dest=None, fullchain_dest=None, chain_dest=None self,
): cert: CertificateChain,
cert_dest: str | os.PathLike | None = None,
fullchain_dest: str | os.PathLike | None = None,
chain_dest: str | os.PathLike | None = None,
) -> bool:
changed = False changed = False
if cert.cert is None:
raise ValueError("Certificate is not present")
if cert_dest and write_file(self.module, cert_dest, cert.cert.encode("utf8")): if cert_dest and write_file(self.module, cert_dest, cert.cert.encode("utf8")):
changed = True changed = True
@@ -282,7 +354,7 @@ class ACMECertificateClient:
return changed return changed
def deactivate_authzs(self, order): def deactivate_authzs(self, order: Order) -> None:
""" """
Deactivates all valid authz's. Does not raise exceptions. Deactivates all valid authz's. Does not raise exceptions.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2 https://community.letsencrypt.org/t/authorization-deactivation/19860/2

View File

@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
ModuleFailException, ModuleFailException,
@@ -19,20 +20,29 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
) )
if t.TYPE_CHECKING:
from .acme import ACMEClient
_CertificateChain = t.TypeVar("_CertificateChain", bound="CertificateChain")
class CertificateChain: class CertificateChain:
""" """
Download and parse the certificate chain. Download and parse the certificate chain.
https://tools.ietf.org/html/rfc8555#section-7.4.2 https://tools.ietf.org/html/rfc8555#section-7.4.2
""" """
def __init__(self, url): def __init__(self, url: str):
self.url = url self.url = url
self.cert = None self.cert: str | None = None
self.chain = [] self.chain: list[str] = []
self.alternates = [] self.alternates: list[str] = []
@classmethod @classmethod
def download(cls, client, url): def download(
cls: t.Type[_CertificateChain], client: ACMEClient, url: str
) -> _CertificateChain:
content, info = client.get_request( content, info = client.get_request(
url, url,
parse_json_result=False, parse_json_result=False,
@@ -43,7 +53,7 @@ class CertificateChain:
"application/pem-certificate-chain" "application/pem-certificate-chain"
): ):
raise ModuleFailException( raise ModuleFailException(
f"Cannot download certificate chain from {url}, as content type is not application/pem-certificate-chain: {content} (headers: {info})" f"Cannot download certificate chain from {url}, as content type is not application/pem-certificate-chain: {content!r} (headers: {info})"
) )
result = cls(url) result = cls(url)
@@ -60,12 +70,12 @@ class CertificateChain:
if result.cert is None: if result.cert is None:
raise ModuleFailException( raise ModuleFailException(
f"Failed to parse certificate chain download from {url}: {content} (headers: {info})" f"Failed to parse certificate chain download from {url}: {content!r} (headers: {info})"
) )
return result return result
def _process_links(self, client, link, relation): def _process_links(self, client: ACMEClient, link: str, relation: str) -> None:
if relation == "up": if relation == "up":
# Process link-up headers if there was no chain in reply # Process link-up headers if there was no chain in reply
if not self.chain: if not self.chain:
@@ -77,7 +87,9 @@ class CertificateChain:
elif relation == "alternate": elif relation == "alternate":
self.alternates.append(link) self.alternates.append(link)
def to_json(self): def to_json(self) -> dict[str, bytes]:
if self.cert is None:
raise ValueError("Has no certificate")
cert = self.cert.encode("utf8") cert = self.cert.encode("utf8")
chain = ("\n".join(self.chain)).encode("utf8") chain = ("\n".join(self.chain)).encode("utf8")
return { return {
@@ -88,18 +100,22 @@ class CertificateChain:
class Criterium: class Criterium:
def __init__(self, criterium, index=None): def __init__(self, criterium: dict[str, t.Any], index: int):
self.index = index self.index = index
self.test_certificates = criterium["test_certificates"] self.test_certificates: t.Literal["first", "last", "all"] = criterium[
self.subject = criterium["subject"] "test_certificates"
self.issuer = criterium["issuer"] ]
self.subject_key_identifier = criterium["subject_key_identifier"] self.subject: dict[str, t.Any] | None = criterium["subject"]
self.authority_key_identifier = criterium["authority_key_identifier"] self.issuer: dict[str, t.Any] | None = criterium["issuer"]
self.subject_key_identifier: str | None = criterium["subject_key_identifier"]
self.authority_key_identifier: str | None = criterium[
"authority_key_identifier"
]
class ChainMatcher(metaclass=abc.ABCMeta): class ChainMatcher(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def match(self, certificate): def match(self, certificate: CertificateChain) -> bool:
""" """
Check whether a certificate chain (CertificateChain instance) matches. Check whether a certificate chain (CertificateChain instance) matches.
""" """

View File

@@ -11,6 +11,7 @@ import ipaddress
import json import json
import re import re
import time import time
import typing as t
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
@@ -23,7 +24,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
) )
def create_key_authorization(client, token): if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from .acme import ACMEClient
def create_key_authorization(client: ACMEClient, token: str) -> str:
""" """
Returns the key authorization for the given token Returns the key authorization for the given token
https://tools.ietf.org/html/rfc8555#section-8.1 https://tools.ietf.org/html/rfc8555#section-8.1
@@ -35,41 +42,49 @@ def create_key_authorization(client, token):
return f"{token}.{thumbprint}" return f"{token}.{thumbprint}"
def combine_identifier(identifier_type, identifier): def combine_identifier(identifier_type: str, identifier: str) -> str:
return f"{identifier_type}:{identifier}" return f"{identifier_type}:{identifier}"
def normalize_combined_identifier(identifier): def normalize_combined_identifier(identifier: str) -> str:
identifier_type, identifier = split_identifier(identifier) identifier_type, identifier = split_identifier(identifier)
# Normalize DNS names and IPs # Normalize DNS names and IPs
identifier = identifier.lower() identifier = identifier.lower()
return combine_identifier(identifier_type, identifier) return combine_identifier(identifier_type, identifier)
def split_identifier(identifier): def split_identifier(identifier: str) -> tuple[str, str]:
parts = identifier.split(":", 1) parts = identifier.split(":", 1)
if len(parts) != 2: if len(parts) != 2:
raise ModuleFailException( raise ModuleFailException(
f'Identifier "{identifier}" is not of the form <type>:<identifier>' f'Identifier "{identifier}" is not of the form <type>:<identifier>'
) )
return parts return parts[0], parts[1]
_Challenge = t.TypeVar("_Challenge", bound="Challenge")
class Challenge: class Challenge:
def __init__(self, data, url): def __init__(self, data: dict[str, t.Any], url: str) -> None:
self.data = data self.data = data
self.type = data["type"] self.type: str = data["type"]
self.url = url self.url = url
self.status = data["status"] self.status: str = data["status"]
self.token = data.get("token") self.token: str | None = data.get("token")
@classmethod @classmethod
def from_json(cls, client, data, url=None): def from_json(
cls: t.Type[_Challenge],
client: ACMEClient,
data: dict[str, t.Any],
url: str | None = None,
) -> _Challenge:
return cls(data, url or data["url"]) return cls(data, url or data["url"])
def call_validate(self, client): def call_validate(self, client: ACMEClient) -> None:
challenge_response = {} challenge_response: dict[str, t.Any] = {}
client.send_signed_request( client.send_signed_request(
self.url, self.url,
challenge_response, challenge_response,
@@ -77,10 +92,15 @@ class Challenge:
expected_status_codes=[200, 202], expected_status_codes=[200, 202],
) )
def to_json(self): def to_json(self) -> dict[str, t.Any]:
return self.data.copy() return self.data.copy()
def get_validation_data(self, client, identifier_type, identifier): def get_validation_data(
self, client: ACMEClient, identifier_type: str, identifier: str
) -> dict[str, t.Any] | None:
if self.token is None:
return None
token = re.sub(r"[^A-Za-z0-9_\-]", "_", self.token) token = re.sub(r"[^A-Za-z0-9_\-]", "_", self.token)
key_authorization = create_key_authorization(client, token) key_authorization = create_key_authorization(client, token)
@@ -113,21 +133,33 @@ class Challenge:
resource += "." resource += "."
else: else:
resource = identifier resource = identifier
value = base64.b64encode( b_value = base64.b64encode(
hashlib.sha256(to_bytes(key_authorization)).digest() hashlib.sha256(to_bytes(key_authorization)).digest()
) )
return { return {
"resource": resource, "resource": resource,
"resource_original": combine_identifier(identifier_type, identifier), "resource_original": combine_identifier(identifier_type, identifier),
"resource_value": value, "resource_value": b_value,
} }
# Unknown challenge type: ignore # Unknown challenge type: ignore
return None return None
_Authorization = t.TypeVar("_Authorization", bound="Authorization")
class Authorization: class Authorization:
def _setup(self, client, data): def __init__(self, url: str) -> None:
self.url = url
self.data: dict[str, t.Any] | None = None
self.challenges: list[Challenge] = []
self.status: str | None = None
self.identifier_type: str | None = None
self.identifier: str | None = None
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
data["uri"] = self.url data["uri"] = self.url
self.data = data self.data = data
# While 'challenges' is a required field, apparently not every CA cares # While 'challenges' is a required field, apparently not every CA cares
@@ -145,29 +177,32 @@ class Authorization:
if data.get("wildcard", False): if data.get("wildcard", False):
self.identifier = f"*.{self.identifier}" self.identifier = f"*.{self.identifier}"
def __init__(self, url):
self.url = url
self.data = None
self.challenges = []
self.status = None
self.identifier_type = None
self.identifier = None
@classmethod @classmethod
def from_json(cls, client, data, url): def from_json(
cls: t.Type[_Authorization],
client: ACMEClient,
data: dict[str, t.Any],
url: str,
) -> _Authorization:
result = cls(url) result = cls(url)
result._setup(client, data) result._setup(client, data)
return result return result
@classmethod @classmethod
def from_url(cls, client, url): def from_url(
cls: t.Type[_Authorization], client: ACMEClient, url: str
) -> _Authorization:
result = cls(url) result = cls(url)
result.refresh(client) result.refresh(client)
return result return result
@classmethod @classmethod
def create(cls, client, identifier_type, identifier): def create(
cls: t.Type[_Authorization],
client: ACMEClient,
identifier_type: str,
identifier: str,
) -> _Authorization:
""" """
Create a new authorization for the given identifier. Create a new authorization for the given identifier.
Return the authorization object of the new authorization Return the authorization object of the new authorization
@@ -194,23 +229,29 @@ class Authorization:
return cls.from_json(client, result, info["location"]) return cls.from_json(client, result, info["location"])
@property @property
def combined_identifier(self): def combined_identifier(self) -> str:
if self.identifier_type is None or self.identifier is None:
raise ValueError("Data not present")
return combine_identifier(self.identifier_type, self.identifier) return combine_identifier(self.identifier_type, self.identifier)
def to_json(self): def to_json(self) -> dict[str, t.Any]:
if self.data is None:
raise ValueError("Data not present")
return self.data.copy() return self.data.copy()
def refresh(self, client): def refresh(self, client: ACMEClient) -> bool:
result, dummy = client.get_request(self.url) result, dummy = client.get_request(self.url)
changed = self.data != result changed = self.data != result
self._setup(client, result) self._setup(client, result)
return changed return changed
def get_challenge_data(self, client): def get_challenge_data(self, client: ACMEClient) -> dict[str, t.Any]:
""" """
Returns a dict with the data for all proposed (and supported) challenges Returns a dict with the data for all proposed (and supported) challenges
of the given authorization. of the given authorization.
""" """
if self.identifier_type is None or self.identifier is None:
raise ValueError("Data not present")
data = {} data = {}
for challenge in self.challenges: for challenge in self.challenges:
validation_data = challenge.get_validation_data( validation_data = challenge.get_validation_data(
@@ -220,7 +261,7 @@ class Authorization:
data[challenge.type] = validation_data data[challenge.type] = validation_data
return data return data
def raise_error(self, error_msg, module=None): def raise_error(self, error_msg: str, module: AnsibleModule) -> t.NoReturn:
""" """
Aborts with a specific error for a challenge. Aborts with a specific error for a challenge.
""" """
@@ -246,13 +287,13 @@ class Authorization:
), ),
) )
def find_challenge(self, challenge_type): def find_challenge(self, challenge_type: str) -> Challenge | None:
for challenge in self.challenges: for challenge in self.challenges:
if challenge_type == challenge.type: if challenge_type == challenge.type:
return challenge return challenge
return None return None
def wait_for_validation(self, client, callenge_type): def wait_for_validation(self, client: ACMEClient, callenge_type: str) -> bool:
while True: while True:
self.refresh(client) self.refresh(client)
if self.status in ["valid", "invalid", "revoked"]: if self.status in ["valid", "invalid", "revoked"]:
@@ -264,7 +305,9 @@ class Authorization:
return self.status == "valid" return self.status == "valid"
def call_validate(self, client, challenge_type, wait=True): def call_validate(
self, client: ACMEClient, challenge_type: str, wait: bool = True
) -> bool:
""" """
Validate the authorization provided in the auth dict. Returns True Validate the authorization provided in the auth dict. Returns True
when the validation was successful and False when it was not. when the validation was successful and False when it was not.
@@ -281,7 +324,7 @@ class Authorization:
return self.status == "valid" return self.status == "valid"
return self.wait_for_validation(client, challenge_type) return self.wait_for_validation(client, challenge_type)
def can_deactivate(self): def can_deactivate(self) -> bool:
""" """
Deactivates this authorization. Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2 https://community.letsencrypt.org/t/authorization-deactivation/19860/2
@@ -289,14 +332,14 @@ class Authorization:
""" """
return self.status in ("valid", "pending") return self.status in ("valid", "pending")
def deactivate(self, client): def deactivate(self, client: ACMEClient) -> bool | None:
""" """
Deactivates this authorization. Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2 https://community.letsencrypt.org/t/authorization-deactivation/19860/2
https://tools.ietf.org/html/rfc8555#section-7.5.2 https://tools.ietf.org/html/rfc8555#section-7.5.2
""" """
if not self.can_deactivate(): if not self.can_deactivate():
return return None
authz_deactivate = {"status": "deactivated"} authz_deactivate = {"status": "deactivated"}
result, info = client.send_signed_request( result, info = client.send_signed_request(
self.url, authz_deactivate, fail_on_error=False self.url, authz_deactivate, fail_on_error=False
@@ -307,7 +350,9 @@ class Authorization:
return False return False
@classmethod @classmethod
def deactivate_url(cls, client, url): def deactivate_url(
cls: t.Type[_Authorization], client: ACMEClient, url: str
) -> _Authorization:
""" """
Deactivates this authorization. Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2 https://community.letsencrypt.org/t/authorization-deactivation/19860/2
@@ -322,7 +367,7 @@ class Authorization:
return authz return authz
def wait_for_validation(authzs, client): def wait_for_validation(authzs: t.Iterable[Authorization], client: ACMEClient) -> None:
""" """
Wait until a list of authz is valid. Fail if at least one of them is invalid or revoked. Wait until a list of authz is valid. Fail if at least one of them is invalid or revoked.
""" """

View File

@@ -5,19 +5,24 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from http.client import responses as http_responses from http.client import responses as http_responses
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
def format_http_status(status_code): if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def format_http_status(status_code: int) -> str:
expl = http_responses.get(status_code) expl = http_responses.get(status_code)
if not expl: if not expl:
return str(status_code) return str(status_code)
return f"{status_code} {expl}" return f"{status_code} {expl}"
def format_error_problem(problem, subproblem_prefix=""): def format_error_problem(problem: dict[str, t.Any], subproblem_prefix: str = "") -> str:
error_type = problem.get( error_type = problem.get(
"type", "about:blank" "type", "about:blank"
) # https://www.rfc-editor.org/rfc/rfc7807#section-3.1 ) # https://www.rfc-editor.org/rfc/rfc7807#section-3.1
@@ -32,8 +37,10 @@ def format_error_problem(problem, subproblem_prefix=""):
msg = f"{msg} Subproblems:" msg = f"{msg} Subproblems:"
for index, problem in enumerate(subproblems): for index, problem in enumerate(subproblems):
index_str = f"{subproblem_prefix}{index}" index_str = f"{subproblem_prefix}{index}"
problem = format_error_problem(problem, subproblem_prefix=f"{index_str}.") problem_str = format_error_problem(
msg = f"{msg}\n({index_str}) {problem}" problem, subproblem_prefix=f"{index_str}."
)
msg = f"{msg}\n({index_str}) {problem_str}"
return msg return msg
@@ -42,25 +49,25 @@ class ModuleFailException(Exception):
If raised, module.fail_json() will be called with the given parameters after cleanup. If raised, module.fail_json() will be called with the given parameters after cleanup.
""" """
def __init__(self, msg, **args): def __init__(self, msg: str, **args: t.Any) -> None:
super(ModuleFailException, self).__init__(self, msg) super(ModuleFailException, self).__init__(self, msg)
self.msg = msg self.msg = msg
self.module_fail_args = args self.module_fail_args = args
def do_fail(self, module, **arguments): def do_fail(self, module: AnsibleModule, **arguments) -> t.NoReturn:
module.fail_json(msg=self.msg, other=self.module_fail_args, **arguments) module.fail_json(msg=self.msg, other=self.module_fail_args, **arguments)
class ACMEProtocolException(ModuleFailException): class ACMEProtocolException(ModuleFailException):
def __init__( def __init__(
self, self,
module, module: AnsibleModule,
msg=None, msg: str | None = None,
info=None, info: dict[str, t.Any] | None = None,
response=None, response=None,
content=None, content: bytes | None = None,
content_json=None, content_json: dict[str, t.Any] | None = None,
extras=None, extras: dict[str, t.Any] | None = None,
): ):
# Try to get hold of content, if response is given and content is not provided # Try to get hold of content, if response is given and content is not provided
if content is None and content_json is None and response is not None: if content is None and content_json is None and response is not None:
@@ -71,7 +78,8 @@ class ACMEProtocolException(ModuleFailException):
raise TypeError raise TypeError
content = response.read() content = response.read()
except (AttributeError, TypeError): except (AttributeError, TypeError):
content = info.pop("body", None) if info is not None:
content = info.pop("body", None)
# Make sure that content_json is None or a dictionary # Make sure that content_json is None or a dictionary
if content_json is not None and not isinstance(content_json, dict): if content_json is not None and not isinstance(content_json, dict):
@@ -139,8 +147,8 @@ class ACMEProtocolException(ModuleFailException):
add_msg = f" The raw result: {to_text(content)}" add_msg = f" The raw result: {to_text(content)}"
super(ACMEProtocolException, self).__init__(f"{msg}.{add_msg}", **extras) super(ACMEProtocolException, self).__init__(f"{msg}.{add_msg}", **extras)
self.problem = {} self.problem: dict[str, t.Any] = {}
self.subproblems = [] self.subproblems: list[dict[str, t.Any]] = []
self.error_code = error_code self.error_code = error_code
self.error_type = error_type self.error_type = error_type
for k, v in extras.items(): for k, v in extras.items():

View File

@@ -10,22 +10,27 @@ import os
import shutil import shutil
import tempfile import tempfile
import traceback import traceback
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
ModuleFailException, ModuleFailException,
) )
def read_file(fn, mode="b"): if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def read_file(fn: str | os.PathLike) -> bytes:
try: try:
with open(fn, "r" + mode) as f: with open(fn, "rb") as f:
return f.read() return f.read()
except Exception as e: except Exception as e:
raise ModuleFailException(f'Error while reading file "{fn}": {e}') raise ModuleFailException(f'Error while reading file "{fn}": {e}')
# This function was adapted from an earlier version of https://github.com/ansible/ansible/blob/devel/lib/ansible/modules/uri.py # This function was adapted from an earlier version of https://github.com/ansible/ansible/blob/devel/lib/ansible/modules/uri.py
def write_file(module, dest, content): def write_file(module: AnsibleModule, dest: str | os.PathLike, content: bytes) -> bool:
""" """
Write content to destination file dest, only if the content Write content to destination file dest, only if the content
has changed. has changed.

View File

@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import time import time
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import ( from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
Authorization, Authorization,
@@ -13,14 +14,35 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.challenges i
) )
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
ACMEProtocolException, ACMEProtocolException,
ModuleFailException,
) )
from ansible_collections.community.crypto.plugins.module_utils.acme.utils import ( from ansible_collections.community.crypto.plugins.module_utils.acme.utils import (
nopad_b64, nopad_b64,
) )
if t.TYPE_CHECKING:
from .acme import ACMEClient
_Order = t.TypeVar("_Order", bound="Order")
class Order: class Order:
def _setup(self, client, data): def __init__(self, url: str) -> None:
self.url = url
self.data: dict[str, t.Any] | None = None
self.status = None
self.identifiers: list[tuple[str, str]] = []
self.replaces_cert_id = None
self.finalize_uri = None
self.certificate_uri = None
self.authorization_uris: list[str] = []
self.authorizations: dict[str, Authorization] = {}
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
self.data = data self.data = data
self.status = data["status"] self.status = data["status"]
@@ -33,33 +55,28 @@ class Order:
self.authorization_uris = data["authorizations"] self.authorization_uris = data["authorizations"]
self.authorizations = {} self.authorizations = {}
def __init__(self, url):
self.url = url
self.data = None
self.status = None
self.identifiers = []
self.replaces_cert_id = None
self.finalize_uri = None
self.certificate_uri = None
self.authorization_uris = []
self.authorizations = {}
@classmethod @classmethod
def from_json(cls, client, data, url): def from_json(
cls: t.Type[_Order], client: ACMEClient, data: dict[str, t.Any], url: str
) -> _Order:
result = cls(url) result = cls(url)
result._setup(client, data) result._setup(client, data)
return result return result
@classmethod @classmethod
def from_url(cls, client, url): def from_url(cls: t.Type[_Order], client: ACMEClient, url: str) -> _Order:
result = cls(url) result = cls(url)
result.refresh(client) result.refresh(client)
return result return result
@classmethod @classmethod
def create(cls, client, identifiers, replaces_cert_id=None, profile=None): def create(
cls: t.Type[_Order],
client: ACMEClient,
identifiers: list[tuple[str, str]],
replaces_cert_id: str | None = None,
profile: str | None = None,
) -> _Order:
""" """
Start a new certificate order (ACME v2 protocol). Start a new certificate order (ACME v2 protocol).
https://tools.ietf.org/html/rfc8555#section-7.4 https://tools.ietf.org/html/rfc8555#section-7.4
@@ -72,7 +89,7 @@ class Order:
"value": identifier, "value": identifier,
} }
) )
new_order = {"identifiers": acme_identifiers} new_order: dict[str, t.Any] = {"identifiers": acme_identifiers}
if replaces_cert_id is not None: if replaces_cert_id is not None:
new_order["replaces"] = replaces_cert_id new_order["replaces"] = replaces_cert_id
if profile is not None: if profile is not None:
@@ -87,15 +104,17 @@ class Order:
@classmethod @classmethod
def create_with_error_handling( def create_with_error_handling(
cls, cls: t.Type[_Order],
client, client: ACMEClient,
identifiers, identifiers: list[tuple[str, str]],
error_strategy="auto", error_strategy: t.Literal[
error_max_retries=3, "auto", "fail", "always", "retry_without_replaces_cert_id"
replaces_cert_id=None, ] = "auto",
profile=None, error_max_retries: int = 3,
message_callback=None, replaces_cert_id: str | None = None,
): profile: str | None = None,
message_callback: t.Callable[[str], None] | None = None,
) -> _Order:
""" """
error_strategy can be one of the following strings: error_strategy can be one of the following strings:
@@ -140,20 +159,20 @@ class Order:
raise raise
def refresh(self, client): def refresh(self, client: ACMEClient) -> bool:
result, dummy = client.get_request(self.url) result, dummy = client.get_request(self.url)
changed = self.data != result changed = self.data != result
self._setup(client, result) self._setup(client, result)
return changed return changed
def load_authorizations(self, client): def load_authorizations(self, client: ACMEClient) -> None:
for auth_uri in self.authorization_uris: for auth_uri in self.authorization_uris:
authz = Authorization.from_url(client, auth_uri) authz = Authorization.from_url(client, auth_uri)
self.authorizations[ self.authorizations[
normalize_combined_identifier(authz.combined_identifier) normalize_combined_identifier(authz.combined_identifier)
] = authz ] = authz
def wait_for_finalization(self, client): def wait_for_finalization(self, client: ACMEClient) -> None:
while True: while True:
self.refresh(client) self.refresh(client)
if self.status in ["valid", "invalid", "pending", "ready"]: if self.status in ["valid", "invalid", "pending", "ready"]:
@@ -167,12 +186,14 @@ class Order:
content_json=self.data, content_json=self.data,
) )
def finalize(self, client, csr_der, wait=True): def finalize(self, client: ACMEClient, csr_der: bytes, wait: bool = True) -> None:
""" """
Create a new certificate based on the csr. Create a new certificate based on the csr.
Return the certificate object as dict Return the certificate object as dict
https://tools.ietf.org/html/rfc8555#section-7.4 https://tools.ietf.org/html/rfc8555#section-7.4
""" """
if self.finalize_uri is None:
raise ModuleFailException("finalize_uri must be set")
new_cert = { new_cert = {
"csr": nopad_b64(csr_der), "csr": nopad_b64(csr_der),
} }

View File

@@ -7,9 +7,11 @@ from __future__ import annotations
import base64 import base64
import datetime import datetime
import os
import re import re
import textwrap import textwrap
import traceback import traceback
import typing as t
from urllib.parse import unquote from urllib.parse import unquote
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import ( from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
@@ -23,11 +25,15 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
) )
def nopad_b64(data): if t.TYPE_CHECKING:
from .backends import CertificateInformation, CryptoBackend
def nopad_b64(data: bytes) -> str:
return base64.urlsafe_b64encode(data).decode("utf8").replace("=", "") return base64.urlsafe_b64encode(data).decode("utf8").replace("=", "")
def der_to_pem(der_cert): def der_to_pem(der_cert: bytes) -> str:
""" """
Convert the DER format certificate in der_cert to a PEM format certificate and return it. Convert the DER format certificate in der_cert to a PEM format certificate and return it.
""" """
@@ -35,7 +41,9 @@ def der_to_pem(der_cert):
return f"-----BEGIN CERTIFICATE-----\n{content}\n-----END CERTIFICATE-----\n" return f"-----BEGIN CERTIFICATE-----\n{content}\n-----END CERTIFICATE-----\n"
def pem_to_der(pem_filename=None, pem_content=None): def pem_to_der(
pem_filename: str | os.PathLike | None = None, pem_content: str | None = None
) -> bytes:
""" """
Load PEM file, or use PEM file's content, and convert to DER. Load PEM file, or use PEM file's content, and convert to DER.
@@ -70,7 +78,9 @@ def pem_to_der(pem_filename=None, pem_content=None):
return base64.b64decode("".join(certificate_lines)) return base64.b64decode("".join(certificate_lines))
def process_links(info, callback): def process_links(
info: dict[str, t.Any], callback: t.Callable[[str, str], None]
) -> None:
""" """
Process link header, calls callback for every link header with the URL and relation as options. Process link header, calls callback for every link header with the URL and relation as options.
@@ -82,7 +92,11 @@ def process_links(info, callback):
callback(unquote(url), relation) callback(unquote(url), relation)
def parse_retry_after(value, relative_with_timezone=True, now=None): def parse_retry_after(
value: str,
relative_with_timezone: bool = True,
now: datetime.datetime | None = None,
) -> datetime.datetime:
""" """
Parse the value of a Retry-After header and return a timestamp. Parse the value of a Retry-After header and return a timestamp.
@@ -106,12 +120,12 @@ def parse_retry_after(value, relative_with_timezone=True, now=None):
def compute_cert_id( def compute_cert_id(
backend, backend: CryptoBackend,
cert_info=None, cert_info: CertificateInformation | None = None,
cert_filename=None, cert_filename: str | os.PathLike | None = None,
cert_content=None, cert_content: str | bytes | None = None,
none_if_required_information_is_missing=False, none_if_required_information_is_missing: bool = False,
): ) -> str | None:
# Obtain certificate info if not provided # Obtain certificate info if not provided
if cert_info is None: if cert_info is None:
cert_info = backend.get_cert_information( cert_info = backend.get_cert_information(

View File

@@ -4,10 +4,15 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
def _ensure_list(value): _T = t.TypeVar("_T")
def _ensure_list(value: list[_T] | tuple[_T] | None) -> list[_T]:
if value is None: if value is None:
return [] return []
return list(value) return list(value)
@@ -16,13 +21,19 @@ def _ensure_list(value):
class ArgumentSpec: class ArgumentSpec:
def __init__( def __init__(
self, self,
argument_spec=None, argument_spec: dict[str, t.Any] | None = None,
mutually_exclusive=None, mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
required_together=None, required_together: list[list[str] | tuple[str, ...]] | None = None,
required_one_of=None, required_one_of: list[list[str] | tuple[str, ...]] | None = None,
required_if=None, required_if: (
required_by=None, list[
): tuple[str, t.Any, list[str] | tuple[str, ...]]
| tuple[str, t.Any, list[str] | tuple[str, ...], bool]
]
| None
) = None,
required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
) -> None:
self.argument_spec = argument_spec or {} self.argument_spec = argument_spec or {}
self.mutually_exclusive = _ensure_list(mutually_exclusive) self.mutually_exclusive = _ensure_list(mutually_exclusive)
self.required_together = _ensure_list(required_together) self.required_together = _ensure_list(required_together)
@@ -30,17 +41,23 @@ class ArgumentSpec:
self.required_if = _ensure_list(required_if) self.required_if = _ensure_list(required_if)
self.required_by = required_by or {} self.required_by = required_by or {}
def update_argspec(self, **kwargs): def update_argspec(self, **kwargs) -> t.Self:
self.argument_spec.update(kwargs) self.argument_spec.update(kwargs)
return self return self
def update( def update(
self, self,
mutually_exclusive=None, mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
required_together=None, required_together: list[list[str] | tuple[str, ...]] | None = None,
required_one_of=None, required_one_of: list[list[str] | tuple[str, ...]] | None = None,
required_if=None, required_if: (
required_by=None, list[
tuple[str, t.Any, list[str] | tuple[str, ...]]
| tuple[str, t.Any, list[str] | tuple[str, ...], bool]
]
| None
) = None,
required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
): ):
if mutually_exclusive: if mutually_exclusive:
self.mutually_exclusive.extend(mutually_exclusive) self.mutually_exclusive.extend(mutually_exclusive)
@@ -57,7 +74,7 @@ class ArgumentSpec:
self.required_by[k] = v self.required_by[k] = v
return self return self
def merge(self, other): def merge(self, other: t.Self) -> t.Self:
self.update_argspec(**other.argument_spec) self.update_argspec(**other.argument_spec)
self.update( self.update(
mutually_exclusive=other.mutually_exclusive, mutually_exclusive=other.mutually_exclusive,
@@ -68,8 +85,22 @@ class ArgumentSpec:
) )
return self return self
def create_ansible_module_helper(self, clazz, args, **kwargs): def create_ansible_module_helper(
return clazz( self, clazz: type[_T], args: tuple, **kwargs: t.Any
) -> _T:
for forbidden_name in (
"argument_spec",
"mutually_exclusive",
"required_together",
"required_one_of",
"required_if",
"required_by",
):
if forbidden_name in kwargs:
raise ValueError(
f"You must not provide a {forbidden_name} keyword parameter to create_ansible_module_helper()"
)
instance = clazz( # type: ignore
*args, *args,
argument_spec=self.argument_spec, argument_spec=self.argument_spec,
mutually_exclusive=self.mutually_exclusive, mutually_exclusive=self.mutually_exclusive,
@@ -79,8 +110,9 @@ class ArgumentSpec:
required_by=self.required_by, required_by=self.required_by,
**kwargs, **kwargs,
) )
return instance
def create_ansible_module(self, **kwargs): def create_ansible_module(self, **kwargs: t.Any) -> AnsibleModule:
return self.create_ansible_module_helper(AnsibleModule, (), **kwargs) return self.create_ansible_module_helper(AnsibleModule, (), **kwargs)

View File

@@ -4,6 +4,7 @@
from __future__ import annotations from __future__ import annotations
import enum
import re import re
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
@@ -32,7 +33,7 @@ ASN1_STRING_REGEX = re.compile(
) )
class TagClass: class TagClass(enum.Enum):
universal = 0 universal = 0
application = 1 application = 1
context_specific = 2 context_specific = 2
@@ -40,11 +41,11 @@ class TagClass:
# Universal tag numbers that can be encoded. # Universal tag numbers that can be encoded.
class TagNumber: class TagNumber(enum.Enum):
utf8_string = 12 utf8_string = 12
def _pack_octet_integer(value): def _pack_octet_integer(value: int) -> bytes:
"""Packs an integer value into 1 or multiple octets.""" """Packs an integer value into 1 or multiple octets."""
# NOTE: This is *NOT* the same as packing an ASN.1 INTEGER like value. # NOTE: This is *NOT* the same as packing an ASN.1 INTEGER like value.
octets = bytearray() octets = bytearray()
@@ -66,7 +67,7 @@ def _pack_octet_integer(value):
return bytes(octets) return bytes(octets)
def serialize_asn1_string_as_der(value): def serialize_asn1_string_as_der(value: str) -> bytes:
"""Deserializes an ASN.1 string to a DER encoded byte string.""" """Deserializes an ASN.1 string to a DER encoded byte string."""
asn1_match = ASN1_STRING_REGEX.match(value) asn1_match = ASN1_STRING_REGEX.match(value)
if not asn1_match: if not asn1_match:
@@ -92,7 +93,7 @@ def serialize_asn1_string_as_der(value):
b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value) b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value)
if tag_type: if tag_type:
tag_class = { tag_class_enum = {
"U": TagClass.universal, "U": TagClass.universal,
"A": TagClass.application, "A": TagClass.application,
"P": TagClass.private, "P": TagClass.private,
@@ -100,13 +101,15 @@ def serialize_asn1_string_as_der(value):
}[tag_class] }[tag_class]
# When adding support for more types this should be looked into further. For now it works with UTF8Strings. # When adding support for more types this should be looked into further. For now it works with UTF8Strings.
constructed = tag_type == "EXPLICIT" and tag_class != TagClass.universal constructed = tag_type == "EXPLICIT" and tag_class_enum != TagClass.universal
b_value = pack_asn1(tag_class, constructed, int(tag_number), b_value) b_value = pack_asn1(tag_class_enum, constructed, int(tag_number), b_value)
return b_value return b_value
def pack_asn1(tag_class, constructed, tag_number, b_data): def pack_asn1(
tag_class: TagClass, constructed: bool, tag_number: TagNumber | int, b_data: bytes
) -> bytes:
"""Pack the value into an ASN.1 data structure. """Pack the value into an ASN.1 data structure.
The structure for an ASN.1 element is The structure for an ASN.1 element is
@@ -115,16 +118,15 @@ def pack_asn1(tag_class, constructed, tag_number, b_data):
""" """
b_asn1_data = bytearray() b_asn1_data = bytearray()
if tag_class < 0 or tag_class > 3:
raise ValueError(f"tag_class must be between 0 and 3 not {tag_class}")
# Bit 8 and 7 denotes the class. # Bit 8 and 7 denotes the class.
identifier_octets = tag_class << 6 identifier_octets = tag_class.value << 6
# Bit 6 denotes whether the value is primitive or constructed. # Bit 6 denotes whether the value is primitive or constructed.
identifier_octets |= (1 if constructed else 0) << 5 identifier_octets |= (1 if constructed else 0) << 5
# Bits 5-1 contain the tag number, if it cannot be encoded in these 5 bits # Bits 5-1 contain the tag number, if it cannot be encoded in these 5 bits
# then they are set and another octet(s) is used to denote the tag number. # then they are set and another octet(s) is used to denote the tag number.
if isinstance(tag_number, TagNumber):
tag_number = tag_number.value
if tag_number < 31: if tag_number < 31:
identifier_octets |= tag_number identifier_octets |= tag_number
b_asn1_data.append(identifier_octets) b_asn1_data.append(identifier_octets)

View File

@@ -34,7 +34,7 @@ from __future__ import annotations
# cryptography versions! # cryptography versions!
def obj2txt(openssl_lib, openssl_ffi, obj): def obj2txt(openssl_lib, openssl_ffi, obj) -> str:
# Set to 80 on the recommendation of # Set to 80 on the recommendation of
# https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values # https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values
# #

View File

@@ -7,9 +7,9 @@ from __future__ import annotations
from ._objects_data import OID_MAP from ._objects_data import OID_MAP
OID_LOOKUP = dict() OID_LOOKUP: dict[str, str] = dict()
NORMALIZE_NAMES = dict() NORMALIZE_NAMES: dict[str, str] = dict()
NORMALIZE_NAMES_SHORT = dict() NORMALIZE_NAMES_SHORT: dict[str, str] = dict()
for dotted, names in OID_MAP.items(): for dotted, names in OID_MAP.items():
for name in names: for name in names:

View File

@@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.version import ( from ansible_collections.community.crypto.plugins.module_utils.version import (
LooseVersion as _LooseVersion, LooseVersion as _LooseVersion,
) )
@@ -21,6 +23,10 @@ from .basic import HAS_CRYPTOGRAPHY
from .cryptography_support import CRYPTOGRAPHY_TIMEZONE, cryptography_decode_name from .cryptography_support import CRYPTOGRAPHY_TIMEZONE, cryptography_decode_name
if t.TYPE_CHECKING:
import datetime
# TODO: once cryptography has a _utc variant of InvalidityDate.invalidity_date, set this # TODO: once cryptography has a _utc variant of InvalidityDate.invalidity_date, set this
# to True and adjust get_invalidity_date() accordingly. # to True and adjust get_invalidity_date() accordingly.
# (https://github.com/pyca/cryptography/issues/10818) # (https://github.com/pyca/cryptography/issues/10818)
@@ -55,7 +61,9 @@ else:
REVOCATION_REASON_MAP_INVERSE = dict() REVOCATION_REASON_MAP_INVERSE = dict()
def cryptography_decode_revoked_certificate(cert): def cryptography_decode_revoked_certificate(
cert: x509.RevokedCertificate,
) -> dict[str, t.Any]:
result = { result = {
"serial_number": cert.serial_number, "serial_number": cert.serial_number,
"revocation_date": get_revocation_date(cert), "revocation_date": get_revocation_date(cert),
@@ -67,27 +75,30 @@ def cryptography_decode_revoked_certificate(cert):
"invalidity_date_critical": False, "invalidity_date_critical": False,
} }
try: try:
ext = cert.extensions.get_extension_for_class(x509.CertificateIssuer) ext_ci = cert.extensions.get_extension_for_class(x509.CertificateIssuer)
result["issuer"] = list(ext.value) result["issuer"] = list(ext_ci.value)
result["issuer_critical"] = ext.critical result["issuer_critical"] = ext_ci.critical
except x509.ExtensionNotFound: except x509.ExtensionNotFound:
pass pass
try: try:
ext = cert.extensions.get_extension_for_class(x509.CRLReason) ext_cr = cert.extensions.get_extension_for_class(x509.CRLReason)
result["reason"] = ext.value.reason result["reason"] = ext_cr.value.reason
result["reason_critical"] = ext.critical result["reason_critical"] = ext_cr.critical
except x509.ExtensionNotFound: except x509.ExtensionNotFound:
pass pass
try: try:
ext = cert.extensions.get_extension_for_class(x509.InvalidityDate) ext_id = cert.extensions.get_extension_for_class(x509.InvalidityDate)
result["invalidity_date"] = get_invalidity_date(ext.value) result["invalidity_date"] = get_invalidity_date(ext_id.value)
result["invalidity_date_critical"] = ext.critical result["invalidity_date_critical"] = ext_id.critical
except x509.ExtensionNotFound: except x509.ExtensionNotFound:
pass pass
return result return result
def cryptography_dump_revoked(entry, idn_rewrite="ignore"): def cryptography_dump_revoked(
entry: dict[str, t.Any],
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
) -> dict[str, t.Any]:
return { return {
"serial_number": entry["serial_number"], "serial_number": entry["serial_number"],
"revocation_date": entry["revocation_date"].strftime(TIMESTAMP_FORMAT), "revocation_date": entry["revocation_date"].strftime(TIMESTAMP_FORMAT),
@@ -115,48 +126,56 @@ def cryptography_dump_revoked(entry, idn_rewrite="ignore"):
} }
def cryptography_get_signature_algorithm_oid_from_crl(crl): def cryptography_get_signature_algorithm_oid_from_crl(
crl: x509.CertificateRevocationList,
) -> x509.oid.ObjectIdentifier:
try: try:
return crl.signature_algorithm_oid return crl.signature_algorithm_oid
except AttributeError: except AttributeError:
# Older cryptography versions do not have signature_algorithm_oid yet # Older cryptography versions do not have signature_algorithm_oid yet
dotted = obj2txt( dotted = obj2txt(
crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm # type: ignore
) )
return x509.oid.ObjectIdentifier(dotted) return x509.oid.ObjectIdentifier(dotted)
def get_next_update(obj): def get_next_update(obj: x509.CertificateRevocationList) -> datetime.datetime | None:
if CRYPTOGRAPHY_TIMEZONE: if CRYPTOGRAPHY_TIMEZONE:
return obj.next_update_utc return obj.next_update_utc
return obj.next_update return obj.next_update
def get_last_update(obj): def get_last_update(obj: x509.CertificateRevocationList) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE: if CRYPTOGRAPHY_TIMEZONE:
return obj.last_update_utc return obj.last_update_utc
return obj.last_update return obj.last_update
def get_revocation_date(obj): def get_revocation_date(obj: x509.RevokedCertificate) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE: if CRYPTOGRAPHY_TIMEZONE:
return obj.revocation_date_utc return obj.revocation_date_utc
return obj.revocation_date return obj.revocation_date
def get_invalidity_date(obj): def get_invalidity_date(obj: x509.InvalidityDate) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE: if CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE:
return obj.invalidity_date_utc return obj.invalidity_date_utc
return obj.invalidity_date return obj.invalidity_date
def set_next_update(builder, value): def set_next_update(
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
) -> x509.CertificateRevocationListBuilder:
return builder.next_update(value) return builder.next_update(value)
def set_last_update(builder, value): def set_last_update(
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
) -> x509.CertificateRevocationListBuilder:
return builder.last_update(value) return builder.last_update(value)
def set_revocation_date(builder, value): def set_revocation_date(
builder: x509.RevokedCertificateBuilder, value: datetime.datetime
) -> x509.RevokedCertificateBuilder:
return builder.revocation_date(value) return builder.revocation_date(value)

View File

@@ -9,6 +9,7 @@ import binascii
import ipaddress import ipaddress
import re import re
import traceback import traceback
import typing as t
from urllib.parse import ( from urllib.parse import (
ParseResult, ParseResult,
urlparse, urlparse,
@@ -40,6 +41,7 @@ except ImportError:
pass pass
try: try:
import cryptography.hazmat.primitives.asymmetric.dh
import cryptography.hazmat.primitives.asymmetric.ed448 import cryptography.hazmat.primitives.asymmetric.ed448
import cryptography.hazmat.primitives.asymmetric.ed25519 import cryptography.hazmat.primitives.asymmetric.ed25519
import cryptography.hazmat.primitives.asymmetric.rsa import cryptography.hazmat.primitives.asymmetric.rsa
@@ -55,7 +57,7 @@ try:
) )
except ImportError: except ImportError:
# Error handled in the calling module. # Error handled in the calling module.
_load_pkcs12 = None _load_pkcs12 = None # type: ignore
try: try:
import idna import idna
@@ -74,6 +76,50 @@ from .basic import (
) )
if t.TYPE_CHECKING:
import datetime
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric.dh import DHPrivateKey, DHPublicKey
from cryptography.hazmat.primitives.asymmetric.dsa import (
DSAPrivateKey,
DSAPublicKey,
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKey,
EllipticCurvePublicKey,
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey,
RSAPublicKey,
)
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
CertificateIssuerPublicKeyTypes,
CertificatePublicKeyTypes,
PrivateKeyTypes,
PublicKeyTypes,
)
from cryptography.hazmat.primitives.serialization.pkcs12 import (
PKCS12KeyAndCertificates,
)
CertificatePrivateKeyTypes = (
CertificateIssuerPrivateKeyTypes
| cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey
| cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
)
PublicKeyTypesWOEdwards = (
DHPublicKey | DSAPublicKey | EllipticCurvePublicKey | RSAPublicKey
)
PrivateKeyTypesWOEdwards = (
DHPrivateKey | DSAPrivateKey | EllipticCurvePrivateKey | RSAPrivateKey
)
else:
PublicKeyTypesWOEdwards = None
PrivateKeyTypesWOEdwards = None
CRYPTOGRAPHY_TIMEZONE = False CRYPTOGRAPHY_TIMEZONE = False
_CRYPTOGRAPHY_36_0_OR_NEWER = False _CRYPTOGRAPHY_36_0_OR_NEWER = False
if _HAS_CRYPTOGRAPHY: if _HAS_CRYPTOGRAPHY:
@@ -88,7 +134,9 @@ if _HAS_CRYPTOGRAPHY:
DOTTED_OID = re.compile(r"^\d+(?:\.\d+)+$") DOTTED_OID = re.compile(r"^\d+(?:\.\d+)+$")
def cryptography_get_extensions_from_cert(cert): def cryptography_get_extensions_from_cert(
cert: x509.Certificate,
) -> dict[str, dict[str, bool | str]]:
result = dict() result = dict()
if _CRYPTOGRAPHY_36_0_OR_NEWER: if _CRYPTOGRAPHY_36_0_OR_NEWER:
@@ -105,7 +153,7 @@ def cryptography_get_extensions_from_cert(cert):
backend = default_backend() backend = default_backend()
x509_obj = cert._x509 x509_obj = cert._x509 # type: ignore
# With cryptography 35.0.0, we can no longer use obj2txt. Unfortunately it still does # With cryptography 35.0.0, we can no longer use obj2txt. Unfortunately it still does
# not allow to get the raw value of an extension, so we have to use this ugly hack: # not allow to get the raw value of an extension, so we have to use this ugly hack:
exts = list(cert.extensions) exts = list(cert.extensions)
@@ -135,7 +183,9 @@ def cryptography_get_extensions_from_cert(cert):
return result return result
def cryptography_get_extensions_from_csr(csr): def cryptography_get_extensions_from_csr(
csr: x509.CertificateSigningRequest,
) -> dict[str, dict[str, bool | str]]:
result = dict() result = dict()
if _CRYPTOGRAPHY_36_0_OR_NEWER: if _CRYPTOGRAPHY_36_0_OR_NEWER:
@@ -153,7 +203,7 @@ def cryptography_get_extensions_from_csr(csr):
backend = default_backend() backend = default_backend()
extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req) extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req) # type: ignore
extensions = backend._ffi.gc( extensions = backend._ffi.gc(
extensions, extensions,
lambda ext: backend._lib.sk_X509_EXTENSION_pop_free( lambda ext: backend._lib.sk_X509_EXTENSION_pop_free(
@@ -175,7 +225,7 @@ def cryptography_get_extensions_from_csr(csr):
crit = backend._lib.X509_EXTENSION_get_critical(ext) crit = backend._lib.X509_EXTENSION_get_critical(ext)
data = backend._lib.X509_EXTENSION_get_data(ext) data = backend._lib.X509_EXTENSION_get_data(ext)
backend.openssl_assert(data != backend._ffi.NULL) backend.openssl_assert(data != backend._ffi.NULL)
der = backend._ffi.buffer(data.data, data.length)[:] der: bytes = backend._ffi.buffer(data.data, data.length)[:] # type: ignore
entry = dict( entry = dict(
critical=(crit == 1), critical=(crit == 1),
value=base64.b64encode(der).decode("ascii"), value=base64.b64encode(der).decode("ascii"),
@@ -193,7 +243,7 @@ def cryptography_get_extensions_from_csr(csr):
return result return result
def cryptography_name_to_oid(name): def cryptography_name_to_oid(name: str) -> x509.oid.ObjectIdentifier:
dotted = OID_LOOKUP.get(name) dotted = OID_LOOKUP.get(name)
if dotted is None: if dotted is None:
if DOTTED_OID.match(name): if DOTTED_OID.match(name):
@@ -202,7 +252,9 @@ def cryptography_name_to_oid(name):
return x509.oid.ObjectIdentifier(dotted) return x509.oid.ObjectIdentifier(dotted)
def cryptography_oid_to_name(oid, short=False): def cryptography_oid_to_name(
oid: x509.oid.ObjectIdentifier, short: bool = False
) -> str:
dotted_string = oid.dotted_string dotted_string = oid.dotted_string
names = OID_MAP.get(dotted_string) names = OID_MAP.get(dotted_string)
if names: if names:
@@ -217,15 +269,22 @@ def cryptography_oid_to_name(oid, short=False):
return NORMALIZE_NAMES.get(name, name) return NORMALIZE_NAMES.get(name, name)
def _get_hex(bytesstr): def _get_hex(bytesstr: bytes) -> str:
if bytesstr is None: if bytesstr is None:
return bytesstr return bytesstr
data = binascii.hexlify(bytesstr) data = binascii.hexlify(bytesstr)
data = to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2))) return to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2)))
return data
def _parse_hex(bytesstr): @t.overload
def _parse_hex(bytesstr: bytes | str) -> bytes: ...
@t.overload
def _parse_hex(bytesstr: bytes | str | None) -> bytes | None: ...
def _parse_hex(bytesstr: bytes | str | None) -> bytes | None:
if bytesstr is None: if bytesstr is None:
return bytesstr return bytesstr
data = "".join( data = "".join(
@@ -234,19 +293,20 @@ def _parse_hex(bytesstr):
for p in to_text(bytesstr).split(":") for p in to_text(bytesstr).split(":")
] ]
) )
data = binascii.unhexlify(data) return binascii.unhexlify(data)
return data
DN_COMPONENT_START_RE = re.compile(b"^ *([a-zA-z0-9.]+) *= *") DN_COMPONENT_START_RE = re.compile(b"^ *([a-zA-z0-9.]+) *= *")
DN_HEX_LETTER = b"0123456789abcdef" DN_HEX_LETTER = b"0123456789abcdef"
def _int_to_byte(value): def _int_to_byte(value: int) -> bytes:
return bytes((value,)) return bytes((value,))
def _parse_dn_component(name, sep=b",", decode_remainder=True): def _parse_dn_component(
name: bytes, sep: bytes = b",", decode_remainder: bool = True
) -> tuple[x509.NameAttribute, bytes]:
m = DN_COMPONENT_START_RE.match(name) m = DN_COMPONENT_START_RE.match(name)
if not m: if not m:
raise OpenSSLObjectError(f'cannot start part in "{to_text(name)}"') raise OpenSSLObjectError(f'cannot start part in "{to_text(name)}"')
@@ -305,7 +365,7 @@ def _parse_dn_component(name, sep=b",", decode_remainder=True):
return x509.NameAttribute(oid, to_text(b"".join(decoded_name))), name[idx:] return x509.NameAttribute(oid, to_text(b"".join(decoded_name))), name[idx:]
def _parse_dn(name): def _parse_dn(name: bytes) -> list[x509.NameAttribute]:
""" """
Parse a Distinguished Name. Parse a Distinguished Name.
@@ -323,31 +383,33 @@ def _parse_dn(name):
attribute, name = _parse_dn_component(name, sep=sep) attribute, name = _parse_dn_component(name, sep=sep)
except OpenSSLObjectError as e: except OpenSSLObjectError as e:
raise OpenSSLObjectError( raise OpenSSLObjectError(
f'Error while parsing distinguished name "{to_text(original_name)}": {e}' f"Error while parsing distinguished name {to_text(original_name)!r}: {e}"
) )
result.append(attribute) result.append(attribute)
if name: if name:
if name[0:1] != sep or len(name) < 2: if name[0:1] != sep or len(name) < 2:
raise OpenSSLObjectError( raise OpenSSLObjectError(
f'Error while parsing distinguished name "{to_text(original_name)}": unexpected end of string' f"Error while parsing distinguished name {to_text(original_name)!r}: unexpected end of string"
) )
name = name[1:] name = name[1:]
return result return result
def cryptography_parse_relative_distinguished_name(rdn): def cryptography_parse_relative_distinguished_name(
rdn: list[str | bytes],
) -> cryptography.x509.RelativeDistinguishedName:
names = [] names = []
for part in rdn: for part in rdn:
try: try:
names.append(_parse_dn_component(to_bytes(part), decode_remainder=False)[0]) names.append(_parse_dn_component(to_bytes(part), decode_remainder=False)[0])
except OpenSSLObjectError as e: except OpenSSLObjectError as e:
raise OpenSSLObjectError( raise OpenSSLObjectError(
f'Error while parsing relative distinguished name "{part}": {e}' f"Error while parsing relative distinguished name {to_text(part)!r}: {e}"
) )
return cryptography.x509.RelativeDistinguishedName(names) return cryptography.x509.RelativeDistinguishedName(names)
def _is_ascii(value): def _is_ascii(value: str) -> bool:
"""Check whether the Unicode string `value` contains only ASCII characters.""" """Check whether the Unicode string `value` contains only ASCII characters."""
try: try:
value.encode("ascii") value.encode("ascii")
@@ -356,7 +418,7 @@ def _is_ascii(value):
return False return False
def _adjust_idn(value, idn_rewrite): def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]) -> str:
if idn_rewrite == "ignore" or not value: if idn_rewrite == "ignore" or not value:
return value return value
if idn_rewrite == "idna" and _is_ascii(value): if idn_rewrite == "idna" and _is_ascii(value):
@@ -399,16 +461,20 @@ def _adjust_idn(value, idn_rewrite):
return ".".join(parts) return ".".join(parts)
def _adjust_idn_email(value, idn_rewrite): def _adjust_idn_email(
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
) -> str:
idx = value.find("@") idx = value.find("@")
if idx < 0: if idx < 0:
return value return value
return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite)}" return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite)}"
def _adjust_idn_url(value, idn_rewrite): def _adjust_idn_url(
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
) -> str:
url = urlparse(value) url = urlparse(value)
host = _adjust_idn(url.hostname, idn_rewrite) host = _adjust_idn(url.hostname, idn_rewrite) if url.hostname else None
if url.username is not None and url.password is not None: if url.username is not None and url.password is not None:
host = f"{url.username}:{url.password}@{host}" host = f"{url.username}:{url.password}@{host}"
elif url.username is not None: elif url.username is not None:
@@ -418,7 +484,7 @@ def _adjust_idn_url(value, idn_rewrite):
return urlunparse( return urlunparse(
ParseResult( ParseResult(
scheme=url.scheme, scheme=url.scheme,
netloc=host, netloc=host or "",
path=url.path, path=url.path,
params=url.params, params=url.params,
query=url.query, query=url.query,
@@ -427,7 +493,9 @@ def _adjust_idn_url(value, idn_rewrite):
) )
def cryptography_get_name(name, what="Subject Alternative Name"): def cryptography_get_name(
name: str, what: str = "Subject Alternative Name"
) -> x509.GeneralName:
""" """
Given a name string, returns a cryptography x509.GeneralName object. Given a name string, returns a cryptography x509.GeneralName object.
Raises an OpenSSLObjectError if the name is unknown or cannot be parsed. Raises an OpenSSLObjectError if the name is unknown or cannot be parsed.
@@ -490,7 +558,7 @@ def cryptography_get_name(name, what="Subject Alternative Name"):
) )
def _dn_escape_value(value): def _dn_escape_value(value: str) -> str:
""" """
Escape Distinguished Name's attribute value. Escape Distinguished Name's attribute value.
""" """
@@ -505,7 +573,10 @@ def _dn_escape_value(value):
return value return value
def cryptography_decode_name(name, idn_rewrite="ignore"): def cryptography_decode_name(
name: x509.GeneralName,
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
) -> str:
""" """
Given a cryptography x509.GeneralName object, returns a string. Given a cryptography x509.GeneralName object, returns a string.
Raises an OpenSSLObjectError if the name is not supported. Raises an OpenSSLObjectError if the name is not supported.
@@ -529,7 +600,7 @@ def cryptography_decode_name(name, idn_rewrite="ignore"):
# list needs to be reversed, and joined by commas # list needs to be reversed, and joined by commas
return "dirName:" + ",".join( return "dirName:" + ",".join(
[ [
f"{to_text(cryptography_oid_to_name(attribute.oid, short=True))}={_dn_escape_value(attribute.value)}" f"{to_text(cryptography_oid_to_name(attribute.oid, short=True))}={_dn_escape_value(to_text(attribute.value))}"
for attribute in reversed(list(name.value)) for attribute in reversed(list(name.value))
] ]
) )
@@ -540,7 +611,7 @@ def cryptography_decode_name(name, idn_rewrite="ignore"):
raise OpenSSLObjectError(f'Cannot decode name "{name}"') raise OpenSSLObjectError(f'Cannot decode name "{name}"')
def _cryptography_get_keyusage(usage): def _cryptography_get_keyusage(usage: str) -> str:
""" """
Given a key usage identifier string, returns the parameter name used by cryptography's x509.KeyUsage(). Given a key usage identifier string, returns the parameter name used by cryptography's x509.KeyUsage().
Raises an OpenSSLObjectError if the identifier is unknown. Raises an OpenSSLObjectError if the identifier is unknown.
@@ -566,7 +637,7 @@ def _cryptography_get_keyusage(usage):
raise OpenSSLObjectError(f'Unknown key usage "{usage}"') raise OpenSSLObjectError(f'Unknown key usage "{usage}"')
def cryptography_parse_key_usage_params(usages): def cryptography_parse_key_usage_params(usages: t.Iterable[str]) -> dict[str, bool]:
""" """
Given a list of key usage identifier strings, returns the parameters for cryptography's x509.KeyUsage(). Given a list of key usage identifier strings, returns the parameters for cryptography's x509.KeyUsage().
Raises an OpenSSLObjectError if an identifier is unknown. Raises an OpenSSLObjectError if an identifier is unknown.
@@ -587,13 +658,15 @@ def cryptography_parse_key_usage_params(usages):
return params return params
def cryptography_get_basic_constraints(constraints): def cryptography_get_basic_constraints(
constraints: t.Iterable[str] | None,
) -> tuple[bool, int | None]:
""" """
Given a list of constraints, returns a tuple (ca, path_length). Given a list of constraints, returns a tuple (ca, path_length).
Raises an OpenSSLObjectError if a constraint is unknown or cannot be parsed. Raises an OpenSSLObjectError if a constraint is unknown or cannot be parsed.
""" """
ca = False ca = False
path_length = None path_length: int | None = None
if constraints: if constraints:
for constraint in constraints: for constraint in constraints:
if constraint.startswith("CA:"): if constraint.startswith("CA:"):
@@ -618,7 +691,9 @@ def cryptography_get_basic_constraints(constraints):
return ca, path_length return ca, path_length
def cryptography_key_needs_digest_for_signing(key): def cryptography_key_needs_digest_for_signing(
key: CertificateIssuerPrivateKeyTypes,
) -> bool:
"""Tests whether the given private key requires a digest algorithm for signing. """Tests whether the given private key requires a digest algorithm for signing.
Ed25519 and Ed448 keys do not; they need None to be passed as the digest algorithm. Ed25519 and Ed448 keys do not; they need None to be passed as the digest algorithm.
@@ -632,19 +707,27 @@ def cryptography_key_needs_digest_for_signing(key):
return True return True
def _compare_public_keys(key1, key2, clazz): def _compare_public_keys(
key1: PublicKeyTypes, key2: PublicKeyTypes, clazz: type[PublicKeyTypes]
) -> bool | None:
a = isinstance(key1, clazz) a = isinstance(key1, clazz)
b = isinstance(key2, clazz) b = isinstance(key2, clazz)
if not (a or b): if not (a or b):
return None return None
if not a or not b: if not a or not b:
return False return False
a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) a_bytes = key1.public_bytes(
b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) serialization.Encoding.Raw, serialization.PublicFormat.Raw
return a == b )
b_bytes = key2.public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)
return a_bytes == b_bytes
def cryptography_compare_public_keys(key1, key2): def cryptography_compare_public_keys(
key1: PublicKeyTypes, key2: PublicKeyTypes
) -> bool:
"""Tests whether two public keys are the same. """Tests whether two public keys are the same.
Needs special logic for Ed25519 and Ed448 keys, since they do not have public_numbers(). Needs special logic for Ed25519 and Ed448 keys, since they do not have public_numbers().
@@ -654,6 +737,13 @@ def cryptography_compare_public_keys(key1, key2):
key2, key2,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
) )
if res is not None:
return res
res = _compare_public_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
)
if res is not None: if res is not None:
return res return res
res = _compare_public_keys( res = _compare_public_keys(
@@ -661,10 +751,20 @@ def cryptography_compare_public_keys(key1, key2):
) )
if res is not None: if res is not None:
return res return res
return key1.public_numbers() == key2.public_numbers() res = _compare_public_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
)
if res is not None:
return res
return (
t.cast(PublicKeyTypesWOEdwards, key1).public_numbers()
== t.cast(PublicKeyTypesWOEdwards, key2).public_numbers()
)
def _compare_private_keys(key1, key2, clazz): def _compare_private_keys(
key1: PrivateKeyTypes, key2: PrivateKeyTypes, clazz: type[PrivateKeyTypes]
) -> bool | None:
a = isinstance(key1, clazz) a = isinstance(key1, clazz)
b = isinstance(key2, clazz) b = isinstance(key2, clazz)
if not (a or b): if not (a or b):
@@ -672,20 +772,22 @@ def _compare_private_keys(key1, key2, clazz):
if not a or not b: if not a or not b:
return False return False
encryption_algorithm = cryptography.hazmat.primitives.serialization.NoEncryption() encryption_algorithm = cryptography.hazmat.primitives.serialization.NoEncryption()
a = key1.private_bytes( a_bytes = key1.private_bytes(
serialization.Encoding.Raw, serialization.Encoding.Raw,
serialization.PrivateFormat.Raw, serialization.PrivateFormat.Raw,
encryption_algorithm=encryption_algorithm, encryption_algorithm=encryption_algorithm,
) )
b = key2.private_bytes( b_bytes = key2.private_bytes(
serialization.Encoding.Raw, serialization.Encoding.Raw,
serialization.PrivateFormat.Raw, serialization.PrivateFormat.Raw,
encryption_algorithm=encryption_algorithm, encryption_algorithm=encryption_algorithm,
) )
return a == b return a_bytes == b_bytes
def cryptography_compare_private_keys(key1, key2): def cryptography_compare_private_keys(
key1: PrivateKeyTypes, key2: PrivateKeyTypes
) -> bool:
"""Tests whether two private keys are the same. """Tests whether two private keys are the same.
Needs special logic for Ed25519, X25519, and Ed448 keys, since they do not have private_numbers(). Needs special logic for Ed25519, X25519, and Ed448 keys, since they do not have private_numbers().
@@ -714,25 +816,39 @@ def cryptography_compare_private_keys(key1, key2):
) )
if res is not None: if res is not None:
return res return res
return key1.private_numbers() == key2.private_numbers() return (
t.cast(PrivateKeyTypesWOEdwards, key1).private_numbers()
== t.cast(PrivateKeyTypesWOEdwards, key2).private_numbers()
)
def parse_pkcs12(pkcs12_bytes, passphrase=None): def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
bytes | None,
]:
"""Returns a tuple (private_key, certificate, additional_certificates, friendly_name).""" """Returns a tuple (private_key, certificate, additional_certificates, friendly_name)."""
passphrase_bytes = None
if passphrase is not None: if passphrase is not None:
passphrase = to_bytes(passphrase) passphrase_bytes = to_bytes(passphrase)
# Main code for cryptography 36.0.0 and forward # Main code for cryptography 36.0.0 and forward
if _load_pkcs12 is not None: if _load_pkcs12 is not None:
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase) return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase_bytes)
if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"): if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"):
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase) return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase_bytes)
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase) return _parse_pkcs12_legacy(pkcs12_bytes, passphrase_bytes)
def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None): def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
bytes | None,
]:
# Requires cryptography 36.0.0 or newer # Requires cryptography 36.0.0 or newer
pkcs12 = _load_pkcs12(pkcs12_bytes, passphrase) pkcs12 = _load_pkcs12(pkcs12_bytes, passphrase)
additional_certificates = [cert.certificate for cert in pkcs12.additional_certs] additional_certificates = [cert.certificate for cert in pkcs12.additional_certs]
@@ -745,7 +861,12 @@ def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None):
return private_key, certificate, additional_certificates, friendly_name return private_key, certificate, additional_certificates, friendly_name
def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None): def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
bytes | None,
]:
# Backwards compatibility code for cryptography 35.x # Backwards compatibility code for cryptography 35.x
private_key, certificate, additional_certificates = _load_key_and_certificates( private_key, certificate, additional_certificates = _load_key_and_certificates(
pkcs12_bytes, passphrase pkcs12_bytes, passphrase
@@ -787,7 +908,12 @@ def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None):
return private_key, certificate, additional_certificates, friendly_name return private_key, certificate, additional_certificates, friendly_name
def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None): def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
bytes | None,
]:
# Backwards compatibility code for cryptography < 35.0.0 # Backwards compatibility code for cryptography < 35.0.0
private_key, certificate, additional_certificates = _load_key_and_certificates( private_key, certificate, additional_certificates = _load_key_and_certificates(
pkcs12_bytes, passphrase pkcs12_bytes, passphrase
@@ -796,14 +922,19 @@ def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None):
friendly_name = None friendly_name = None
if certificate: if certificate:
# See https://github.com/pyca/cryptography/issues/5760#issuecomment-842687238 # See https://github.com/pyca/cryptography/issues/5760#issuecomment-842687238
backend = certificate._backend backend = certificate._backend # type: ignore
maybe_name = backend._lib.X509_alias_get0(certificate._x509, backend._ffi.NULL) maybe_name = backend._lib.X509_alias_get0(certificate._x509, backend._ffi.NULL) # type: ignore
if maybe_name != backend._ffi.NULL: if maybe_name != backend._ffi.NULL:
friendly_name = backend._ffi.string(maybe_name) friendly_name = backend._ffi.string(maybe_name)
return private_key, certificate, additional_certificates, friendly_name return private_key, certificate, additional_certificates, friendly_name
def cryptography_verify_signature(signature, data, hash_algorithm, signer_public_key): def cryptography_verify_signature(
signature: bytes,
data: bytes,
hash_algorithm: hashes.HashAlgorithm | None,
signer_public_key: PublicKeyTypes,
) -> bool:
""" """
Check whether the given signature of the given data was signed by the given public key object. Check whether the given signature of the given data was signed by the given public key object.
""" """
@@ -812,6 +943,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
signer_public_key, signer_public_key,
cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey,
): ):
if hash_algorithm is None:
raise OpenSSLObjectError("Need hash_algorithm for RSA keys")
signer_public_key.verify( signer_public_key.verify(
signature, data, padding.PKCS1v15(), hash_algorithm signature, data, padding.PKCS1v15(), hash_algorithm
) )
@@ -820,6 +953,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
signer_public_key, signer_public_key,
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey,
): ):
if hash_algorithm is None:
raise OpenSSLObjectError("Need hash_algorithm for ECC keys")
signer_public_key.verify( signer_public_key.verify(
signature, signature,
data, data,
@@ -830,6 +965,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
signer_public_key, signer_public_key,
cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey,
): ):
if hash_algorithm is None:
raise OpenSSLObjectError("Need hash_algorithm for DSA keys")
signer_public_key.verify(signature, data, hash_algorithm) signer_public_key.verify(signature, data, hash_algorithm)
return True return True
if isinstance( if isinstance(
@@ -851,7 +988,9 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
return False return False
def cryptography_verify_certificate_signature(certificate, signer_public_key): def cryptography_verify_certificate_signature(
certificate: x509.Certificate, signer_public_key: PublicKeyTypes
) -> bool:
""" """
Check whether the given X509 certificate object was signed by the given public key object. Check whether the given X509 certificate object was signed by the given public key object.
""" """
@@ -863,21 +1002,65 @@ def cryptography_verify_certificate_signature(certificate, signer_public_key):
) )
def get_not_valid_after(obj): def get_not_valid_after(obj: x509.Certificate) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE: if CRYPTOGRAPHY_TIMEZONE:
return obj.not_valid_after_utc return obj.not_valid_after_utc
return obj.not_valid_after return obj.not_valid_after
def get_not_valid_before(obj): def get_not_valid_before(obj: x509.Certificate) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE: if CRYPTOGRAPHY_TIMEZONE:
return obj.not_valid_before_utc return obj.not_valid_before_utc
return obj.not_valid_before return obj.not_valid_before
def set_not_valid_after(builder, value): def set_not_valid_after(
builder: x509.CertificateBuilder, value: datetime.datetime
) -> x509.CertificateBuilder:
return builder.not_valid_after(value) return builder.not_valid_after(value)
def set_not_valid_before(builder, value): def set_not_valid_before(
builder: x509.CertificateBuilder, value: datetime.datetime
) -> x509.CertificateBuilder:
return builder.not_valid_before(value) return builder.not_valid_before(value)
def is_potential_certificate_private_key(
key: PrivateKeyTypes,
) -> t.TypeGuard[CertificatePrivateKeyTypes]:
return not isinstance(
key, cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey
)
def is_potential_certificate_issuer_private_key(
key: PrivateKeyTypes,
) -> t.TypeGuard[CertificateIssuerPrivateKeyTypes]:
return not isinstance(
key,
(
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey,
),
)
def is_potential_certificate_public_key(
key: PublicKeyTypes,
) -> t.TypeGuard[CertificatePublicKeyTypes]:
return not isinstance(key, DHPublicKey)
def is_potential_certificate_issuer_public_key(
key: PublicKeyTypes,
) -> t.TypeGuard[CertificateIssuerPublicKeyTypes]:
return not isinstance(
key,
(
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey,
cryptography.hazmat.primitives.asymmetric.dh.DHPublicKey,
),
)

View File

@@ -5,7 +5,7 @@
from __future__ import annotations from __future__ import annotations
def binary_exp_mod(f, e, m): def binary_exp_mod(f: int, e: int, m: int) -> int:
"""Computes f^e mod m in O(log e) multiplications modulo m.""" """Computes f^e mod m in O(log e) multiplications modulo m."""
# Compute len_e = floor(log_2(e)) # Compute len_e = floor(log_2(e))
len_e = -1 len_e = -1
@@ -22,14 +22,14 @@ def binary_exp_mod(f, e, m):
return result return result
def simple_gcd(a, b): def simple_gcd(a: int, b: int) -> int:
"""Compute GCD of its two inputs.""" """Compute GCD of its two inputs."""
while b != 0: while b != 0:
a, b = b, a % b a, b = b, a % b
return a return a
def quick_is_not_prime(n): def quick_is_not_prime(n: int) -> bool:
"""Does some quick checks to see if we can poke a hole into the primality of n. """Does some quick checks to see if we can poke a hole into the primality of n.
A result of `False` does **not** mean that the number is prime; it just means A result of `False` does **not** mean that the number is prime; it just means
@@ -97,7 +97,7 @@ def quick_is_not_prime(n):
return False return False
def count_bytes(no): def count_bytes(no: int) -> int:
""" """
Given an integer, compute the number of bytes necessary to store its absolute value. Given an integer, compute the number of bytes necessary to store its absolute value.
""" """
@@ -107,7 +107,7 @@ def count_bytes(no):
return (no.bit_length() + 7) // 8 return (no.bit_length() + 7) // 8
def count_bits(no): def count_bits(no: int) -> int:
""" """
Given an integer, compute the number of bits necessary to store its absolute value. Given an integer, compute the number of bits necessary to store its absolute value.
""" """
@@ -117,19 +117,7 @@ def count_bits(no):
return no.bit_length() return no.bit_length()
def _convert_int_to_bytes(count, no): def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
return no.to_bytes(count, byteorder="big")
def _convert_bytes_to_int(data):
return int.from_bytes(data, byteorder="big", signed=False)
def _to_hex(no):
return f"{no:x}"
def convert_int_to_bytes(no, count=None):
""" """
Convert the absolute value of an integer to a byte string in network byte order. Convert the absolute value of an integer to a byte string in network byte order.
@@ -142,10 +130,10 @@ def convert_int_to_bytes(no, count=None):
no = abs(no) no = abs(no)
if count is None: if count is None:
count = count_bytes(no) count = count_bytes(no)
return _convert_int_to_bytes(count, no) return no.to_bytes(count, byteorder="big")
def convert_int_to_hex(no, digits=None): def convert_int_to_hex(no: int, digits: int | None = None) -> str:
""" """
Convert the absolute value of an integer to a string of hexadecimal digits. Convert the absolute value of an integer to a string of hexadecimal digits.
@@ -154,14 +142,14 @@ def convert_int_to_hex(no, digits=None):
the string will be longer. the string will be longer.
""" """
no = abs(no) no = abs(no)
value = _to_hex(no) value = f"{no:x}"
if digits is not None and len(value) < digits: if digits is not None and len(value) < digits:
value = "0" * (digits - len(value)) + value value = "0" * (digits - len(value)) + value
return value return value
def convert_bytes_to_int(data): def convert_bytes_to_int(data: bytes) -> int:
""" """
Convert a byte string to an unsigned integer in network byte order. Convert a byte string to an unsigned integer in network byte order.
""" """
return _convert_bytes_to_int(data) return int.from_bytes(data, byteorder="big", signed=False)

View File

@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.argspec import ( from ansible_collections.community.crypto.plugins.module_utils.argspec import (
ArgumentSpec, ArgumentSpec,
@@ -24,8 +25,8 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
) )
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
load_certificate, load_certificate,
load_certificate_privatekey,
load_certificate_request, load_certificate_request,
load_privatekey,
) )
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import ( from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION, COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -33,6 +34,17 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
) )
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
)
from ..cryptography_support import CertificatePrivateKeyTypes
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try: try:
@@ -47,41 +59,45 @@ class CertificateError(OpenSSLObjectError):
class CertificateBackend(metaclass=abc.ABCMeta): class CertificateBackend(metaclass=abc.ABCMeta):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
self.module = module self.module = module
self.force = module.params["force"] self.force: bool = module.params["force"]
self.ignore_timestamps = module.params["ignore_timestamps"] self.ignore_timestamps: bool = module.params["ignore_timestamps"]
self.privatekey_path = module.params["privatekey_path"] self.privatekey_path: str | None = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"] privatekey_content: str | None = module.params["privatekey_content"]
if self.privatekey_content is not None: if privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8") self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"] else:
self.csr_path = module.params["csr_path"] self.privatekey_content = None
self.csr_content = module.params["csr_content"] self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
if self.csr_content is not None: self.csr_path: str | None = module.params["csr_path"]
self.csr_content = self.csr_content.encode("utf-8") csr_content = module.params["csr_content"]
if csr_content is not None:
self.csr_content: bytes | None = csr_content.encode("utf-8")
else:
self.csr_content = None
# The following are default values which make sure check() works as # The following are default values which make sure check() works as
# before if providers do not explicitly change these properties. # before if providers do not explicitly change these properties.
self.create_subject_key_identifier = "never_create" self.create_subject_key_identifier: str = "never_create"
self.create_authority_key_identifier = False self.create_authority_key_identifier: bool = False
self.privatekey = None self.privatekey: CertificatePrivateKeyTypes | None = None
self.csr = None self.csr: x509.CertificateSigningRequest | None = None
self.cert = None self.cert: x509.Certificate | None = None
self.existing_certificate = None self.existing_certificate: x509.Certificate | None = None
self.existing_certificate_bytes = None self.existing_certificate_bytes: bytes | None = None
self.check_csr_subject = True self.check_csr_subject: bool = True
self.check_csr_extensions = True self.check_csr_extensions: bool = True
self.diff_before = self._get_info(None) self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None) self.diff_after = self._get_info(None)
def _get_info(self, data): def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None: if data is None:
return dict() return {}
try: try:
result = get_certificate_info( result = get_certificate_info(
self.module, data, prefer_one_fingerprint=True self.module, data, prefer_one_fingerprint=True
@@ -92,34 +108,34 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return dict(can_parse_certificate=False) return dict(can_parse_certificate=False)
@abc.abstractmethod @abc.abstractmethod
def generate_certificate(self): def generate_certificate(self) -> None:
"""(Re-)Generate certificate.""" """(Re-)Generate certificate."""
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_certificate_data(self): def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert.""" """Return bytes for self.cert."""
pass pass
def set_existing(self, certificate_bytes): def set_existing(self, certificate_bytes: bytes | None) -> None:
"""Set existing certificate bytes. None indicates that the key does not exist.""" """Set existing certificate bytes. None indicates that the key does not exist."""
self.existing_certificate_bytes = certificate_bytes self.existing_certificate_bytes = certificate_bytes
self.diff_after = self.diff_before = self._get_info( self.diff_after = self.diff_before = self._get_info(
self.existing_certificate_bytes self.existing_certificate_bytes
) )
def has_existing(self): def has_existing(self) -> bool:
"""Query whether an existing certificate is/has been there.""" """Query whether an existing certificate is/has been there."""
return self.existing_certificate_bytes is not None return self.existing_certificate_bytes is not None
def _ensure_private_key_loaded(self): def _ensure_private_key_loaded(self) -> None:
"""Load the provided private key into self.privatekey.""" """Load the provided private key into self.privatekey."""
if self.privatekey is not None: if self.privatekey is not None:
return return
if self.privatekey_path is None and self.privatekey_content is None: if self.privatekey_path is None and self.privatekey_content is None:
return return
try: try:
self.privatekey = load_privatekey( self.privatekey = load_certificate_privatekey(
path=self.privatekey_path, path=self.privatekey_path,
content=self.privatekey_content, content=self.privatekey_content,
passphrase=self.privatekey_passphrase, passphrase=self.privatekey_passphrase,
@@ -127,7 +143,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
except OpenSSLBadPassphraseError as exc: except OpenSSLBadPassphraseError as exc:
raise CertificateError(exc) raise CertificateError(exc)
def _ensure_csr_loaded(self): def _ensure_csr_loaded(self) -> None:
"""Load the CSR into self.csr.""" """Load the CSR into self.csr."""
if self.csr is not None: if self.csr is not None:
return return
@@ -138,7 +154,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
content=self.csr_content, content=self.csr_content,
) )
def _ensure_existing_certificate_loaded(self): def _ensure_existing_certificate_loaded(self) -> None:
"""Load the existing certificate into self.existing_certificate.""" """Load the existing certificate into self.existing_certificate."""
if self.existing_certificate is not None: if self.existing_certificate is not None:
return return
@@ -149,14 +165,28 @@ class CertificateBackend(metaclass=abc.ABCMeta):
content=self.existing_certificate_bytes, content=self.existing_certificate_bytes,
) )
def _check_privatekey(self): def _check_privatekey(self) -> bool:
"""Check whether provided parameters match, assuming self.existing_certificate and self.privatekey have been populated.""" """Check whether provided parameters match, assuming self.existing_certificate and self.privatekey have been populated."""
if self.existing_certificate is None:
raise AssertionError(
"Contract violation: existing_certificate has not been populated"
)
if self.privatekey is None:
raise AssertionError(
"Contract violation: privatekey has not been populated"
)
return cryptography_compare_public_keys( return cryptography_compare_public_keys(
self.existing_certificate.public_key(), self.privatekey.public_key() self.existing_certificate.public_key(), self.privatekey.public_key()
) )
def _check_csr(self): def _check_csr(self) -> bool:
"""Check whether provided parameters match, assuming self.existing_certificate and self.csr have been populated.""" """Check whether provided parameters match, assuming self.existing_certificate and self.csr have been populated."""
if self.existing_certificate is None:
raise AssertionError(
"Contract violation: existing_certificate has not been populated"
)
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
# Verify that CSR is signed by certificate's private key # Verify that CSR is signed by certificate's private key
if not self.csr.is_signature_valid: if not self.csr.is_signature_valid:
return False return False
@@ -214,8 +244,14 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return False return False
return True return True
def _check_subject_key_identifier(self): def _check_subject_key_identifier(self) -> bool:
"""Check whether Subject Key Identifier matches, assuming self.existing_certificate has been populated.""" """Check whether Subject Key Identifier matches, assuming self.existing_certificate and self.csr have been populated."""
if self.existing_certificate is None:
raise AssertionError(
"Contract violation: existing_certificate has not been populated"
)
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
# Get hold of certificate's SKI # Get hold of certificate's SKI
try: try:
ext = self.existing_certificate.extensions.get_extension_for_class( ext = self.existing_certificate.extensions.get_extension_for_class(
@@ -247,7 +283,11 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return False return False
return True return True
def needs_regeneration(self, not_before=None, not_after=None): def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
"""Check whether a regeneration is necessary.""" """Check whether a regeneration is necessary."""
if self.force or self.existing_certificate_bytes is None: if self.force or self.existing_certificate_bytes is None:
return True return True
@@ -256,6 +296,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
self._ensure_existing_certificate_loaded() self._ensure_existing_certificate_loaded()
except Exception: except Exception:
return True return True
assert self.existing_certificate is not None
# Check whether private key matches # Check whether private key matches
self._ensure_private_key_loaded() self._ensure_private_key_loaded()
@@ -285,9 +326,12 @@ class CertificateBackend(metaclass=abc.ABCMeta):
return True return True
return False return False
def dump(self, include_certificate): def dump(self, include_certificate: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = {"privatekey": self.privatekey_path, "csr": self.csr_path} result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
"csr": self.csr_path,
}
# Get hold of certificate bytes # Get hold of certificate bytes
certificate_bytes = self.existing_certificate_bytes certificate_bytes = self.existing_certificate_bytes
if self.cert is not None: if self.cert is not None:
@@ -299,35 +343,33 @@ class CertificateBackend(metaclass=abc.ABCMeta):
certificate_bytes.decode("utf-8") if certificate_bytes else None certificate_bytes.decode("utf-8") if certificate_bytes else None
) )
result["diff"] = dict( result["diff"] = {
before=self.diff_before, "before": self.diff_before,
after=self.diff_after, "after": self.diff_after,
) }
return result return result
class CertificateProvider(metaclass=abc.ABCMeta): class CertificateProvider(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def validate_module_args(self, module): def validate_module_args(self, module: AnsibleModule) -> None:
"""Check module arguments""" """Check module arguments"""
@abc.abstractmethod @abc.abstractmethod
def needs_version_two_certs(self, module): def needs_version_two_certs(self, module: AnsibleModule) -> bool:
"""Whether the provider needs to create a version 2 certificate.""" """Whether the provider needs to create a version 2 certificate."""
@abc.abstractmethod @abc.abstractmethod
def create_backend(self, module): def create_backend(self, module: AnsibleModule) -> CertificateBackend:
"""Create an implementation for a backend. """Create an implementation for a backend.
Return value must be instance of CertificateBackend. Return value must be instance of CertificateBackend.
""" """
def select_backend(module, provider): def select_backend(
""" module: AnsibleModule, provider: CertificateProvider
:type module: AnsibleModule ) -> CertificateBackend:
:type provider: CertificateProvider
"""
provider.validate_module_args(module) provider.validate_module_args(module)
assert_required_cryptography_version( assert_required_cryptography_version(
@@ -343,7 +385,7 @@ def select_backend(module, provider):
return provider.create_backend(module) return provider.create_backend(module)
def get_certificate_argument_spec(): def get_certificate_argument_spec() -> ArgumentSpec:
return ArgumentSpec( return ArgumentSpec(
argument_spec=dict( argument_spec=dict(
provider=dict( provider=dict(

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import os import os
import tempfile import tempfile
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
@@ -17,22 +18,30 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
) )
class AcmeCertificateBackend(CertificateBackend): if t.TYPE_CHECKING:
def __init__(self, module): from ansible.module_utils.basic import AnsibleModule
super(AcmeCertificateBackend, self).__init__(module)
self.accountkey_path = module.params["acme_accountkey_path"]
self.challenge_path = module.params["acme_challenge_path"]
self.use_chain = module.params["acme_chain"]
self.acme_directory = module.params["acme_directory"]
if self.csr_content is None and self.csr_path is None: from ...argspec import ArgumentSpec
raise CertificateError(
"csr_path or csr_content is required for ownca provider"
) class AcmeCertificateBackend(CertificateBackend):
if self.csr_content is None and not os.path.exists(self.csr_path): def __init__(self, module: AnsibleModule) -> None:
raise CertificateError( super(AcmeCertificateBackend, self).__init__(module)
f"The certificate signing request file {self.csr_path} does not exist" self.accountkey_path: str = module.params["acme_accountkey_path"]
) self.challenge_path: str = module.params["acme_challenge_path"]
self.use_chain: bool = module.params["acme_chain"]
self.acme_directory: str = module.params["acme_directory"]
self.cert_bytes: bytes | None = None
if self.csr_content is None:
if self.csr_path is None:
raise CertificateError(
"csr_path or csr_content is required for ownca provider"
)
if not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
if not os.path.exists(self.accountkey_path): if not os.path.exists(self.accountkey_path):
raise CertificateError( raise CertificateError(
@@ -46,7 +55,7 @@ class AcmeCertificateBackend(CertificateBackend):
self.acme_tiny_path = self.module.get_bin_path("acme-tiny", required=True) self.acme_tiny_path = self.module.get_bin_path("acme-tiny", required=True)
def generate_certificate(self): def generate_certificate(self) -> None:
"""(Re-)Generate certificate.""" """(Re-)Generate certificate."""
command = [self.acme_tiny_path] command = [self.acme_tiny_path]
@@ -77,22 +86,26 @@ class AcmeCertificateBackend(CertificateBackend):
command.extend(["--directory-url", self.acme_directory]) command.extend(["--directory-url", self.acme_directory])
try: try:
self.cert = to_bytes(self.module.run_command(command, check_rc=True)[1]) self.cert_bytes = to_bytes(
self.module.run_command(command, check_rc=True)[1]
)
except OSError as exc: except OSError as exc:
raise CertificateError(exc) raise CertificateError(exc)
def get_certificate_data(self): def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert.""" """Return bytes for self.cert."""
return self.cert if self.cert_bytes is None:
raise AssertionError("Contract violation: cert_bytes is None")
return self.cert_bytes
def dump(self, include_certificate): def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(AcmeCertificateBackend, self).dump(include_certificate) result = super(AcmeCertificateBackend, self).dump(include_certificate)
result["accountkey"] = self.accountkey_path result["accountkey"] = self.accountkey_path
return result return result
class AcmeCertificateProvider(CertificateProvider): class AcmeCertificateProvider(CertificateProvider):
def validate_module_args(self, module): def validate_module_args(self, module: AnsibleModule) -> None:
if module.params["acme_accountkey_path"] is None: if module.params["acme_accountkey_path"] is None:
module.fail_json( module.fail_json(
msg="The acme_accountkey_path option must be specified for the acme provider." msg="The acme_accountkey_path option must be specified for the acme provider."
@@ -102,14 +115,14 @@ class AcmeCertificateProvider(CertificateProvider):
msg="The acme_challenge_path option must be specified for the acme provider." msg="The acme_challenge_path option must be specified for the acme provider."
) )
def needs_version_two_certs(self, module): def needs_version_two_certs(self, module: AnsibleModule) -> bool:
return False return False
def create_backend(self, module): def create_backend(self, module: AnsibleModule) -> AcmeCertificateBackend:
return AcmeCertificateBackend(module) return AcmeCertificateBackend(module)
def add_acme_provider_to_argument_spec(argument_spec): def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("acme") argument_spec.argument_spec["provider"]["choices"].append("acme")
argument_spec.argument_spec.update( argument_spec.argument_spec.update(
dict( dict(

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import datetime import datetime
import os import os
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
@@ -32,6 +33,12 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ...argspec import ArgumentSpec
try: try:
from cryptography.x509.oid import NameOID from cryptography.x509.oid import NameOID
except ImportError: except ImportError:
@@ -39,7 +46,7 @@ except ImportError:
class EntrustCertificateBackend(CertificateBackend): class EntrustCertificateBackend(CertificateBackend):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(EntrustCertificateBackend, self).__init__(module) super(EntrustCertificateBackend, self).__init__(module)
self.trackingId = None self.trackingId = None
self.notAfter = get_relative_time_option( self.notAfter = get_relative_time_option(
@@ -48,16 +55,19 @@ class EntrustCertificateBackend(CertificateBackend):
with_timezone=CRYPTOGRAPHY_TIMEZONE, with_timezone=CRYPTOGRAPHY_TIMEZONE,
) )
if self.csr_content is None and self.csr_path is None: if self.csr_content is None:
raise CertificateError( if self.csr_path is None:
"csr_path or csr_content is required for entrust provider" raise CertificateError(
) "csr_path or csr_content is required for entrust provider"
if self.csr_content is None and not os.path.exists(self.csr_path): )
raise CertificateError( if not os.path.exists(self.csr_path):
f"The certificate signing request file {self.csr_path} does not exist" raise CertificateError(
) f"The certificate signing request file {self.csr_path} does not exist"
)
self._ensure_csr_loaded() self._ensure_csr_loaded()
if self.csr is None:
raise CertificateError("CSR not provided")
# ECS API defaults to using the validated organization tied to the account. # ECS API defaults to using the validated organization tied to the account.
# We want to always force behavior of trying to use the organization provided in the CSR. # We want to always force behavior of trying to use the organization provided in the CSR.
@@ -93,9 +103,9 @@ class EntrustCertificateBackend(CertificateBackend):
], ],
) )
except SessionConfigurationException as e: except SessionConfigurationException as e:
module.fail_json(msg=f"Failed to initialize Entrust Provider: {e.message}") module.fail_json(msg=f"Failed to initialize Entrust Provider: {e}")
def generate_certificate(self): def generate_certificate(self) -> None:
"""(Re-)Generate certificate.""" """(Re-)Generate certificate."""
body = {} body = {}
@@ -104,6 +114,7 @@ class EntrustCertificateBackend(CertificateBackend):
# csr_content contains bytes # csr_content contains bytes
body["csr"] = to_native(self.csr_content) body["csr"] = to_native(self.csr_content)
else: else:
assert self.csr_path is not None
with open(self.csr_path, "r") as csr_file: with open(self.csr_path, "r") as csr_file:
body["csr"] = csr_file.read() body["csr"] = csr_file.read()
@@ -138,11 +149,15 @@ class EntrustCertificateBackend(CertificateBackend):
content=self.cert_bytes, content=self.cert_bytes,
) )
def get_certificate_data(self): def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert.""" """Return bytes for self.cert."""
return self.cert_bytes return self.cert_bytes
def needs_regeneration(self): def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
parent_check = super(EntrustCertificateBackend, self).needs_regeneration() parent_check = super(EntrustCertificateBackend, self).needs_regeneration()
try: try:
@@ -167,12 +182,12 @@ class EntrustCertificateBackend(CertificateBackend):
return parent_check return parent_check
def _get_cert_details(self): def _get_cert_details(self) -> dict[str, t.Any]:
cert_details = {} cert_details: dict[str, t.Any] = {}
try: try:
self._ensure_existing_certificate_loaded() self._ensure_existing_certificate_loaded()
except Exception: except Exception:
return return cert_details
if self.existing_certificate: if self.existing_certificate:
serial_number = f"{self.existing_certificate.serial_number:X}" serial_number = f"{self.existing_certificate.serial_number:X}"
expiry = get_not_valid_after(self.existing_certificate) expiry = get_not_valid_after(self.existing_certificate)
@@ -203,17 +218,17 @@ class EntrustCertificateBackend(CertificateBackend):
class EntrustCertificateProvider(CertificateProvider): class EntrustCertificateProvider(CertificateProvider):
def validate_module_args(self, module): def validate_module_args(self, module: AnsibleModule) -> None:
pass pass
def needs_version_two_certs(self, module): def needs_version_two_certs(self, module: AnsibleModule) -> t.Literal[False]:
return False return False
def create_backend(self, module): def create_backend(self, module: AnsibleModule) -> EntrustCertificateBackend:
return EntrustCertificateBackend(module) return EntrustCertificateBackend(module)
def add_entrust_provider_to_argument_spec(argument_spec): def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("entrust") argument_spec.argument_spec["provider"]["choices"].append("entrust")
argument_spec.argument_spec.update( argument_spec.argument_spec.update(
dict( dict(
@@ -248,7 +263,7 @@ def add_entrust_provider_to_argument_spec(argument_spec):
) )
) )
argument_spec.required_if.append( argument_spec.required_if.append(
[ (
"provider", "provider",
"entrust", "entrust",
[ [
@@ -260,5 +275,5 @@ def add_entrust_provider_to_argument_spec(argument_spec):
"entrust_api_client_cert_path", "entrust_api_client_cert_path",
"entrust_api_client_cert_key_path", "entrust_api_client_cert_key_path",
], ],
] )
) )

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import abc import abc
import binascii import binascii
import typing as t
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
@@ -34,6 +35,19 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
) )
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes
from ....plugin_utils.action_module import AnsibleActionModule
from ....plugin_utils.filter_module import FilterModuleMock
from ...argspec import ArgumentSpec
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try: try:
@@ -48,93 +62,97 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CertificateInfoRetrieval(metaclass=abc.ABCMeta): class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(self, module, content): def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
# content must be a bytes string # content must be a bytes string
self.module = module self.module = module
self.content = content self.content = content
@abc.abstractmethod @abc.abstractmethod
def _get_der_bytes(self): def _get_der_bytes(self) -> bytes:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_signature_algorithm(self): def _get_signature_algorithm(self) -> str:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_subject_ordered(self): def _get_subject_ordered(self) -> list[list[str]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_issuer_ordered(self): def _get_issuer_ordered(self) -> list[list[str]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_version(self): def _get_version(self) -> int | str:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_key_usage(self): def _get_key_usage(self) -> tuple[list[str] | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_extended_key_usage(self): def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_basic_constraints(self): def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_ocsp_must_staple(self): def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_subject_alt_name(self): def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_not_before(self): def get_not_before(self) -> datetime.datetime:
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_not_after(self): def get_not_after(self) -> datetime.datetime:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_public_key_pem(self): def _get_public_key_pem(self) -> bytes:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_public_key_object(self): def _get_public_key_object(self) -> PublicKeyTypes:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_subject_key_identifier(self): def _get_subject_key_identifier(self) -> bytes | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_authority_key_identifier(self): def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_serial_number(self): def _get_serial_number(self) -> int:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_all_extensions(self): def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_ocsp_uri(self): def _get_ocsp_uri(self) -> str | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_issuer_uri(self): def _get_issuer_uri(self) -> str | None:
pass pass
def get_info(self, prefer_one_fingerprint=False, der_support_enabled=False): def get_info(
result = dict() self, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.cert = load_certificate( self.cert = load_certificate(
None, None,
content=self.content, content=self.content,
@@ -194,16 +212,20 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
self._get_der_bytes(), prefer_one=prefer_one_fingerprint self._get_der_bytes(), prefer_one=prefer_one_fingerprint
) )
ski = self._get_subject_key_identifier() ski_bytes = self._get_subject_key_identifier()
if ski is not None: if ski_bytes is not None:
ski = binascii.hexlify(ski).decode("ascii") ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)]) ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
else:
ski = None
result["subject_key_identifier"] = ski result["subject_key_identifier"] = ski
aki, aci, acsn = self._get_authority_key_identifier() aki_bytes, aci, acsn = self._get_authority_key_identifier()
if aki is not None: if aki_bytes is not None:
aki = binascii.hexlify(aki).decode("ascii") aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)]) aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
else:
aki = None
result["authority_key_identifier"] = aki result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn result["authority_cert_serial_number"] = acsn
@@ -219,36 +241,40 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval): class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
"""Validate the supplied cert, using the cryptography backend""" """Validate the supplied cert, using the cryptography backend"""
def __init__(self, module, content): def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
super(CertificateInfoRetrievalCryptography, self).__init__(module, content) super(CertificateInfoRetrievalCryptography, self).__init__(module, content)
self.name_encoding = module.params.get("name_encoding", "ignore") self.name_encoding = module.params.get("name_encoding", "ignore")
def _get_der_bytes(self): def _get_der_bytes(self) -> bytes:
return self.cert.public_bytes(serialization.Encoding.DER) return self.cert.public_bytes(serialization.Encoding.DER)
def _get_signature_algorithm(self): def _get_signature_algorithm(self) -> str:
return cryptography_oid_to_name(self.cert.signature_algorithm_oid) return cryptography_oid_to_name(self.cert.signature_algorithm_oid)
def _get_subject_ordered(self): def _get_subject_ordered(self) -> list[list[str]]:
result = [] result: list[list[str]] = []
for attribute in self.cert.subject: for attribute in self.cert.subject:
result.append([cryptography_oid_to_name(attribute.oid), attribute.value]) result.append(
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
)
return result return result
def _get_issuer_ordered(self): def _get_issuer_ordered(self) -> list[list[str]]:
result = [] result = []
for attribute in self.cert.issuer: for attribute in self.cert.issuer:
result.append([cryptography_oid_to_name(attribute.oid), attribute.value]) result.append(
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
)
return result return result
def _get_version(self): def _get_version(self) -> int | str:
if self.cert.version == x509.Version.v1: if self.cert.version == x509.Version.v1:
return 1 return 1
if self.cert.version == x509.Version.v3: if self.cert.version == x509.Version.v3:
return 3 return 3
return "unknown" return "unknown"
def _get_key_usage(self): def _get_key_usage(self) -> tuple[list[str] | None, bool]:
try: try:
current_key_ext = self.cert.extensions.get_extension_for_class( current_key_ext = self.cert.extensions.get_extension_for_class(
x509.KeyUsage x509.KeyUsage
@@ -297,7 +323,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, False return None, False
def _get_extended_key_usage(self): def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
try: try:
ext_keyusage_ext = self.cert.extensions.get_extension_for_class( ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
x509.ExtendedKeyUsage x509.ExtendedKeyUsage
@@ -311,7 +337,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, False return None, False
def _get_basic_constraints(self): def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
try: try:
ext_keyusage_ext = self.cert.extensions.get_extension_for_class( ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
x509.BasicConstraints x509.BasicConstraints
@@ -324,7 +350,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, False return None, False
def _get_ocsp_must_staple(self): def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
try: try:
tlsfeature_ext = self.cert.extensions.get_extension_for_class( tlsfeature_ext = self.cert.extensions.get_extension_for_class(
x509.TLSFeature x509.TLSFeature
@@ -336,7 +362,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, False return None, False
def _get_subject_alt_name(self): def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
try: try:
san_ext = self.cert.extensions.get_extension_for_class( san_ext = self.cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName x509.SubjectAlternativeName
@@ -349,22 +375,22 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, False return None, False
def get_not_before(self): def get_not_before(self) -> datetime.datetime:
return get_not_valid_before(self.cert) return get_not_valid_before(self.cert)
def get_not_after(self): def get_not_after(self) -> datetime.datetime:
return get_not_valid_after(self.cert) return get_not_valid_after(self.cert)
def _get_public_key_pem(self): def _get_public_key_pem(self) -> bytes:
return self.cert.public_key().public_bytes( return self.cert.public_key().public_bytes(
serialization.Encoding.PEM, serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo, serialization.PublicFormat.SubjectPublicKeyInfo,
) )
def _get_public_key_object(self): def _get_public_key_object(self) -> PublicKeyTypes:
return self.cert.public_key() return self.cert.public_key()
def _get_subject_key_identifier(self): def _get_subject_key_identifier(self) -> bytes | None:
try: try:
ext = self.cert.extensions.get_extension_for_class( ext = self.cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier x509.SubjectKeyIdentifier
@@ -373,7 +399,9 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None return None
def _get_authority_key_identifier(self): def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
try: try:
ext = self.cert.extensions.get_extension_for_class( ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier x509.AuthorityKeyIdentifier
@@ -392,13 +420,13 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, None, None return None, None, None
def _get_serial_number(self): def _get_serial_number(self) -> int:
return self.cert.serial_number return self.cert.serial_number
def _get_all_extensions(self): def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
return cryptography_get_extensions_from_cert(self.cert) return cryptography_get_extensions_from_cert(self.cert)
def _get_ocsp_uri(self): def _get_ocsp_uri(self) -> str | None:
try: try:
ext = self.cert.extensions.get_extension_for_class( ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityInformationAccess x509.AuthorityInformationAccess
@@ -411,7 +439,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
pass pass
return None return None
def _get_issuer_uri(self): def _get_issuer_uri(self) -> str | None:
try: try:
ext = self.cert.extensions.get_extension_for_class( ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityInformationAccess x509.AuthorityInformationAccess
@@ -428,12 +456,16 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
return None return None
def get_certificate_info(module, content, prefer_one_fingerprint=False): def get_certificate_info(
module: GeneralAnsibleModule, content: bytes, prefer_one_fingerprint: bool = False
) -> dict[str, t.Any]:
info = CertificateInfoRetrievalCryptography(module, content) info = CertificateInfoRetrievalCryptography(module, content)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint) return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(module, content): def select_backend(
module: GeneralAnsibleModule, content: bytes
) -> CertificateInfoRetrieval:
assert_required_cryptography_version( assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )

View File

@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
import typing as t
from random import randrange from random import randrange
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -18,6 +19,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
cryptography_verify_certificate_signature, cryptography_verify_certificate_signature,
get_not_valid_after, get_not_valid_after,
get_not_valid_before, get_not_valid_before,
is_potential_certificate_issuer_public_key,
set_not_valid_after, set_not_valid_after,
set_not_valid_before, set_not_valid_before,
) )
@@ -28,7 +30,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
) )
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
load_certificate, load_certificate,
load_privatekey, load_certificate_issuer_privatekey,
select_message_digest, select_message_digest,
) )
from ansible_collections.community.crypto.plugins.module_utils.time import ( from ansible_collections.community.crypto.plugins.module_utils.time import (
@@ -36,6 +38,17 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
) )
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
)
from ...argspec import ArgumentSpec
try: try:
import cryptography import cryptography
from cryptography import x509 from cryptography import x509
@@ -45,13 +58,13 @@ except ImportError:
class OwnCACertificateBackendCryptography(CertificateBackend): class OwnCACertificateBackendCryptography(CertificateBackend):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(OwnCACertificateBackendCryptography, self).__init__(module) super(OwnCACertificateBackendCryptography, self).__init__(module)
self.create_subject_key_identifier = module.params[ self.create_subject_key_identifier: t.Literal[
"ownca_create_subject_key_identifier" "create_if_not_provided", "always_create", "never_create"
] ] = module.params["ownca_create_subject_key_identifier"]
self.create_authority_key_identifier = module.params[ self.create_authority_key_identifier: bool = module.params[
"ownca_create_authority_key_identifier" "ownca_create_authority_key_identifier"
] ]
self.notBefore = get_relative_time_option( self.notBefore = get_relative_time_option(
@@ -65,31 +78,40 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
with_timezone=CRYPTOGRAPHY_TIMEZONE, with_timezone=CRYPTOGRAPHY_TIMEZONE,
) )
self.digest = select_message_digest(module.params["ownca_digest"]) self.digest = select_message_digest(module.params["ownca_digest"])
self.version = module.params["ownca_version"] self.version: int = module.params["ownca_version"]
self.serial_number = x509.random_serial_number() self.serial_number = x509.random_serial_number()
self.ca_cert_path = module.params["ownca_path"] self.ca_cert_path: str | None = module.params["ownca_path"]
self.ca_cert_content = module.params["ownca_content"] ca_cert_content: str | None = module.params["ownca_content"]
if self.ca_cert_content is not None: if ca_cert_content is not None:
self.ca_cert_content = self.ca_cert_content.encode("utf-8") self.ca_cert_content: bytes | None = ca_cert_content.encode("utf-8")
self.ca_privatekey_path = module.params["ownca_privatekey_path"] else:
self.ca_privatekey_content = module.params["ownca_privatekey_content"] self.ca_cert_content = None
if self.ca_privatekey_content is not None: self.ca_privatekey_path: str | None = module.params["ownca_privatekey_path"]
self.ca_privatekey_content = self.ca_privatekey_content.encode("utf-8") ca_privatekey_content: str | None = module.params["ownca_privatekey_content"]
self.ca_privatekey_passphrase = module.params["ownca_privatekey_passphrase"] if ca_privatekey_content is not None:
self.ca_privatekey_content: bytes | None = ca_privatekey_content.encode(
"utf-8"
)
else:
self.ca_privatekey_content = None
self.ca_privatekey_passphrase: str | None = module.params[
"ownca_privatekey_passphrase"
]
if self.csr_content is None and self.csr_path is None: if self.csr_content is None:
raise CertificateError( if self.csr_path is None:
"csr_path or csr_content is required for ownca provider" raise CertificateError(
) "csr_path or csr_content is required for ownca provider"
if self.csr_content is None and not os.path.exists(self.csr_path): )
raise CertificateError( if not os.path.exists(self.csr_path):
f"The certificate signing request file {self.csr_path} does not exist" raise CertificateError(
) f"The certificate signing request file {self.csr_path} does not exist"
if self.ca_cert_content is None and not os.path.exists(self.ca_cert_path): )
if self.ca_cert_path is not None and not os.path.exists(self.ca_cert_path):
raise CertificateError( raise CertificateError(
f"The CA certificate file {self.ca_cert_path} does not exist" f"The CA certificate file {self.ca_cert_path} does not exist"
) )
if self.ca_privatekey_content is None and not os.path.exists( if self.ca_privatekey_path is not None and not os.path.exists(
self.ca_privatekey_path self.ca_privatekey_path
): ):
raise CertificateError( raise CertificateError(
@@ -101,8 +123,12 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
path=self.ca_cert_path, path=self.ca_cert_path,
content=self.ca_cert_content, content=self.ca_cert_content,
) )
if not is_potential_certificate_issuer_public_key(self.ca_cert.public_key()):
raise CertificateError(
"CA certificate's public key cannot be used to sign certificates"
)
try: try:
self.ca_private_key = load_privatekey( self.ca_private_key = load_certificate_issuer_privatekey(
path=self.ca_privatekey_path, path=self.ca_privatekey_path,
content=self.ca_privatekey_content, content=self.ca_privatekey_content,
passphrase=self.ca_privatekey_passphrase, passphrase=self.ca_privatekey_passphrase,
@@ -125,8 +151,10 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
else: else:
self.digest = None self.digest = None
def generate_certificate(self): def generate_certificate(self) -> None:
"""(Re-)Generate certificate.""" """(Re-)Generate certificate."""
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
cert_builder = x509.CertificateBuilder() cert_builder = x509.CertificateBuilder()
cert_builder = cert_builder.subject_name(self.csr.subject) cert_builder = cert_builder.subject_name(self.csr.subject)
cert_builder = cert_builder.issuer_name(self.ca_cert.subject) cert_builder = cert_builder.issuer_name(self.ca_cert.subject)
@@ -166,10 +194,10 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
critical=False, critical=False,
) )
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
public_key = self.ca_cert.public_key()
assert is_potential_certificate_issuer_public_key(public_key)
cert_builder = cert_builder.add_extension( cert_builder = cert_builder.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_public_key( x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key),
self.ca_cert.public_key()
),
critical=False, critical=False,
) )
@@ -180,17 +208,24 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
self.cert = certificate self.cert = certificate
def get_certificate_data(self): def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert.""" """Return bytes for self.cert."""
if self.cert is None:
raise AssertionError("Contract violation: cert has not been populated")
return self.cert.public_bytes(Encoding.PEM) return self.cert.public_bytes(Encoding.PEM)
def needs_regeneration(self): def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
if super(OwnCACertificateBackendCryptography, self).needs_regeneration( if super(OwnCACertificateBackendCryptography, self).needs_regeneration(
not_before=self.notBefore, not_after=self.notAfter not_before=self.notBefore, not_after=self.notAfter
): ):
return True return True
self._ensure_existing_certificate_loaded() self._ensure_existing_certificate_loaded()
assert self.existing_certificate is not None
# Check whether certificate is signed by CA certificate # Check whether certificate is signed by CA certificate
if not cryptography_verify_certificate_signature( if not cryptography_verify_certificate_signature(
@@ -205,31 +240,33 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
# Check AuthorityKeyIdentifier # Check AuthorityKeyIdentifier
if self.create_authority_key_identifier: if self.create_authority_key_identifier:
try: try:
ext = self.ca_cert.extensions.get_extension_for_class( ext_ski = self.ca_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier x509.SubjectKeyIdentifier
) )
expected_ext = ( expected_ext = (
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ext.value ext_ski.value
) )
) )
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
public_key = self.ca_cert.public_key()
assert is_potential_certificate_issuer_public_key(public_key)
expected_ext = x509.AuthorityKeyIdentifier.from_issuer_public_key( expected_ext = x509.AuthorityKeyIdentifier.from_issuer_public_key(
self.ca_cert.public_key() public_key
) )
try: try:
ext = self.existing_certificate.extensions.get_extension_for_class( ext_aki = self.existing_certificate.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier x509.AuthorityKeyIdentifier
) )
if ext.value != expected_ext: if ext_aki.value != expected_ext:
return True return True
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return True return True
return False return False
def dump(self, include_certificate): def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(OwnCACertificateBackendCryptography, self).dump( result = super(OwnCACertificateBackendCryptography, self).dump(
include_certificate include_certificate
) )
@@ -251,6 +288,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
else: else:
if self.cert is None: if self.cert is None:
self.cert = self.existing_certificate self.cert = self.existing_certificate
assert self.cert is not None
result.update( result.update(
{ {
"notBefore": get_not_valid_before(self.cert).strftime( "notBefore": get_not_valid_before(self.cert).strftime(
@@ -266,7 +304,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
return result return result
def generate_serial_number(): def generate_serial_number() -> int:
"""Generate a serial number for a certificate""" """Generate a serial number for a certificate"""
while True: while True:
result = randrange(0, 1 << 160) result = randrange(0, 1 << 160)
@@ -275,7 +313,7 @@ def generate_serial_number():
class OwnCACertificateProvider(CertificateProvider): class OwnCACertificateProvider(CertificateProvider):
def validate_module_args(self, module): def validate_module_args(self, module: AnsibleModule) -> None:
if ( if (
module.params["ownca_path"] is None module.params["ownca_path"] is None
and module.params["ownca_content"] is None and module.params["ownca_content"] is None
@@ -291,14 +329,16 @@ class OwnCACertificateProvider(CertificateProvider):
msg="One of ownca_privatekey_path and ownca_privatekey_content must be specified for the ownca provider." msg="One of ownca_privatekey_path and ownca_privatekey_content must be specified for the ownca provider."
) )
def needs_version_two_certs(self, module): def needs_version_two_certs(self, module: AnsibleModule) -> bool:
return module.params["ownca_version"] == 2 return module.params["ownca_version"] == 2
def create_backend(self, module): def create_backend(
self, module: AnsibleModule
) -> OwnCACertificateBackendCryptography:
return OwnCACertificateBackendCryptography(module) return OwnCACertificateBackendCryptography(module)
def add_ownca_provider_to_argument_spec(argument_spec): def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("ownca") argument_spec.argument_spec["provider"]["choices"].append("ownca")
argument_spec.argument_spec.update( argument_spec.argument_spec.update(
dict( dict(

View File

@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
import typing as t
from random import randrange from random import randrange
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
@@ -14,6 +15,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
cryptography_verify_certificate_signature, cryptography_verify_certificate_signature,
get_not_valid_after, get_not_valid_after,
get_not_valid_before, get_not_valid_before,
is_potential_certificate_issuer_private_key,
set_not_valid_after, set_not_valid_after,
set_not_valid_before, set_not_valid_before,
) )
@@ -30,6 +32,17 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
) )
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
)
from ...argspec import ArgumentSpec
try: try:
import cryptography import cryptography
from cryptography import x509 from cryptography import x509
@@ -39,12 +52,14 @@ except ImportError:
class SelfSignedCertificateBackendCryptography(CertificateBackend): class SelfSignedCertificateBackendCryptography(CertificateBackend):
def __init__(self, module): privatekey: CertificateIssuerPrivateKeyTypes
def __init__(self, module: AnsibleModule) -> None:
super(SelfSignedCertificateBackendCryptography, self).__init__(module) super(SelfSignedCertificateBackendCryptography, self).__init__(module)
self.create_subject_key_identifier = module.params[ self.create_subject_key_identifier: t.Literal[
"selfsigned_create_subject_key_identifier" "create_if_not_provided", "always_create", "never_create"
] ] = module.params["selfsigned_create_subject_key_identifier"]
self.notBefore = get_relative_time_option( self.notBefore = get_relative_time_option(
module.params["selfsigned_not_before"], module.params["selfsigned_not_before"],
"selfsigned_not_before", "selfsigned_not_before",
@@ -56,14 +71,16 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
with_timezone=CRYPTOGRAPHY_TIMEZONE, with_timezone=CRYPTOGRAPHY_TIMEZONE,
) )
self.digest = select_message_digest(module.params["selfsigned_digest"]) self.digest = select_message_digest(module.params["selfsigned_digest"])
self.version = module.params["selfsigned_version"] self.version: int = module.params["selfsigned_version"]
self.serial_number = x509.random_serial_number() self.serial_number = x509.random_serial_number()
if self.csr_path is not None and not os.path.exists(self.csr_path): if self.csr_path is not None and not os.path.exists(self.csr_path):
raise CertificateError( raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist" f"The certificate signing request file {self.csr_path} does not exist"
) )
if self.privatekey_content is None and not os.path.exists(self.privatekey_path): if self.privatekey_path is not None and not os.path.exists(
self.privatekey_path
):
raise CertificateError( raise CertificateError(
f"The private key file {self.privatekey_path} does not exist" f"The private key file {self.privatekey_path} does not exist"
) )
@@ -71,20 +88,10 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
self._module = module self._module = module
self._ensure_private_key_loaded() self._ensure_private_key_loaded()
if self.privatekey is None:
self._ensure_csr_loaded() raise CertificateError("Private key has not been provided")
if self.csr is None: if not is_potential_certificate_issuer_private_key(self.privatekey):
# Create empty CSR on the fly raise CertificateError("Private key cannot be used to sign certificates")
csr = cryptography.x509.CertificateSigningRequestBuilder()
csr = csr.subject_name(cryptography.x509.Name([]))
digest = None
if cryptography_key_needs_digest_for_signing(self.privatekey):
digest = self.digest
if digest is None:
self.module.fail_json(
msg=f'Unsupported digest "{module.params["selfsigned_digest"]}"'
)
self.csr = csr.sign(self.privatekey, digest)
if cryptography_key_needs_digest_for_signing(self.privatekey): if cryptography_key_needs_digest_for_signing(self.privatekey):
if self.digest is None: if self.digest is None:
@@ -94,8 +101,21 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
else: else:
self.digest = None self.digest = None
def generate_certificate(self): self._ensure_csr_loaded()
if self.csr is None:
# Create empty CSR on the fly
csr = cryptography.x509.CertificateSigningRequestBuilder()
csr = csr.subject_name(cryptography.x509.Name([]))
self.csr = csr.sign(self.privatekey, self.digest)
def generate_certificate(self) -> None:
"""(Re-)Generate certificate.""" """(Re-)Generate certificate."""
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
if self.privatekey is None:
raise AssertionError(
"Contract violation: privatekey has not been populated"
)
try: try:
cert_builder = x509.CertificateBuilder() cert_builder = x509.CertificateBuilder()
cert_builder = cert_builder.subject_name(self.csr.subject) cert_builder = cert_builder.subject_name(self.csr.subject)
@@ -130,17 +150,26 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
self.cert = certificate self.cert = certificate
def get_certificate_data(self): def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert.""" """Return bytes for self.cert."""
if self.cert is None:
raise AssertionError("Contract violation: cert has not been populated")
return self.cert.public_bytes(Encoding.PEM) return self.cert.public_bytes(Encoding.PEM)
def needs_regeneration(self): def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
assert self.privatekey is not None
if super(SelfSignedCertificateBackendCryptography, self).needs_regeneration( if super(SelfSignedCertificateBackendCryptography, self).needs_regeneration(
not_before=self.notBefore, not_after=self.notAfter not_before=self.notBefore, not_after=self.notAfter
): ):
return True return True
self._ensure_existing_certificate_loaded() self._ensure_existing_certificate_loaded()
assert self.existing_certificate is not None
# Check whether certificate is signed by private key # Check whether certificate is signed by private key
if not cryptography_verify_certificate_signature( if not cryptography_verify_certificate_signature(
@@ -150,7 +179,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
return False return False
def dump(self, include_certificate): def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(SelfSignedCertificateBackendCryptography, self).dump( result = super(SelfSignedCertificateBackendCryptography, self).dump(
include_certificate include_certificate
) )
@@ -166,6 +195,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
else: else:
if self.cert is None: if self.cert is None:
self.cert = self.existing_certificate self.cert = self.existing_certificate
assert self.cert is not None
result.update( result.update(
{ {
"notBefore": get_not_valid_before(self.cert).strftime( "notBefore": get_not_valid_before(self.cert).strftime(
@@ -181,7 +211,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
return result return result
def generate_serial_number(): def generate_serial_number() -> int:
"""Generate a serial number for a certificate""" """Generate a serial number for a certificate"""
while True: while True:
result = randrange(0, 1 << 160) result = randrange(0, 1 << 160)
@@ -190,7 +220,7 @@ def generate_serial_number():
class SelfSignedCertificateProvider(CertificateProvider): class SelfSignedCertificateProvider(CertificateProvider):
def validate_module_args(self, module): def validate_module_args(self, module: AnsibleModule) -> None:
if ( if (
module.params["privatekey_path"] is None module.params["privatekey_path"] is None
and module.params["privatekey_content"] is None and module.params["privatekey_content"] is None
@@ -199,14 +229,16 @@ class SelfSignedCertificateProvider(CertificateProvider):
msg="One of privatekey_path and privatekey_content must be specified for the selfsigned provider." msg="One of privatekey_path and privatekey_content must be specified for the selfsigned provider."
) )
def needs_version_two_certs(self, module): def needs_version_two_certs(self, module: AnsibleModule) -> bool:
return module.params["selfsigned_version"] == 2 return module.params["selfsigned_version"] == 2
def create_backend(self, module): def create_backend(
self, module: AnsibleModule
) -> SelfSignedCertificateBackendCryptography:
return SelfSignedCertificateBackendCryptography(module) return SelfSignedCertificateBackendCryptography(module)
def add_selfsigned_provider_to_argument_spec(argument_spec): def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("selfsigned") argument_spec.argument_spec["provider"]["choices"].append("selfsigned")
argument_spec.argument_spec.update( argument_spec.argument_spec.update(
dict( dict(

View File

@@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_crl import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_crl import (
TIMESTAMP_FORMAT, TIMESTAMP_FORMAT,
cryptography_decode_revoked_certificate, cryptography_decode_revoked_certificate,
@@ -22,6 +24,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
from ....plugin_utils.action_module import AnsibleActionModule
from ....plugin_utils.filter_module import FilterModuleMock
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
# crypto_utils # crypto_utils
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
@@ -33,14 +47,19 @@ except ImportError:
class CRLInfoRetrieval: class CRLInfoRetrieval:
def __init__(self, module, content, list_revoked_certificates=True): def __init__(
self,
module: GeneralAnsibleModule,
content: bytes,
list_revoked_certificates: bool = True,
) -> None:
# content must be a bytes string # content must be a bytes string
self.module = module self.module = module
self.content = content self.content = content
self.list_revoked_certificates = list_revoked_certificates self.list_revoked_certificates = list_revoked_certificates
self.name_encoding = module.params.get("name_encoding", "ignore") self.name_encoding = module.params.get("name_encoding", "ignore")
def get_info(self): def get_info(self) -> dict[str, t.Any]:
self.crl_pem = identify_pem_format(self.content) self.crl_pem = identify_pem_format(self.content)
try: try:
if self.crl_pem: if self.crl_pem:
@@ -50,7 +69,7 @@ class CRLInfoRetrieval:
except ValueError as e: except ValueError as e:
self.module.fail_json(msg=f"Error while decoding CRL: {e}") self.module.fail_json(msg=f"Error while decoding CRL: {e}")
result = { result: dict[str, t.Any] = {
"changed": False, "changed": False,
"format": "pem" if self.crl_pem else "der", "format": "pem" if self.crl_pem else "der",
"last_update": None, "last_update": None,
@@ -61,7 +80,11 @@ class CRLInfoRetrieval:
} }
result["last_update"] = self.crl.last_update.strftime(TIMESTAMP_FORMAT) result["last_update"] = self.crl.last_update.strftime(TIMESTAMP_FORMAT)
result["next_update"] = self.crl.next_update.strftime(TIMESTAMP_FORMAT) result["next_update"] = (
self.crl.next_update.strftime(TIMESTAMP_FORMAT)
if self.crl.next_update
else None
)
result["digest"] = cryptography_oid_to_name( result["digest"] = cryptography_oid_to_name(
cryptography_get_signature_algorithm_oid_from_crl(self.crl) cryptography_get_signature_algorithm_oid_from_crl(self.crl)
) )
@@ -83,7 +106,9 @@ class CRLInfoRetrieval:
return result return result
def get_crl_info(module, content, list_revoked_certificates=True): def get_crl_info(
module: GeneralAnsibleModule, content: bytes, list_revoked_certificates: bool = True
) -> dict[str, t.Any]:
assert_required_cryptography_version( assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import abc import abc
import binascii import binascii
import typing as t
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.crypto.plugins.module_utils.argspec import ( from ansible_collections.community.crypto.plugins.module_utils.argspec import (
@@ -26,13 +27,14 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
cryptography_name_to_oid, cryptography_name_to_oid,
cryptography_parse_key_usage_params, cryptography_parse_key_usage_params,
cryptography_parse_relative_distinguished_name, cryptography_parse_relative_distinguished_name,
is_potential_certificate_issuer_public_key,
) )
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr_info import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr_info import (
get_csr_info, get_csr_info,
) )
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
load_certificate_issuer_privatekey,
load_certificate_request, load_certificate_request,
load_privatekey,
parse_name_field, parse_name_field,
parse_ordered_name_field, parse_ordered_name_field,
select_message_digest, select_message_digest,
@@ -43,6 +45,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)
from ..cryptography_support import CertificatePrivateKeyTypes
_ET = t.TypeVar("_ET", bound="cryptography.x509.ExtensionType")
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try: try:
@@ -69,49 +83,58 @@ class CertificateSigningRequestError(OpenSSLObjectError):
class CertificateSigningRequestBackend(metaclass=abc.ABCMeta): class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
self.module = module self.module = module
self.digest = module.params["digest"] self.digest: str = module.params["digest"]
self.privatekey_path = module.params["privatekey_path"] self.privatekey_path: str | None = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"] privatekey_content: str | None = module.params["privatekey_content"]
if self.privatekey_content is not None: if privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8") self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"] else:
self.version = module.params["version"] self.privatekey_content = None
self.subjectAltName = module.params["subject_alt_name"] self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.subjectAltName_critical = module.params["subject_alt_name_critical"] self.version: t.Literal[1] = module.params["version"]
self.keyUsage = module.params["key_usage"] self.subjectAltName: list[str] | None = module.params["subject_alt_name"]
self.keyUsage_critical = module.params["key_usage_critical"] self.subjectAltName_critical: bool = module.params["subject_alt_name_critical"]
self.extendedKeyUsage = module.params["extended_key_usage"] self.keyUsage: list[str] | None = module.params["key_usage"]
self.extendedKeyUsage_critical = module.params["extended_key_usage_critical"] self.keyUsage_critical: bool = module.params["key_usage_critical"]
self.basicConstraints = module.params["basic_constraints"] self.extendedKeyUsage: list[str] | None = module.params["extended_key_usage"]
self.basicConstraints_critical = module.params["basic_constraints_critical"] self.extendedKeyUsage_critical: bool = module.params[
self.ocspMustStaple = module.params["ocsp_must_staple"] "extended_key_usage_critical"
self.ocspMustStaple_critical = module.params["ocsp_must_staple_critical"] ]
self.name_constraints_permitted = ( self.basicConstraints: list[str] | None = module.params["basic_constraints"]
self.basicConstraints_critical: bool = module.params[
"basic_constraints_critical"
]
self.ocspMustStaple: bool = module.params["ocsp_must_staple"]
self.ocspMustStaple_critical: bool = module.params["ocsp_must_staple_critical"]
self.name_constraints_permitted: list[str] = (
module.params["name_constraints_permitted"] or [] module.params["name_constraints_permitted"] or []
) )
self.name_constraints_excluded = ( self.name_constraints_excluded: list[str] = (
module.params["name_constraints_excluded"] or [] module.params["name_constraints_excluded"] or []
) )
self.name_constraints_critical = module.params["name_constraints_critical"] self.name_constraints_critical: bool = module.params[
self.create_subject_key_identifier = module.params[ "name_constraints_critical"
]
self.create_subject_key_identifier: bool = module.params[
"create_subject_key_identifier" "create_subject_key_identifier"
] ]
self.subject_key_identifier = module.params["subject_key_identifier"] subject_key_identifier: str | None = module.params["subject_key_identifier"]
self.authority_key_identifier = module.params["authority_key_identifier"] authority_key_identifier: str | None = module.params["authority_key_identifier"]
self.authority_cert_issuer = module.params["authority_cert_issuer"] self.authority_cert_issuer: list[str] | None = module.params[
self.authority_cert_serial_number = module.params[ "authority_cert_issuer"
]
self.authority_cert_serial_number: int = module.params[
"authority_cert_serial_number" "authority_cert_serial_number"
] ]
self.crl_distribution_points = module.params["crl_distribution_points"] self.crl_distribution_points: (
self.csr = None list[cryptography.x509.DistributionPoint] | None
self.privatekey = None ) = None
self.csr: cryptography.x509.CertificateSigningRequest | None = None
self.privatekey: CertificateIssuerPrivateKeyTypes | None = None
if ( if self.create_subject_key_identifier and subject_key_identifier is not None:
self.create_subject_key_identifier
and self.subject_key_identifier is not None
):
module.fail_json( module.fail_json(
msg="subject_key_identifier cannot be specified if create_subject_key_identifier is true" msg="subject_key_identifier cannot be specified if create_subject_key_identifier is true"
) )
@@ -153,35 +176,37 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self.using_common_name_for_san = True self.using_common_name_for_san = True
break break
if self.subject_key_identifier is not None: self.subject_key_identifier: bytes | None = None
if subject_key_identifier is not None:
try: try:
self.subject_key_identifier = binascii.unhexlify( self.subject_key_identifier = binascii.unhexlify(
self.subject_key_identifier.replace(":", "") subject_key_identifier.replace(":", "")
) )
except Exception as e: except Exception as e:
raise CertificateSigningRequestError( raise CertificateSigningRequestError(
f"Cannot parse subject_key_identifier: {e}" f"Cannot parse subject_key_identifier: {e}"
) )
if self.authority_key_identifier is not None: self.authority_key_identifier: bytes | None = None
if authority_key_identifier is not None:
try: try:
self.authority_key_identifier = binascii.unhexlify( self.authority_key_identifier = binascii.unhexlify(
self.authority_key_identifier.replace(":", "") authority_key_identifier.replace(":", "")
) )
except Exception as e: except Exception as e:
raise CertificateSigningRequestError( raise CertificateSigningRequestError(
f"Cannot parse authority_key_identifier: {e}" f"Cannot parse authority_key_identifier: {e}"
) )
self.existing_csr = None self.existing_csr: cryptography.x509.CertificateSigningRequest | None = None
self.existing_csr_bytes = None self.existing_csr_bytes: bytes | None = None
self.diff_before = self._get_info(None) self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None) self.diff_after = self._get_info(None)
def _get_info(self, data): def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None: if data is None:
return dict() return {}
try: try:
result = get_csr_info( result = get_csr_info(
self.module, self.module,
@@ -195,30 +220,28 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
return dict(can_parse_csr=False) return dict(can_parse_csr=False)
@abc.abstractmethod @abc.abstractmethod
def generate_csr(self): def generate_csr(self) -> None:
"""(Re-)Generate CSR.""" """(Re-)Generate CSR."""
pass
@abc.abstractmethod @abc.abstractmethod
def get_csr_data(self): def get_csr_data(self) -> bytes:
"""Return bytes for self.csr.""" """Return bytes for self.csr."""
pass
def set_existing(self, csr_bytes): def set_existing(self, csr_bytes: bytes | None) -> None:
"""Set existing CSR bytes. None indicates that the CSR does not exist.""" """Set existing CSR bytes. None indicates that the CSR does not exist."""
self.existing_csr_bytes = csr_bytes self.existing_csr_bytes = csr_bytes
self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes) self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes)
def has_existing(self): def has_existing(self) -> bool:
"""Query whether an existing CSR is/has been there.""" """Query whether an existing CSR is/has been there."""
return self.existing_csr_bytes is not None return self.existing_csr_bytes is not None
def _ensure_private_key_loaded(self): def _ensure_private_key_loaded(self) -> None:
"""Load the provided private key into self.privatekey.""" """Load the provided private key into self.privatekey."""
if self.privatekey is not None: if self.privatekey is not None:
return return
try: try:
self.privatekey = load_privatekey( self.privatekey = load_certificate_issuer_privatekey(
path=self.privatekey_path, path=self.privatekey_path,
content=self.privatekey_content, content=self.privatekey_content,
passphrase=self.privatekey_passphrase, passphrase=self.privatekey_passphrase,
@@ -227,11 +250,10 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
raise CertificateSigningRequestError(exc) raise CertificateSigningRequestError(exc)
@abc.abstractmethod @abc.abstractmethod
def _check_csr(self): def _check_csr(self) -> bool:
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated.""" """Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
pass
def needs_regeneration(self): def needs_regeneration(self) -> bool:
"""Check whether a regeneration is necessary.""" """Check whether a regeneration is necessary."""
if self.existing_csr_bytes is None: if self.existing_csr_bytes is None:
return True return True
@@ -245,9 +267,9 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self._ensure_private_key_loaded() self._ensure_private_key_loaded()
return not self._check_csr() return not self._check_csr()
def dump(self, include_csr): def dump(self, include_csr: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = { result: dict[str, t.Any] = {
"privatekey": self.privatekey_path, "privatekey": self.privatekey_path,
"subject": self.subject, "subject": self.subject,
"subjectAltName": self.subjectAltName, "subjectAltName": self.subjectAltName,
@@ -274,44 +296,49 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
return result return result
def parse_crl_distribution_points(module, crl_distribution_points): def parse_crl_distribution_points(
module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
) -> list[cryptography.x509.DistributionPoint]:
result = [] result = []
for index, parse_crl_distribution_point in enumerate(crl_distribution_points): for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
try: try:
params = dict( full_name = None
full_name=None, relative_name = None
relative_name=None, crl_issuer = None
crl_issuer=None, reasons = None
reasons=None,
)
if parse_crl_distribution_point["full_name"] is not None: if parse_crl_distribution_point["full_name"] is not None:
if not parse_crl_distribution_point["full_name"]: if not parse_crl_distribution_point["full_name"]:
raise OpenSSLObjectError("full_name must not be empty") raise OpenSSLObjectError("full_name must not be empty")
params["full_name"] = [ full_name = [
cryptography_get_name(name, "full name") cryptography_get_name(name, "full name")
for name in parse_crl_distribution_point["full_name"] for name in parse_crl_distribution_point["full_name"]
] ]
if parse_crl_distribution_point["relative_name"] is not None: if parse_crl_distribution_point["relative_name"] is not None:
if not parse_crl_distribution_point["relative_name"]: if not parse_crl_distribution_point["relative_name"]:
raise OpenSSLObjectError("relative_name must not be empty") raise OpenSSLObjectError("relative_name must not be empty")
params["relative_name"] = ( relative_name = cryptography_parse_relative_distinguished_name(
cryptography_parse_relative_distinguished_name( parse_crl_distribution_point["relative_name"]
parse_crl_distribution_point["relative_name"]
)
) )
if parse_crl_distribution_point["crl_issuer"] is not None: if parse_crl_distribution_point["crl_issuer"] is not None:
if not parse_crl_distribution_point["crl_issuer"]: if not parse_crl_distribution_point["crl_issuer"]:
raise OpenSSLObjectError("crl_issuer must not be empty") raise OpenSSLObjectError("crl_issuer must not be empty")
params["crl_issuer"] = [ crl_issuer = [
cryptography_get_name(name, "CRL issuer") cryptography_get_name(name, "CRL issuer")
for name in parse_crl_distribution_point["crl_issuer"] for name in parse_crl_distribution_point["crl_issuer"]
] ]
if parse_crl_distribution_point["reasons"] is not None: if parse_crl_distribution_point["reasons"] is not None:
reasons = [] reasons_list = []
for reason in parse_crl_distribution_point["reasons"]: for reason in parse_crl_distribution_point["reasons"]:
reasons.append(REVOCATION_REASON_MAP[reason]) reasons_list.append(REVOCATION_REASON_MAP[reason])
params["reasons"] = frozenset(reasons) reasons = frozenset(reasons_list)
result.append(cryptography.x509.DistributionPoint(**params)) result.append(
cryptography.x509.DistributionPoint(
full_name=full_name,
relative_name=relative_name,
crl_issuer=crl_issuer,
reasons=reasons,
)
)
except (OpenSSLObjectError, ValueError) as e: except (OpenSSLObjectError, ValueError) as e:
raise OpenSSLObjectError( raise OpenSSLObjectError(
f"Error while parsing CRL distribution point #{index}: {e}" f"Error while parsing CRL distribution point #{index}: {e}"
@@ -321,21 +348,25 @@ def parse_crl_distribution_points(module, crl_distribution_points):
# Implementation with using cryptography # Implementation with using cryptography
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend): class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(CertificateSigningRequestCryptographyBackend, self).__init__(module) super(CertificateSigningRequestCryptographyBackend, self).__init__(module)
if self.version != 1: if self.version != 1:
module.warn( module.warn(
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)" "The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
) )
if self.crl_distribution_points: crl_distribution_points: list[dict[str, t.Any]] | None = module.params[
"crl_distribution_points"
]
if crl_distribution_points:
self.crl_distribution_points = parse_crl_distribution_points( self.crl_distribution_points = parse_crl_distribution_points(
module, self.crl_distribution_points module, crl_distribution_points
) )
def generate_csr(self): def generate_csr(self) -> None:
"""(Re-)Generate CSR.""" """(Re-)Generate CSR."""
self._ensure_private_key_loaded() self._ensure_private_key_loaded()
assert self.privatekey is not None
csr = cryptography.x509.CertificateSigningRequestBuilder() csr = cryptography.x509.CertificateSigningRequestBuilder()
try: try:
@@ -412,6 +443,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
raise OpenSSLObjectError(f"Error while parsing name constraint: {e}") raise OpenSSLObjectError(f"Error while parsing name constraint: {e}")
if self.create_subject_key_identifier: if self.create_subject_key_identifier:
if not is_potential_certificate_issuer_public_key(
self.privatekey.public_key()
):
raise OpenSSLObjectError(
"Private key can not be used to create subject key identifier"
)
csr = csr.add_extension( csr = csr.add_extension(
cryptography.x509.SubjectKeyIdentifier.from_public_key( cryptography.x509.SubjectKeyIdentifier.from_public_key(
self.privatekey.public_key() self.privatekey.public_key()
@@ -450,7 +487,10 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
critical=False, critical=False,
) )
digest = None # csr.sign() does not accept some digests we theoretically could have in digest.
# For that reason we use type t.Any here. csr.sign() will complain if
# the digest is not acceptable.
digest: t.Any | None = None
if cryptography_key_needs_digest_for_signing(self.privatekey): if cryptography_key_needs_digest_for_signing(self.privatekey):
digest = select_message_digest(self.digest) digest = select_message_digest(self.digest)
if digest is None: if digest is None:
@@ -482,16 +522,22 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
+ "This is probably caused by an invalid Subject Alternative DNS Name." + "This is probably caused by an invalid Subject Alternative DNS Name."
) )
def get_csr_data(self): def get_csr_data(self) -> bytes:
"""Return bytes for self.csr.""" """Return bytes for self.csr."""
if self.csr is None:
raise AssertionError("Violated contract: csr is not populated")
return self.csr.public_bytes( return self.csr.public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM cryptography.hazmat.primitives.serialization.Encoding.PEM
) )
def _check_csr(self): def _check_csr(self) -> bool:
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated.""" """Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
if self.existing_csr is None:
raise AssertionError("Violated contract: existing_csr is not populated")
if self.privatekey is None:
raise AssertionError("Violated contract: privatekey is not populated")
def _check_subject(csr): def _check_subject(csr: cryptography.x509.CertificateSigningRequest) -> bool:
subject = [ subject = [
(cryptography_name_to_oid(entry[0]), to_text(entry[1])) (cryptography_name_to_oid(entry[0]), to_text(entry[1]))
for entry in self.subject for entry in self.subject
@@ -502,12 +548,14 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else: else:
return set(subject) == set(current_subject) return set(subject) == set(current_subject)
def _find_extension(extensions, exttype): def _find_extension(
extensions: cryptography.x509.Extensions, exttype: type[_ET]
) -> cryptography.x509.Extension[_ET] | None:
return next( return next(
(ext for ext in extensions if isinstance(ext.value, exttype)), None (ext for ext in extensions if isinstance(ext.value, exttype)), None
) )
def _check_subjectAltName(extensions): def _check_subjectAltName(extensions: cryptography.x509.Extensions) -> bool:
current_altnames_ext = _find_extension( current_altnames_ext = _find_extension(
extensions, cryptography.x509.SubjectAlternativeName extensions, cryptography.x509.SubjectAlternativeName
) )
@@ -526,12 +574,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
) )
if set(altnames) != set(current_altnames): if set(altnames) != set(current_altnames):
return False return False
if altnames: if altnames and current_altnames_ext:
if current_altnames_ext.critical != self.subjectAltName_critical: if current_altnames_ext.critical != self.subjectAltName_critical:
return False return False
return True return True
def _check_keyUsage(extensions): def _check_keyUsage(extensions: cryptography.x509.Extensions) -> bool:
current_keyusage_ext = _find_extension( current_keyusage_ext = _find_extension(
extensions, cryptography.x509.KeyUsage extensions, cryptography.x509.KeyUsage
) )
@@ -547,7 +595,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
return False return False
return True return True
def _check_extenededKeyUsage(extensions): def _check_extenededKeyUsage(extensions: cryptography.x509.Extensions) -> bool:
current_usages_ext = _find_extension( current_usages_ext = _find_extension(
extensions, cryptography.x509.ExtendedKeyUsage extensions, cryptography.x509.ExtendedKeyUsage
) )
@@ -566,12 +614,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
) )
if set(current_usages) != set(usages): if set(current_usages) != set(usages):
return False return False
if usages: if usages and current_usages_ext:
if current_usages_ext.critical != self.extendedKeyUsage_critical: if current_usages_ext.critical != self.extendedKeyUsage_critical:
return False return False
return True return True
def _check_basicConstraints(extensions): def _check_basicConstraints(extensions: cryptography.x509.Extensions) -> bool:
bc_ext = _find_extension(extensions, cryptography.x509.BasicConstraints) bc_ext = _find_extension(extensions, cryptography.x509.BasicConstraints)
current_ca = bc_ext.value.ca if bc_ext else False current_ca = bc_ext.value.ca if bc_ext else False
current_path_length = bc_ext.value.path_length if bc_ext else None current_path_length = bc_ext.value.path_length if bc_ext else None
@@ -591,7 +639,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else: else:
return bc_ext is None return bc_ext is None
def _check_ocspMustStaple(extensions): def _check_ocspMustStaple(extensions: cryptography.x509.Extensions) -> bool:
tlsfeature_ext = _find_extension(extensions, cryptography.x509.TLSFeature) tlsfeature_ext = _find_extension(extensions, cryptography.x509.TLSFeature)
if self.ocspMustStaple: if self.ocspMustStaple:
if ( if (
@@ -606,7 +654,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else: else:
return tlsfeature_ext is None return tlsfeature_ext is None
def _check_nameConstraints(extensions): def _check_nameConstraints(extensions: cryptography.x509.Extensions) -> bool:
current_nc_ext = _find_extension( current_nc_ext = _find_extension(
extensions, cryptography.x509.NameConstraints extensions, cryptography.x509.NameConstraints
) )
@@ -638,12 +686,14 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
current_nc_excl current_nc_excl
): ):
return False return False
if nc_perm or nc_excl: if (nc_perm or nc_excl) and current_nc_ext:
if current_nc_ext.critical != self.name_constraints_critical: if current_nc_ext.critical != self.name_constraints_critical:
return False return False
return True return True
def _check_subject_key_identifier(extensions): def _check_subject_key_identifier(
extensions: cryptography.x509.Extensions,
) -> bool:
ext = _find_extension(extensions, cryptography.x509.SubjectKeyIdentifier) ext = _find_extension(extensions, cryptography.x509.SubjectKeyIdentifier)
if ( if (
self.create_subject_key_identifier self.create_subject_key_identifier
@@ -652,6 +702,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
if not ext or ext.critical: if not ext or ext.critical:
return False return False
if self.create_subject_key_identifier: if self.create_subject_key_identifier:
assert self.privatekey is not None
digest = cryptography.x509.SubjectKeyIdentifier.from_public_key( digest = cryptography.x509.SubjectKeyIdentifier.from_public_key(
self.privatekey.public_key() self.privatekey.public_key()
).digest ).digest
@@ -661,7 +712,9 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else: else:
return ext is None return ext is None
def _check_authority_key_identifier(extensions): def _check_authority_key_identifier(
extensions: cryptography.x509.Extensions,
) -> bool:
ext = _find_extension(extensions, cryptography.x509.AuthorityKeyIdentifier) ext = _find_extension(extensions, cryptography.x509.AuthorityKeyIdentifier)
if ( if (
self.authority_key_identifier is not None self.authority_key_identifier is not None
@@ -688,7 +741,9 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
else: else:
return ext is None return ext is None
def _check_crl_distribution_points(extensions): def _check_crl_distribution_points(
extensions: cryptography.x509.Extensions,
) -> bool:
ext = _find_extension(extensions, cryptography.x509.CRLDistributionPoints) ext = _find_extension(extensions, cryptography.x509.CRLDistributionPoints)
if self.crl_distribution_points is None: if self.crl_distribution_points is None:
return ext is None return ext is None
@@ -696,7 +751,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
return False return False
return list(ext.value) == self.crl_distribution_points return list(ext.value) == self.crl_distribution_points
def _check_extensions(csr): def _check_extensions(csr: cryptography.x509.CertificateSigningRequest) -> bool:
extensions = csr.extensions extensions = csr.extensions
return ( return (
_check_subjectAltName(extensions) _check_subjectAltName(extensions)
@@ -710,7 +765,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
and _check_crl_distribution_points(extensions) and _check_crl_distribution_points(extensions)
) )
def _check_signature(csr): def _check_signature(csr: cryptography.x509.CertificateSigningRequest) -> bool:
if not csr.is_signature_valid: if not csr.is_signature_valid:
return False return False
# To check whether public key of CSR belongs to private key, # To check whether public key of CSR belongs to private key,
@@ -719,6 +774,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
cryptography.hazmat.primitives.serialization.Encoding.PEM, cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo, cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
) )
assert self.privatekey is not None
key_b = self.privatekey.public_key().public_bytes( key_b = self.privatekey.public_key().public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM, cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo, cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
@@ -732,14 +788,16 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
) )
def select_backend(module): def select_backend(
module: AnsibleModule,
) -> CertificateSigningRequestCryptographyBackend:
assert_required_cryptography_version( assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )
return CertificateSigningRequestCryptographyBackend(module) return CertificateSigningRequestCryptographyBackend(module)
def get_csr_argument_spec(): def get_csr_argument_spec() -> ArgumentSpec:
return ArgumentSpec( return ArgumentSpec(
argument_spec=dict( argument_spec=dict(
digest=dict(type="str", default="sha256"), digest=dict(type="str", default="sha256"),

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import abc import abc
import binascii import binascii
import typing as t
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
@@ -27,6 +28,19 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificatePublicKeyTypes,
PrivateKeyTypes,
)
from ....plugin_utils.action_module import AnsibleActionModule
from ....plugin_utils.filter_module import FilterModuleMock
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try: try:
@@ -41,66 +55,69 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CSRInfoRetrieval(metaclass=abc.ABCMeta): class CSRInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(self, module, content, validate_signature): def __init__(
# content must be a bytes string self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
self.module = module self.module = module
self.content = content self.content = content
self.validate_signature = validate_signature self.validate_signature = validate_signature
@abc.abstractmethod @abc.abstractmethod
def _get_subject_ordered(self): def _get_subject_ordered(self) -> list[list[str]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_key_usage(self): def _get_key_usage(self) -> tuple[list[str] | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_extended_key_usage(self): def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_basic_constraints(self): def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_ocsp_must_staple(self): def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_subject_alt_name(self): def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_name_constraints(self): def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_public_key_pem(self): def _get_public_key_pem(self) -> bytes:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_public_key_object(self): def _get_public_key_object(self) -> CertificatePublicKeyTypes:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_subject_key_identifier(self): def _get_subject_key_identifier(self) -> bytes | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_authority_key_identifier(self): def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_all_extensions(self): def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _is_signature_valid(self): def _is_signature_valid(self) -> bool:
pass pass
def get_info(self, prefer_one_fingerprint=False): def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result = dict() result: dict[str, t.Any] = {}
self.csr = load_certificate_request( self.csr = load_certificate_request(
None, None,
content=self.content, content=self.content,
@@ -145,15 +162,17 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
} }
) )
ski = self._get_subject_key_identifier() ski_bytes = self._get_subject_key_identifier()
if ski is not None: ski = None
ski = binascii.hexlify(ski).decode("ascii") if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)]) ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
result["subject_key_identifier"] = ski result["subject_key_identifier"] = ski
aki, aci, acsn = self._get_authority_key_identifier() aki_bytes, aci, acsn = self._get_authority_key_identifier()
if aki is not None: aki = None
aki = binascii.hexlify(aki).decode("ascii") if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)]) aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
result["authority_key_identifier"] = aki result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci result["authority_cert_issuer"] = aci
@@ -170,19 +189,25 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
class CSRInfoRetrievalCryptography(CSRInfoRetrieval): class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
"""Validate the supplied CSR, using the cryptography backend""" """Validate the supplied CSR, using the cryptography backend"""
def __init__(self, module, content, validate_signature): def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
super(CSRInfoRetrievalCryptography, self).__init__( super(CSRInfoRetrievalCryptography, self).__init__(
module, content, validate_signature module, content, validate_signature
) )
self.name_encoding = module.params.get("name_encoding", "ignore") self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get(
"name_encoding", "ignore"
)
def _get_subject_ordered(self): def _get_subject_ordered(self) -> list[list[str]]:
result = [] result: list[list[str]] = []
for attribute in self.csr.subject: for attribute in self.csr.subject:
result.append([cryptography_oid_to_name(attribute.oid), attribute.value]) result.append(
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
)
return result return result
def _get_key_usage(self): def _get_key_usage(self) -> tuple[list[str] | None, bool]:
try: try:
current_key_ext = self.csr.extensions.get_extension_for_class(x509.KeyUsage) current_key_ext = self.csr.extensions.get_extension_for_class(x509.KeyUsage)
current_key_usage = current_key_ext.value current_key_usage = current_key_ext.value
@@ -229,7 +254,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, False return None, False
def _get_extended_key_usage(self): def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
try: try:
ext_keyusage_ext = self.csr.extensions.get_extension_for_class( ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
x509.ExtendedKeyUsage x509.ExtendedKeyUsage
@@ -243,7 +268,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, False return None, False
def _get_basic_constraints(self): def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
try: try:
ext_keyusage_ext = self.csr.extensions.get_extension_for_class( ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
x509.BasicConstraints x509.BasicConstraints
@@ -255,7 +280,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, False return None, False
def _get_ocsp_must_staple(self): def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
try: try:
# This only works with cryptography >= 2.1 # This only works with cryptography >= 2.1
tlsfeature_ext = self.csr.extensions.get_extension_for_class( tlsfeature_ext = self.csr.extensions.get_extension_for_class(
@@ -268,7 +293,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, False return None, False
def _get_subject_alt_name(self): def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
try: try:
san_ext = self.csr.extensions.get_extension_for_class( san_ext = self.csr.extensions.get_extension_for_class(
x509.SubjectAlternativeName x509.SubjectAlternativeName
@@ -281,7 +306,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, False return None, False
def _get_name_constraints(self): def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
try: try:
nc_ext = self.csr.extensions.get_extension_for_class(x509.NameConstraints) nc_ext = self.csr.extensions.get_extension_for_class(x509.NameConstraints)
permitted = [ permitted = [
@@ -296,23 +321,25 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, None, False return None, None, False
def _get_public_key_pem(self): def _get_public_key_pem(self) -> bytes:
return self.csr.public_key().public_bytes( return self.csr.public_key().public_bytes(
serialization.Encoding.PEM, serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo, serialization.PublicFormat.SubjectPublicKeyInfo,
) )
def _get_public_key_object(self): def _get_public_key_object(self) -> CertificatePublicKeyTypes:
return self.csr.public_key() return self.csr.public_key()
def _get_subject_key_identifier(self): def _get_subject_key_identifier(self) -> bytes | None:
try: try:
ext = self.csr.extensions.get_extension_for_class(x509.SubjectKeyIdentifier) ext = self.csr.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
return ext.value.digest return ext.value.digest
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None return None
def _get_authority_key_identifier(self): def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
try: try:
ext = self.csr.extensions.get_extension_for_class( ext = self.csr.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier x509.AuthorityKeyIdentifier
@@ -331,23 +358,28 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
return None, None, None return None, None, None
def _get_all_extensions(self): def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
return cryptography_get_extensions_from_csr(self.csr) return cryptography_get_extensions_from_csr(self.csr)
def _is_signature_valid(self): def _is_signature_valid(self) -> bool:
return self.csr.is_signature_valid return self.csr.is_signature_valid
def get_csr_info( def get_csr_info(
module, content, validate_signature=True, prefer_one_fingerprint=False module: GeneralAnsibleModule,
): content: bytes,
validate_signature: bool = True,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CSRInfoRetrievalCryptography( info = CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature module, content, validate_signature=validate_signature
) )
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint) return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(module, content, validate_signature=True): def select_backend(
module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
) -> CSRInfoRetrieval:
assert_required_cryptography_version( assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import abc import abc
import base64 import base64
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.argspec import ( from ansible_collections.community.crypto.plugins.module_utils.argspec import (
@@ -33,6 +34,17 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
from ....plugin_utils.action_module import AnsibleActionModule
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try: try:
@@ -64,29 +76,37 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyBackend(metaclass=abc.ABCMeta): class PrivateKeyBackend(metaclass=abc.ABCMeta):
def __init__(self, module): def __init__(self, module: GeneralAnsibleModule) -> None:
self.module = module self.module = module
self.type = module.params["type"] self.type: t.Literal[
self.size = module.params["size"] "DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"
self.curve = module.params["curve"] ] = module.params["type"]
self.passphrase = module.params["passphrase"] self.size: int = module.params["size"]
self.cipher = module.params["cipher"] self.curve: str | None = module.params["curve"]
self.format = module.params["format"] self.passphrase: str | None = module.params["passphrase"]
self.format_mismatch = module.params.get("format_mismatch", "regenerate") self.cipher: str = module.params["cipher"]
self.regenerate = module.params.get("regenerate", "full_idempotence") self.format: t.Literal["pkcs1", "pkcs8", "raw", "auto", "auto_ignore"] = (
module.params["format"]
)
self.format_mismatch: t.Literal["regenerate", "convert"] = module.params.get(
"format_mismatch", "regenerate"
)
self.regenerate: t.Literal[
"never", "fail", "partial_idempotence", "full_idempotence", "always"
] = module.params.get("regenerate", "full_idempotence")
self.private_key = None self.private_key: PrivateKeyTypes | None = None
self.existing_private_key = None self.existing_private_key: PrivateKeyTypes | None = None
self.existing_private_key_bytes = None self.existing_private_key_bytes: bytes | None = None
self.diff_before = self._get_info(None) self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None) self.diff_after = self._get_info(None)
def _get_info(self, data): def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None: if data is None:
return dict() return {}
result = dict(can_parse_key=False) result: dict[str, t.Any] = {"can_parse_key": False}
try: try:
result.update( result.update(
get_privatekey_info( get_privatekey_info(
@@ -106,11 +126,11 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
return result return result
@abc.abstractmethod @abc.abstractmethod
def generate_private_key(self): def generate_private_key(self) -> None:
"""(Re-)Generate private key.""" """(Re-)Generate private key."""
pass pass
def convert_private_key(self): def convert_private_key(self) -> None:
"""Convert existing private key (self.existing_private_key) to new private key (self.private_key). """Convert existing private key (self.existing_private_key) to new private key (self.private_key).
This is effectively a copy without active conversion. The conversion is done This is effectively a copy without active conversion. The conversion is done
@@ -121,42 +141,37 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
self.private_key = self.existing_private_key self.private_key = self.existing_private_key
@abc.abstractmethod @abc.abstractmethod
def get_private_key_data(self): def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key.""" """Return bytes for self.private_key."""
pass
def set_existing(self, privatekey_bytes): def set_existing(self, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist.""" """Set existing private key bytes. None indicates that the key does not exist."""
self.existing_private_key_bytes = privatekey_bytes self.existing_private_key_bytes = privatekey_bytes
self.diff_after = self.diff_before = self._get_info( self.diff_after = self.diff_before = self._get_info(
self.existing_private_key_bytes self.existing_private_key_bytes
) )
def has_existing(self): def has_existing(self) -> bool:
"""Query whether an existing private key is/has been there.""" """Query whether an existing private key is/has been there."""
return self.existing_private_key_bytes is not None return self.existing_private_key_bytes is not None
@abc.abstractmethod @abc.abstractmethod
def _check_passphrase(self): def _check_passphrase(self) -> bool:
"""Check whether provided passphrase matches, assuming self.existing_private_key_bytes has been populated.""" """Check whether provided passphrase matches, assuming self.existing_private_key_bytes has been populated."""
pass
@abc.abstractmethod @abc.abstractmethod
def _ensure_existing_private_key_loaded(self): def _ensure_existing_private_key_loaded(self) -> None:
"""Make sure that self.existing_private_key is populated from self.existing_private_key_bytes.""" """Make sure that self.existing_private_key is populated from self.existing_private_key_bytes."""
pass
@abc.abstractmethod @abc.abstractmethod
def _check_size_and_type(self): def _check_size_and_type(self) -> bool:
"""Check whether provided size and type matches, assuming self.existing_private_key has been populated.""" """Check whether provided size and type matches, assuming self.existing_private_key has been populated."""
pass
@abc.abstractmethod @abc.abstractmethod
def _check_format(self): def _check_format(self) -> bool:
"""Check whether the key file format, assuming self.existing_private_key and self.existing_private_key_bytes has been populated.""" """Check whether the key file format, assuming self.existing_private_key and self.existing_private_key_bytes has been populated."""
pass
def needs_regeneration(self): def needs_regeneration(self) -> bool:
"""Check whether a regeneration is necessary.""" """Check whether a regeneration is necessary."""
if self.regenerate == "always": if self.regenerate == "always":
return True return True
@@ -194,7 +209,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
) )
return False return False
def needs_conversion(self): def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False.""" """Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
# During conversion step, convert if format does not match and format_mismatch == 'convert' # During conversion step, convert if format does not match and format_mismatch == 'convert'
self._ensure_existing_private_key_loaded() self._ensure_existing_private_key_loaded()
@@ -204,7 +219,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
and not self._check_format() and not self._check_format()
) )
def _get_fingerprint(self): def _get_fingerprint(self) -> dict[str, str] | None:
if self.private_key: if self.private_key:
return get_fingerprint_of_privatekey(self.private_key) return get_fingerprint_of_privatekey(self.private_key)
try: try:
@@ -214,8 +229,9 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
pass pass
if self.existing_private_key: if self.existing_private_key:
return get_fingerprint_of_privatekey(self.existing_private_key) return get_fingerprint_of_privatekey(self.existing_private_key)
return None
def dump(self, include_key): def dump(self, include_key: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
if not self.private_key: if not self.private_key:
@@ -224,7 +240,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
except Exception: except Exception:
# Ignore errors # Ignore errors
pass pass
result = { result: dict[str, t.Any] = {
"type": self.type, "type": self.type,
"size": self.size, "size": self.size,
"fingerprint": self._get_fingerprint(), "fingerprint": self._get_fingerprint(),
@@ -253,38 +269,57 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
return result return result
# Implementation with using cryptography class _Curve:
class PrivateKeyCryptographyBackend(PrivateKeyBackend): def __init__(
self,
name: str,
ectype: str,
deprecated: bool,
) -> None:
self.name = name
self.ectype = ectype
self.deprecated = deprecated
def _get_ec_class(self, ectype): def _get_ec_class(
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(ectype) self, module: GeneralAnsibleModule
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore
if ecclass is None: if ecclass is None:
self.module.fail_json( module.fail_json(
msg=f"Your cryptography version does not support {ectype}" msg=f"Your cryptography version does not support {self.ectype}"
) )
return ecclass return ecclass
def _add_curve(self, name, ectype, deprecated=False): def create(
def create(size): self, size: int, module: GeneralAnsibleModule
ecclass = self._get_ec_class(ectype) ) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
return ecclass() ecclass = self._get_ec_class(module)
return ecclass()
def verify(privatekey): def verify(
ecclass = self._get_ec_class(ectype) self,
return isinstance( privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
privatekey.private_numbers().public_numbers.curve, ecclass module: GeneralAnsibleModule,
) ) -> bool:
ecclass = self._get_ec_class(module)
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
self.curves[name] = {
"create": create,
"verify": verify,
"deprecated": deprecated,
}
def __init__(self, module): # Implementation with using cryptography
class PrivateKeyCryptographyBackend(PrivateKeyBackend):
def _add_curve(
self,
name: str,
ectype: str,
deprecated: bool = False,
) -> None:
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
def __init__(self, module: GeneralAnsibleModule) -> None:
super(PrivateKeyCryptographyBackend, self).__init__(module=module) super(PrivateKeyCryptographyBackend, self).__init__(module=module)
self.curves = dict() self.curves: dict[str, _Curve] = {}
self._add_curve("secp224r1", "SECP224R1") self._add_curve("secp224r1", "SECP224R1")
self._add_curve("secp256k1", "SECP256K1") self._add_curve("secp256k1", "SECP256K1")
self._add_curve("secp256r1", "SECP256R1") self._add_curve("secp256r1", "SECP256R1")
@@ -305,15 +340,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
self._add_curve("brainpoolP384r1", "BrainpoolP384R1", deprecated=True) self._add_curve("brainpoolP384r1", "BrainpoolP384R1", deprecated=True)
self._add_curve("brainpoolP512r1", "BrainpoolP512R1", deprecated=True) self._add_curve("brainpoolP512r1", "BrainpoolP512R1", deprecated=True)
def _get_wanted_format(self): def _get_wanted_format(self) -> t.Literal["pkcs1", "pkcs8", "raw"]:
if self.format not in ("auto", "auto_ignore"): if self.format not in ("auto", "auto_ignore"):
return self.format return self.format # type: ignore
if self.type in ("X25519", "X448", "Ed25519", "Ed448"): if self.type in ("X25519", "X448", "Ed25519", "Ed448"):
return "pkcs8" return "pkcs8"
else: else:
return "pkcs1" return "pkcs1"
def generate_private_key(self): def generate_private_key(self) -> None:
"""(Re-)Generate private key.""" """(Re-)Generate private key."""
try: try:
if self.type == "RSA": if self.type == "RSA":
@@ -346,13 +381,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate() cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate()
) )
if self.type == "ECC" and self.curve in self.curves: if self.type == "ECC" and self.curve in self.curves:
if self.curves[self.curve]["deprecated"]: if self.curves[self.curve].deprecated:
self.module.warn( self.module.warn(
f"Elliptic curves of type {self.curve} should not be used for new keys!" f"Elliptic curves of type {self.curve} should not be used for new keys!"
) )
self.private_key = ( self.private_key = (
cryptography.hazmat.primitives.asymmetric.ec.generate_private_key( cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
curve=self.curves[self.curve]["create"](self.size), curve=self.curves[self.curve].create(
size=self.size, module=self.module
),
) )
) )
except cryptography.exceptions.UnsupportedAlgorithm: except cryptography.exceptions.UnsupportedAlgorithm:
@@ -360,22 +397,24 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
msg=f"Cryptography backend does not support the algorithm required for {self.type}" msg=f"Cryptography backend does not support the algorithm required for {self.type}"
) )
def get_private_key_data(self): def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key""" """Return bytes for self.private_key"""
if self.private_key is None:
raise AssertionError("private_key not set")
# Select export format and encoding # Select export format and encoding
try: try:
export_format = self._get_wanted_format() export_format_txt = self._get_wanted_format()
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
if export_format == "pkcs1": if export_format_txt == "pkcs1":
# "TraditionalOpenSSL" format is PKCS1 # "TraditionalOpenSSL" format is PKCS1
export_format = ( export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
) )
elif export_format == "pkcs8": elif export_format_txt == "pkcs8":
export_format = ( export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8 cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
) )
elif export_format == "raw": elif export_format_txt == "raw":
export_format = ( export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.Raw cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
) )
@@ -388,9 +427,9 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
) )
# Select key encryption # Select key encryption
encryption_algorithm = ( encryption_algorithm: (
cryptography.hazmat.primitives.serialization.NoEncryption() cryptography.hazmat.primitives.serialization.KeySerializationEncryption
) ) = cryptography.hazmat.primitives.serialization.NoEncryption()
if self.cipher and self.passphrase: if self.cipher and self.passphrase:
if self.cipher == "auto": if self.cipher == "auto":
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption( encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
@@ -418,8 +457,10 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
exception=traceback.format_exc(), exception=traceback.format_exc(),
) )
def _load_privatekey(self): def _load_privatekey(self) -> PrivateKeyTypes:
data = self.existing_private_key_bytes data = self.existing_private_key_bytes
if data is None:
raise AssertionError("existing_private_key_bytes not set")
try: try:
# Interpret bytes depending on format. # Interpret bytes depending on format.
format = identify_private_key_format(data) format = identify_private_key_format(data)
@@ -460,11 +501,13 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
except Exception as e: except Exception as e:
raise PrivateKeyError(e) raise PrivateKeyError(e)
def _ensure_existing_private_key_loaded(self): def _ensure_existing_private_key_loaded(self) -> None:
if self.existing_private_key is None and self.has_existing(): if self.existing_private_key is None and self.has_existing():
self.existing_private_key = self._load_privatekey() self.existing_private_key = self._load_privatekey()
def _check_passphrase(self): def _check_passphrase(self) -> bool:
if self.existing_private_key_bytes is None:
raise AssertionError("existing_private_key_bytes not set")
try: try:
format = identify_private_key_format(self.existing_private_key_bytes) format = identify_private_key_format(self.existing_private_key_bytes)
if format == "raw": if format == "raw":
@@ -475,7 +518,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
# provided. # provided.
return self.passphrase is None return self.passphrase is None
else: else:
return ( return bool(
cryptography.hazmat.primitives.serialization.load_pem_private_key( cryptography.hazmat.primitives.serialization.load_pem_private_key(
self.existing_private_key_bytes, self.existing_private_key_bytes,
None if self.passphrase is None else to_bytes(self.passphrase), None if self.passphrase is None else to_bytes(self.passphrase),
@@ -484,7 +527,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
except Exception: except Exception:
return False return False
def _check_size_and_type(self): def _check_size_and_type(self) -> bool:
if isinstance( if isinstance(
self.existing_private_key, self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey,
@@ -527,11 +570,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
return False return False
if self.curve not in self.curves: if self.curve not in self.curves:
return False return False
return self.curves[self.curve]["verify"](self.existing_private_key) return self.curves[self.curve].verify(
self.existing_private_key, module=self.module
)
return False return False
def _check_format(self): def _check_format(self) -> bool:
if self.existing_private_key_bytes is None:
raise AssertionError("existing_private_key_bytes not set")
if self.format == "auto_ignore": if self.format == "auto_ignore":
return True return True
try: try:
@@ -541,14 +588,14 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
return False return False
def select_backend(module): def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend:
assert_required_cryptography_version( assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )
return PrivateKeyCryptographyBackend(module) return PrivateKeyCryptographyBackend(module)
def get_privatekey_argument_spec(): def get_privatekey_argument_spec() -> ArgumentSpec:
return ArgumentSpec( return ArgumentSpec(
argument_spec=dict( argument_spec=dict(
size=dict(type="int", default=4096), size=dict(type="int", default=4096),
@@ -607,6 +654,6 @@ def get_privatekey_argument_spec():
), ),
), ),
required_if=[ required_if=[
["type", "ECC", ["curve"]], ("type", "ECC", ["curve"]),
], ],
) )

View File

@@ -6,6 +6,7 @@ from __future__ import annotations
import abc import abc
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.argspec import ( from ansible_collections.community.crypto.plugins.module_utils.argspec import (
@@ -27,6 +28,13 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
from ansible_collections.community.crypto.plugins.module_utils.io import load_file from ansible_collections.community.crypto.plugins.module_utils.io import load_file
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try: try:
@@ -58,42 +66,48 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyConvertBackend(metaclass=abc.ABCMeta): class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
self.module = module self.module = module
self.src_path = module.params["src_path"] self.src_path: str | None = module.params["src_path"]
self.src_content = module.params["src_content"] self.src_content: str | None = module.params["src_content"]
self.src_passphrase = module.params["src_passphrase"] self.src_passphrase: str | None = module.params["src_passphrase"]
self.format = module.params["format"] self.format: t.Literal["pkcs1", "pkcs8", "raw"] = module.params["format"]
self.dest_passphrase = module.params["dest_passphrase"] self.dest_passphrase: str | None = module.params["dest_passphrase"]
self.src_private_key = None self.src_private_key: PrivateKeyTypes | None = None
if self.src_path is not None: if self.src_path is not None:
self.src_private_key_bytes = load_file(self.src_path, module) self.src_private_key_bytes = load_file(self.src_path, module)
else: else:
if self.src_content is None:
raise AssertionError("src_content is None")
self.src_private_key_bytes = self.src_content.encode("utf-8") self.src_private_key_bytes = self.src_content.encode("utf-8")
self.dest_private_key = None self.dest_private_key: PrivateKeyTypes | None = None
self.dest_private_key_bytes = None self.dest_private_key_bytes: bytes | None = None
@abc.abstractmethod @abc.abstractmethod
def get_private_key_data(self): def get_private_key_data(self) -> bytes:
"""Return bytes for self.src_private_key in output format.""" """Return bytes for self.src_private_key in output format."""
pass pass
def set_existing_destination(self, privatekey_bytes): def set_existing_destination(self, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist.""" """Set existing private key bytes. None indicates that the key does not exist."""
self.dest_private_key_bytes = privatekey_bytes self.dest_private_key_bytes = privatekey_bytes
def has_existing_destination(self): def has_existing_destination(self) -> bool:
"""Query whether an existing private key is/has been there.""" """Query whether an existing private key is/has been there."""
return self.dest_private_key_bytes is not None return self.dest_private_key_bytes is not None
@abc.abstractmethod @abc.abstractmethod
def _load_private_key(self, data, passphrase, current_hint=None): def _load_private_key(
self,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
) -> tuple[str, PrivateKeyTypes]:
"""Check whether data can be loaded as a private key with the provided passphrase. Return tuple (type, private_key).""" """Check whether data can be loaded as a private key with the provided passphrase. Return tuple (type, private_key)."""
pass
def needs_conversion(self): def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False.""" """Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
dummy, self.src_private_key = self._load_private_key( dummy, self.src_private_key = self._load_private_key(
self.src_private_key_bytes, self.src_passphrase self.src_private_key_bytes, self.src_passphrase
@@ -101,6 +115,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
if not self.has_existing_destination(): if not self.has_existing_destination():
return True return True
assert self.dest_private_key_bytes is not None
try: try:
format, self.dest_private_key = self._load_private_key( format, self.dest_private_key = self._load_private_key(
@@ -115,18 +130,20 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
self.dest_private_key, self.src_private_key self.dest_private_key, self.src_private_key
) )
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
return {} return {}
# Implementation with using cryptography # Implementation with using cryptography
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend): class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module) super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module)
def get_private_key_data(self): def get_private_key_data(self) -> bytes:
"""Return bytes for self.src_private_key in output format""" """Return bytes for self.src_private_key in output format"""
if self.src_private_key is None:
raise AssertionError("src_private_key not set")
# Select export format and encoding # Select export format and encoding
try: try:
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
@@ -152,9 +169,9 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
) )
# Select key encryption # Select key encryption
encryption_algorithm = ( encryption_algorithm: (
cryptography.hazmat.primitives.serialization.NoEncryption() cryptography.hazmat.primitives.serialization.KeySerializationEncryption
) ) = cryptography.hazmat.primitives.serialization.NoEncryption()
if self.dest_passphrase: if self.dest_passphrase:
encryption_algorithm = ( encryption_algorithm = (
cryptography.hazmat.primitives.serialization.BestAvailableEncryption( cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
@@ -179,7 +196,12 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
exception=traceback.format_exc(), exception=traceback.format_exc(),
) )
def _load_private_key(self, data, passphrase, current_hint=None): def _load_private_key(
self,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
) -> tuple[str, PrivateKeyTypes]:
try: try:
# Interpret bytes depending on format. # Interpret bytes depending on format.
format = identify_private_key_format(data) format = identify_private_key_format(data)
@@ -247,14 +269,14 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
raise PrivateKeyError(e) raise PrivateKeyError(e)
def select_backend(module): def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend:
assert_required_cryptography_version( assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )
return PrivateKeyConvertCryptographyBackend(module) return PrivateKeyConvertCryptographyBackend(module)
def get_privatekey_argument_spec(): def get_privatekey_argument_spec() -> ArgumentSpec:
return ArgumentSpec( return ArgumentSpec(
argument_spec=dict( argument_spec=dict(
src_path=dict(type="path"), src_path=dict(type="path"),

View File

@@ -7,6 +7,7 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -29,6 +30,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
from ....plugin_utils.action_module import AnsibleActionModule
from ....plugin_utils.filter_module import FilterModuleMock
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try: try:
@@ -40,38 +53,49 @@ except ImportError:
SIGNATURE_TEST_DATA = b"1234" SIGNATURE_TEST_DATA = b"1234"
def _get_cryptography_private_key_info(key, need_private_key_data=False): def _get_cryptography_private_key_info(
key: PrivateKeyTypes, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
key_type, key_public_data = _get_cryptography_public_key_info(key.public_key()) key_type, key_public_data = _get_cryptography_public_key_info(key.public_key())
key_private_data = dict() key_private_data: dict[str, t.Any] = {}
if need_private_key_data: if need_private_key_data:
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey): if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
private_numbers = key.private_numbers() rsa_private_numbers = key.private_numbers()
key_private_data["p"] = private_numbers.p key_private_data["p"] = rsa_private_numbers.p
key_private_data["q"] = private_numbers.q key_private_data["q"] = rsa_private_numbers.q
key_private_data["exponent"] = private_numbers.d key_private_data["exponent"] = rsa_private_numbers.d
elif isinstance( elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey
): ):
private_numbers = key.private_numbers() dsa_private_numbers = key.private_numbers()
key_private_data["x"] = private_numbers.x key_private_data["x"] = dsa_private_numbers.x
elif isinstance( elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
): ):
private_numbers = key.private_numbers() ecc_private_numbers = key.private_numbers()
key_private_data["multiplier"] = private_numbers.private_value key_private_data["multiplier"] = ecc_private_numbers.private_value
return key_type, key_public_data, key_private_data return key_type, key_public_data, key_private_data
def _check_dsa_consistency(key_public_data, key_private_data): def _check_dsa_consistency(
key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
# Get parameters # Get parameters
p = key_public_data.get("p") p: int | None = key_public_data.get("p")
q = key_public_data.get("q") if p is None:
g = key_public_data.get("g") return None
y = key_public_data.get("y") q: int | None = key_public_data.get("q")
x = key_private_data.get("x") if q is None:
for v in (p, q, g, y, x): return None
if v is None: g: int | None = key_public_data.get("g")
return None if g is None:
return None
y: int | None = key_public_data.get("y")
if y is None:
return None
x: int | None = key_private_data.get("x")
if x is None:
return None
# Make sure that g is not 0, 1 or -1 in Z/pZ # Make sure that g is not 0, 1 or -1 in Z/pZ
if g < 2 or g >= p - 1: if g < 2 or g >= p - 1:
return False return False
@@ -94,13 +118,16 @@ def _check_dsa_consistency(key_public_data, key_private_data):
def _is_cryptography_key_consistent( def _is_cryptography_key_consistent(
key, key_public_data, key_private_data, warn_func=None key: PrivateKeyTypes,
): key_public_data: dict[str, t.Any],
key_private_data: dict[str, t.Any],
warn_func: t.Callable[[str], None] | None = None,
) -> bool | None:
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey): if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
# key._backend was removed in cryptography 42.0.0 # key._backend was removed in cryptography 42.0.0
backend = getattr(key, "_backend", None) backend = getattr(key, "_backend", None)
if backend is not None: if backend is not None:
return bool(backend._lib.RSA_check_key(key._rsa_cdata)) return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey): if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
result = _check_dsa_consistency(key_public_data, key_private_data) result = _check_dsa_consistency(key_public_data, key_private_data)
if result is not None: if result is not None:
@@ -145,9 +172,9 @@ def _is_cryptography_key_consistent(
if isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey): if isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey):
has_simple_sign_function = True has_simple_sign_function = True
if has_simple_sign_function: if has_simple_sign_function:
signature = key.sign(SIGNATURE_TEST_DATA) signature = key.sign(SIGNATURE_TEST_DATA) # type: ignore
try: try:
key.public_key().verify(signature, SIGNATURE_TEST_DATA) key.public_key().verify(signature, SIGNATURE_TEST_DATA) # type: ignore
return True return True
except cryptography.exceptions.InvalidSignature: except cryptography.exceptions.InvalidSignature:
return False return False
@@ -158,14 +185,14 @@ def _is_cryptography_key_consistent(
class PrivateKeyConsistencyError(OpenSSLObjectError): class PrivateKeyConsistencyError(OpenSSLObjectError):
def __init__(self, msg, result): def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
super(PrivateKeyConsistencyError, self).__init__(msg) super(PrivateKeyConsistencyError, self).__init__(msg)
self.error_message = msg self.error_message = msg
self.result = result self.result = result
class PrivateKeyParseError(OpenSSLObjectError): class PrivateKeyParseError(OpenSSLObjectError):
def __init__(self, msg, result): def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
super(PrivateKeyParseError, self).__init__(msg) super(PrivateKeyParseError, self).__init__(msg)
self.error_message = msg self.error_message = msg
self.result = result self.result = result
@@ -174,13 +201,12 @@ class PrivateKeyParseError(OpenSSLObjectError):
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta): class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__( def __init__(
self, self,
module, module: GeneralAnsibleModule,
content, content: bytes,
passphrase=None, passphrase: str | None = None,
return_private_key_data=False, return_private_key_data: bool = False,
check_consistency=False, check_consistency: bool = False,
): ):
# content must be a bytes string
self.module = module self.module = module
self.content = content self.content = content
self.passphrase = passphrase self.passphrase = passphrase
@@ -188,22 +214,26 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
self.check_consistency = check_consistency self.check_consistency = check_consistency
@abc.abstractmethod @abc.abstractmethod
def _get_public_key(self, binary): def _get_public_key(self, binary: bool) -> bytes:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_key_info(self, need_private_key_data=False): def _get_key_info(
self, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _is_key_consistent(self, key_public_data, key_private_data): def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
pass pass
def get_info(self, prefer_one_fingerprint=False): def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result = dict( result: dict[str, t.Any] = {
can_parse_key=False, "can_parse_key": False,
key_is_consistent=None, "key_is_consistent": None,
) }
priv_key_detail = self.content priv_key_detail = self.content
try: try:
self.key = load_privatekey( self.key = load_privatekey(
@@ -252,35 +282,39 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval): class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
"""Validate the supplied private key, using the cryptography backend""" """Validate the supplied private key, using the cryptography backend"""
def __init__(self, module, content, **kwargs): def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None:
super(PrivateKeyInfoRetrievalCryptography, self).__init__( super(PrivateKeyInfoRetrievalCryptography, self).__init__(
module, content, **kwargs module, content, **kwargs
) )
def _get_public_key(self, binary): def _get_public_key(self, binary: bool) -> bytes:
return self.key.public_key().public_bytes( return self.key.public_key().public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM, serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo, serialization.PublicFormat.SubjectPublicKeyInfo,
) )
def _get_key_info(self, need_private_key_data=False): def _get_key_info(
self, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return _get_cryptography_private_key_info( return _get_cryptography_private_key_info(
self.key, need_private_key_data=need_private_key_data self.key, need_private_key_data=need_private_key_data
) )
def _is_key_consistent(self, key_public_data, key_private_data): def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
return _is_cryptography_key_consistent( return _is_cryptography_key_consistent(
self.key, key_public_data, key_private_data, warn_func=self.module.warn self.key, key_public_data, key_private_data, warn_func=self.module.warn
) )
def get_privatekey_info( def get_privatekey_info(
module, module: GeneralAnsibleModule,
content, content: bytes,
passphrase=None, passphrase: str | None = None,
return_private_key_data=False, return_private_key_data: bool = False,
prefer_one_fingerprint=False, prefer_one_fingerprint: bool = False,
): ) -> dict[str, t.Any]:
info = PrivateKeyInfoRetrievalCryptography( info = PrivateKeyInfoRetrievalCryptography(
module, module,
content, content,
@@ -291,12 +325,12 @@ def get_privatekey_info(
def select_backend( def select_backend(
module, module: GeneralAnsibleModule,
content, content: bytes,
passphrase=None, passphrase: str | None = None,
return_private_key_data=False, return_private_key_data: bool = False,
check_consistency=False, check_consistency: bool = False,
): ) -> PrivateKeyInfoRetrieval:
assert_required_cryptography_version( assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )

View File

@@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError, OpenSSLObjectError,
@@ -19,6 +20,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PublicKeyTypes,
)
from ....plugin_utils.action_module import AnsibleActionModule
from ....plugin_utils.filter_module import FilterModuleMock
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try: try:
@@ -32,23 +45,25 @@ except ImportError:
pass pass
def _get_cryptography_public_key_info(key): def _get_cryptography_public_key_info(
key_public_data = dict() key: PublicKeyTypes,
) -> tuple[str, dict[str, t.Any]]:
key_public_data: dict[str, t.Any] = {}
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey): if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey):
key_type = "RSA" key_type = "RSA"
public_numbers = key.public_numbers() rsa_public_numbers = key.public_numbers()
key_public_data["size"] = key.key_size key_public_data["size"] = key.key_size
key_public_data["modulus"] = public_numbers.n key_public_data["modulus"] = rsa_public_numbers.n
key_public_data["exponent"] = public_numbers.e key_public_data["exponent"] = rsa_public_numbers.e
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey): elif isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey):
key_type = "DSA" key_type = "DSA"
parameter_numbers = key.parameters().parameter_numbers() dsa_parameter_numbers = key.parameters().parameter_numbers()
public_numbers = key.public_numbers() dsa_public_numbers = key.public_numbers()
key_public_data["size"] = key.key_size key_public_data["size"] = key.key_size
key_public_data["p"] = parameter_numbers.p key_public_data["p"] = dsa_parameter_numbers.p
key_public_data["q"] = parameter_numbers.q key_public_data["q"] = dsa_parameter_numbers.q
key_public_data["g"] = parameter_numbers.g key_public_data["g"] = dsa_parameter_numbers.g
key_public_data["y"] = public_numbers.y key_public_data["y"] = dsa_public_numbers.y
elif isinstance( elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey
): ):
@@ -67,10 +82,10 @@ def _get_cryptography_public_key_info(key):
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey
): ):
key_type = "ECC" key_type = "ECC"
public_numbers = key.public_numbers() ecc_public_numbers = key.public_numbers()
key_public_data["curve"] = key.curve.name key_public_data["curve"] = key.curve.name
key_public_data["x"] = public_numbers.x key_public_data["x"] = ecc_public_numbers.x
key_public_data["y"] = public_numbers.y key_public_data["y"] = ecc_public_numbers.y
key_public_data["exponent_size"] = key.curve.key_size key_public_data["exponent_size"] = key.curve.key_size
else: else:
key_type = f"unknown ({type(key)})" key_type = f"unknown ({type(key)})"
@@ -78,29 +93,34 @@ def _get_cryptography_public_key_info(key):
class PublicKeyParseError(OpenSSLObjectError): class PublicKeyParseError(OpenSSLObjectError):
def __init__(self, msg, result): def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
super(PublicKeyParseError, self).__init__(msg) super(PublicKeyParseError, self).__init__(msg)
self.error_message = msg self.error_message = msg
self.result = result self.result = result
class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta): class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(self, module, content=None, key=None): def __init__(
self,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> None:
# content must be a bytes string # content must be a bytes string
self.module = module self.module = module
self.content = content self.content = content
self.key = key self.key = key
@abc.abstractmethod @abc.abstractmethod
def _get_public_key(self, binary): def _get_public_key(self, binary: bool) -> bytes:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _get_key_info(self): def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
pass pass
def get_info(self, prefer_one_fingerprint=False): def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result = dict() result: dict[str, t.Any] = {}
if self.key is None: if self.key is None:
try: try:
self.key = load_publickey(content=self.content) self.key = load_publickey(content=self.content)
@@ -123,27 +143,45 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval): class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
"""Validate the supplied public key, using the cryptography backend""" """Validate the supplied public key, using the cryptography backend"""
def __init__(self, module, content=None, key=None): def __init__(
self,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> None:
super(PublicKeyInfoRetrievalCryptography, self).__init__( super(PublicKeyInfoRetrievalCryptography, self).__init__(
module, content=content, key=key module, content=content, key=key
) )
def _get_public_key(self, binary): def _get_public_key(self, binary: bool) -> bytes:
if self.key is None:
raise AssertionError("key must be set")
return self.key.public_bytes( return self.key.public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM, serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo, serialization.PublicFormat.SubjectPublicKeyInfo,
) )
def _get_key_info(self): def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
if self.key is None:
raise AssertionError("key must be set")
return _get_cryptography_public_key_info(self.key) return _get_cryptography_public_key_info(self.key)
def get_publickey_info(module, content=None, key=None, prefer_one_fingerprint=False): def get_publickey_info(
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key) info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint) return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(module, content=None, key=None): def select_backend(
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> PublicKeyInfoRetrieval:
assert_required_cryptography_version( assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )

View File

@@ -8,3 +8,6 @@ from __future__ import annotations
from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import ( # noqa: F401, pylint: disable=unused-import from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import ( # noqa: F401, pylint: disable=unused-import
parse_openssh_version, parse_openssh_version,
) )
# TODO: delete!

View File

@@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
PEM_START = "-----BEGIN " PEM_START = "-----BEGIN "
PEM_END_START = "-----END " PEM_END_START = "-----END "
@@ -12,7 +14,7 @@ PKCS8_PRIVATEKEY_NAMES = ("PRIVATE KEY", "ENCRYPTED PRIVATE KEY")
PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY" PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY"
def identify_pem_format(content, encoding="utf-8"): def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
"""Given the contents of a binary file, tests whether this could be a PEM file.""" """Given the contents of a binary file, tests whether this could be a PEM file."""
try: try:
first_pem = extract_first_pem(content.decode(encoding)) first_pem = extract_first_pem(content.decode(encoding))
@@ -30,7 +32,9 @@ def identify_pem_format(content, encoding="utf-8"):
return False return False
def identify_private_key_format(content, encoding="utf-8"): def identify_private_key_format(
content: bytes, encoding: str = "utf-8"
) -> t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]:
"""Given the contents of a private key file, identifies its format.""" """Given the contents of a private key file, identifies its format."""
# See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85 # See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85
# (PEM_read_bio_PrivateKey) # (PEM_read_bio_PrivateKey)
@@ -59,12 +63,12 @@ def identify_private_key_format(content, encoding="utf-8"):
return "raw" return "raw"
def split_pem_list(text, keep_inbetween=False): def split_pem_list(text: str, keep_inbetween: bool = False) -> list[str]:
""" """
Split concatenated PEM objects into a list of strings, where each is one PEM object. Split concatenated PEM objects into a list of strings, where each is one PEM object.
""" """
result = [] result = []
current = [] if keep_inbetween else None current: list[str] | None = [] if keep_inbetween else None
for line in text.splitlines(True): for line in text.splitlines(True):
if line.strip(): if line.strip():
if not keep_inbetween and line.startswith("-----BEGIN "): if not keep_inbetween and line.startswith("-----BEGIN "):
@@ -77,7 +81,7 @@ def split_pem_list(text, keep_inbetween=False):
return result return result
def extract_first_pem(text): def extract_first_pem(text: str) -> str | None:
""" """
Given one PEM or multiple concatenated PEM objects, return only the first one, or None if there is none. Given one PEM or multiple concatenated PEM objects, return only the first one, or None if there is none.
""" """
@@ -87,7 +91,7 @@ def extract_first_pem(text):
return all_pems[0] return all_pems[0]
def _extract_type(line, start=PEM_START): def _extract_type(line: str, start: str = PEM_START) -> str | None:
if not line.startswith(start): if not line.startswith(start):
return None return None
if not line.endswith(PEM_END): if not line.endswith(PEM_END):
@@ -95,7 +99,7 @@ def _extract_type(line, start=PEM_START):
return line[len(start) : -len(PEM_END)] return line[len(start) : -len(PEM_END)]
def extract_pem(content, strict=False): def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
lines = content.splitlines() lines = content.splitlines()
if len(lines) < 3: if len(lines) < 3:
raise ValueError(f"PEM must have at least 3 lines, have only {len(lines)}") raise ValueError(f"PEM must have at least 3 lines, have only {len(lines)}")
@@ -117,5 +121,4 @@ def extract_pem(content, strict=False):
raise ValueError( raise ValueError(
f"Last line has length {len(lines[-2])}, should be in (0, 64]" f"Last line has length {len(lines[-2])}, should be in (0, 64]"
) )
content = lines[1:-1] return header_type, "".join(lines[1:-1])
return header_type, "".join(content)

View File

@@ -8,8 +8,13 @@ import abc
import errno import errno
import hashlib import hashlib
import os import os
import typing as t
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
is_potential_certificate_issuer_private_key,
is_potential_certificate_private_key,
)
from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import (
identify_pem_format, identify_pem_format,
) )
@@ -34,6 +39,17 @@ except ImportError:
from .basic import OpenSSLBadPassphraseError, OpenSSLObjectError from .basic import OpenSSLBadPassphraseError, OpenSSLObjectError
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
PublicKeyTypes,
)
from .cryptography_support import CertificatePrivateKeyTypes
# This list of preferred fingerprints is used when prefer_one=True is supplied to the # This list of preferred fingerprints is used when prefer_one=True is supplied to the
# fingerprinting methods. # fingerprinting methods.
PREFERRED_FINGERPRINTS = ( PREFERRED_FINGERPRINTS = (
@@ -48,18 +64,12 @@ PREFERRED_FINGERPRINTS = (
) )
def get_fingerprint_of_bytes(source, prefer_one=False): def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[str, str]:
"""Generate the fingerprint of the given bytes.""" """Generate the fingerprint of the given bytes."""
fingerprint = {} fingerprint = {}
try: algorithms: t.Iterable[str] = hashlib.algorithms_guaranteed
algorithms = hashlib.algorithms
except AttributeError:
try:
algorithms = hashlib.algorithms_guaranteed
except AttributeError:
return None
if prefer_one: if prefer_one:
# Sort algorithms to have the ones in PREFERRED_FINGERPRINTS at the beginning # Sort algorithms to have the ones in PREFERRED_FINGERPRINTS at the beginning
@@ -97,7 +107,9 @@ def get_fingerprint_of_bytes(source, prefer_one=False):
return fingerprint return fingerprint
def get_fingerprint_of_privatekey(privatekey, prefer_one=False): def get_fingerprint_of_privatekey(
privatekey: PrivateKeyTypes, prefer_one: bool = False
) -> dict[str, str]:
"""Generate the fingerprint of the public key.""" """Generate the fingerprint of the public key."""
publickey = privatekey.public_key().public_bytes( publickey = privatekey.public_key().public_bytes(
@@ -107,11 +119,16 @@ def get_fingerprint_of_privatekey(privatekey, prefer_one=False):
return get_fingerprint_of_bytes(publickey, prefer_one=prefer_one) return get_fingerprint_of_bytes(publickey, prefer_one=prefer_one)
def get_fingerprint(path, passphrase=None, content=None, prefer_one=False): def get_fingerprint(
path: os.PathLike | str | None = None,
passphrase: str | bytes | None = None,
content: bytes | None = None,
prefer_one: bool = False,
) -> dict[str, str]:
"""Generate the fingerprint of the public key.""" """Generate the fingerprint of the public key."""
privatekey = load_privatekey( privatekey = load_privatekey(
path, path=path,
passphrase=passphrase, passphrase=passphrase,
content=content, content=content,
check_passphrase=False, check_passphrase=False,
@@ -121,11 +138,11 @@ def get_fingerprint(path, passphrase=None, content=None, prefer_one=False):
def load_privatekey( def load_privatekey(
path, path: os.PathLike | str | None = None,
passphrase=None, passphrase: str | bytes | None = None,
check_passphrase=True, check_passphrase: bool = True,
content=None, content: bytes | None = None,
): ) -> PrivateKeyTypes:
"""Load the specified OpenSSL private key. """Load the specified OpenSSL private key.
The content can also be specified via content; in that case, The content can also be specified via content; in that case,
@@ -134,6 +151,8 @@ def load_privatekey(
try: try:
if content is None: if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
with open(path, "rb") as b_priv_key_fh: with open(path, "rb") as b_priv_key_fh:
priv_key_detail = b_priv_key_fh.read() priv_key_detail = b_priv_key_fh.read()
else: else:
@@ -154,7 +173,55 @@ def load_privatekey(
raise OpenSSLBadPassphraseError("Wrong passphrase provided for private key") raise OpenSSLBadPassphraseError("Wrong passphrase provided for private key")
def load_publickey(path=None, content=None): def load_certificate_privatekey(
*,
path: os.PathLike | str | None = None,
content: bytes | None = None,
passphrase: str | bytes | None = None,
check_passphrase: bool = True,
) -> CertificatePrivateKeyTypes:
"""
Load the specified OpenSSL private key that can be used as a private key for certificates.
"""
private_key = load_privatekey(
path=path,
passphrase=passphrase,
check_passphrase=check_passphrase,
content=content,
)
if not is_potential_certificate_private_key(private_key):
raise OpenSSLObjectError(
f"Key of type {type(private_key)} not supported for certificates"
)
return private_key
def load_certificate_issuer_privatekey(
*,
path: os.PathLike | str | None = None,
content: bytes | None = None,
passphrase: str | bytes | None = None,
check_passphrase: bool = True,
) -> CertificateIssuerPrivateKeyTypes:
"""
Load the specified OpenSSL private key that can be used for issuing certificates.
"""
private_key = load_privatekey(
path=path,
passphrase=passphrase,
check_passphrase=check_passphrase,
content=content,
)
if not is_potential_certificate_issuer_private_key(private_key):
raise OpenSSLObjectError(
f"Key of type {type(private_key)} not supported for issuing certificates"
)
return private_key
def load_publickey(
path: os.PathLike | str | None = None, content: bytes | None = None
) -> PublicKeyTypes:
if content is None: if content is None:
if path is None: if path is None:
raise OpenSSLObjectError("Must provide either path or content") raise OpenSSLObjectError("Must provide either path or content")
@@ -170,11 +237,17 @@ def load_publickey(path=None, content=None):
raise OpenSSLObjectError(f"Error while deserializing key: {e}") raise OpenSSLObjectError(f"Error while deserializing key: {e}")
def load_certificate(path, content=None, der_support_enabled=False): def load_certificate(
path: os.PathLike | str | None = None,
content: bytes | None = None,
der_support_enabled: bool = False,
) -> x509.Certificate:
"""Load the specified certificate.""" """Load the specified certificate."""
try: try:
if content is None: if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
with open(path, "rb") as cert_fh: with open(path, "rb") as cert_fh:
cert_content = cert_fh.read() cert_content = cert_fh.read()
else: else:
@@ -193,10 +266,14 @@ def load_certificate(path, content=None, der_support_enabled=False):
raise OpenSSLObjectError(f"Cannot parse DER certificate: {exc}") raise OpenSSLObjectError(f"Cannot parse DER certificate: {exc}")
def load_certificate_request(path, content=None): def load_certificate_request(
path: os.PathLike | str | None = None, content: bytes | None = None
) -> x509.CertificateSigningRequest:
"""Load the specified certificate signing request.""" """Load the specified certificate signing request."""
try: try:
if content is None: if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
with open(path, "rb") as csr_fh: with open(path, "rb") as csr_fh:
csr_content = csr_fh.read() csr_content = csr_fh.read()
else: else:
@@ -209,45 +286,44 @@ def load_certificate_request(path, content=None):
raise OpenSSLObjectError(exc) raise OpenSSLObjectError(exc)
def parse_name_field(input_dict, name_field_name=None): def parse_name_field(
input_dict: dict[str, list[str | bytes] | str | bytes],
name_field_name: str | None = None,
) -> list[tuple[str, str | bytes]]:
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples""" """Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
error_str = "{key}" if name_field_name is None else "{key} in {name}"
def error_str(key: str) -> str:
if name_field_name is None:
return f"{key}"
return f"{key} in {name_field_name}"
result = [] result = []
for key, value in input_dict.items(): for key, value in input_dict.items():
if isinstance(value, list): if isinstance(value, list):
for entry in value: for entry in value:
if not isinstance(entry, (str, bytes)): if not isinstance(entry, (str, bytes)):
raise TypeError( raise TypeError(f"Values {error_str(key)} must be strings")
f"Values {error_str} must be strings".format(
key=key, name=name_field_name
)
)
if not entry: if not entry:
raise ValueError( raise ValueError(
f"Values for {error_str} must not be empty strings".format( f"Values for {error_str(key)} must not be empty strings"
key=key, name=name_field_name
)
) )
result.append((key, entry)) result.append((key, entry))
elif isinstance(value, (str, bytes)): elif isinstance(value, (str, bytes)):
if not value: if not value:
raise ValueError( raise ValueError(
f"Value for {error_str} must not be an empty string".format( f"Value for {error_str(key)} must not be an empty string"
key=key, name=name_field_name
)
) )
result.append((key, value)) result.append((key, value))
else: else:
raise TypeError( raise TypeError(
( f"Value for {error_str(key)} must be either a string or a list of strings"
f"Value for {error_str} must be either a string or a list of strings"
).format(key=key, name=name_field_name)
) )
return result return result
def parse_ordered_name_field(input_list, name_field_name): def parse_ordered_name_field(
input_list: list[dict[str, list[str | bytes] | str | bytes]], name_field_name: str
) -> list[tuple[str, str | bytes]]:
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples""" """Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
result = [] result = []
@@ -265,24 +341,39 @@ def parse_ordered_name_field(input_list, name_field_name):
return result return result
def select_message_digest(digest_string): @t.overload
digest = None def select_message_digest(
digest_string: t.Literal["sha256", "sha384", "sha512", "sha1", "md5"],
) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5: ...
@t.overload
def select_message_digest(
digest_string: str,
) -> (
hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None
): ...
def select_message_digest(
digest_string: str,
) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None:
if digest_string == "sha256": if digest_string == "sha256":
digest = hashes.SHA256() return hashes.SHA256()
elif digest_string == "sha384": if digest_string == "sha384":
digest = hashes.SHA384() return hashes.SHA384()
elif digest_string == "sha512": if digest_string == "sha512":
digest = hashes.SHA512() return hashes.SHA512()
elif digest_string == "sha1": if digest_string == "sha1":
digest = hashes.SHA1() return hashes.SHA1()
elif digest_string == "md5": if digest_string == "md5":
digest = hashes.MD5() return hashes.MD5()
return digest return None
class OpenSSLObject(metaclass=abc.ABCMeta): class OpenSSLObject(metaclass=abc.ABCMeta):
def __init__(self, path, state, force, check_mode): def __init__(self, path: str, state: str, force: bool, check_mode: bool) -> None:
self.path = path self.path = path
self.state = state self.state = state
self.force = force self.force = force
@@ -290,13 +381,13 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
self.changed = False self.changed = False
self.check_mode = check_mode self.check_mode = check_mode
def check(self, module, perms_required=True): def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state.""" """Ensure the resource is in its desired state."""
def _check_state(): def _check_state() -> bool:
return os.path.exists(self.path) return os.path.exists(self.path)
def _check_perms(module): def _check_perms(module: AnsibleModule) -> bool:
file_args = module.load_file_common_arguments(module.params) file_args = module.load_file_common_arguments(module.params)
if module.check_file_absent_if_check_mode(file_args["path"]): if module.check_file_absent_if_check_mode(file_args["path"]):
return False return False
@@ -308,18 +399,14 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
return _check_state() and _check_perms(module) return _check_state() and _check_perms(module)
@abc.abstractmethod @abc.abstractmethod
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
pass
@abc.abstractmethod @abc.abstractmethod
def generate(self): def generate(self, module: AnsibleModule) -> None:
"""Generate the resource.""" """Generate the resource."""
pass def remove(self, module: AnsibleModule) -> None:
def remove(self, module):
"""Remove the resource from the filesystem.""" """Remove the resource from the filesystem."""
if self.check_mode: if self.check_mode:
if os.path.exists(self.path): if os.path.exists(self.path):

View File

@@ -11,6 +11,7 @@ Must be kept in sync with plugins/doc_fragments/cryptography_dep.py.
from __future__ import annotations from __future__ import annotations
import traceback import traceback
import typing as t
from ansible.module_utils.basic import missing_required_lib from ansible.module_utils.basic import missing_required_lib
from ansible_collections.community.crypto.plugins.module_utils.version import ( from ansible_collections.community.crypto.plugins.module_utils.version import (
@@ -18,19 +19,29 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
) )
_CRYPTOGRAPHY_IMP_ERR = None if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ..plugin_utils.action_module import AnsibleActionModule
from ..plugin_utils.filter_module import FilterModuleMock
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
_CRYPTOGRAPHY_IMP_ERR: str | None = None
_CRYPTOGRAPHY_FILE: str | None = None
try: try:
import cryptography import cryptography
from cryptography import x509 # noqa: F401, pylint: disable=unused-import from cryptography import x509 # noqa: F401, pylint: disable=unused-import
_CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__) CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
_CRYPTOGRAPHY_FILE = cryptography.__file__ _CRYPTOGRAPHY_FILE = cryptography.__file__
except ImportError: except ImportError:
_CRYPTOGRAPHY_IMP_ERR = traceback.format_exc() _CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
_CRYPTOGRAPHY_FOUND = False CRYPTOGRAPHY_FOUND = False
_CRYPTOGRAPHY_FILE = None CRYPTOGRAPHY_VERSION = LooseVersion("0.0")
else: else:
_CRYPTOGRAPHY_FOUND = True CRYPTOGRAPHY_FOUND = True
# Corresponds to the community.crypto.cryptography_dep.minimum doc fragment # Corresponds to the community.crypto.cryptography_dep.minimum doc fragment
@@ -38,25 +49,27 @@ COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION = "3.3"
def assert_required_cryptography_version( def assert_required_cryptography_version(
module, module: GeneralAnsibleModule,
*, *,
minimum_cryptography_version: str = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION, minimum_cryptography_version: str = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
) -> None: ) -> None:
if not _CRYPTOGRAPHY_FOUND: if not CRYPTOGRAPHY_FOUND:
module.fail_json( module.fail_json(
msg=missing_required_lib(f"cryptography >= {minimum_cryptography_version}"), msg=missing_required_lib(f"cryptography >= {minimum_cryptography_version}"),
exception=_CRYPTOGRAPHY_IMP_ERR, exception=_CRYPTOGRAPHY_IMP_ERR,
) )
if _CRYPTOGRAPHY_VERSION < LooseVersion(minimum_cryptography_version): if CRYPTOGRAPHY_VERSION < LooseVersion(minimum_cryptography_version):
module.fail_json( module.fail_json(
msg=( msg=(
f"Cannot detect the required Python library cryptography (>= {minimum_cryptography_version})." f"Cannot detect the required Python library cryptography (>= {minimum_cryptography_version})."
f" Only found a too old version ({_CRYPTOGRAPHY_VERSION}) at {_CRYPTOGRAPHY_FILE}." f" Only found a too old version ({CRYPTOGRAPHY_VERSION}) at {_CRYPTOGRAPHY_FILE}."
), ),
) )
__all__ = ( __all__ = (
"COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION", "COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION",
"CRYPTOGRAPHY_FOUND",
"CRYPTOGRAPHY_VERSION",
"assert_required_cryptography_version", "assert_required_cryptography_version",
) )

View File

@@ -14,6 +14,7 @@ import json
import os import os
import re import re
import traceback import traceback
import typing as t
from urllib.error import HTTPError from urllib.error import HTTPError
from urllib.parse import urlencode from urllib.parse import urlencode
@@ -34,7 +35,7 @@ else:
valid_file_format = re.compile(r".*(\.)(yml|yaml|json)$") valid_file_format = re.compile(r".*(\.)(yml|yaml|json)$")
def ecs_client_argument_spec(): def ecs_client_argument_spec() -> dict[str, t.Any]:
return dict( return dict(
entrust_api_user=dict(type="str", required=True), entrust_api_user=dict(type="str", required=True),
entrust_api_key=dict(type="str", required=True, no_log=True), entrust_api_key=dict(type="str", required=True, no_log=True),
@@ -50,19 +51,17 @@ def ecs_client_argument_spec():
class SessionConfigurationException(Exception): class SessionConfigurationException(Exception):
"""Raised if we cannot configure a session with the API""" """Raised if we cannot configure a session with the API"""
pass
class RestOperationException(Exception): class RestOperationException(Exception):
"""Encapsulate a REST API error""" """Encapsulate a REST API error"""
def __init__(self, error): def __init__(self, error: dict[str, t.Any]) -> None:
self.status = to_native(error.get("status", None)) self.status = to_native(error.get("status", None))
self.errors = [to_native(err.get("message")) for err in error.get("errors", {})] self.errors = [to_native(err.get("message")) for err in error.get("errors", {})]
self.message = " ".join(self.errors) self.message = " ".join(self.errors)
def generate_docstring(operation_spec): def generate_docstring(operation_spec: dict[str, t.Any]) -> str:
"""Generate a docstring for an operation defined in operation_spec (swagger)""" """Generate a docstring for an operation defined in operation_spec (swagger)"""
# Description of the operation # Description of the operation
docs = operation_spec.get("description", "No Description") docs = operation_spec.get("description", "No Description")

View File

@@ -14,7 +14,9 @@ class GPGError(Exception):
class GPGRunner(metaclass=abc.ABCMeta): class GPGRunner(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def run_command(self, command, check_rc=True, data=None): def run_command(
self, command: list[str], check_rc: bool = True, data: bytes | None = None
) -> tuple[int, str, str]:
""" """
Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``. Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``.
@@ -29,7 +31,7 @@ class GPGRunner(metaclass=abc.ABCMeta):
pass pass
def get_fingerprint_from_stdout(stdout): def get_fingerprint_from_stdout(stdout: str) -> str:
lines = stdout.splitlines(False) lines = stdout.splitlines(False)
for line in lines: for line in lines:
if line.startswith("fpr:"): if line.startswith("fpr:"):
@@ -42,7 +44,7 @@ def get_fingerprint_from_stdout(stdout):
raise GPGError(f'Cannot extract fingerprint from stdout "{stdout}"') raise GPGError(f'Cannot extract fingerprint from stdout "{stdout}"')
def get_fingerprint_from_file(gpg_runner, path): def get_fingerprint_from_file(gpg_runner: GPGRunner, path: str) -> str:
if not os.path.exists(path): if not os.path.exists(path):
raise GPGError(f"{path} does not exist") raise GPGError(f"{path} does not exist")
stdout = gpg_runner.run_command( stdout = gpg_runner.run_command(
@@ -59,7 +61,7 @@ def get_fingerprint_from_file(gpg_runner, path):
return get_fingerprint_from_stdout(stdout) return get_fingerprint_from_stdout(stdout)
def get_fingerprint_from_bytes(gpg_runner, content): def get_fingerprint_from_bytes(gpg_runner: GPGRunner, content: bytes) -> str:
stdout = gpg_runner.run_command( stdout = gpg_runner.run_command(
[ [
"--no-keyring", "--no-keyring",

View File

@@ -7,9 +7,14 @@ from __future__ import annotations
import errno import errno
import os import os
import tempfile import tempfile
import typing as t
def load_file(path, module=None): if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def load_file(path: str | os.PathLike, module: AnsibleModule | None = None) -> bytes:
""" """
Load the file as a bytes string. Load the file as a bytes string.
""" """
@@ -22,7 +27,11 @@ def load_file(path, module=None):
module.fail_json(f"Error while loading {path} - {exc}") module.fail_json(f"Error while loading {path} - {exc}")
def load_file_if_exists(path, module=None, ignore_errors=False): def load_file_if_exists(
path: str | os.PathLike,
module: AnsibleModule | None = None,
ignore_errors: bool = False,
) -> bytes | None:
""" """
Load the file as a bytes string. If the file does not exist, ``None`` is returned. Load the file as a bytes string. If the file does not exist, ``None`` is returned.
@@ -49,7 +58,12 @@ def load_file_if_exists(path, module=None, ignore_errors=False):
module.fail_json(f"Error while loading {path} - {exc}") module.fail_json(f"Error while loading {path} - {exc}")
def write_file(module, content, default_mode=None, path=None): def write_file(
module: AnsibleModule,
content: bytes,
default_mode: str | int | None = None,
path: str | os.PathLike | None = None,
) -> None:
""" """
Writes content into destination file as securely as possible. Writes content into destination file as securely as possible.
Uses file arguments from module. Uses file arguments from module.

View File

@@ -8,14 +8,31 @@ import abc
import os import os
import stat import stat
import traceback import traceback
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import ( from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import (
parse_openssh_version, parse_openssh_version,
) )
def restore_on_failure(f): if t.TYPE_CHECKING:
def backup_and_restore(module, path, *args, **kwargs): from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)
from ..certificate import OpensshCertificateTimeParameters
Param = t.ParamSpec("Param")
def restore_on_failure(
f: t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None],
) -> t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None]:
def backup_and_restore(
module: AnsibleModule, path: str | os.PathLike, *args, **kwargs
) -> None:
backup_file = module.backup_local(path) if os.path.exists(path) else None backup_file = module.backup_local(path) if os.path.exists(path) else None
try: try:
@@ -31,12 +48,31 @@ def restore_on_failure(f):
@restore_on_failure @restore_on_failure
def safe_atomic_move(module, path, destination): def safe_atomic_move(
module: AnsibleModule, path: str | os.PathLike, destination: str | os.PathLike
) -> None:
module.atomic_move(os.path.abspath(path), os.path.abspath(destination)) module.atomic_move(os.path.abspath(path), os.path.abspath(destination))
def _restore_all_on_failure(f): def _restore_all_on_failure(
def backup_and_restore(self, sources_and_destinations, *args, **kwargs): f: t.Callable[
t.Concatenate[
OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param
],
None,
],
) -> t.Callable[
t.Concatenate[
OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param
],
None,
]:
def backup_and_restore(
self: OpensshModule,
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
*args,
**kwargs,
) -> None:
backups = [ backups = [
(d, self.module.backup_local(d)) (d, self.module.backup_local(d))
for s, d in sources_and_destinations for s, d in sources_and_destinations
@@ -59,13 +95,13 @@ def _restore_all_on_failure(f):
class OpensshModule(metaclass=abc.ABCMeta): class OpensshModule(metaclass=abc.ABCMeta):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
self.module = module self.module = module
self.changed = False self.changed: bool = False
self.check_mode = self.module.check_mode self.check_mode: bool = self.module.check_mode
def execute(self): def execute(self) -> t.NoReturn:
try: try:
self._execute() self._execute()
except Exception as e: except Exception as e:
@@ -77,11 +113,11 @@ class OpensshModule(metaclass=abc.ABCMeta):
self.module.exit_json(**self.result) self.module.exit_json(**self.result)
@abc.abstractmethod @abc.abstractmethod
def _execute(self): def _execute(self) -> None:
pass pass
@property @property
def result(self): def result(self) -> dict[str, t.Any]:
result = self._result result = self._result
result["changed"] = self.changed result["changed"] = self.changed
@@ -93,31 +129,31 @@ class OpensshModule(metaclass=abc.ABCMeta):
@property @property
@abc.abstractmethod @abc.abstractmethod
def _result(self): def _result(self) -> dict[str, t.Any]:
pass pass
@property @property
@abc.abstractmethod @abc.abstractmethod
def diff(self): def diff(self) -> dict[str, t.Any]:
pass pass
@staticmethod @staticmethod
def skip_if_check_mode(f): def skip_if_check_mode(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs) -> None:
if not self.check_mode: if not self.check_mode:
f(self, *args, **kwargs) f(self, *args, **kwargs)
return wrapper return wrapper # type: ignore
@staticmethod @staticmethod
def trigger_change(f): def trigger_change(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs) -> None:
f(self, *args, **kwargs) f(self, *args, **kwargs)
self.changed = True self.changed = True
return wrapper return wrapper # type: ignore
def _check_if_base_dir(self, path): def _check_if_base_dir(self, path: str | os.PathLike) -> None:
base_dir = os.path.dirname(path) or "." base_dir = os.path.dirname(path) or "."
if not os.path.isdir(base_dir): if not os.path.isdir(base_dir):
self.module.fail_json( self.module.fail_json(
@@ -125,16 +161,19 @@ class OpensshModule(metaclass=abc.ABCMeta):
msg=f"The directory {base_dir} does not exist or the file is not a directory", msg=f"The directory {base_dir} does not exist or the file is not a directory",
) )
def _get_ssh_version(self): def _get_ssh_version(self) -> str | None:
ssh_bin = self.module.get_bin_path("ssh") ssh_bin = self.module.get_bin_path("ssh")
if not ssh_bin: if not ssh_bin:
return "" return None
return parse_openssh_version( return parse_openssh_version(
self.module.run_command([ssh_bin, "-V", "-q"], check_rc=True)[2].strip() self.module.run_command([ssh_bin, "-V", "-q"], check_rc=True)[2].strip()
) )
@_restore_all_on_failure @_restore_all_on_failure
def _safe_secure_move(self, sources_and_destinations): def _safe_secure_move(
self,
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
) -> None:
"""Moves a list of files from 'source' to 'destination' and restores 'destination' from backup upon failure. """Moves a list of files from 'source' to 'destination' and restores 'destination' from backup upon failure.
If 'destination' does not already exist, then 'source' permissions are preserved to prevent If 'destination' does not already exist, then 'source' permissions are preserved to prevent
exposing protected data ('atomic_move' uses the 'destination' base directory mask for exposing protected data ('atomic_move' uses the 'destination' base directory mask for
@@ -148,7 +187,7 @@ class OpensshModule(metaclass=abc.ABCMeta):
else: else:
self.module.preserved_copy(source, destination) self.module.preserved_copy(source, destination)
def _update_permissions(self, path): def _update_permissions(self, path: str | os.PathLike) -> None:
file_args = self.module.load_file_common_arguments(self.module.params) file_args = self.module.load_file_common_arguments(self.module.params)
file_args["path"] = path file_args["path"] = path
@@ -161,25 +200,25 @@ class OpensshModule(metaclass=abc.ABCMeta):
class KeygenCommand: class KeygenCommand:
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
self._bin_path = module.get_bin_path("ssh-keygen", True) self._bin_path = module.get_bin_path("ssh-keygen", True)
self._run_command = module.run_command self._run_command = module.run_command
def generate_certificate( def generate_certificate(
self, self,
certificate_path, certificate_path: str,
identifier, identifier: str,
options, options: list[str] | None,
pkcs11_provider, pkcs11_provider: str | None,
principals, principals: list[str] | None,
serial_number, serial_number: int | None,
signature_algorithm, signature_algorithm: str | None,
signing_key_path, signing_key_path: str,
type, type: t.Literal["host", "user"] | None,
time_parameters, time_parameters: OpensshCertificateTimeParameters,
use_agent, use_agent: bool,
**kwargs, **kwargs,
): ) -> tuple[int, str, str]:
args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier] args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier]
if options: if options:
@@ -203,7 +242,9 @@ class KeygenCommand:
return self._run_command(args, **kwargs) return self._run_command(args, **kwargs)
def generate_keypair(self, private_key_path, size, type, comment, **kwargs): def generate_keypair(
self, private_key_path: str, size: int, type: str, comment: str | None, **kwargs
) -> tuple[int, str, str]:
args = [ args = [
self._bin_path, self._bin_path,
"-q", "-q",
@@ -224,32 +265,40 @@ class KeygenCommand:
return self._run_command(args, data=data, **kwargs) return self._run_command(args, data=data, **kwargs)
def get_certificate_info(self, certificate_path, **kwargs): def get_certificate_info(
self, certificate_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command( return self._run_command(
[self._bin_path, "-L", "-f", certificate_path], **kwargs [self._bin_path, "-L", "-f", certificate_path], **kwargs
) )
def get_matching_public_key(self, private_key_path, **kwargs): def get_matching_public_key(
self, private_key_path: str, **kwargs
) -> tuple[int, str, str]:
return self._run_command( return self._run_command(
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs [self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
) )
def get_private_key(self, private_key_path, **kwargs): def get_private_key(self, private_key_path: str, **kwargs) -> tuple[int, str, str]:
return self._run_command( return self._run_command(
[self._bin_path, "-l", "-f", private_key_path], **kwargs [self._bin_path, "-l", "-f", private_key_path], **kwargs
) )
def update_comment( def update_comment(
self, private_key_path, comment, force_new_format=True, **kwargs self,
): private_key_path: str,
comment: str,
force_new_format: bool = True,
**kwargs,
) -> tuple[int, str, str]:
if os.path.exists(private_key_path) and not os.access( if os.path.exists(private_key_path) and not os.access(
private_key_path, os.W_OK private_key_path, os.W_OK
): ):
try: try:
os.chmod(private_key_path, stat.S_IWUSR + stat.S_IRUSR) os.chmod(private_key_path, stat.S_IWUSR + stat.S_IRUSR)
except (IOError, OSError) as e: except (IOError, OSError) as e:
raise e( raise ValueError(
f"The private key at {private_key_path} is not writeable preventing a comment update" f"The private key at {private_key_path} is not writeable preventing a comment update ({e})"
) )
command = [self._bin_path, "-q"] command = [self._bin_path, "-q"]
@@ -259,31 +308,36 @@ class KeygenCommand:
return self._run_command(command, **kwargs) return self._run_command(command, **kwargs)
_PrivateKey = t.TypeVar("_PrivateKey", bound="PrivateKey")
class PrivateKey: class PrivateKey:
def __init__(self, size, key_type, fingerprint, format=""): def __init__(
self, size: int, key_type: str, fingerprint: str, format: str = ""
) -> None:
self._size = size self._size = size
self._type = key_type self._type = key_type
self._fingerprint = fingerprint self._fingerprint = fingerprint
self._format = format self._format = format
@property @property
def size(self): def size(self) -> int:
return self._size return self._size
@property @property
def type(self): def type(self) -> str:
return self._type return self._type
@property @property
def fingerprint(self): def fingerprint(self) -> str:
return self._fingerprint return self._fingerprint
@property @property
def format(self): def format(self) -> str:
return self._format return self._format
@classmethod @classmethod
def from_string(cls, string): def from_string(cls: t.Type[_PrivateKey], string: str) -> _PrivateKey:
properties = string.split() properties = string.split()
return cls( return cls(
@@ -292,7 +346,7 @@ class PrivateKey:
fingerprint=properties[1], fingerprint=properties[1],
) )
def to_dict(self): def to_dict(self) -> dict[str, t.Any]:
return { return {
"size": self._size, "size": self._size,
"type": self._type, "type": self._type,
@@ -301,13 +355,16 @@ class PrivateKey:
} }
_PublicKey = t.TypeVar("_PublicKey", bound="PublicKey")
class PublicKey: class PublicKey:
def __init__(self, type_string, data, comment): def __init__(self, type_string: str, data: str, comment: str | None) -> None:
self._type_string = type_string self._type_string = type_string
self._data = data self._data = data
self._comment = comment self._comment = comment
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)): if not isinstance(other, type(self)):
return NotImplemented return NotImplemented
@@ -323,30 +380,30 @@ class PublicKey:
] ]
) )
def __ne__(self, other): def __ne__(self, other: object) -> bool:
return not self == other return not self == other
def __str__(self): def __str__(self) -> str:
return f"{self._type_string} {self._data}" return f"{self._type_string} {self._data}"
@property @property
def comment(self): def comment(self) -> str | None:
return self._comment return self._comment
@comment.setter @comment.setter
def comment(self, value): def comment(self, value: str | None) -> None:
self._comment = value self._comment = value
@property @property
def data(self): def data(self) -> str:
return self._data return self._data
@property @property
def type_string(self): def type_string(self) -> str:
return self._type_string return self._type_string
@classmethod @classmethod
def from_string(cls, string): def from_string(cls: t.Type[_PublicKey], string: str) -> _PublicKey:
properties = string.strip("\n").split(" ", 2) properties = string.strip("\n").split(" ", 2)
return cls( return cls(
@@ -356,7 +413,7 @@ class PublicKey:
) )
@classmethod @classmethod
def load(cls, path): def load(cls: t.Type[_PublicKey], path: str | os.PathLike) -> _PublicKey | None:
try: try:
with open(path, "r") as f: with open(path, "r") as f:
properties = f.read().strip(" \n").split(" ", 2) properties = f.read().strip(" \n").split(" ", 2)
@@ -372,14 +429,16 @@ class PublicKey:
comment="" if len(properties) <= 2 else properties[2], comment="" if len(properties) <= 2 else properties[2],
) )
def to_dict(self): def to_dict(self) -> dict[str, t.Any]:
return { return {
"comment": self._comment, "comment": self._comment,
"public_key": self._data, "public_key": self._data,
} }
def parse_private_key_format(path): def parse_private_key_format(
path: str | os.PathLike,
) -> t.Literal["SSH", "PKCS8", "PKCS1", ""]:
with open(path, "r") as file: with open(path, "r") as file:
header = file.readline().strip() header = file.readline().strip()

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import abc import abc
import os import os
import typing as t
from ansible.module_utils.basic import missing_required_lib from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.common.text.converters import to_bytes, to_text
@@ -39,31 +40,43 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)
class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta): class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackend, self).__init__(module) super(KeypairBackend, self).__init__(module)
self.comment = self.module.params["comment"] self.comment: str | None = self.module.params["comment"]
self.private_key_path = self.module.params["path"] self.private_key_path: str = self.module.params["path"]
self.public_key_path = self.private_key_path + ".pub" self.public_key_path = self.private_key_path + ".pub"
self.regenerate = ( self.regenerate: t.Literal[
"never", "fail", "partial_idempotence", "full_idempotence", "always"
] = (
self.module.params["regenerate"] self.module.params["regenerate"]
if not self.module.params["force"] if not self.module.params["force"]
else "always" else "always"
) )
self.state = self.module.params["state"] self.state: t.Literal["present", "absent"] = self.module.params["state"]
self.type = self.module.params["type"] self.type: t.Literal["rsa", "dsa", "rsa1", "ecdsa", "ed25519"] = (
self.module.params["type"]
)
self.size = self._get_size(self.module.params["size"]) self.size: int = self._get_size(self.module.params["size"])
self._validate_path() self._validate_path()
self.original_private_key = None self.original_private_key: PrivateKey | None = None
self.original_public_key = None self.original_public_key: PublicKey | None = None
self.private_key = None self.private_key: PrivateKey | None = None
self.public_key = None self.public_key: PublicKey | None = None
def _get_size(self, size): def _get_size(self, size: int | None) -> int:
if self.type in ("rsa", "rsa1"): if self.type in ("rsa", "rsa1"):
result = 4096 if size is None else size result = 4096 if size is None else size
if result < 1024: if result < 1024:
@@ -96,7 +109,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
return result return result
def _validate_path(self): def _validate_path(self) -> None:
self._check_if_base_dir(self.private_key_path) self._check_if_base_dir(self.private_key_path)
if os.path.isdir(self.private_key_path): if os.path.isdir(self.private_key_path):
@@ -104,7 +117,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
msg=f"{self.private_key_path} is a directory. Please specify a path to a file." msg=f"{self.private_key_path} is a directory. Please specify a path to a file."
) )
def _execute(self): def _execute(self) -> None:
self.original_private_key = self._load_private_key() self.original_private_key = self._load_private_key()
self.original_public_key = self._load_public_key() self.original_public_key = self._load_public_key()
@@ -125,7 +138,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
if self._should_remove(): if self._should_remove():
self._remove() self._remove()
def _load_private_key(self): def _load_private_key(self) -> PrivateKey | None:
result = None result = None
if self._private_key_exists(): if self._private_key_exists():
try: try:
@@ -135,14 +148,14 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
return result return result
def _private_key_exists(self): def _private_key_exists(self) -> bool:
return os.path.exists(self.private_key_path) return os.path.exists(self.private_key_path)
@abc.abstractmethod @abc.abstractmethod
def _get_private_key(self): def _get_private_key(self) -> PrivateKey:
pass pass
def _load_public_key(self): def _load_public_key(self) -> PublicKey | None:
result = None result = None
if self._public_key_exists(): if self._public_key_exists():
try: try:
@@ -151,10 +164,10 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
pass pass
return result return result
def _public_key_exists(self): def _public_key_exists(self) -> bool:
return os.path.exists(self.public_key_path) return os.path.exists(self.public_key_path)
def _validate_key_load(self): def _validate_key_load(self) -> None:
if ( if (
self._private_key_exists() self._private_key_exists()
and self.regenerate in ("never", "fail", "partial_idempotence") and self.regenerate in ("never", "fail", "partial_idempotence")
@@ -167,10 +180,10 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
) )
@abc.abstractmethod @abc.abstractmethod
def _private_key_readable(self): def _private_key_readable(self) -> bool:
pass pass
def _should_generate(self): def _should_generate(self) -> bool:
if self.original_private_key is None: if self.original_private_key is None:
return True return True
elif self.regenerate == "never": elif self.regenerate == "never":
@@ -188,7 +201,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
else: else:
return True return True
def _private_key_valid(self): def _private_key_valid(self) -> bool:
if self.original_private_key is None: if self.original_private_key is None:
return False return False
@@ -196,17 +209,17 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
[ [
self.size == self.original_private_key.size, self.size == self.original_private_key.size,
self.type == self.original_private_key.type, self.type == self.original_private_key.type,
self._private_key_valid_backend(), self._private_key_valid_backend(self.original_private_key),
] ]
) )
@abc.abstractmethod @abc.abstractmethod
def _private_key_valid_backend(self): def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
pass pass
@OpensshModule.trigger_change @OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode @OpensshModule.skip_if_check_mode
def _generate(self): def _generate(self) -> None:
temp_private_key, temp_public_key = self._generate_temp_keypair() temp_private_key, temp_public_key = self._generate_temp_keypair()
try: try:
@@ -219,7 +232,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
except OSError as e: except OSError as e:
self.module.fail_json(msg=str(e)) self.module.fail_json(msg=str(e))
def _generate_temp_keypair(self): def _generate_temp_keypair(self) -> tuple[str, str]:
temp_private_key = os.path.join( temp_private_key = os.path.join(
self.module.tmpdir, os.path.basename(self.private_key_path) self.module.tmpdir, os.path.basename(self.private_key_path)
) )
@@ -236,25 +249,26 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
return temp_private_key, temp_public_key return temp_private_key, temp_public_key
@abc.abstractmethod @abc.abstractmethod
def _generate_keypair(self, private_key_path): def _generate_keypair(self, private_key_path: str) -> None:
pass pass
def _public_key_valid(self): def _public_key_valid(self) -> bool:
if self.original_public_key is None: if self.original_public_key is None:
return False return False
valid_public_key = self._get_public_key() valid_public_key = self._get_public_key()
valid_public_key.comment = self.comment if valid_public_key:
valid_public_key.comment = self.comment
return self.original_public_key == valid_public_key return self.original_public_key == valid_public_key
@abc.abstractmethod @abc.abstractmethod
def _get_public_key(self): def _get_public_key(self) -> PublicKey | t.Literal[""]:
pass pass
@OpensshModule.trigger_change @OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode @OpensshModule.skip_if_check_mode
def _restore_public_key(self): def _restore_public_key(self) -> None:
try: try:
temp_public_key = self._create_temp_public_key( temp_public_key = self._create_temp_public_key(
str(self._get_public_key()) + "\n" str(self._get_public_key()) + "\n"
@@ -269,7 +283,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
if self.comment: if self.comment:
self._update_comment() self._update_comment()
def _create_temp_public_key(self, content): def _create_temp_public_key(self, content: str | bytes) -> str:
temp_public_key = os.path.join( temp_public_key = os.path.join(
self.module.tmpdir, os.path.basename(self.public_key_path) self.module.tmpdir, os.path.basename(self.public_key_path)
) )
@@ -290,15 +304,15 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
return temp_public_key return temp_public_key
@abc.abstractmethod @abc.abstractmethod
def _update_comment(self): def _update_comment(self) -> None:
pass pass
def _should_remove(self): def _should_remove(self) -> bool:
return self._private_key_exists() or self._public_key_exists() return self._private_key_exists() or self._public_key_exists()
@OpensshModule.trigger_change @OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode @OpensshModule.skip_if_check_mode
def _remove(self): def _remove(self) -> None:
try: try:
if self._private_key_exists(): if self._private_key_exists():
os.remove(self.private_key_path) os.remove(self.private_key_path)
@@ -308,7 +322,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
self.module.fail_json(msg=str(e)) self.module.fail_json(msg=str(e))
@property @property
def _result(self): def _result(self) -> dict[str, t.Any]:
private_key = self.private_key or self.original_private_key private_key = self.private_key or self.original_private_key
public_key = self.public_key or self.original_public_key public_key = self.public_key or self.original_public_key
@@ -322,7 +336,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
} }
@property @property
def diff(self): def diff(self) -> dict[str, t.Any]:
before = ( before = (
self.original_private_key.to_dict() if self.original_private_key else {} self.original_private_key.to_dict() if self.original_private_key else {}
) )
@@ -340,7 +354,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
class KeypairBackendOpensshBin(KeypairBackend): class KeypairBackendOpensshBin(KeypairBackend):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackendOpensshBin, self).__init__(module) super(KeypairBackendOpensshBin, self).__init__(module)
if self.module.params["private_key_format"] != "auto": if self.module.params["private_key_format"] != "auto":
@@ -350,12 +364,12 @@ class KeypairBackendOpensshBin(KeypairBackend):
self.ssh_keygen = KeygenCommand(self.module) self.ssh_keygen = KeygenCommand(self.module)
def _generate_keypair(self, private_key_path): def _generate_keypair(self, private_key_path: str) -> None:
self.ssh_keygen.generate_keypair( self.ssh_keygen.generate_keypair(
private_key_path, self.size, self.type, self.comment, check_rc=True private_key_path, self.size, self.type, self.comment, check_rc=True
) )
def _get_private_key(self): def _get_private_key(self) -> PrivateKey:
rc, private_key_content, err = self.ssh_keygen.get_private_key( rc, private_key_content, err = self.ssh_keygen.get_private_key(
self.private_key_path, check_rc=False self.private_key_path, check_rc=False
) )
@@ -363,13 +377,13 @@ class KeypairBackendOpensshBin(KeypairBackend):
raise ValueError(err) raise ValueError(err)
return PrivateKey.from_string(private_key_content) return PrivateKey.from_string(private_key_content)
def _get_public_key(self): def _get_public_key(self) -> PublicKey | t.Literal[""]:
public_key_content = self.ssh_keygen.get_matching_public_key( public_key_content = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=True self.private_key_path, check_rc=True
)[1] )[1]
return PublicKey.from_string(public_key_content) return PublicKey.from_string(public_key_content)
def _private_key_readable(self): def _private_key_readable(self) -> bool:
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key( rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=False self.private_key_path, check_rc=False
) )
@@ -383,7 +397,7 @@ class KeypairBackendOpensshBin(KeypairBackend):
) )
) )
def _update_comment(self): def _update_comment(self) -> None:
try: try:
ssh_version = self._get_ssh_version() or "7.8" ssh_version = self._get_ssh_version() or "7.8"
force_new_format = ( force_new_format = (
@@ -391,19 +405,19 @@ class KeypairBackendOpensshBin(KeypairBackend):
) )
self.ssh_keygen.update_comment( self.ssh_keygen.update_comment(
self.private_key_path, self.private_key_path,
self.comment, self.comment or "",
force_new_format=force_new_format, force_new_format=force_new_format,
check_rc=True, check_rc=True,
) )
except (IOError, OSError) as e: except (IOError, OSError) as e:
self.module.fail_json(msg=str(e)) self.module.fail_json(msg=str(e))
def _private_key_valid_backend(self): def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
return True return True
class KeypairBackendCryptography(KeypairBackend): class KeypairBackendCryptography(KeypairBackend):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(KeypairBackendCryptography, self).__init__(module) super(KeypairBackendCryptography, self).__init__(module)
if self.type == "rsa1": if self.type == "rsa1":
@@ -416,12 +430,15 @@ class KeypairBackendCryptography(KeypairBackend):
if module.params["passphrase"] if module.params["passphrase"]
else None else None
) )
self.private_key_format = self._get_key_format( key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"] = module.params[
module.params["private_key_format"] "private_key_format"
) ]
self.private_key_format = self._get_key_format(key_format)
def _get_key_format(self, key_format): def _get_key_format(
result = "SSH" self, key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"]
) -> t.Literal["SSH", "PKCS1", "PKCS8"]:
result: t.Literal["SSH", "PKCS1", "PKCS8"] = "SSH"
if key_format == "auto": if key_format == "auto":
# Default to OpenSSH 7.8 compatibility when OpenSSH is not installed # Default to OpenSSH 7.8 compatibility when OpenSSH is not installed
@@ -435,11 +452,12 @@ class KeypairBackendCryptography(KeypairBackend):
# but still defaulted to PKCS1 format with the exception of ed25519 keys # but still defaulted to PKCS1 format with the exception of ed25519 keys
result = "PKCS1" result = "PKCS1"
else: else:
result = key_format.upper() result = key_format.upper() # type: ignore
return result return result
def _generate_keypair(self, private_key_path): def _generate_keypair(self, private_key_path: str) -> None:
assert self.type != "rsa1"
keypair = OpensshKeypair.generate( keypair = OpensshKeypair.generate(
keytype=self.type, keytype=self.type,
size=self.size, size=self.size,
@@ -455,7 +473,7 @@ class KeypairBackendCryptography(KeypairBackend):
public_key_path = private_key_path + ".pub" public_key_path = private_key_path + ".pub"
secure_write(public_key_path, 0o644, keypair.public_key) secure_write(public_key_path, 0o644, keypair.public_key)
def _get_private_key(self): def _get_private_key(self) -> PrivateKey:
keypair = OpensshKeypair.load( keypair = OpensshKeypair.load(
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
) )
@@ -467,7 +485,7 @@ class KeypairBackendCryptography(KeypairBackend):
format=parse_private_key_format(self.private_key_path), format=parse_private_key_format(self.private_key_path),
) )
def _get_public_key(self): def _get_public_key(self) -> PublicKey | t.Literal[""]:
try: try:
keypair = OpensshKeypair.load( keypair = OpensshKeypair.load(
path=self.private_key_path, path=self.private_key_path,
@@ -480,7 +498,7 @@ class KeypairBackendCryptography(KeypairBackend):
return PublicKey.from_string(to_text(keypair.public_key)) return PublicKey.from_string(to_text(keypair.public_key))
def _private_key_readable(self): def _private_key_readable(self) -> bool:
try: try:
OpensshKeypair.load( OpensshKeypair.load(
path=self.private_key_path, path=self.private_key_path,
@@ -504,7 +522,7 @@ class KeypairBackendCryptography(KeypairBackend):
return True return True
def _update_comment(self): def _update_comment(self) -> None:
keypair = OpensshKeypair.load( keypair = OpensshKeypair.load(
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
) )
@@ -519,16 +537,18 @@ class KeypairBackendCryptography(KeypairBackend):
except (IOError, OSError) as e: except (IOError, OSError) as e:
self.module.fail_json(msg=str(e)) self.module.fail_json(msg=str(e))
def _private_key_valid_backend(self): def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
# avoids breaking behavior and prevents # avoids breaking behavior and prevents
# automatic conversions with OpenSSH upgrades # automatic conversions with OpenSSH upgrades
if self.module.params["private_key_format"] == "auto": if self.module.params["private_key_format"] == "auto":
return True return True
return self.private_key_format == self.original_private_key.format return self.private_key_format == original_private_key.format
def select_backend(module, backend): def select_backend(
module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
) -> KeypairBackend:
can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion( can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion(
CRYPTOGRAPHY_VERSION CRYPTOGRAPHY_VERSION
) >= LooseVersion(COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION) ) >= LooseVersion(COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION)

View File

@@ -8,6 +8,7 @@ import abc
import binascii import binascii
import datetime as _datetime import datetime as _datetime
import os import os
import typing as t
from base64 import b64encode from base64 import b64encode
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
@@ -26,6 +27,16 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
) )
if t.TYPE_CHECKING:
from .cryptography import KeyType
DateFormat = t.Literal["human_readable", "openssh", "timestamp"]
DateFormatStr = t.Literal["human_readable", "openssh"]
DateFormatInt = t.Literal["timestamp"]
else:
KeyType = None
# Protocol References # Protocol References
# ------------------- # -------------------
# https://datatracker.ietf.org/doc/html/rfc4251 # https://datatracker.ietf.org/doc/html/rfc4251
@@ -44,7 +55,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
_USER_TYPE = 1 _USER_TYPE = 1
_HOST_TYPE = 2 _HOST_TYPE = 2
_SSH_TYPE_STRINGS = { _SSH_TYPE_STRINGS: dict[KeyType | str, bytes] = {
"rsa": b"ssh-rsa", "rsa": b"ssh-rsa",
"dsa": b"ssh-dss", "dsa": b"ssh-dss",
"ecdsa-nistp256": b"ecdsa-sha2-nistp256", "ecdsa-nistp256": b"ecdsa-sha2-nistp256",
@@ -94,16 +105,18 @@ _EXTENSIONS = (
class OpensshCertificateTimeParameters: class OpensshCertificateTimeParameters:
def __init__(self, valid_from, valid_to): def __init__(
self, valid_from: str | bytes | int, valid_to: str | bytes | int
) -> None:
self._valid_from = self.to_datetime(valid_from) self._valid_from = self.to_datetime(valid_from)
self._valid_to = self.to_datetime(valid_to) self._valid_to = self.to_datetime(valid_to)
if self._valid_from > self._valid_to: if self._valid_from > self._valid_to:
raise ValueError( raise ValueError(
f"Valid from: {valid_from} must not be greater than Valid to: {valid_to}" f"Valid from: {valid_from!r} must not be greater than Valid to: {valid_to!r}"
) )
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)): if not isinstance(other, type(self)):
return NotImplemented return NotImplemented
else: else:
@@ -112,55 +125,83 @@ class OpensshCertificateTimeParameters:
and self._valid_to == other._valid_to and self._valid_to == other._valid_to
) )
def __ne__(self, other): def __ne__(self, other: object) -> bool:
return not self == other return not self == other
@property @property
def validity_string(self): def validity_string(self) -> str:
if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER): if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER):
return f"{self.valid_from(date_format='openssh')}:{self.valid_to(date_format='openssh')}" return f"{self.valid_from(date_format='openssh')}:{self.valid_to(date_format='openssh')}"
return "" return ""
def valid_from(self, date_format): @t.overload
def valid_from(self, date_format: DateFormatStr) -> str: ...
@t.overload
def valid_from(self, date_format: DateFormatInt) -> int: ...
@t.overload
def valid_from(self, date_format: DateFormat) -> str | int: ...
def valid_from(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_from, date_format) return self.format_datetime(self._valid_from, date_format)
def valid_to(self, date_format): @t.overload
def valid_to(self, date_format: DateFormatStr) -> str: ...
@t.overload
def valid_to(self, date_format: DateFormatInt) -> int: ...
@t.overload
def valid_to(self, date_format: DateFormat) -> str | int: ...
def valid_to(self, date_format: DateFormat) -> str | int:
return self.format_datetime(self._valid_to, date_format) return self.format_datetime(self._valid_to, date_format)
def within_range(self, valid_at): def within_range(self, valid_at: str | bytes | int | None) -> bool:
if valid_at is not None: if valid_at is not None:
valid_at_datetime = self.to_datetime(valid_at) valid_at_datetime = self.to_datetime(valid_at)
return self._valid_from <= valid_at_datetime <= self._valid_to return self._valid_from <= valid_at_datetime <= self._valid_to
return True return True
@t.overload
@staticmethod @staticmethod
def format_datetime(dt, date_format): def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ...
@t.overload
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ...
@staticmethod
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int:
if date_format in ("human_readable", "openssh"): if date_format in ("human_readable", "openssh"):
if dt == _ALWAYS: if dt == _ALWAYS:
result = "always" return "always"
elif dt == _FOREVER: if dt == _FOREVER:
result = "forever" return "forever"
else: else:
result = ( return (
dt.isoformat().replace("+00:00", "") dt.isoformat().replace("+00:00", "")
if date_format == "human_readable" if date_format == "human_readable"
else dt.strftime("%Y%m%d%H%M%S") else dt.strftime("%Y%m%d%H%M%S")
) )
elif date_format == "timestamp": if date_format == "timestamp":
td = dt - _ALWAYS td = dt - _ALWAYS
result = int( return int(
(td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6 (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6
) )
else: raise ValueError(f"{date_format} is not a valid format")
raise ValueError(f"{date_format} is not a valid format")
return result
@staticmethod @staticmethod
def to_datetime(time_string_or_timestamp): def to_datetime(time_string_or_timestamp: str | bytes | int) -> datetime:
try: try:
if isinstance(time_string_or_timestamp, (str, bytes)): if isinstance(time_string_or_timestamp, (str, bytes)):
result = OpensshCertificateTimeParameters._time_string_to_datetime( result = OpensshCertificateTimeParameters._time_string_to_datetime(
time_string_or_timestamp.strip() to_text(time_string_or_timestamp.strip())
) )
elif isinstance(time_string_or_timestamp, int): elif isinstance(time_string_or_timestamp, int):
result = OpensshCertificateTimeParameters._timestamp_to_datetime( result = OpensshCertificateTimeParameters._timestamp_to_datetime(
@@ -175,43 +216,53 @@ class OpensshCertificateTimeParameters:
return result return result
@staticmethod @staticmethod
def _timestamp_to_datetime(timestamp): def _timestamp_to_datetime(timestamp: int) -> datetime:
if timestamp == 0x0: if timestamp == 0x0:
result = _ALWAYS return _ALWAYS
elif timestamp == 0xFFFFFFFFFFFFFFFF: if timestamp == 0xFFFFFFFFFFFFFFFF:
result = _FOREVER return _FOREVER
else: try:
try: return datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
result = datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc) except OverflowError:
except OverflowError: raise ValueError
raise ValueError
return result
@staticmethod @staticmethod
def _time_string_to_datetime(time_string): def _time_string_to_datetime(time_string: str) -> datetime:
result = None
if time_string == "always": if time_string == "always":
result = _ALWAYS return _ALWAYS
elif time_string == "forever": if time_string == "forever":
result = _FOREVER return _FOREVER
elif is_relative_time_string(time_string): if is_relative_time_string(time_string):
result = convert_relative_to_datetime(time_string, with_timezone=True) result = convert_relative_to_datetime(time_string, with_timezone=True)
else:
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
try:
result = _add_or_remove_timezone(
datetime.strptime(time_string, time_format),
with_timezone=True,
)
except ValueError:
pass
if result is None: if result is None:
raise ValueError raise ValueError
return result
result = None
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
try:
result = _add_or_remove_timezone(
datetime.strptime(time_string, time_format),
with_timezone=True,
)
except ValueError:
pass
if result is None:
raise ValueError
return result return result
_OpensshCertificateOption = t.TypeVar(
"_OpensshCertificateOption", bound="OpensshCertificateOption"
)
class OpensshCertificateOption: class OpensshCertificateOption:
def __init__(self, option_type, name, data): def __init__(
self,
option_type: t.Literal["critical", "extension"],
name: str | bytes,
data: str | bytes,
):
if option_type not in ("critical", "extension"): if option_type not in ("critical", "extension"):
raise ValueError("type must be either 'critical' or 'extension'") raise ValueError("type must be either 'critical' or 'extension'")
@@ -225,7 +276,7 @@ class OpensshCertificateOption:
self._name = name.lower() self._name = name.lower()
self._data = data self._data = data
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)): if not isinstance(other, type(self)):
return NotImplemented return NotImplemented
@@ -237,32 +288,34 @@ class OpensshCertificateOption:
] ]
) )
def __hash__(self): def __hash__(self) -> int:
return hash((self._option_type, self._name, self._data)) return hash((self._option_type, self._name, self._data))
def __ne__(self, other): def __ne__(self, other: object) -> bool:
return not self == other return not self == other
def __str__(self): def __str__(self) -> str:
if self._data: if self._data:
return f"{self._name}={self._data}" return f"{self._name!r}={self._data!r}"
return self._name return f"{self._name!r}"
@property @property
def data(self): def data(self) -> str | bytes:
return self._data return self._data
@property @property
def name(self): def name(self) -> str | bytes:
return self._name return self._name
@property @property
def type(self): def type(self) -> t.Literal["critical", "extension"]:
return self._option_type return self._option_type
@classmethod @classmethod
def from_string(cls, option_string): def from_string(
if not isinstance(option_string, (str, bytes)): cls: t.Type[_OpensshCertificateOption], option_string: str
) -> _OpensshCertificateOption:
if not isinstance(option_string, str):
raise ValueError( raise ValueError(
f"option_string must be a string not {type(option_string)}" f"option_string must be a string not {type(option_string)}"
) )
@@ -280,7 +333,8 @@ class OpensshCertificateOption:
name, data = option_string.strip(), "" name, data = option_string.strip(), ""
return cls( return cls(
option_type=option_type or get_option_type(name.lower()), # We have str, but we're expecting a specific literal:
option_type=option_type or get_option_type(name.lower()), # type: ignore
name=name, name=name,
data=data, data=data,
) )
@@ -291,21 +345,21 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
def __init__( def __init__(
self, self,
nonce=None, nonce: bytes | None = None,
serial=None, serial: int | None = None,
cert_type=None, cert_type: int | None = None,
key_id=None, key_id: bytes | None = None,
principals=None, principals: list[bytes] | None = None,
valid_after=None, valid_after: int | None = None,
valid_before=None, valid_before: int | None = None,
critical_options=None, critical_options: list[tuple[bytes, bytes]] | None = None,
extensions=None, extensions: list[tuple[bytes, bytes]] | None = None,
reserved=None, reserved: bytes | None = None,
signing_key=None, signing_key: bytes | None = None,
): ):
self.nonce = nonce self.nonce = nonce
self.serial = serial self.serial = serial
self._cert_type = cert_type self._cert_type: int | None = cert_type
self.key_id = key_id self.key_id = key_id
self.principals = principals self.principals = principals
self.valid_after = valid_after self.valid_after = valid_after
@@ -315,10 +369,10 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
self.reserved = reserved self.reserved = reserved
self.signing_key = signing_key self.signing_key = signing_key
self.type_string = None self.type_string: bytes | None = None
@property @property
def cert_type(self): def cert_type(self) -> t.Literal["user", "host", ""]:
if self._cert_type == _USER_TYPE: if self._cert_type == _USER_TYPE:
return "user" return "user"
elif self._cert_type == _HOST_TYPE: elif self._cert_type == _HOST_TYPE:
@@ -327,7 +381,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
return "" return ""
@cert_type.setter @cert_type.setter
def cert_type(self, cert_type): def cert_type(self, cert_type: t.Literal["user", "host"] | int) -> None:
if cert_type == "user" or cert_type == _USER_TYPE: if cert_type == "user" or cert_type == _USER_TYPE:
self._cert_type = _USER_TYPE self._cert_type = _USER_TYPE
elif cert_type == "host" or cert_type == _HOST_TYPE: elif cert_type == "host" or cert_type == _HOST_TYPE:
@@ -335,28 +389,30 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
else: else:
raise ValueError(f"{cert_type} is not a valid certificate type") raise ValueError(f"{cert_type} is not a valid certificate type")
def signing_key_fingerprint(self): def signing_key_fingerprint(self) -> bytes:
if self.signing_key is None:
raise ValueError("signing_key not present")
return fingerprint(self.signing_key) return fingerprint(self.signing_key)
@abc.abstractmethod @abc.abstractmethod
def public_key_fingerprint(self): def public_key_fingerprint(self) -> bytes:
pass pass
@abc.abstractmethod @abc.abstractmethod
def parse_public_numbers(self, parser): def parse_public_numbers(self, parser: OpensshParser) -> None:
pass pass
class OpensshRSACertificateInfo(OpensshCertificateInfo): class OpensshRSACertificateInfo(OpensshCertificateInfo):
def __init__(self, e=None, n=None, **kwargs): def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None:
super(OpensshRSACertificateInfo, self).__init__(**kwargs) super(OpensshRSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01 self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
self.e = e self.e = e
self.n = n self.n = n
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6 # See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self): def public_key_fingerprint(self) -> bytes:
if any([self.e is None, self.n is None]): if self.e is None or self.n is None:
return b"" return b""
writer = _OpensshWriter() writer = _OpensshWriter()
@@ -366,13 +422,20 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
return fingerprint(writer.bytes()) return fingerprint(writer.bytes())
def parse_public_numbers(self, parser): def parse_public_numbers(self, parser: OpensshParser) -> None:
self.e = parser.mpint() self.e = parser.mpint()
self.n = parser.mpint() self.n = parser.mpint()
class OpensshDSACertificateInfo(OpensshCertificateInfo): class OpensshDSACertificateInfo(OpensshCertificateInfo):
def __init__(self, p=None, q=None, g=None, y=None, **kwargs): def __init__(
self,
p: int | None = None,
q: int | None = None,
g: int | None = None,
y: int | None = None,
**kwargs,
) -> None:
super(OpensshDSACertificateInfo, self).__init__(**kwargs) super(OpensshDSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01 self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01
self.p = p self.p = p
@@ -381,8 +444,8 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
self.y = y self.y = y
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6 # See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self): def public_key_fingerprint(self) -> bytes:
if any([self.p is None, self.q is None, self.g is None, self.y is None]): if self.p is None or self.q is None or self.g is None or self.y is None:
return b"" return b""
writer = _OpensshWriter() writer = _OpensshWriter()
@@ -394,7 +457,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
return fingerprint(writer.bytes()) return fingerprint(writer.bytes())
def parse_public_numbers(self, parser): def parse_public_numbers(self, parser: OpensshParser) -> None:
self.p = parser.mpint() self.p = parser.mpint()
self.q = parser.mpint() self.q = parser.mpint()
self.g = parser.mpint() self.g = parser.mpint()
@@ -402,7 +465,9 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
class OpensshECDSACertificateInfo(OpensshCertificateInfo): class OpensshECDSACertificateInfo(OpensshCertificateInfo):
def __init__(self, curve=None, public_key=None, **kwargs): def __init__(
self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
):
super(OpensshECDSACertificateInfo, self).__init__(**kwargs) super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
self._curve = None self._curve = None
if curve is not None: if curve is not None:
@@ -411,11 +476,11 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
self.public_key = public_key self.public_key = public_key
@property @property
def curve(self): def curve(self) -> bytes | None:
return self._curve return self._curve
@curve.setter @curve.setter
def curve(self, curve): def curve(self, curve: bytes) -> None:
if curve in _ECDSA_CURVE_IDENTIFIERS.values(): if curve in _ECDSA_CURVE_IDENTIFIERS.values():
self._curve = curve self._curve = curve
self.type_string = ( self.type_string = (
@@ -428,8 +493,8 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
) )
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6 # See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self): def public_key_fingerprint(self) -> bytes:
if any([self.curve is None, self.public_key is None]): if self.curve is None or self.public_key is None:
return b"" return b""
writer = _OpensshWriter() writer = _OpensshWriter()
@@ -439,18 +504,18 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
return fingerprint(writer.bytes()) return fingerprint(writer.bytes())
def parse_public_numbers(self, parser): def parse_public_numbers(self, parser: OpensshParser) -> None:
self.curve = parser.string() self.curve = parser.string()
self.public_key = parser.string() self.public_key = parser.string()
class OpensshED25519CertificateInfo(OpensshCertificateInfo): class OpensshED25519CertificateInfo(OpensshCertificateInfo):
def __init__(self, pk=None, **kwargs): def __init__(self, pk: bytes | None = None, **kwargs) -> None:
super(OpensshED25519CertificateInfo, self).__init__(**kwargs) super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01 self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
self.pk = pk self.pk = pk
def public_key_fingerprint(self): def public_key_fingerprint(self) -> bytes:
if self.pk is None: if self.pk is None:
return b"" return b""
@@ -460,21 +525,26 @@ class OpensshED25519CertificateInfo(OpensshCertificateInfo):
return fingerprint(writer.bytes()) return fingerprint(writer.bytes())
def parse_public_numbers(self, parser): def parse_public_numbers(self, parser: OpensshParser) -> None:
self.pk = parser.string() self.pk = parser.string()
_OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate")
# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD # See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
class OpensshCertificate: class OpensshCertificate:
"""Encapsulates a formatted OpenSSH certificate including signature and signing key""" """Encapsulates a formatted OpenSSH certificate including signature and signing key"""
def __init__(self, cert_info, signature): def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes):
self._cert_info = cert_info self._cert_info = cert_info
self.signature = signature self.signature = signature
@classmethod @classmethod
def load(cls, path): def load(
cls: t.Type[_OpensshCertificate], path: str | os.PathLike
) -> _OpensshCertificate:
if not os.path.exists(path): if not os.path.exists(path):
raise ValueError(f"{path} is not a valid path.") raise ValueError(f"{path} is not a valid path.")
@@ -492,11 +562,11 @@ class OpensshCertificate:
for key_type, string in _SSH_TYPE_STRINGS.items(): for key_type, string in _SSH_TYPE_STRINGS.items():
if format_identifier == string + _CERT_SUFFIX_V01: if format_identifier == string + _CERT_SUFFIX_V01:
pub_key_type = key_type pub_key_type = t.cast(KeyType, key_type)
break break
else: else:
raise ValueError( raise ValueError(
f"Invalid certificate format identifier: {format_identifier}" f"Invalid certificate format identifier: {format_identifier!r}"
) )
parser = OpensshParser(cert) parser = OpensshParser(cert)
@@ -521,75 +591,97 @@ class OpensshCertificate:
) )
@property @property
def type_string(self): def type_string(self) -> str:
return to_text(self._cert_info.type_string) return to_text(self._cert_info.type_string)
@property @property
def nonce(self): def nonce(self) -> bytes:
if self._cert_info.nonce is None:
raise ValueError
return self._cert_info.nonce return self._cert_info.nonce
@property @property
def public_key(self): def public_key(self) -> str:
return to_text(self._cert_info.public_key_fingerprint()) return to_text(self._cert_info.public_key_fingerprint())
@property @property
def serial(self): def serial(self) -> int:
if self._cert_info.serial is None:
raise ValueError
return self._cert_info.serial return self._cert_info.serial
@property @property
def type(self): def type(self) -> t.Literal["user", "host"]:
return self._cert_info.cert_type result = self._cert_info.cert_type
if result == "":
raise ValueError
return result
@property @property
def key_id(self): def key_id(self) -> str:
return to_text(self._cert_info.key_id) return to_text(self._cert_info.key_id)
@property @property
def principals(self): def principals(self) -> list[str]:
if self._cert_info.principals is None:
raise ValueError
return [to_text(p) for p in self._cert_info.principals] return [to_text(p) for p in self._cert_info.principals]
@property @property
def valid_after(self): def valid_after(self) -> int:
if self._cert_info.valid_after is None:
raise ValueError
return self._cert_info.valid_after return self._cert_info.valid_after
@property @property
def valid_before(self): def valid_before(self) -> int:
if self._cert_info.valid_before is None:
raise ValueError
return self._cert_info.valid_before return self._cert_info.valid_before
@property @property
def critical_options(self): def critical_options(self) -> list[OpensshCertificateOption]:
if self._cert_info.critical_options is None:
raise ValueError
return [ return [
OpensshCertificateOption("critical", to_text(n), to_text(d)) OpensshCertificateOption("critical", to_text(n), to_text(d))
for n, d in self._cert_info.critical_options for n, d in self._cert_info.critical_options
] ]
@property @property
def extensions(self): def extensions(self) -> list[OpensshCertificateOption]:
if self._cert_info.extensions is None:
raise ValueError
return [ return [
OpensshCertificateOption("extension", to_text(n), to_text(d)) OpensshCertificateOption("extension", to_text(n), to_text(d))
for n, d in self._cert_info.extensions for n, d in self._cert_info.extensions
] ]
@property @property
def reserved(self): def reserved(self) -> bytes:
if self._cert_info.reserved is None:
raise ValueError
return self._cert_info.reserved return self._cert_info.reserved
@property @property
def signing_key(self): def signing_key(self) -> str:
return to_text(self._cert_info.signing_key_fingerprint()) return to_text(self._cert_info.signing_key_fingerprint())
@property @property
def signature_type(self): def signature_type(self) -> str:
signature_data = OpensshParser.signature_data(self.signature) signature_data = OpensshParser.signature_data(self.signature)
return to_text(signature_data["signature_type"]) return to_text(signature_data["signature_type"])
@staticmethod @staticmethod
def _parse_cert_info(pub_key_type, parser): def _parse_cert_info(
pub_key_type: KeyType, parser: OpensshParser
) -> OpensshCertificateInfo:
cert_info = get_cert_info_object(pub_key_type) cert_info = get_cert_info_object(pub_key_type)
cert_info.nonce = parser.string() cert_info.nonce = parser.string()
cert_info.parse_public_numbers(parser) cert_info.parse_public_numbers(parser)
cert_info.serial = parser.uint64() cert_info.serial = parser.uint64()
cert_info.cert_type = parser.uint32() # mypy doesn't understand that the setter accepts other types than the getter:
cert_info.cert_type = parser.uint32() # type: ignore
cert_info.key_id = parser.string() cert_info.key_id = parser.string()
cert_info.principals = parser.string_list() cert_info.principals = parser.string_list()
cert_info.valid_after = parser.uint64() cert_info.valid_after = parser.uint64()
@@ -601,7 +693,7 @@ class OpensshCertificate:
return cert_info return cert_info
def to_dict(self): def to_dict(self) -> dict[str, t.Any]:
time_parameters = OpensshCertificateTimeParameters( time_parameters = OpensshCertificateTimeParameters(
valid_from=self.valid_after, valid_to=self.valid_before valid_from=self.valid_after, valid_to=self.valid_before
) )
@@ -624,7 +716,7 @@ class OpensshCertificate:
} }
def apply_directives(directives): def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOption]:
if any(d not in _DIRECTIVES for d in directives): if any(d not in _DIRECTIVES for d in directives):
raise ValueError(f"directives must be one of {', '.join(_DIRECTIVES)}") raise ValueError(f"directives must be one of {', '.join(_DIRECTIVES)}")
@@ -650,50 +742,47 @@ def apply_directives(directives):
) )
def default_options(): def default_options() -> list[OpensshCertificateOption]:
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS] return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
def fingerprint(public_key): def fingerprint(public_key: bytes) -> bytes:
"""Generates a SHA256 hash and formats output to resemble ``ssh-keygen``""" """Generates a SHA256 hash and formats output to resemble ``ssh-keygen``"""
h = sha256() h = sha256()
h.update(public_key) h.update(public_key)
return b"SHA256:" + b64encode(h.digest()).rstrip(b"=") return b"SHA256:" + b64encode(h.digest()).rstrip(b"=")
def get_cert_info_object(key_type): def get_cert_info_object(key_type: KeyType) -> OpensshCertificateInfo:
if key_type == "rsa": if key_type == "rsa":
cert_info = OpensshRSACertificateInfo() return OpensshRSACertificateInfo()
elif key_type == "dsa": if key_type == "dsa":
cert_info = OpensshDSACertificateInfo() return OpensshDSACertificateInfo()
elif key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"): if key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
cert_info = OpensshECDSACertificateInfo() return OpensshECDSACertificateInfo()
elif key_type == "ed25519": if key_type == "ed25519":
cert_info = OpensshED25519CertificateInfo() return OpensshED25519CertificateInfo()
else: raise ValueError(f"{key_type} is not a valid key type")
raise ValueError(f"{key_type} is not a valid key type")
return cert_info
def get_option_type(name): def get_option_type(name: str) -> t.Literal["critical", "extension"]:
if name in _CRITICAL_OPTIONS: if name in _CRITICAL_OPTIONS:
result = "critical" return "critical"
elif name in _EXTENSIONS: if name in _EXTENSIONS:
result = "extension" return "extension"
else: raise ValueError(
raise ValueError( f"{name} is not a valid option. "
f"{name} is not a valid option. " "Custom options must start with 'critical:' or 'extension:' to indicate type"
"Custom options must start with 'critical:' or 'extension:' to indicate type" )
)
return result
def is_relative_time_string(time_string): def is_relative_time_string(time_string: str) -> bool:
return time_string.startswith("+") or time_string.startswith("-") return time_string.startswith("+") or time_string.startswith("-")
def parse_option_list(option_list): def parse_option_list(
option_list: t.Iterable[str],
) -> tuple[list[OpensshCertificateOption], list[OpensshCertificateOption]]:
critical_options = [] critical_options = []
directives = [] directives = []
extensions = [] extensions = []

View File

@@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
import typing as t
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from getpass import getuser from getpass import getuser
from socket import gethostname from socket import gethostname
@@ -64,6 +65,27 @@ except ImportError:
CRYPTOGRAPHY_VERSION = "0.0" CRYPTOGRAPHY_VERSION = "0.0"
_ALGORITHM_PARAMETERS = {} _ALGORITHM_PARAMETERS = {}
if t.TYPE_CHECKING:
KeyFormat = t.Literal["SSH", "PKCS8", "PKCS1"]
KeySerializationFormat = t.Literal["PEM", "DER", "SSH"]
KeyType = t.Literal["rsa", "dsa", "ed25519", "ecdsa"]
PrivateKeyTypes = t.Union[
rsa.RSAPrivateKey,
dsa.DSAPrivateKey,
ec.EllipticCurvePrivateKey,
Ed25519PrivateKey,
]
PublicKeyTypes = t.Union[
rsa.RSAPublicKey, dsa.DSAPublicKey, ec.EllipticCurvePublicKey, Ed25519PublicKey
]
from cryptography.hazmat.primitives.asymmetric.types import (
PublicKeyTypes as AllPublicKeyTypes,
)
_TEXT_ENCODING = "UTF-8" _TEXT_ENCODING = "UTF-8"
@@ -111,11 +133,19 @@ class InvalidSignatureError(OpenSSHError):
pass pass
_AsymmetricKeypair = t.TypeVar("_AsymmetricKeypair", bound="AsymmetricKeypair")
class AsymmetricKeypair: class AsymmetricKeypair:
"""Container for newly generated asymmetric key pairs or those loaded from existing files""" """Container for newly generated asymmetric key pairs or those loaded from existing files"""
@classmethod @classmethod
def generate(cls, keytype="rsa", size=None, passphrase=None): def generate(
cls: t.Type[_AsymmetricKeypair],
keytype: KeyType = "rsa",
size: int | None = None,
passphrase: bytes | None = None,
) -> _AsymmetricKeypair:
"""Returns an Asymmetric_Keypair object generated with the supplied parameters """Returns an Asymmetric_Keypair object generated with the supplied parameters
or defaults to an unencrypted RSA-2048 key or defaults to an unencrypted RSA-2048 key
@@ -124,19 +154,21 @@ class AsymmetricKeypair:
:passphrase: Secret of type Bytes used to encrypt the private key being generated :passphrase: Secret of type Bytes used to encrypt the private key being generated
""" """
if keytype not in _ALGORITHM_PARAMETERS.keys(): if keytype not in _ALGORITHM_PARAMETERS:
raise InvalidKeyTypeError( raise InvalidKeyTypeError(
f"{keytype} is not a valid keytype. Valid keytypes are {', '.join(_ALGORITHM_PARAMETERS)}" f"{keytype} is not a valid keytype. Valid keytypes are {', '.join(_ALGORITHM_PARAMETERS)}"
) )
if not size: if not size:
size = _ALGORITHM_PARAMETERS[keytype]["default_size"] size = _ALGORITHM_PARAMETERS[keytype]["default_size"] # type: ignore
else: else:
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]: if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]: # type: ignore
raise InvalidKeySizeError( raise InvalidKeySizeError(
f"{size} is not a valid key size for {keytype} keys" f"{size} is not a valid key size for {keytype} keys"
) )
size = t.cast(int, size)
privatekey: PrivateKeyTypes
if passphrase: if passphrase:
encryption_algorithm = get_encryption_algorithm(passphrase) encryption_algorithm = get_encryption_algorithm(passphrase)
else: else:
@@ -157,7 +189,7 @@ class AsymmetricKeypair:
privatekey = Ed25519PrivateKey.generate() privatekey = Ed25519PrivateKey.generate()
elif keytype == "ecdsa": elif keytype == "ecdsa":
privatekey = ec.generate_private_key( privatekey = ec.generate_private_key(
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size], _ALGORITHM_PARAMETERS["ecdsa"]["curves"][size], # type: ignore
) )
publickey = privatekey.public_key() publickey = privatekey.public_key()
@@ -172,13 +204,13 @@ class AsymmetricKeypair:
@classmethod @classmethod
def load( def load(
cls, cls: t.Type[_AsymmetricKeypair],
path, path: str | os.PathLike,
passphrase=None, passphrase: bytes | None = None,
private_key_format="PEM", private_key_format: KeySerializationFormat = "PEM",
public_key_format="PEM", public_key_format: KeySerializationFormat = "PEM",
no_public_key=False, no_public_key: bool = False,
): ) -> _AsymmetricKeypair:
"""Returns an Asymmetric_Keypair object loaded from the supplied file path """Returns an Asymmetric_Keypair object loaded from the supplied file path
:path: A path to an existing private key to be loaded :path: A path to an existing private key to be loaded
@@ -197,14 +229,17 @@ class AsymmetricKeypair:
if no_public_key: if no_public_key:
publickey = privatekey.public_key() publickey = privatekey.public_key()
else: else:
publickey = load_publickey(path + ".pub", public_key_format) # TODO: BUG: load_publickey() can return unsupported key types
# (Also we should check whether the public key fits the private key...)
publickey = load_publickey(path + ".pub", public_key_format) # type: ignore
# Ed25519 keys are always of size 256 and do not have a key_size attribute # Ed25519 keys are always of size 256 and do not have a key_size attribute
if isinstance(privatekey, Ed25519PrivateKey): if isinstance(privatekey, Ed25519PrivateKey):
size = _ALGORITHM_PARAMETERS["ed25519"]["default_size"] size: int = _ALGORITHM_PARAMETERS["ed25519"]["default_size"] # type: ignore
else: else:
size = privatekey.key_size size = privatekey.key_size
keytype: KeyType
if isinstance(privatekey, rsa.RSAPrivateKey): if isinstance(privatekey, rsa.RSAPrivateKey):
keytype = "rsa" keytype = "rsa"
elif isinstance(privatekey, dsa.DSAPrivateKey): elif isinstance(privatekey, dsa.DSAPrivateKey):
@@ -224,7 +259,14 @@ class AsymmetricKeypair:
encryption_algorithm=encryption_algorithm, encryption_algorithm=encryption_algorithm,
) )
def __init__(self, keytype, size, privatekey, publickey, encryption_algorithm): def __init__(
self,
keytype: KeyType,
size: int,
privatekey: PrivateKeyTypes,
publickey: PublicKeyTypes,
encryption_algorithm: serialization.KeySerializationEncryption,
) -> None:
""" """
:keytype: One of rsa, dsa, ecdsa, ed25519 :keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for the private key of this key pair :size: The key length for the private key of this key pair
@@ -246,7 +288,7 @@ class AsymmetricKeypair:
"The private key and public key of this keypair do not match" "The private key and public key of this keypair do not match"
) )
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, AsymmetricKeypair): if not isinstance(other, AsymmetricKeypair):
return NotImplemented return NotImplemented
@@ -256,55 +298,53 @@ class AsymmetricKeypair:
self.encryption_algorithm, other.encryption_algorithm self.encryption_algorithm, other.encryption_algorithm
) )
def __ne__(self, other): def __ne__(self, other: object) -> bool:
return not self == other return not self == other
@property @property
def private_key(self): def private_key(self) -> PrivateKeyTypes:
"""Returns the private key of this key pair""" """Returns the private key of this key pair"""
return self.__privatekey return self.__privatekey
@property @property
def public_key(self): def public_key(self) -> PublicKeyTypes:
"""Returns the public key of this key pair""" """Returns the public key of this key pair"""
return self.__publickey return self.__publickey
@property @property
def size(self): def size(self) -> int:
"""Returns the size of the private key of this key pair""" """Returns the size of the private key of this key pair"""
return self.__size return self.__size
@property @property
def key_type(self): def key_type(self) -> KeyType:
"""Returns the key type of this key pair""" """Returns the key type of this key pair"""
return self.__keytype return self.__keytype
@property @property
def encryption_algorithm(self): def encryption_algorithm(self) -> serialization.KeySerializationEncryption:
"""Returns the key encryption algorithm of this key pair""" """Returns the key encryption algorithm of this key pair"""
return self.__encryption_algorithm return self.__encryption_algorithm
def sign(self, data): def sign(self, data: bytes) -> bytes:
"""Returns signature of data signed with the private key of this key pair """Returns signature of data signed with the private key of this key pair
:data: byteslike data to sign :data: byteslike data to sign
""" """
try: try:
signature = self.__privatekey.sign( return self.__privatekey.sign(
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"] data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"] # type: ignore
) )
except TypeError as e: except TypeError as e:
raise InvalidDataError(e) raise InvalidDataError(e)
return signature def verify(self, signature: bytes, data: bytes) -> None:
def verify(self, signature, data):
"""Verifies that the signature associated with the provided data was signed """Verifies that the signature associated with the provided data was signed
by the private key of this key pair. by the private key of this key pair.
@@ -312,15 +352,15 @@ class AsymmetricKeypair:
:data: byteslike data signed by the provided signature :data: byteslike data signed by the provided signature
""" """
try: try:
return self.__publickey.verify( self.__publickey.verify(
signature, signature,
data, data,
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"], **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"], # type: ignore
) )
except InvalidSignature: except InvalidSignature:
raise InvalidSignatureError raise InvalidSignatureError
def update_passphrase(self, passphrase=None): def update_passphrase(self, passphrase: bytes | None = None) -> None:
"""Updates the encryption algorithm of this key pair """Updates the encryption algorithm of this key pair
:passphrase: Byte secret used to encrypt this key pair :passphrase: Byte secret used to encrypt this key pair
@@ -332,11 +372,20 @@ class AsymmetricKeypair:
self.__encryption_algorithm = serialization.NoEncryption() self.__encryption_algorithm = serialization.NoEncryption()
_OpensshKeypair = t.TypeVar("_OpensshKeypair", bound="OpensshKeypair")
class OpensshKeypair: class OpensshKeypair:
"""Container for OpenSSH encoded asymmetric key pairs""" """Container for OpenSSH encoded asymmetric key pairs"""
@classmethod @classmethod
def generate(cls, keytype="rsa", size=None, passphrase=None, comment=None): def generate(
cls: t.Type[_OpensshKeypair],
keytype: KeyType = "rsa",
size: int | None = None,
passphrase: bytes | None = None,
comment: str | None = None,
) -> _OpensshKeypair:
"""Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key """Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key
:keytype: One of rsa, dsa, ecdsa, ed25519 :keytype: One of rsa, dsa, ecdsa, ed25519
@@ -362,7 +411,12 @@ class OpensshKeypair:
) )
@classmethod @classmethod
def load(cls, path, passphrase=None, no_public_key=False): def load(
cls: t.Type[_OpensshKeypair],
path: str | os.PathLike,
passphrase: bytes | None = None,
no_public_key: bool = False,
) -> _OpensshKeypair:
"""Returns an Openssh_Keypair object loaded from the supplied file path """Returns an Openssh_Keypair object loaded from the supplied file path
:path: A path to an existing private key to be loaded :path: A path to an existing private key to be loaded
@@ -373,7 +427,7 @@ class OpensshKeypair:
if no_public_key: if no_public_key:
comment = "" comment = ""
else: else:
comment = extract_comment(path + ".pub") comment = extract_comment(str(path) + ".pub")
asym_keypair = AsymmetricKeypair.load( asym_keypair = AsymmetricKeypair.load(
path, passphrase, "SSH", "SSH", no_public_key path, passphrase, "SSH", "SSH", no_public_key
@@ -391,7 +445,9 @@ class OpensshKeypair:
) )
@staticmethod @staticmethod
def encode_openssh_privatekey(asym_keypair, key_format): def encode_openssh_privatekey(
asym_keypair: AsymmetricKeypair, key_format: KeyFormat
) -> bytes:
"""Returns an OpenSSH encoded private key for a given keypair """Returns an OpenSSH encoded private key for a given keypair
:asym_keypair: Asymmetric_Keypair from the private key is extracted :asym_keypair: Asymmetric_Keypair from the private key is extracted
@@ -422,7 +478,9 @@ class OpensshKeypair:
return encoded_privatekey return encoded_privatekey
@staticmethod @staticmethod
def encode_openssh_publickey(asym_keypair, comment): def encode_openssh_publickey(
asym_keypair: AsymmetricKeypair, comment: str
) -> bytes:
"""Returns an OpenSSH encoded public key for a given keypair """Returns an OpenSSH encoded public key for a given keypair
:asym_keypair: Asymmetric_Keypair from the public key is extracted :asym_keypair: Asymmetric_Keypair from the public key is extracted
@@ -436,14 +494,19 @@ class OpensshKeypair:
validate_comment(comment) validate_comment(comment)
encoded_publickey += ( encoded_publickey += (
f" {comment}".encode(encoding=_TEXT_ENCODING) if comment else b"" (b" " + comment.encode(encoding=_TEXT_ENCODING)) if comment else b""
) )
return encoded_publickey return encoded_publickey
def __init__( def __init__(
self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment self,
): asym_keypair: AsymmetricKeypair,
openssh_privatekey: bytes,
openssh_publickey: bytes,
fingerprint: str,
comment: str | None,
) -> None:
""" """
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived :asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
:openssh_privatekey: An OpenSSH encoded private key :openssh_privatekey: An OpenSSH encoded private key
@@ -458,7 +521,7 @@ class OpensshKeypair:
self.__fingerprint = fingerprint self.__fingerprint = fingerprint
self.__comment = comment self.__comment = comment
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, OpensshKeypair): if not isinstance(other, OpensshKeypair):
return NotImplemented return NotImplemented
@@ -468,49 +531,49 @@ class OpensshKeypair:
) )
@property @property
def asymmetric_keypair(self): def asymmetric_keypair(self) -> AsymmetricKeypair:
"""Returns the underlying asymmetric key pair of this OpenSSH encoded key pair""" """Returns the underlying asymmetric key pair of this OpenSSH encoded key pair"""
return self.__asym_keypair return self.__asym_keypair
@property @property
def private_key(self): def private_key(self) -> bytes:
"""Returns the OpenSSH formatted private key of this key pair""" """Returns the OpenSSH formatted private key of this key pair"""
return self.__openssh_privatekey return self.__openssh_privatekey
@property @property
def public_key(self): def public_key(self) -> bytes:
"""Returns the OpenSSH formatted public key of this key pair""" """Returns the OpenSSH formatted public key of this key pair"""
return self.__openssh_publickey return self.__openssh_publickey
@property @property
def size(self): def size(self) -> int:
"""Returns the size of the private key of this key pair""" """Returns the size of the private key of this key pair"""
return self.__asym_keypair.size return self.__asym_keypair.size
@property @property
def key_type(self): def key_type(self) -> KeyType:
"""Returns the key type of this key pair""" """Returns the key type of this key pair"""
return self.__asym_keypair.key_type return self.__asym_keypair.key_type
@property @property
def fingerprint(self): def fingerprint(self) -> str:
"""Returns the fingerprint (SHA256 Hash) of the public key of this key pair""" """Returns the fingerprint (SHA256 Hash) of the public key of this key pair"""
return self.__fingerprint return self.__fingerprint
@property @property
def comment(self): def comment(self) -> str | None:
"""Returns the comment applied to the OpenSSH formatted public key of this key pair""" """Returns the comment applied to the OpenSSH formatted public key of this key pair"""
return self.__comment return self.__comment
@comment.setter @comment.setter
def comment(self, comment): def comment(self, comment: str) -> bytes:
"""Updates the comment applied to the OpenSSH formatted public key of this key pair """Updates the comment applied to the OpenSSH formatted public key of this key pair
:comment: Text to update the OpenSSH public key comment :comment: Text to update the OpenSSH public key comment
@@ -529,7 +592,7 @@ class OpensshKeypair:
) )
return self.__openssh_publickey return self.__openssh_publickey
def update_passphrase(self, passphrase): def update_passphrase(self, passphrase: bytes | None) -> None:
"""Updates the passphrase used to encrypt the private key of this keypair """Updates the passphrase used to encrypt the private key of this keypair
:passphrase: Text secret used for encryption :passphrase: Text secret used for encryption
@@ -541,18 +604,17 @@ class OpensshKeypair:
) )
def load_privatekey(path, passphrase, key_format): def load_privatekey(
path: str | os.PathLike,
passphrase: bytes | None,
key_format: KeySerializationFormat,
) -> PrivateKeyTypes:
privatekey_loaders = { privatekey_loaders = {
"PEM": serialization.load_pem_private_key, "PEM": serialization.load_pem_private_key,
"DER": serialization.load_der_private_key, "DER": serialization.load_der_private_key,
"SSH": serialization.load_ssh_private_key,
} }
# OpenSSH formatted private keys are not available in Cryptography <3.0
if hasattr(serialization, "load_ssh_private_key"):
privatekey_loaders["SSH"] = serialization.load_ssh_private_key
else:
privatekey_loaders["SSH"] = serialization.load_pem_private_key
try: try:
privatekey_loader = privatekey_loaders[key_format] privatekey_loader = privatekey_loaders[key_format]
except KeyError: except KeyError:
@@ -567,16 +629,16 @@ def load_privatekey(path, passphrase, key_format):
with open(path, "rb") as f: with open(path, "rb") as f:
content = f.read() content = f.read()
privatekey = privatekey_loader( privatekey = privatekey_loader( # type: ignore
data=content, data=content,
password=passphrase, password=passphrase,
) )
except ValueError as e: except ValueError as exc:
# Revert to PEM if key could not be loaded in SSH format # Revert to PEM if key could not be loaded in SSH format
if key_format == "SSH": if key_format == "SSH":
try: try:
privatekey = privatekey_loaders["PEM"]( privatekey = privatekey_loaders["PEM"]( # type: ignore
data=content, data=content,
password=passphrase, password=passphrase,
) )
@@ -587,7 +649,7 @@ def load_privatekey(path, passphrase, key_format):
except UnsupportedAlgorithm as e: except UnsupportedAlgorithm as e:
raise InvalidAlgorithmError(e) raise InvalidAlgorithmError(e)
else: else:
raise InvalidPrivateKeyFileError(e) raise InvalidPrivateKeyFileError(exc)
except TypeError as e: except TypeError as e:
raise InvalidPassphraseError(e) raise InvalidPassphraseError(e)
except UnsupportedAlgorithm as e: except UnsupportedAlgorithm as e:
@@ -596,7 +658,9 @@ def load_privatekey(path, passphrase, key_format):
return privatekey return privatekey
def load_publickey(path, key_format): def load_publickey(
path: str | os.PathLike, key_format: KeySerializationFormat
) -> AllPublicKeyTypes:
publickey_loaders = { publickey_loaders = {
"PEM": serialization.load_pem_public_key, "PEM": serialization.load_pem_public_key,
"DER": serialization.load_der_public_key, "DER": serialization.load_der_public_key,
@@ -628,20 +692,27 @@ def load_publickey(path, key_format):
return publickey return publickey
def compare_publickeys(pk1, pk2): def compare_publickeys(pk1: PublicKeyTypes, pk2: PublicKeyTypes) -> bool:
a = isinstance(pk1, Ed25519PublicKey) a = isinstance(pk1, Ed25519PublicKey)
b = isinstance(pk2, Ed25519PublicKey) b = isinstance(pk2, Ed25519PublicKey)
if a or b: if a or b:
if not a or not b: if not a or not b:
return False return False
a = pk1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) a_bytes = pk1.public_bytes(
b = pk2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) serialization.Encoding.Raw, serialization.PublicFormat.Raw
return a == b )
b_bytes = pk2.public_bytes(
serialization.Encoding.Raw, serialization.PublicFormat.Raw
)
return a_bytes == b_bytes
else: else:
return pk1.public_numbers() == pk2.public_numbers() return pk1.public_numbers() == pk2.public_numbers() # type: ignore
def compare_encryption_algorithms(ea1, ea2): def compare_encryption_algorithms(
ea1: serialization.KeySerializationEncryption,
ea2: serialization.KeySerializationEncryption,
) -> bool:
if isinstance(ea1, serialization.NoEncryption) and isinstance( if isinstance(ea1, serialization.NoEncryption) and isinstance(
ea2, serialization.NoEncryption ea2, serialization.NoEncryption
): ):
@@ -654,19 +725,21 @@ def compare_encryption_algorithms(ea1, ea2):
return False return False
def get_encryption_algorithm(passphrase): def get_encryption_algorithm(
passphrase: bytes,
) -> serialization.KeySerializationEncryption:
try: try:
return serialization.BestAvailableEncryption(passphrase) return serialization.BestAvailableEncryption(passphrase)
except ValueError as e: except ValueError as e:
raise InvalidPassphraseError(e) raise InvalidPassphraseError(e)
def validate_comment(comment): def validate_comment(comment: str) -> None:
if not hasattr(comment, "encode"): if not hasattr(comment, "encode"):
raise InvalidCommentError(f"{comment} cannot be encoded to text") raise InvalidCommentError(f"{comment} cannot be encoded to text")
def extract_comment(path): def extract_comment(path: str | os.PathLike) -> str:
if not os.path.exists(path): if not os.path.exists(path):
raise InvalidPublicKeyFileError(f"No file was found at {path}") raise InvalidPublicKeyFileError(f"No file was found at {path}")
@@ -684,7 +757,7 @@ def extract_comment(path):
return comment return comment
def calculate_fingerprint(openssh_publickey): def calculate_fingerprint(openssh_publickey: bytes) -> str:
digest = hashes.Hash(hashes.SHA256()) digest = hashes.Hash(hashes.SHA256())
decoded_pubkey = b64decode(openssh_publickey.split(b" ")[1]) decoded_pubkey = b64decode(openssh_publickey.split(b" ")[1])
digest.update(decoded_pubkey) digest.update(decoded_pubkey)

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import os import os
import re import re
import typing as t
from contextlib import contextmanager from contextlib import contextmanager
from struct import Struct from struct import Struct
@@ -38,17 +39,20 @@ _UINT64 = Struct(b"!Q")
_UINT64_MAX = 0xFFFFFFFFFFFFFFFF _UINT64_MAX = 0xFFFFFFFFFFFFFFFF
def any_in(sequence, *elements): _T = t.TypeVar("_T")
def any_in(sequence: t.Iterable[_T], *elements: _T) -> bool:
return any(e in sequence for e in elements) return any(e in sequence for e in elements)
def file_mode(path): def file_mode(path: str | os.PathLike) -> int:
if not os.path.exists(path): if not os.path.exists(path):
return 0o000 return 0o000
return os.stat(path).st_mode & 0o777 return os.stat(path).st_mode & 0o777
def parse_openssh_version(version_string): def parse_openssh_version(version_string: str) -> str | None:
"""Parse the version output of ssh -V and return version numbers that can be compared""" """Parse the version output of ssh -V and return version numbers that can be compared"""
parsed_result = re.match( parsed_result = re.match(
@@ -63,7 +67,7 @@ def parse_openssh_version(version_string):
@contextmanager @contextmanager
def secure_open(path, mode): def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode) fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode)
try: try:
yield fd yield fd
@@ -71,7 +75,7 @@ def secure_open(path, mode):
os.close(fd) os.close(fd)
def secure_write(path, mode, content): def secure_write(path: str | os.PathLike, mode: int, content: bytes) -> None:
with secure_open(path, mode) as fd: with secure_open(path, mode) as fd:
os.write(fd, content) os.write(fd, content)
@@ -84,35 +88,35 @@ class OpensshParser:
UINT32_OFFSET = 4 UINT32_OFFSET = 4
UINT64_OFFSET = 8 UINT64_OFFSET = 8
def __init__(self, data): def __init__(self, data: bytes | bytearray) -> None:
if not isinstance(data, (bytes, bytearray)): if not isinstance(data, (bytes, bytearray)):
raise TypeError(f"Data must be bytes-like not {type(data)}") raise TypeError(f"Data must be bytes-like not {type(data)}")
self._data = memoryview(data) self._data = memoryview(data)
self._pos = 0 self._pos = 0
def boolean(self): def boolean(self) -> bool:
next_pos = self._check_position(self.BOOLEAN_OFFSET) next_pos = self._check_position(self.BOOLEAN_OFFSET)
value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0] value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos self._pos = next_pos
return value return value
def uint32(self): def uint32(self) -> int:
next_pos = self._check_position(self.UINT32_OFFSET) next_pos = self._check_position(self.UINT32_OFFSET)
value = _UINT32.unpack(self._data[self._pos : next_pos])[0] value = _UINT32.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos self._pos = next_pos
return value return value
def uint64(self): def uint64(self) -> int:
next_pos = self._check_position(self.UINT64_OFFSET) next_pos = self._check_position(self.UINT64_OFFSET)
value = _UINT64.unpack(self._data[self._pos : next_pos])[0] value = _UINT64.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos self._pos = next_pos
return value return value
def string(self): def string(self) -> bytes:
length = self.uint32() length = self.uint32()
next_pos = self._check_position(length) next_pos = self._check_position(length)
@@ -122,15 +126,15 @@ class OpensshParser:
# Cast to bytes is required as a memoryview slice is itself a memoryview # Cast to bytes is required as a memoryview slice is itself a memoryview
return bytes(value) return bytes(value)
def mpint(self): def mpint(self) -> int:
return self._big_int(self.string(), "big", signed=True) return self._big_int(self.string(), "big", signed=True)
def name_list(self): def name_list(self) -> list[str]:
raw_string = self.string() raw_string = self.string()
return raw_string.decode("ASCII").split(",") return raw_string.decode("ASCII").split(",")
# Convenience function, but not an official data type from SSH # Convenience function, but not an official data type from SSH
def string_list(self): def string_list(self) -> list[bytes]:
result = [] result = []
raw_string = self.string() raw_string = self.string()
@@ -142,7 +146,7 @@ class OpensshParser:
return result return result
# Convenience function, but not an official data type from SSH # Convenience function, but not an official data type from SSH
def option_list(self): def option_list(self) -> list[tuple[bytes, bytes]]:
result = [] result = []
raw_string = self.string() raw_string = self.string()
@@ -159,15 +163,15 @@ class OpensshParser:
return result return result
def seek(self, offset): def seek(self, offset: int) -> int:
self._pos = self._check_position(offset) self._pos = self._check_position(offset)
return self._pos return self._pos
def remaining_bytes(self): def remaining_bytes(self) -> int:
return len(self._data) - self._pos return len(self._data) - self._pos
def _check_position(self, offset): def _check_position(self, offset: int) -> int:
if self._pos + offset > len(self._data): if self._pos + offset > len(self._data):
raise ValueError(f"Insufficient data remaining at position: {self._pos}") raise ValueError(f"Insufficient data remaining at position: {self._pos}")
elif self._pos + offset < 0: elif self._pos + offset < 0:
@@ -176,8 +180,8 @@ class OpensshParser:
return self._pos + offset return self._pos + offset
@classmethod @classmethod
def signature_data(cls, signature_string): def signature_data(cls, signature_string: bytes) -> dict[str, bytes | int]:
signature_data = {} signature_data: dict[str, bytes | int] = {}
parser = cls(signature_string) parser = cls(signature_string)
signature_type = parser.string() signature_type = parser.string()
@@ -205,14 +209,19 @@ class OpensshParser:
signature_data["R"] = cls._big_int(signature_blob[:32], "little") signature_data["R"] = cls._big_int(signature_blob[:32], "little")
signature_data["S"] = cls._big_int(signature_blob[32:], "little") signature_data["S"] = cls._big_int(signature_blob[32:], "little")
else: else:
raise ValueError(f"{signature_type} is not a valid signature type") raise ValueError(f"{signature_type!r} is not a valid signature type")
signature_data["signature_type"] = signature_type signature_data["signature_type"] = signature_type
return signature_data return signature_data
@classmethod @classmethod
def _big_int(cls, raw_string, byte_order, signed=False): def _big_int(
cls,
raw_string: bytes,
byte_order: t.Literal["big", "little"],
signed: bool = False,
) -> int:
if byte_order not in ("big", "little"): if byte_order not in ("big", "little"):
raise ValueError( raise ValueError(
f"Byte_order must be one of (big, little) not {byte_order}" f"Byte_order must be one of (big, little) not {byte_order}"
@@ -230,18 +239,16 @@ class _OpensshWriter:
in validating parsed material. in validating parsed material.
""" """
def __init__(self, buffer=None): def __init__(self, buffer: bytearray | None = None):
if buffer is not None: if buffer is not None:
if not isinstance(buffer, (bytes, bytearray)): if not isinstance(buffer, bytearray):
raise TypeError( raise TypeError(f"Buffer must be a bytearray, not {type(buffer)}")
f"Buffer must be a bytes-like object not {type(buffer)}"
)
else: else:
buffer = bytearray() buffer = bytearray()
self._buff = buffer self._buff: bytearray = buffer
def boolean(self, value): def boolean(self, value: bool) -> t.Self:
if not isinstance(value, bool): if not isinstance(value, bool):
raise TypeError(f"Value must be of type bool not {type(value)}") raise TypeError(f"Value must be of type bool not {type(value)}")
@@ -249,7 +256,7 @@ class _OpensshWriter:
return self return self
def uint32(self, value): def uint32(self, value: int) -> t.Self:
if not isinstance(value, int): if not isinstance(value, int):
raise TypeError(f"Value must be of type int not {type(value)}") raise TypeError(f"Value must be of type int not {type(value)}")
if value < 0 or value > _UINT32_MAX: if value < 0 or value > _UINT32_MAX:
@@ -261,7 +268,7 @@ class _OpensshWriter:
return self return self
def uint64(self, value): def uint64(self, value: int) -> t.Self:
if not isinstance(value, int): if not isinstance(value, int):
raise TypeError(f"Value must be of type int not {type(value)}") raise TypeError(f"Value must be of type int not {type(value)}")
if value < 0 or value > _UINT64_MAX: if value < 0 or value > _UINT64_MAX:
@@ -273,7 +280,7 @@ class _OpensshWriter:
return self return self
def string(self, value): def string(self, value: bytes | bytearray) -> t.Self:
if not isinstance(value, (bytes, bytearray)): if not isinstance(value, (bytes, bytearray)):
raise TypeError(f"Value must be bytes-like not {type(value)}") raise TypeError(f"Value must be bytes-like not {type(value)}")
self.uint32(len(value)) self.uint32(len(value))
@@ -281,7 +288,7 @@ class _OpensshWriter:
return self return self
def mpint(self, value): def mpint(self, value: int) -> t.Self:
if not isinstance(value, int): if not isinstance(value, int):
raise TypeError(f"Value must be of type int not {type(value)}") raise TypeError(f"Value must be of type int not {type(value)}")
@@ -289,7 +296,7 @@ class _OpensshWriter:
return self return self
def name_list(self, value): def name_list(self, value: list[str]) -> t.Self:
if not isinstance(value, list): if not isinstance(value, list):
raise TypeError(f"Value must be a list of byte strings not {type(value)}") raise TypeError(f"Value must be a list of byte strings not {type(value)}")
@@ -300,7 +307,7 @@ class _OpensshWriter:
return self return self
def string_list(self, value): def string_list(self, value: list[bytes]) -> t.Self:
if not isinstance(value, list): if not isinstance(value, list):
raise TypeError(f"Value must be a list of byte string not {type(value)}") raise TypeError(f"Value must be a list of byte string not {type(value)}")
@@ -312,7 +319,7 @@ class _OpensshWriter:
return self return self
def option_list(self, value): def option_list(self, value: list[tuple[bytes, bytes]]) -> t.Self:
if not isinstance(value, list) or (value and not isinstance(value[0], tuple)): if not isinstance(value, list) or (value and not isinstance(value[0], tuple)):
raise TypeError("Value must be a list of tuples") raise TypeError("Value must be a list of tuples")
@@ -327,7 +334,7 @@ class _OpensshWriter:
return self return self
@staticmethod @staticmethod
def _int_to_mpint(num): def _int_to_mpint(num: int) -> bytes:
byte_length = (num.bit_length() + 7) // 8 byte_length = (num.bit_length() + 7) // 8
try: try:
return num.to_bytes(byte_length, "big", signed=True) return num.to_bytes(byte_length, "big", signed=True)
@@ -335,5 +342,5 @@ class _OpensshWriter:
except OverflowError: except OverflowError:
return num.to_bytes(byte_length + 1, "big", signed=True) return num.to_bytes(byte_length + 1, "big", signed=True)
def bytes(self): def bytes(self) -> bytes:
return bytes(self._buff) return bytes(self._buff)

View File

@@ -10,7 +10,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.math impor
) )
def th(number): def th(number: int) -> str:
abs_number = abs(number) abs_number = abs(number)
mod_10 = abs_number % 10 mod_10 = abs_number % 10
mod_100 = abs_number % 100 mod_100 = abs_number % 100
@@ -24,13 +24,13 @@ def th(number):
return "th" return "th"
def parse_serial(value): def parse_serial(value: str | bytes) -> int:
""" """
Given a colon-separated string of hexadecimal byte values, converts it to an integer. Given a colon-separated string of hexadecimal byte values, converts it to an integer.
""" """
value = to_native(value) value_str = to_native(value)
result = 0 result = 0
for i, part in enumerate(value.split(":")): for i, part in enumerate(value_str.split(":")):
try: try:
part_value = int(part, 16) part_value = int(part, 16)
if part_value < 0 or part_value > 255: if part_value < 0 or part_value > 255:
@@ -43,11 +43,11 @@ def parse_serial(value):
return result return result
def to_serial(value): def to_serial(value: int) -> str:
""" """
Given an integer, converts its absolute value to a colon-separated string of hexadecimal byte values. Given an integer, converts its absolute value to a colon-separated string of hexadecimal byte values.
""" """
value = convert_int_to_hex(value).upper() value_str = convert_int_to_hex(value).upper()
if len(value) % 2 != 0: if len(value_str) % 2 != 0:
value = "0" + value value_str = f"0{value_str}"
return ":".join(value[i : i + 2] for i in range(0, len(value), 2)) return ":".join(value_str[i : i + 2] for i in range(0, len(value_str), 2))

View File

@@ -13,37 +13,16 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.basic impo
) )
try: UTC = datetime.timezone.utc
UTC = datetime.timezone.utc
except AttributeError:
_DURATION_ZERO = datetime.timedelta(0)
class _UTCClass(datetime.tzinfo):
def utcoffset(self, dt):
return _DURATION_ZERO
def dst(self, dt):
return _DURATION_ZERO
def tzname(self, dt):
return "UTC"
def fromutc(self, dt):
return dt
def __repr__(self):
return "UTC"
UTC = _UTCClass()
def get_now_datetime(with_timezone): def get_now_datetime(with_timezone: bool) -> datetime.datetime:
if with_timezone: if with_timezone:
return datetime.datetime.now(tz=UTC) return datetime.datetime.now(tz=UTC)
return datetime.datetime.utcnow() return datetime.datetime.utcnow()
def ensure_utc_timezone(timestamp): def ensure_utc_timezone(timestamp: datetime.datetime) -> datetime.datetime:
if timestamp.tzinfo is UTC: if timestamp.tzinfo is UTC:
return timestamp return timestamp
if timestamp.tzinfo is None: if timestamp.tzinfo is None:
@@ -52,7 +31,7 @@ def ensure_utc_timezone(timestamp):
return timestamp.astimezone(UTC) return timestamp.astimezone(UTC)
def remove_timezone(timestamp): def remove_timezone(timestamp: datetime.datetime) -> datetime.datetime:
# Convert to native datetime object # Convert to native datetime object
if timestamp.tzinfo is None: if timestamp.tzinfo is None:
return timestamp return timestamp
@@ -61,26 +40,34 @@ def remove_timezone(timestamp):
return timestamp.replace(tzinfo=None) return timestamp.replace(tzinfo=None)
def add_or_remove_timezone(timestamp, with_timezone): def add_or_remove_timezone(
timestamp: datetime.datetime, with_timezone: bool
) -> datetime.datetime:
return ( return (
ensure_utc_timezone(timestamp) if with_timezone else remove_timezone(timestamp) ensure_utc_timezone(timestamp) if with_timezone else remove_timezone(timestamp)
) )
def get_epoch_seconds(timestamp): def get_epoch_seconds(timestamp: datetime.datetime) -> float:
if timestamp.tzinfo is None: if timestamp.tzinfo is None:
# timestamp.timestamp() is offset by the local timezone if timestamp has no timezone # timestamp.timestamp() is offset by the local timezone if timestamp has no timezone
timestamp = ensure_utc_timezone(timestamp) timestamp = ensure_utc_timezone(timestamp)
return timestamp.timestamp() return timestamp.timestamp()
def from_epoch_seconds(timestamp, with_timezone): def from_epoch_seconds(
timestamp: int | float, with_timezone: bool
) -> datetime.datetime:
if with_timezone: if with_timezone:
return datetime.datetime.fromtimestamp(timestamp, UTC) return datetime.datetime.fromtimestamp(timestamp, UTC)
return datetime.datetime.utcfromtimestamp(timestamp) return datetime.datetime.utcfromtimestamp(timestamp)
def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=None): def convert_relative_to_datetime(
relative_time_string: str,
with_timezone: bool = False,
now: datetime.datetime | None = None,
) -> datetime.datetime | None:
"""Get a datetime.datetime or None from a string in the time format described in sshd_config(5)""" """Get a datetime.datetime or None from a string in the time format described in sshd_config(5)"""
parsed_result = re.match( parsed_result = re.match(
@@ -115,7 +102,12 @@ def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=
return now - offset return now - offset
def get_relative_time_option(input_string, input_name, with_timezone=False, now=None): def get_relative_time_option(
input_string: str,
input_name: str,
with_timezone: bool = False,
now: datetime.datetime | None = None,
) -> datetime.datetime:
""" """
Return an absolute timespec if a relative timespec or an ASN1 formatted Return an absolute timespec if a relative timespec or an ASN1 formatted
string is provided. string is provided.
@@ -129,9 +121,12 @@ def get_relative_time_option(input_string, input_name, with_timezone=False, now=
) )
# Relative time # Relative time
if result.startswith("+") or result.startswith("-"): if result.startswith("+") or result.startswith("-"):
return convert_relative_to_datetime( res = convert_relative_to_datetime(result, with_timezone=with_timezone, now=now)
result, with_timezone=with_timezone, now=now if res is None:
) raise OpenSSLObjectError(
f'The timespec "{input_string}" for {input_name} is invalid'
)
return res
# Absolute time # Absolute time
for date_fmt, length in [ for date_fmt, length in [
( (

View File

@@ -165,6 +165,7 @@ account_uri:
""" """
import base64 import base64
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount, ACMEAccount,
@@ -180,7 +181,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
) )
def main(): def main() -> t.NoReturn:
argument_spec = create_default_argspec() argument_spec = create_default_argspec()
argument_spec.update_argspec( argument_spec.update_argspec(
terms_agreed=dict(type="bool", default=False), terms_agreed=dict(type="bool", default=False),
@@ -204,24 +205,24 @@ def main():
), ),
) )
argument_spec.update( argument_spec.update(
mutually_exclusive=(["new_account_key_src", "new_account_key_content"],), mutually_exclusive=[("new_account_key_src", "new_account_key_content")],
required_if=( required_if=[
# Make sure that for state == changed_key, one of # Make sure that for state == changed_key, one of
# new_account_key_src and new_account_key_content are specified # new_account_key_src and new_account_key_content are specified
[ (
"state", "state",
"changed_key", "changed_key",
["new_account_key_src", "new_account_key_content"], ["new_account_key_src", "new_account_key_content"],
True, True,
], ),
), ],
) )
module = argument_spec.create_ansible_module(supports_check_mode=True) module = argument_spec.create_ansible_module(supports_check_mode=True)
backend = create_backend(module, True) backend = create_backend(module, True)
if module.params["external_account_binding"]: if module.params["external_account_binding"]:
# Make sure padding is there # Make sure padding is there
key = module.params["external_account_binding"]["key"] key: str = module.params["external_account_binding"]["key"]
if len(key) % 4 != 0: if len(key) % 4 != 0:
key = key + ("=" * (4 - (len(key) % 4))) key = key + ("=" * (4 - (len(key) % 4)))
# Make sure key is Base64 encoded # Make sure key is Base64 encoded
@@ -237,24 +238,25 @@ def main():
client = ACMEClient(module, backend) client = ACMEClient(module, backend)
account = ACMEAccount(client) account = ACMEAccount(client)
changed = False changed = False
state = module.params.get("state") state: t.Literal["present", "absent", "changed_key"] = module.params["state"]
diff_before = {} diff_before: dict[str, t.Any] = {}
diff_after = {} diff_after: dict[str, t.Any] = {}
if state == "absent": if state == "absent":
created, account_data = account.setup_account(allow_creation=False) created, account_data = account.setup_account(allow_creation=False)
if account_data: if account_data:
diff_before = dict(account_data) diff_before = dict(account_data)
diff_before["public_account_key"] = client.account_key_data["jwk"] if client.account_key_data:
diff_before["public_account_key"] = client.account_key_data["jwk"]
if created: if created:
raise AssertionError("Unwanted account creation") raise AssertionError("Unwanted account creation")
if account_data is not None: if account_data is not None:
# Account is not yet deactivated # Account is not yet deactivated
if not module.check_mode: if not module.check_mode:
# Deactivate it # Deactivate it
payload = {"status": "deactivated"} deactivate_payload = {"status": "deactivated"}
result, info = client.send_signed_request( result, info = client.send_signed_request(
client.account_uri, t.cast(str, client.account_uri),
payload, deactivate_payload,
error_msg="Failed to deactivate account", error_msg="Failed to deactivate account",
expected_status_codes=[200], expected_status_codes=[200],
) )
@@ -278,13 +280,15 @@ def main():
diff_before = {} diff_before = {}
else: else:
diff_before = dict(account_data) diff_before = dict(account_data)
diff_before["public_account_key"] = client.account_key_data["jwk"] if client.account_key_data:
diff_before["public_account_key"] = client.account_key_data["jwk"]
updated = False updated = False
if not created: if not created:
updated, account_data = account.update_account(account_data, contact) updated, account_data = account.update_account(account_data, contact)
changed = created or updated changed = created or updated
diff_after = dict(account_data) diff_after = dict(account_data)
diff_after["public_account_key"] = client.account_key_data["jwk"] if client.account_key_data:
diff_after["public_account_key"] = client.account_key_data["jwk"]
elif state == "changed_key": elif state == "changed_key":
# Parse new account key # Parse new account key
try: try:
@@ -306,7 +310,8 @@ def main():
msg="Account does not exist or is deactivated." msg="Account does not exist or is deactivated."
) )
diff_before = dict(account_data) diff_before = dict(account_data)
diff_before["public_account_key"] = client.account_key_data["jwk"] if client.account_key_data:
diff_before["public_account_key"] = client.account_key_data["jwk"]
# Now we can start the account key rollover # Now we can start the account key rollover
if not module.check_mode: if not module.check_mode:
# Compose inner signed message # Compose inner signed message
@@ -317,12 +322,12 @@ def main():
"jwk": new_key_data["jwk"], "jwk": new_key_data["jwk"],
"url": url, "url": url,
} }
payload = { change_key_payload = {
"account": client.account_uri, "account": client.account_uri,
"newKey": new_key_data["jwk"], # specified in draft 12 and older "newKey": new_key_data["jwk"], # specified in draft 12 and older
"oldKey": client.account_jwk, # specified in draft 13 and newer "oldKey": client.account_jwk, # specified in draft 13 and newer
} }
data = client.sign_request(protected, payload, new_key_data) data = client.sign_request(protected, change_key_payload, new_key_data)
# Send request and verify result # Send request and verify result
result, info = client.send_signed_request( result, info = client.send_signed_request(
url, url,
@@ -332,8 +337,9 @@ def main():
) )
if module._diff: if module._diff:
client.account_key_data = new_key_data client.account_key_data = new_key_data
client.account_jws_header["alg"] = new_key_data["alg"] if client.account_jws_header:
diff_after = account.get_account_data() client.account_jws_header["alg"] = new_key_data["alg"]
diff_after = account.get_account_data() or {}
elif module._diff: elif module._diff:
# Kind of fake diff_after # Kind of fake diff_after
diff_after = dict(diff_before) diff_after = dict(diff_before)

View File

@@ -204,6 +204,8 @@ order_uris:
version_added: 1.5.0 version_added: 1.5.0
""" """
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount, ACMEAccount,
) )
@@ -220,48 +222,55 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
) )
def get_orders_list(module, client, orders_url): if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def get_orders_list(
module: AnsibleModule, client: ACMEClient, orders_url: str
) -> list[str]:
""" """
Retrieves orders list (handles pagination). Retrieves order URL list (handles pagination).
""" """
orders = [] orders: list[str] = []
while orders_url: next_orders_url: str | None = orders_url
while next_orders_url:
# Get part of orders list # Get part of orders list
res, info = client.get_request( res, info = client.get_request(
orders_url, parse_json_result=True, fail_on_error=True next_orders_url, parse_json_result=True, fail_on_error=True
) )
if not res.get("orders"): if not res.get("orders"):
if orders: if orders:
module.warn( module.warn(
f"When retrieving orders list part {orders_url}, got empty result list" f"When retrieving orders list part {next_orders_url}, got empty result list"
) )
break break
# Add order URLs to result list # Add order URLs to result list
orders.extend(res["orders"]) orders.extend(res["orders"])
# Extract URL of next part of results list # Extract URL of next part of results list
new_orders_url = [] new_orders_url: list[str | None] = []
def f(link, relation): def f(link: str, relation: str) -> None:
if relation == "next": if relation == "next":
new_orders_url.append(link) new_orders_url.append(link)
process_links(info, f) process_links(info, f)
new_orders_url.append(None) new_orders_url.append(None)
previous_orders_url, orders_url = orders_url, new_orders_url.pop(0) previous_orders_url, next_orders_url = next_orders_url, new_orders_url.pop(0)
if orders_url == previous_orders_url: if next_orders_url == previous_orders_url:
# Prevent infinite loop # Prevent infinite loop
orders_url = None next_orders_url = None
return orders return orders
def get_order(client, order_url): def get_order(client: ACMEClient, order_url: str) -> dict[str, t.Any]:
""" """
Retrieve order data. Retrieve order data.
""" """
return client.get_request(order_url, parse_json_result=True, fail_on_error=True)[0] return client.get_request(order_url, parse_json_result=True, fail_on_error=True)[0]
def main(): def main() -> t.NoReturn:
argument_spec = create_default_argspec() argument_spec = create_default_argspec()
argument_spec.update_argspec( argument_spec.update_argspec(
retrieve_orders=dict( retrieve_orders=dict(
@@ -282,16 +291,19 @@ def main():
) )
if created: if created:
raise AssertionError("Unwanted account creation") raise AssertionError("Unwanted account creation")
result = { result: dict[str, t.Any] = {
"changed": False, "changed": False,
"exists": client.account_uri is not None, "exists": False,
"account_uri": client.account_uri, "account_uri": None,
} }
if client.account_uri is not None: if client.account_uri is not None and account_data:
result["account_uri"] = client.account_uri
result["exists"] = True
# Make sure promised data is there # Make sure promised data is there
if "contact" not in account_data: if "contact" not in account_data:
account_data["contact"] = [] account_data["contact"] = []
account_data["public_account_key"] = client.account_key_data["jwk"] if client.account_key_data:
account_data["public_account_key"] = client.account_key_data["jwk"]
result["account"] = account_data result["account"] = account_data
# Retrieve orders list # Retrieve orders list
if ( if (

View File

@@ -94,6 +94,8 @@ renewal_info:
sample: '2024-04-29T01:17:10.236921+00:00' sample: '2024-04-29T01:17:10.236921+00:00'
""" """
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
ACMEClient, ACMEClient,
create_backend, create_backend,
@@ -104,15 +106,15 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
) )
def main(): def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_account=False) argument_spec = create_default_argspec(with_account=False)
argument_spec.update_argspec( argument_spec.update_argspec(
certificate_path=dict(type="path"), certificate_path=dict(type="path"),
certificate_content=dict(type="str"), certificate_content=dict(type="str"),
) )
argument_spec.update( argument_spec.update(
required_one_of=(["certificate_path", "certificate_content"],), required_one_of=[("certificate_path", "certificate_content")],
mutually_exclusive=(["certificate_path", "certificate_content"],), mutually_exclusive=[("certificate_path", "certificate_content")],
) )
module = argument_spec.create_ansible_module(supports_check_mode=True) module = argument_spec.create_ansible_module(supports_check_mode=True)
backend = create_backend(module, True) backend = create_backend(module, True)

View File

@@ -562,6 +562,7 @@ all_chains:
""" """
import os import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount, ACMEAccount,
@@ -592,6 +593,17 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
CertificateInformation,
CryptoBackend,
)
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
Authorization,
)
NO_CHALLENGE = "no challenge" NO_CHALLENGE = "no challenge"
@@ -602,7 +614,7 @@ class ACMECertificateClient:
certificates. certificates.
""" """
def __init__(self, module, backend): def __init__(self, module: AnsibleModule, backend: CryptoBackend):
self.module = module self.module = module
self.version = module.params["acme_version"] self.version = module.params["acme_version"]
self.challenge = module.params["challenge"] self.challenge = module.params["challenge"]
@@ -618,9 +630,9 @@ class ACMECertificateClient:
self.account = ACMEAccount(self.client) self.account = ACMEAccount(self.client)
self.directory = self.client.directory self.directory = self.client.directory
self.data = module.params["data"] self.data = module.params["data"]
self.authorizations = None self.authorizations: dict[str, Authorization] | None = None
self.cert_days = -1 self.cert_days = -1
self.order = None self.order: Order | None = None
self.order_uri = self.data.get("order_uri") if self.data else None self.order_uri = self.data.get("order_uri") if self.data else None
self.all_chains = None self.all_chains = None
self.select_chain_matcher = [] self.select_chain_matcher = []
@@ -662,7 +674,6 @@ class ACMECertificateClient:
contact.append("mailto:" + module.params["account_email"]) contact.append("mailto:" + module.params["account_email"])
created, account_data = self.account.setup_account( created, account_data = self.account.setup_account(
contact, contact,
agreement=module.params.get("agreement"),
terms_agreed=module.params.get("terms_agreed"), terms_agreed=module.params.get("terms_agreed"),
allow_creation=modify_account, allow_creation=modify_account,
) )
@@ -681,7 +692,7 @@ class ACMECertificateClient:
csr_filename=self.csr, csr_content=self.csr_content csr_filename=self.csr, csr_content=self.csr_content
) )
def is_first_step(self): def is_first_step(self) -> bool:
""" """
Return True if this is the first execution of this module, i.e. if a Return True if this is the first execution of this module, i.e. if a
sufficient data object from a first run has not been provided. sufficient data object from a first run has not been provided.
@@ -692,7 +703,7 @@ class ACMECertificateClient:
# stored in self.order_uri by the constructor). # stored in self.order_uri by the constructor).
return self.order_uri is None return self.order_uri is None
def _get_cert_info_or_none(self): def _get_cert_info_or_none(self) -> CertificateInformation | None:
if self.module.params.get("dest"): if self.module.params.get("dest"):
filename = self.module.params["dest"] filename = self.module.params["dest"]
else: else:
@@ -701,7 +712,7 @@ class ACMECertificateClient:
return None return None
return self.client.backend.get_cert_information(cert_filename=filename) return self.client.backend.get_cert_information(cert_filename=filename)
def start_challenges(self): def start_challenges(self) -> None:
""" """
Create new authorizations for all identifiers of the CSR, Create new authorizations for all identifiers of the CSR,
respectively start a new order for ACME v2. respectively start a new order for ACME v2.
@@ -733,13 +744,16 @@ class ACMECertificateClient:
self.authorizations.update(self.order.authorizations) self.authorizations.update(self.order.authorizations)
self.changed = True self.changed = True
def get_challenges_data(self, first_step): def get_challenges_data(
self, first_step: bool
) -> tuple[dict[str, t.Any], dict[str, list[str]]]:
""" """
Get challenge details for the chosen challenge type. Get challenge details for the chosen challenge type.
Return a tuple of generic challenge details, and specialized DNS challenge details. Return a tuple of generic challenge details, and specialized DNS challenge details.
""" """
# Get general challenge data assert self.authorizations is not None
data = {} data: dict[str, t.Any] = {}
data_dns: dict[str, list[str]] = {}
for type_identifier, authz in self.authorizations.items(): for type_identifier, authz in self.authorizations.items():
identifier_type, identifier = split_identifier(type_identifier) identifier_type, identifier = split_identifier(type_identifier)
# Skip valid authentications: their challenges are already valid # Skip valid authentications: their challenges are already valid
@@ -747,7 +761,9 @@ class ACMECertificateClient:
if authz.status == "valid": if authz.status == "valid":
continue continue
# We drop the type from the key to preserve backwards compatibility # We drop the type from the key to preserve backwards compatibility
data[authz.identifier] = authz.get_challenge_data(self.client) challenges = authz.get_challenge_data(self.client)
assert authz.identifier is not None
data[authz.identifier] = challenges
if ( if (
first_step first_step
and self.challenge is not None and self.challenge is not None
@@ -756,10 +772,7 @@ class ACMECertificateClient:
raise ModuleFailException( raise ModuleFailException(
f"Found no challenge of type '{self.challenge}' for identifier {type_identifier}!" f"Found no challenge of type '{self.challenge}' for identifier {type_identifier}!"
) )
# Get DNS challenge data if self.challenge == "dns-01":
data_dns = {}
if self.challenge == "dns-01":
for identifier, challenges in data.items():
if self.challenge in challenges: if self.challenge in challenges:
values = data_dns.get(challenges[self.challenge]["record"]) values = data_dns.get(challenges[self.challenge]["record"])
if values is None: if values is None:
@@ -768,7 +781,7 @@ class ACMECertificateClient:
values.append(challenges[self.challenge]["resource_value"]) values.append(challenges[self.challenge]["resource_value"])
return data, data_dns return data, data_dns
def finish_challenges(self): def finish_challenges(self) -> None:
""" """
Verify challenges for all identifiers of the CSR. Verify challenges for all identifiers of the CSR.
""" """
@@ -777,6 +790,7 @@ class ACMECertificateClient:
# Step 1: obtain challenge information # Step 1: obtain challenge information
# For ACME v2, we obtain the order object by fetching the # For ACME v2, we obtain the order object by fetching the
# order URI, and extract the information from there. # order URI, and extract the information from there.
assert self.order_uri is not None
self.order = Order.from_url(self.client, self.order_uri) self.order = Order.from_url(self.client, self.order_uri)
self.order.load_authorizations(self.client) self.order.load_authorizations(self.client)
self.authorizations.update(self.order.authorizations) self.authorizations.update(self.order.authorizations)
@@ -799,7 +813,9 @@ class ACMECertificateClient:
# Step 3: wait for authzs to validate # Step 3: wait for authzs to validate
wait_for_validation(authzs_to_wait_for, self.client) wait_for_validation(authzs_to_wait_for, self.client)
def download_alternate_chains(self, cert): def download_alternate_chains(
self, cert: CertificateChain
) -> list[CertificateChain]:
alternate_chains = [] alternate_chains = []
for alternate in cert.alternates: for alternate in cert.alternates:
try: try:
@@ -812,7 +828,9 @@ class ACMECertificateClient:
alternate_chains.append(alt_cert) alternate_chains.append(alt_cert)
return alternate_chains return alternate_chains
def find_matching_chain(self, chains): def find_matching_chain(
self, chains: t.Iterable[CertificateChain]
) -> CertificateChain | None:
for criterium_idx, matcher in enumerate(self.select_chain_matcher): for criterium_idx, matcher in enumerate(self.select_chain_matcher):
for chain in chains: for chain in chains:
if matcher.match(chain): if matcher.match(chain):
@@ -822,12 +840,13 @@ class ACMECertificateClient:
return chain return chain
return None return None
def get_certificate(self): def get_certificate(self) -> None:
""" """
Request a new certificate and write it to the destination file. Request a new certificate and write it to the destination file.
First verifies whether all authorizations are valid; if not, aborts First verifies whether all authorizations are valid; if not, aborts
with an error. with an error.
""" """
assert self.authorizations is not None
for identifier_type, identifier in self.identifiers: for identifier_type, identifier in self.identifiers:
authz = self.authorizations.get( authz = self.authorizations.get(
normalize_combined_identifier( normalize_combined_identifier(
@@ -844,7 +863,9 @@ class ACMECertificateClient:
module=self.module, module=self.module,
) )
assert self.order is not None
self.order.finalize(self.client, pem_to_der(self.csr, self.csr_content)) self.order.finalize(self.client, pem_to_der(self.csr, self.csr_content))
assert self.order.certificate_uri is not None
cert = CertificateChain.download(self.client, self.order.certificate_uri) cert = CertificateChain.download(self.client, self.order.certificate_uri)
if self.module.params["retrieve_all_alternates"] or self.select_chain_matcher: if self.module.params["retrieve_all_alternates"] or self.select_chain_matcher:
# Retrieve alternate chains # Retrieve alternate chains
@@ -887,12 +908,13 @@ class ACMECertificateClient:
): ):
self.changed = True self.changed = True
def deactivate_authzs(self): def deactivate_authzs(self) -> None:
""" """
Deactivates all valid authz's. Does not raise exceptions. Deactivates all valid authz's. Does not raise exceptions.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2 https://community.letsencrypt.org/t/authorization-deactivation/19860/2
https://tools.ietf.org/html/rfc8555#section-7.5.2 https://tools.ietf.org/html/rfc8555#section-7.5.2
""" """
assert self.authorizations is not None
for authz in self.authorizations.values(): for authz in self.authorizations.values():
try: try:
authz.deactivate(self.client) authz.deactivate(self.client)
@@ -905,7 +927,7 @@ class ACMECertificateClient:
) )
def main(): def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=True) argument_spec = create_default_argspec(with_certificate=True)
argument_spec.argument_spec["csr"]["aliases"] = ["src"] argument_spec.argument_spec["csr"]["aliases"] = ["src"]
argument_spec.update_argspec( argument_spec.update_argspec(
@@ -981,7 +1003,7 @@ def main():
else: else:
client = ACMECertificateClient(module, backend) client = ACMECertificateClient(module, backend)
client.cert_days = cert_days client.cert_days = cert_days
other = dict() other: dict[str, t.Any] = {}
is_first_step = client.is_first_step() is_first_step = client.is_first_step()
if is_first_step: if is_first_step:
# First run: start challenges / start new order # First run: start challenges / start new order
@@ -998,6 +1020,7 @@ def main():
client.deactivate_authzs() client.deactivate_authzs()
data, data_dns = client.get_challenges_data(first_step=is_first_step) data, data_dns = client.get_challenges_data(first_step=is_first_step)
auths = dict() auths = dict()
assert client.authorizations is not None
for k, v in client.authorizations.items(): for k, v in client.authorizations.items():
# Remove "type:" from key # Remove "type:" from key
auths[v.identifier] = v.to_json() auths[v.identifier] = v.to_json()

View File

@@ -51,6 +51,8 @@ EXAMPLES = r"""
RETURN = """#""" RETURN = """#"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount, ACMEAccount,
) )
@@ -65,7 +67,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
from ansible_collections.community.crypto.plugins.module_utils.acme.orders import Order from ansible_collections.community.crypto.plugins.module_utils.acme.orders import Order
def main(): def main() -> t.NoReturn:
argument_spec = create_default_argspec() argument_spec = create_default_argspec()
argument_spec.update_argspec( argument_spec.update_argspec(
order_uri=dict(type="str", required=True), order_uri=dict(type="str", required=True),

View File

@@ -371,6 +371,8 @@ account_uri:
type: str type: str
""" """
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend, create_backend,
create_default_argspec, create_default_argspec,
@@ -383,7 +385,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
) )
def main(): def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=True) argument_spec = create_default_argspec(with_certificate=True)
argument_spec.update_argspec( argument_spec.update_argspec(
deactivate_authzs=dict(type="bool", default=True), deactivate_authzs=dict(type="bool", default=True),

View File

@@ -317,6 +317,8 @@ selected_chain:
returned: always returned: always
""" """
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend, create_backend,
create_default_argspec, create_default_argspec,
@@ -329,7 +331,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
) )
def main(): if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.acme.certificates import (
CertificateChain,
)
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=True) argument_spec = create_default_argspec(with_certificate=True)
argument_spec.update_argspec( argument_spec.update_argspec(
order_uri=dict(type="str", required=True), order_uri=dict(type="str", required=True),
@@ -375,6 +383,7 @@ def main():
or module.params["retrieve_all_alternates"] or module.params["retrieve_all_alternates"]
) )
changed = False changed = False
alternate_chains: list[CertificateChain] | None
if order.status == "valid": if order.status == "valid":
# Step 2 and 3: download certificate(s) and chain(s) # Step 2 and 3: download certificate(s) and chain(s)
cert, alternate_chains = client.download_certificate( cert, alternate_chains = client.download_certificate(

View File

@@ -357,6 +357,8 @@ authorizations_by_status:
returned: always returned: always
""" """
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend, create_backend,
create_default_argspec, create_default_argspec,
@@ -369,7 +371,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
) )
def main(): def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=False) argument_spec = create_default_argspec(with_certificate=False)
argument_spec.update_argspec( argument_spec.update_argspec(
order_uri=dict(type="str", required=True), order_uri=dict(type="str", required=True),
@@ -381,8 +383,8 @@ def main():
try: try:
client = ACMECertificateClient(module, backend) client = ACMECertificateClient(module, backend)
order = client.load_order() order = client.load_order()
authorizations_by_identifier = dict() authorizations_by_identifier: dict[str, dict[str, t.Any]] = {}
authorizations_by_status = { authorizations_by_status: dict[str, list[str]] = {
"pending": [], "pending": [],
"invalid": [], "invalid": [],
"valid": [], "valid": [],
@@ -392,7 +394,8 @@ def main():
} }
for identifier, authz in order.authorizations.items(): for identifier, authz in order.authorizations.items():
authorizations_by_identifier[identifier] = authz.to_json() authorizations_by_identifier[identifier] = authz.to_json()
authorizations_by_status[authz.status].append(identifier) if authz.status is not None:
authorizations_by_status[authz.status].append(identifier)
module.exit_json( module.exit_json(
changed=False, changed=False,
account_uri=client.client.account_uri, account_uri=client.client.account_uri,

View File

@@ -229,6 +229,8 @@ validating_challenges:
returned: always returned: always
""" """
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
create_backend, create_backend,
create_default_argspec, create_default_argspec,
@@ -241,7 +243,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
) )
def main(): if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
Authorization,
)
def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_certificate=False) argument_spec = create_default_argspec(with_certificate=False)
argument_spec.update_argspec( argument_spec.update_argspec(
order_uri=dict(type="str", required=True), order_uri=dict(type="str", required=True),
@@ -271,10 +279,12 @@ def main():
missing_challenge_authzs = [k for k, v in challenges.items() if v is None] missing_challenge_authzs = [k for k, v in challenges.items() if v is None]
if missing_challenge_authzs: if missing_challenge_authzs:
missing_challenge_authzs = ", ".join(sorted(missing_challenge_authzs)) missing_challenge_authzs_str = ", ".join(
sorted(missing_challenge_authzs)
)
raise ModuleFailException( raise ModuleFailException(
"The challenge parameter must be supplied if there are pending authorizations." "The challenge parameter must be supplied if there are pending authorizations."
f" The following authorizations are pending: {missing_challenge_authzs}" f" The following authorizations are pending: {missing_challenge_authzs_str}"
) )
bad_challenge_authzs = [ bad_challenge_authzs = [
@@ -293,11 +303,13 @@ def main():
f"The following authorizations do not support the selected challenges: {authz_challenges_pairs}" f"The following authorizations do not support the selected challenges: {authz_challenges_pairs}"
) )
def is_pending(authz: Authorization) -> bool:
challenge_name = challenges[authz.combined_identifier]
challenge_obj = authz.find_challenge(challenge_name)
return challenge_obj is not None and challenge_obj.status == "pending"
really_pending_authzs = [ really_pending_authzs = [
authz authz for authz in pending_authzs if is_pending(authz)
for authz in pending_authzs
if authz.find_challenge(challenges[authz.combined_identifier]).status
== "pending"
] ]
# Step 4: validate pending authorizations # Step 4: validate pending authorizations
@@ -320,7 +332,7 @@ def main():
identifier_type=authz.identifier_type, identifier_type=authz.identifier_type,
authz_url=authz.url, authz_url=authz.url,
challenge_type=challenge_type, challenge_type=challenge_type,
challenge_url=challenge.url, challenge_url=challenge.url if challenge else None,
) )
for authz, challenge_type, challenge in authzs_with_challenges_to_wait_for for authz, challenge_type, challenge in authzs_with_challenges_to_wait_for
], ],

View File

@@ -160,6 +160,7 @@ cert_id:
import os import os
import random import random
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
ACMEClient, ACMEClient,
@@ -175,7 +176,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
) )
def main(): def main() -> t.NoReturn:
argument_spec = create_default_argspec(with_account=False) argument_spec = create_default_argspec(with_account=False)
argument_spec.update_argspec( argument_spec.update_argspec(
certificate_path=dict(type="path"), certificate_path=dict(type="path"),
@@ -190,7 +191,7 @@ def main():
treat_parsing_error_as_non_existing=dict(type="bool", default=False), treat_parsing_error_as_non_existing=dict(type="bool", default=False),
) )
argument_spec.update( argument_spec.update(
mutually_exclusive=(["certificate_path", "certificate_content"],), mutually_exclusive=[("certificate_path", "certificate_content")],
) )
module = argument_spec.create_ansible_module(supports_check_mode=True) module = argument_spec.create_ansible_module(supports_check_mode=True)
backend = create_backend(module, True) backend = create_backend(module, True)
@@ -203,7 +204,7 @@ def main():
supports_ari=False, supports_ari=False,
) )
def complete(should_renew, **kwargs): def complete(should_renew: bool, **kwargs) -> t.NoReturn:
result["should_renew"] = should_renew result["should_renew"] = should_renew
result.update(kwargs) result.update(kwargs)
module.exit_json(**result) module.exit_json(**result)

View File

@@ -110,6 +110,8 @@ EXAMPLES = r"""
RETURN = """#""" RETURN = """#"""
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.acme.account import ( from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
ACMEAccount, ACMEAccount,
) )
@@ -129,7 +131,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
) )
def main(): def main() -> t.NoReturn:
argument_spec = create_default_argspec(require_account_key=False) argument_spec = create_default_argspec(require_account_key=False)
argument_spec.update_argspec( argument_spec.update_argspec(
private_key_src=dict(type="path"), private_key_src=dict(type="path"),
@@ -139,22 +141,22 @@ def main():
revoke_reason=dict(type="int"), revoke_reason=dict(type="int"),
) )
argument_spec.update( argument_spec.update(
required_one_of=( required_one_of=[
[ (
"account_key_src", "account_key_src",
"account_key_content", "account_key_content",
"private_key_src", "private_key_src",
"private_key_content", "private_key_content",
], ),
), ],
mutually_exclusive=( mutually_exclusive=[
[ (
"account_key_src", "account_key_src",
"account_key_content", "account_key_content",
"private_key_src", "private_key_src",
"private_key_content", "private_key_content",
], ),
), ],
) )
module = argument_spec.create_ansible_module() module = argument_spec.create_ansible_module()
backend = create_backend(module, False) backend = create_backend(module, False)
@@ -164,9 +166,9 @@ def main():
account = ACMEAccount(client) account = ACMEAccount(client)
# Load certificate # Load certificate
certificate = pem_to_der(module.params.get("certificate")) certificate = pem_to_der(module.params.get("certificate"))
certificate = nopad_b64(certificate) certificate_b64 = nopad_b64(certificate)
# Construct payload # Construct payload
payload = {"certificate": certificate} payload = {"certificate": certificate_b64}
if module.params.get("revoke_reason") is not None: if module.params.get("revoke_reason") is not None:
payload["reason"] = module.params.get("revoke_reason") payload["reason"] = module.params.get("revoke_reason")
endpoint = client.directory["revokeCert"] endpoint = client.directory["revokeCert"]

View File

@@ -149,6 +149,7 @@ regular_certificate:
import base64 import base64
import datetime import datetime
import ipaddress import ipaddress
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.common.text.converters import to_bytes, to_text
@@ -173,10 +174,13 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
try: try:
import cryptography import cryptography
import cryptography.hazmat.backends import cryptography.hazmat.backends
import cryptography.hazmat.primitives.asymmetric.dh
import cryptography.hazmat.primitives.asymmetric.ec import cryptography.hazmat.primitives.asymmetric.ec
import cryptography.hazmat.primitives.asymmetric.padding import cryptography.hazmat.primitives.asymmetric.padding
import cryptography.hazmat.primitives.asymmetric.rsa import cryptography.hazmat.primitives.asymmetric.rsa
import cryptography.hazmat.primitives.asymmetric.utils import cryptography.hazmat.primitives.asymmetric.utils
import cryptography.hazmat.primitives.asymmetric.x448
import cryptography.hazmat.primitives.asymmetric.x25519
import cryptography.hazmat.primitives.hashes import cryptography.hazmat.primitives.hashes
import cryptography.hazmat.primitives.serialization import cryptography.hazmat.primitives.serialization
import cryptography.x509 import cryptography.x509
@@ -186,7 +190,7 @@ except ImportError:
# Convert byte string to ASN1 encoded octet string # Convert byte string to ASN1 encoded octet string
def encode_octet_string(octet_string): def encode_octet_string(octet_string: bytes) -> bytes:
if len(octet_string) >= 128: if len(octet_string) >= 128:
raise ModuleFailException( raise ModuleFailException(
"Cannot handle octet strings with more than 128 bytes" "Cannot handle octet strings with more than 128 bytes"
@@ -194,7 +198,7 @@ def encode_octet_string(octet_string):
return bytes([0x4, len(octet_string)]) + octet_string return bytes([0x4, len(octet_string)]) + octet_string
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
challenge=dict(type="str", required=True, choices=["tls-alpn-01"]), challenge=dict(type="str", required=True, choices=["tls-alpn-01"]),
@@ -213,16 +217,16 @@ def main():
try: try:
# Get parameters # Get parameters
challenge = module.params["challenge"] challenge: t.Literal["tls-alpn-01"] = module.params["challenge"]
challenge_data = module.params["challenge_data"] challenge_data: dict[str, t.Any] = module.params["challenge_data"]
# Get hold of private key # Get hold of private key
private_key_content = module.params.get("private_key_content") private_key_content_str: str | None = module.params["private_key_content"]
private_key_passphrase = module.params.get("private_key_passphrase") private_key_passphrase: str | None = module.params["private_key_passphrase"]
if private_key_content is None: if private_key_content_str is None:
private_key_content = read_file(module.params["private_key_src"]) private_key_content = read_file(module.params["private_key_src"])
else: else:
private_key_content = to_bytes(private_key_content) private_key_content = to_bytes(private_key_content_str)
try: try:
private_key = ( private_key = (
cryptography.hazmat.primitives.serialization.load_pem_private_key( cryptography.hazmat.primitives.serialization.load_pem_private_key(
@@ -236,6 +240,17 @@ def main():
) )
except Exception as e: except Exception as e:
raise ModuleFailException(f"Error while loading private key: {e}") raise ModuleFailException(f"Error while loading private key: {e}")
if isinstance(
private_key,
(
cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
),
):
raise ModuleFailException(
f"Cannot use private key type {type(private_key)}"
)
# Some common attributes # Some common attributes
domain = to_text(challenge_data["resource"]) domain = to_text(challenge_data["resource"])
@@ -246,6 +261,7 @@ def main():
now = get_now_datetime(with_timezone=CRYPTOGRAPHY_TIMEZONE) now = get_now_datetime(with_timezone=CRYPTOGRAPHY_TIMEZONE)
not_valid_before = now not_valid_before = now
not_valid_after = now + datetime.timedelta(days=10) not_valid_after = now + datetime.timedelta(days=10)
san: cryptography.x509.GeneralName
if identifier_type == "dns": if identifier_type == "dns":
san = cryptography.x509.DNSName(identifier) san = cryptography.x509.DNSName(identifier)
elif identifier_type == "ip": elif identifier_type == "ip":

View File

@@ -223,6 +223,8 @@ output_json:
- '...' - '...'
""" """
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import ( from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
ACMEClient, ACMEClient,
@@ -235,7 +237,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
) )
def main(): def main() -> t.NoReturn:
argument_spec = create_default_argspec(require_account_key=False) argument_spec = create_default_argspec(require_account_key=False)
argument_spec.update_argspec( argument_spec.update_argspec(
url=dict(type="str"), url=dict(type="str"),
@@ -246,17 +248,17 @@ def main():
fail_on_acme_error=dict(type="bool", default=True), fail_on_acme_error=dict(type="bool", default=True),
) )
argument_spec.update( argument_spec.update(
required_if=( required_if=[
["method", "get", ["url"]], ("method", "get", ["url"]),
["method", "post", ["url", "content"]], ("method", "post", ["url", "content"]),
["method", "get", ["account_key_src", "account_key_content"], True], ("method", "get", ["account_key_src", "account_key_content"], True),
["method", "post", ["account_key_src", "account_key_content"], True], ("method", "post", ["account_key_src", "account_key_content"], True),
), ],
) )
module = argument_spec.create_ansible_module() module = argument_spec.create_ansible_module()
backend = create_backend(module, False) backend = create_backend(module, False)
result = dict() result: dict[str, t.Any] = {}
changed = False changed = False
try: try:
# Get hold of ACMEClient and ACMEAccount objects (includes directory) # Get hold of ACMEClient and ACMEAccount objects (includes directory)

View File

@@ -121,6 +121,7 @@ complete_chain:
""" """
import os import os
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
@@ -153,14 +154,18 @@ class Certificate:
Stores PEM with parsed certificate. Stores PEM with parsed certificate.
""" """
def __init__(self, pem, cert): def __init__(self, pem: str, cert: cryptography.x509.Certificate) -> None:
if not (pem.endswith("\n") or pem.endswith("\r")): if not (pem.endswith("\n") or pem.endswith("\r")):
pem = pem + "\n" pem = pem + "\n"
self.pem = pem self.pem = pem
self.cert = cert self.cert = cert
def is_parent(module, cert, potential_parent): def is_parent(
module: AnsibleModule,
cert: Certificate,
potential_parent: Certificate,
) -> bool:
""" """
Tests whether the given certificate has been issued by the potential parent certificate. Tests whether the given certificate has been issued by the potential parent certificate.
""" """
@@ -173,6 +178,10 @@ def is_parent(module, cert, potential_parent):
if isinstance( if isinstance(
public_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey public_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey
): ):
if cert.cert.signature_hash_algorithm is None:
raise AssertionError(
"signature_hash_algorithm should be present for RSA certificates"
)
public_key.verify( public_key.verify(
cert.cert.signature, cert.cert.signature,
cert.cert.tbs_certificate_bytes, cert.cert.tbs_certificate_bytes,
@@ -183,6 +192,10 @@ def is_parent(module, cert, potential_parent):
public_key, public_key,
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey,
): ):
if cert.cert.signature_hash_algorithm is None:
raise AssertionError(
"signature_hash_algorithm should be present for EC certificates"
)
public_key.verify( public_key.verify(
cert.cert.signature, cert.cert.signature,
cert.cert.tbs_certificate_bytes, cert.cert.tbs_certificate_bytes,
@@ -213,11 +226,16 @@ def is_parent(module, cert, potential_parent):
module.fail_json(msg=f"Unknown error on signature validation: {e}") module.fail_json(msg=f"Unknown error on signature validation: {e}")
def parse_PEM_list(module, text, source, fail_on_error=True): def parse_PEM_list(
module: AnsibleModule,
text: str,
source: str | os.PathLike,
fail_on_error: bool = True,
) -> list[Certificate]:
""" """
Parse concatenated PEM certificates. Return list of ``Certificate`` objects. Parse concatenated PEM certificates. Return list of ``Certificate`` objects.
""" """
result = [] result: list[Certificate] = []
for cert_pem in split_pem_list(text): for cert_pem in split_pem_list(text):
# Try to load PEM certificate # Try to load PEM certificate
try: try:
@@ -232,7 +250,9 @@ def parse_PEM_list(module, text, source, fail_on_error=True):
return result return result
def load_PEM_list(module, path, fail_on_error=True): def load_PEM_list(
module: AnsibleModule, path: str | os.PathLike, fail_on_error: bool = True
) -> list[Certificate]:
""" """
Load concatenated PEM certificates from file. Return list of ``Certificate`` objects. Load concatenated PEM certificates from file. Return list of ``Certificate`` objects.
""" """
@@ -258,13 +278,15 @@ class CertificateSet:
Stores a set of certificates. Allows to search for parent (issuer of a certificate). Stores a set of certificates. Allows to search for parent (issuer of a certificate).
""" """
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
self.module = module self.module = module
self.certificates = set() self.certificates: set[Certificate] = set()
self.certificates_by_issuer = dict() self.certificates_by_issuer: dict[cryptography.x509.Name, list[Certificate]] = (
self.certificate_by_cert = dict() {}
)
self.certificate_by_cert: dict[cryptography.x509.Certificate, Certificate] = {}
def _load_file(self, path): def _load_file(self, path: str | os.PathLike) -> None:
certs = load_PEM_list(self.module, path, fail_on_error=False) certs = load_PEM_list(self.module, path, fail_on_error=False)
for cert in certs: for cert in certs:
self.certificates.add(cert) self.certificates.add(cert)
@@ -273,7 +295,7 @@ class CertificateSet:
self.certificates_by_issuer[cert.cert.subject].append(cert) self.certificates_by_issuer[cert.cert.subject].append(cert)
self.certificate_by_cert[cert.cert] = cert self.certificate_by_cert[cert.cert] = cert
def load(self, path): def load(self, path: str | os.PathLike) -> None:
""" """
Load lists of PEM certificates from a file or a directory. Load lists of PEM certificates from a file or a directory.
""" """
@@ -285,7 +307,7 @@ class CertificateSet:
else: else:
self._load_file(b_path) self._load_file(b_path)
def find_parent(self, cert): def find_parent(self, cert: Certificate) -> Certificate | None:
""" """
Search for the parent (issuer) of a certificate. Return ``None`` if none was found. Search for the parent (issuer) of a certificate. Return ``None`` if none was found.
""" """
@@ -296,14 +318,18 @@ class CertificateSet:
return None return None
def format_cert(cert): def format_cert(cert: Certificate) -> str:
""" """
Return human readable representation of certificate for error messages. Return human readable representation of certificate for error messages.
""" """
return str(cert.cert) return str(cert.cert)
def check_cycle(module, occured_certificates, next): def check_cycle(
module: AnsibleModule,
occured_certificates: set[cryptography.x509.Certificate],
next: Certificate,
) -> None:
""" """
Make sure that next is not in occured_certificates so far, and add it. Make sure that next is not in occured_certificates so far, and add it.
""" """
@@ -313,7 +339,7 @@ def check_cycle(module, occured_certificates, next):
occured_certificates.add(next_cert) occured_certificates.add(next_cert)
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
input_chain=dict(type="str", required=True), input_chain=dict(type="str", required=True),
@@ -354,10 +380,10 @@ def main():
roots.load(path) roots.load(path)
# Try to complete chain # Try to complete chain
current = chain[-1] current: Certificate | None = chain[-1]
completed = [] completed = []
occured_certificates = set([cert.cert for cert in chain]) occured_certificates = set([cert.cert for cert in chain])
if current.cert in roots.certificate_by_cert: if current and current.cert in roots.certificate_by_cert:
# Do not try to complete the chain when it is already ending with a root certificate # Do not try to complete the chain when it is already ending with a root certificate
current = None current = None
while current: while current:

View File

@@ -152,10 +152,13 @@ openssl:
""" """
import traceback import traceback
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
CRYPTOGRAPHY_VERSION: str | None
CRYPTOGRAPHY_IMP_ERR: str | None
try: try:
import cryptography import cryptography
from cryptography.exceptions import UnsupportedAlgorithm from cryptography.exceptions import UnsupportedAlgorithm
@@ -165,10 +168,10 @@ try:
# only got added in 0.2, so let's guard the import # only got added in 0.2, so let's guard the import
from cryptography.exceptions import InternalError as CryptographyInternalError from cryptography.exceptions import InternalError as CryptographyInternalError
except ImportError: except ImportError:
CryptographyInternalError = Exception CryptographyInternalError = Exception # type: ignore
except ImportError: except ImportError:
UnsupportedAlgorithm = Exception UnsupportedAlgorithm = Exception # type: ignore
CryptographyInternalError = Exception CryptographyInternalError = Exception # type: ignore
HAS_CRYPTOGRAPHY = False HAS_CRYPTOGRAPHY = False
CRYPTOGRAPHY_VERSION = None CRYPTOGRAPHY_VERSION = None
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc() CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -201,8 +204,8 @@ CURVES = (
) )
def add_crypto_information(module): def add_crypto_information(module: AnsibleModule) -> dict[str, t.Any]:
result = {} result: dict[str, t.Any] = {}
result["python_cryptography_installed"] = HAS_CRYPTOGRAPHY result["python_cryptography_installed"] = HAS_CRYPTOGRAPHY
if not HAS_CRYPTOGRAPHY: if not HAS_CRYPTOGRAPHY:
result["python_cryptography_import_error"] = CRYPTOGRAPHY_IMP_ERR result["python_cryptography_import_error"] = CRYPTOGRAPHY_IMP_ERR
@@ -397,9 +400,9 @@ def add_crypto_information(module):
return result return result
def add_openssl_information(module): def add_openssl_information(module: AnsibleModule) -> dict[str, t.Any]:
openssl_binary = module.get_bin_path("openssl") openssl_binary = module.get_bin_path("openssl")
result = { result: dict[str, t.Any] = {
"openssl_present": openssl_binary is not None, "openssl_present": openssl_binary is not None,
} }
if openssl_binary is None: if openssl_binary is None:
@@ -426,9 +429,9 @@ INFO_FUNCTIONS = (
) )
def main(): def main() -> t.NoReturn:
module = AnsibleModule(argument_spec={}, supports_check_mode=True) module = AnsibleModule(argument_spec={}, supports_check_mode=True)
result = {} result: dict[str, t.Any] = {}
for fn in INFO_FUNCTIONS: for fn in INFO_FUNCTIONS:
result.update(fn(module)) result.update(fn(module))
module.exit_json(**result) module.exit_json(**result)

View File

@@ -550,6 +550,7 @@ import datetime
import os import os
import re import re
import time import time
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
@@ -572,7 +573,7 @@ from ansible_collections.community.crypto.plugins.module_utils.io import write_f
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
def validate_cert_expiry(cert_expiry): def validate_cert_expiry(cert_expiry: str) -> bool:
search_string_partial = re.compile( search_string_partial = re.compile(
r"^([0-9]+)-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01])\Z" r"^([0-9]+)-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01])\Z"
) )
@@ -587,7 +588,7 @@ def validate_cert_expiry(cert_expiry):
return False return False
def calculate_cert_days(expires_after): def calculate_cert_days(expires_after: str | None) -> int:
cert_days = 0 cert_days = 0
if expires_after: if expires_after:
expires_after_datetime = datetime.datetime.strptime( expires_after_datetime = datetime.datetime.strptime(
@@ -600,7 +601,9 @@ def calculate_cert_days(expires_after):
# Populate the value of body[dict_param_name] with the JSON equivalent of # Populate the value of body[dict_param_name] with the JSON equivalent of
# module parameter of param_name if that parameter is present, otherwise leave field # module parameter of param_name if that parameter is present, otherwise leave field
# out of resulting dict # out of resulting dict
def convert_module_param_to_json_bool(module, dict_param_name, param_name): def convert_module_param_to_json_bool(
module: AnsibleModule, dict_param_name: str, param_name: str
) -> dict[str, str]:
body = {} body = {}
if module.params[param_name] is not None: if module.params[param_name] is not None:
if module.params[param_name]: if module.params[param_name]:
@@ -886,7 +889,7 @@ class EcsCertificate:
return result return result
def custom_fields_spec(): def custom_fields_spec() -> dict[str, dict[str, str]]:
return dict( return dict(
text1=dict(type="str"), text1=dict(type="str"),
text2=dict(type="str"), text2=dict(type="str"),
@@ -926,7 +929,7 @@ def custom_fields_spec():
) )
def ecs_certificate_argument_spec(): def ecs_certificate_argument_spec() -> dict[str, dict[str, t.Any]]:
return dict( return dict(
backup=dict(type="bool", default=False), backup=dict(type="bool", default=False),
force=dict(type="bool", default=False), force=dict(type="bool", default=False),
@@ -979,7 +982,7 @@ def ecs_certificate_argument_spec():
) )
def main(): def main() -> t.NoReturn:
ecs_argument_spec = ecs_client_argument_spec() ecs_argument_spec = ecs_client_argument_spec()
ecs_argument_spec.update(ecs_certificate_argument_spec()) ecs_argument_spec.update(ecs_certificate_argument_spec())
module = AnsibleModule( module = AnsibleModule(

View File

@@ -218,6 +218,7 @@ ev_days_remaining:
import datetime import datetime
import time import time
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.ecs.api import ( from ansible_collections.community.crypto.plugins.module_utils.ecs.api import (
@@ -228,7 +229,7 @@ from ansible_collections.community.crypto.plugins.module_utils.ecs.api import (
) )
def calculate_days_remaining(expiry_date): def calculate_days_remaining(expiry_date: str | None) -> int | None:
days_remaining = None days_remaining = None
if expiry_date: if expiry_date:
expiry_datetime = datetime.datetime.strptime(expiry_date, "%Y-%m-%dT%H:%M:%SZ") expiry_datetime = datetime.datetime.strptime(expiry_date, "%Y-%m-%dT%H:%M:%SZ")
@@ -403,8 +404,8 @@ class EcsDomain:
msg=f"Failed to request domain validation from Entrust (ECS) {e.message}" msg=f"Failed to request domain validation from Entrust (ECS) {e.message}"
) )
def dump(self): def dump(self) -> dict[str, t.Any]:
result = { result: dict[str, t.Any] = {
"changed": self.changed, "changed": self.changed,
"client_id": self.client_id, "client_id": self.client_id,
"domain_status": self.domain_status, "domain_status": self.domain_status,
@@ -436,7 +437,7 @@ class EcsDomain:
return result return result
def ecs_domain_argument_spec(): def ecs_domain_argument_spec() -> dict[str, dict[str, t.Any]]:
return dict( return dict(
client_id=dict(type="int", default=1), client_id=dict(type="int", default=1),
domain_name=dict(type="str", required=True), domain_name=dict(type="str", required=True),
@@ -447,7 +448,7 @@ def ecs_domain_argument_spec():
) )
def main(): def main() -> t.NoReturn:
ecs_argument_spec = ecs_client_argument_spec() ecs_argument_spec = ecs_client_argument_spec()
ecs_argument_spec.update(ecs_domain_argument_spec()) ecs_argument_spec.update(ecs_domain_argument_spec())
module = AnsibleModule( module = AnsibleModule(

View File

@@ -268,6 +268,7 @@ import atexit
import base64 import base64
import ssl import ssl
import sys import sys
import typing as t
from os.path import isfile from os.path import isfile
from socket import create_connection, setdefaulttimeout, socket from socket import create_connection, setdefaulttimeout, socket
from ssl import ( from ssl import (
@@ -305,7 +306,7 @@ except ImportError:
pass pass
def send_starttls_packet(sock, server_type): def send_starttls_packet(sock: socket, server_type: t.Literal["mysql"]) -> None:
if server_type == "mysql": if server_type == "mysql":
ssl_request_packet = ( ssl_request_packet = (
b"\x20\x00\x00\x01\x85\xae\x7f\x00" b"\x20\x00\x00\x01\x85\xae\x7f\x00"
@@ -321,7 +322,7 @@ def send_starttls_packet(sock, server_type):
sock.send(ssl_request_packet) sock.send(ssl_request_packet)
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
ca_cert=dict(type="path"), ca_cert=dict(type="path"),
@@ -342,18 +343,18 @@ def main():
), ),
) )
ca_cert = module.params.get("ca_cert") ca_cert: str | None = module.params.get("ca_cert")
host = module.params.get("host") host: str = module.params.get("host")
port = module.params.get("port") port: int = module.params.get("port")
proxy_host = module.params.get("proxy_host") proxy_host: str | None = module.params.get("proxy_host")
proxy_port = module.params.get("proxy_port") proxy_port: int | None = module.params.get("proxy_port")
timeout = module.params.get("timeout") timeout: int = module.params.get("timeout")
server_name = module.params.get("server_name") server_name: str | None = module.params.get("server_name")
start_tls_server_type = module.params.get("starttls") start_tls_server_type: t.Literal["mysql"] | None = module.params.get("starttls")
ciphers = module.params.get("ciphers") ciphers: list[str] | None = module.params.get("ciphers")
asn1_base64 = module.params["asn1_base64"] asn1_base64: bool = module.params["asn1_base64"]
tls_ctx_options = module.params["tls_ctx_options"] tls_ctx_options: list[str | bytes | int] | None = module.params["tls_ctx_options"]
get_certificate_chain = module.params["get_certificate_chain"] get_certificate_chain: bool = module.params["get_certificate_chain"]
if get_certificate_chain and sys.version_info < (3, 10): if get_certificate_chain and sys.version_info < (3, 10):
module.fail_json( module.fail_json(
@@ -365,9 +366,9 @@ def main():
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )
result = dict( result: dict[str, t.Any] = {
changed=False, "changed": False,
) }
if timeout: if timeout:
setdefaulttimeout(timeout) setdefaulttimeout(timeout)
@@ -409,7 +410,7 @@ def main():
if tls_ctx_options is not None: if tls_ctx_options is not None:
# Clear default ctx options # Clear default ctx options
ctx.options = 0 ctx.options = 0 # type: ignore
# For each item in the tls_ctx_options list # For each item in the tls_ctx_options list
for tls_ctx_option in tls_ctx_options: for tls_ctx_option in tls_ctx_options:
@@ -450,8 +451,10 @@ def main():
) )
tls_sock = ctx.wrap_socket(sock, server_hostname=server_name or host) tls_sock = ctx.wrap_socket(sock, server_hostname=server_name or host)
cert = tls_sock.getpeercert(True) cert_der = tls_sock.getpeercert(True)
cert = DER_cert_to_PEM_cert(cert) if cert_der is None:
raise Exception("Unexpected error: no peer certificate has been returned")
cert: str = DER_cert_to_PEM_cert(cert_der)
if get_certificate_chain: if get_certificate_chain:
if sys.version_info < (3, 13): if sys.version_info < (3, 13):
@@ -474,7 +477,7 @@ def main():
# Python 3.13 do not return lists of byte strings, but lists of _ssl.Certificate objects. This is going to # Python 3.13 do not return lists of byte strings, but lists of _ssl.Certificate objects. This is going to
# be fixed by https://github.com/python/cpython/pull/118669. For now we convert the certificates ourselves # be fixed by https://github.com/python/cpython/pull/118669. For now we convert the certificates ourselves
# if they are not byte strings to work around this. # if they are not byte strings to work around this.
def _convert_chain(chain): def _convert_chain(chain: list[bytes]) -> list[bytes]:
return [ return [
( (
c c
@@ -514,13 +517,13 @@ def main():
result["extensions"] = [] result["extensions"] = []
for dotted_number, entry in cryptography_get_extensions_from_cert(x509).items(): for dotted_number, entry in cryptography_get_extensions_from_cert(x509).items():
oid = cryptography.x509.oid.ObjectIdentifier(dotted_number) oid = cryptography.x509.oid.ObjectIdentifier(dotted_number)
ext = { ext: dict[str, t.Any] = {
"critical": entry["critical"], "critical": entry["critical"],
"asn1_data": entry["value"], "asn1_data": entry["value"],
"name": cryptography_oid_to_name(oid, short=True), "name": cryptography_oid_to_name(oid, short=True),
} }
if not asn1_base64: if not asn1_base64:
ext["asn1_data"] = base64.b64decode(ext["asn1_data"]) ext["asn1_data"] = base64.b64decode(entry["value"]) # type: ignore
result["extensions"].append(ext) result["extensions"].append(ext)
result["issuer"] = {} result["issuer"] = {}

View File

@@ -420,16 +420,13 @@ name:
import os import os
import re import re
import stat import stat
import typing as t
from base64 import b64decode from base64 import b64decode
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible.module_utils.common.text.converters import to_bytes, to_native
RETURN_CODE = 0
STDOUT = 1
STDERR = 2
# used to get <luks-name> out of lsblk output in format 'crypt <luks-name>' # used to get <luks-name> out of lsblk output in format 'crypt <luks-name>'
# regex takes care of any possible blank characters # regex takes care of any possible blank characters
LUKS_NAME_REGEX = re.compile(r"^crypt\s+([^\s]*)\s*$") LUKS_NAME_REGEX = re.compile(r"^crypt\s+([^\s]*)\s*$")
@@ -456,7 +453,7 @@ LUKS2_HEADER_OFFSETS = [
LUKS2_HEADER2 = b"SKUL\xba\xbe" LUKS2_HEADER2 = b"SKUL\xba\xbe"
def wipe_luks_headers(device): def wipe_luks_headers(device: str) -> None:
wipe_offsets = [] wipe_offsets = []
with open(device, "rb") as f: with open(device, "rb") as f:
# f.seek(0) # f.seek(0)
@@ -478,12 +475,12 @@ def wipe_luks_headers(device):
class Handler: class Handler:
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
self._module = module self._module = module
self._lsblk_bin = self._module.get_bin_path("lsblk", True) self._lsblk_bin = self._module.get_bin_path("lsblk", True)
self._passphrase_encoding = module.params["passphrase_encoding"] self._passphrase_encoding = module.params["passphrase_encoding"]
def get_passphrase_from_module_params(self, parameter_name): def get_passphrase_from_module_params(self, parameter_name: str) -> bytes | None:
passphrase = self._module.params[parameter_name] passphrase = self._module.params[parameter_name]
if passphrase is None: if passphrase is None:
return None return None
@@ -496,88 +493,91 @@ class Handler:
f"Error while base64-decoding '{parameter_name}': {exc}" f"Error while base64-decoding '{parameter_name}': {exc}"
) )
def _run_command(self, command, data=None): def _run_command(
self, command: list[str], data: bytes | None = None
) -> tuple[int, str, str]:
return self._module.run_command(command, data=data, binary_data=True) return self._module.run_command(command, data=data, binary_data=True)
def get_device_by_uuid(self, uuid): def get_device_by_uuid(self, uuid: str | None) -> str | None:
"""Returns the device that holds UUID passed by user""" """Returns the device that holds UUID passed by user"""
self._blkid_bin = self._module.get_bin_path("blkid", True) self._blkid_bin = self._module.get_bin_path("blkid", True)
uuid = self._module.params["uuid"]
if uuid is None: if uuid is None:
return None return None
result = self._run_command([self._blkid_bin, "--uuid", uuid]) rc, stdout, dummy = self._run_command([self._blkid_bin, "--uuid", uuid])
if result[RETURN_CODE] != 0: if rc != 0:
return None return None
return result[STDOUT].strip() return stdout.strip()
def get_device_by_label(self, label): def get_device_by_label(self, label: str) -> str | None:
"""Returns the device that holds label passed by user""" """Returns the device that holds label passed by user"""
self._blkid_bin = self._module.get_bin_path("blkid", True) self._blkid_bin = self._module.get_bin_path("blkid", True)
label = self._module.params["label"] label = self._module.params["label"]
if label is None: if label is None:
return None return None
result = self._run_command([self._blkid_bin, "--label", label]) rc, stdout, dummy = self._run_command([self._blkid_bin, "--label", label])
if result[RETURN_CODE] != 0: if rc != 0:
return None return None
return result[STDOUT].strip() return stdout.strip()
def generate_luks_name(self, device): def generate_luks_name(self, device: str) -> str:
"""Generate name for luks based on device UUID ('luks-<UUID>'). """Generate name for luks based on device UUID ('luks-<UUID>').
Raises ValueError when obtaining of UUID fails. Raises ValueError when obtaining of UUID fails.
""" """
result = self._run_command([self._lsblk_bin, "-n", device, "-o", "UUID"]) rc, stdout, stderr = self._run_command(
[self._lsblk_bin, "-n", device, "-o", "UUID"]
)
if result[RETURN_CODE] != 0: if rc != 0:
raise ValueError( raise ValueError(f"Error while generating LUKS name for {device}: {stderr}")
f"Error while generating LUKS name for {device}: {result[STDERR]}" dev_uuid = stdout.strip()
)
dev_uuid = result[STDOUT].strip()
return f"luks-{dev_uuid}" return f"luks-{dev_uuid}"
class CryptHandler(Handler): class CryptHandler(Handler):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(CryptHandler, self).__init__(module) super(CryptHandler, self).__init__(module)
self._cryptsetup_bin = self._module.get_bin_path("cryptsetup", True) self._cryptsetup_bin = self._module.get_bin_path("cryptsetup", True)
def get_container_name_by_device(self, device): def get_container_name_by_device(self, device: str) -> str | None:
"""obtain LUKS container name based on the device where it is located """obtain LUKS container name based on the device where it is located
return None if not found return None if not found
raise ValueError if lsblk command fails raise ValueError if lsblk command fails
""" """
result = self._run_command([self._lsblk_bin, device, "-nlo", "type,name"]) rc, stdout, stderr = self._run_command(
if result[RETURN_CODE] != 0: [self._lsblk_bin, device, "-nlo", "type,name"]
raise ValueError( )
f"Error while obtaining LUKS name for {device}: {result[STDERR]}" if rc != 0:
) raise ValueError(f"Error while obtaining LUKS name for {device}: {stderr}")
for line in result[STDOUT].splitlines(False): for line in stdout.splitlines(False):
m = LUKS_NAME_REGEX.match(line) m = LUKS_NAME_REGEX.match(line)
if m: if m:
return m.group(1) return m.group(1)
return None return None
def get_container_device_by_name(self, name): def get_container_device_by_name(self, name: str) -> str | None:
"""obtain device name based on the LUKS container name """obtain device name based on the LUKS container name
return None if not found return None if not found
raise ValueError if lsblk command fails raise ValueError if lsblk command fails
""" """
# apparently each device can have only one LUKS container on it # apparently each device can have only one LUKS container on it
result = self._run_command([self._cryptsetup_bin, "status", name]) rc, stdout, dummy = self._run_command([self._cryptsetup_bin, "status", name])
if result[RETURN_CODE] != 0: if rc != 0:
return None return None
m = LUKS_DEVICE_REGEX.search(result[STDOUT]) m = LUKS_DEVICE_REGEX.search(stdout)
if not m:
return None
device = m.group(1) device = m.group(1)
return device return device
def is_luks(self, device): def is_luks(self, device: str) -> bool:
"""check if the LUKS container does exist""" """check if the LUKS container does exist"""
result = self._run_command([self._cryptsetup_bin, "isLuks", device]) rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "isLuks", device])
return result[RETURN_CODE] == 0 return rc == 0
def get_luks_type(self, device): def get_luks_type(self, device: str) -> t.Literal["luks1", "luks2"] | None:
"""get the luks type of a device""" """get the luks type of a device"""
if self.is_luks(device): if self.is_luks(device):
with open(device, "rb") as f: with open(device, "rb") as f:
@@ -589,16 +589,18 @@ class CryptHandler(Handler):
return "luks1" return "luks1"
return None return None
def is_luks_slot_set(self, device, keyslot): def is_luks_slot_set(self, device: str, keyslot: int) -> bool:
"""check if a keyslot is set""" """check if a keyslot is set"""
result = self._run_command([self._cryptsetup_bin, "luksDump", device]) rc, stdout, dummy = self._run_command(
if result[RETURN_CODE] != 0: [self._cryptsetup_bin, "luksDump", device]
)
if rc != 0:
raise ValueError(f"Error while dumping LUKS header from {device}") raise ValueError(f"Error while dumping LUKS header from {device}")
result_luks1 = f"Key Slot {keyslot}: ENABLED" in result[STDOUT] result_luks1 = f"Key Slot {keyslot}: ENABLED" in stdout
result_luks2 = f" {keyslot}: luks2" in result[STDOUT] result_luks2 = f" {keyslot}: luks2" in stdout
return result_luks1 or result_luks2 return result_luks1 or result_luks2
def _add_pbkdf_options(self, options, pbkdf): def _add_pbkdf_options(self, options: list[str], pbkdf: dict[str, t.Any]) -> None:
if pbkdf["iteration_time"] is not None: if pbkdf["iteration_time"] is not None:
options.extend(["--iter-time", str(int(pbkdf["iteration_time"] * 1000))]) options.extend(["--iter-time", str(int(pbkdf["iteration_time"] * 1000))])
if pbkdf["iteration_count"] is not None: if pbkdf["iteration_count"] is not None:
@@ -612,16 +614,16 @@ class CryptHandler(Handler):
def run_luks_create( def run_luks_create(
self, self,
device, device: str,
keyfile, keyfile: str | None,
passphrase, passphrase: bytes | None,
keyslot, keyslot: int | None,
keysize, keysize: int | None,
cipher, cipher: str | None,
hash_, hash_: str | None,
sector_size, sector_size: str | None,
pbkdf, pbkdf: dict[str, t.Any] | None,
): ) -> None:
# create a new luks container; use batch mode to auto confirm # create a new luks container; use batch mode to auto confirm
luks_type = self._module.params["type"] luks_type = self._module.params["type"]
label = self._module.params["label"] label = self._module.params["label"]
@@ -653,23 +655,23 @@ class CryptHandler(Handler):
else: else:
args.append("-") args.append("-")
result = self._run_command(args, data=passphrase) rc, dummy, stderr = self._run_command(args, data=passphrase)
if result[RETURN_CODE] != 0: if rc != 0:
raise ValueError(f"Error while creating LUKS on {device}: {result[STDERR]}") raise ValueError(f"Error while creating LUKS on {device}: {stderr}")
def run_luks_open( def run_luks_open(
self, self,
device, device: str,
keyfile, keyfile: str | None,
passphrase, passphrase: bytes | None,
perf_same_cpu_crypt, perf_same_cpu_crypt: bool,
perf_submit_from_crypt_cpus, perf_submit_from_crypt_cpus: bool,
perf_no_read_workqueue, perf_no_read_workqueue: bool,
perf_no_write_workqueue, perf_no_write_workqueue: bool,
persistent, persistent: bool,
allow_discards, allow_discards: bool,
name, name: str,
): ) -> None:
args = [self._cryptsetup_bin] args = [self._cryptsetup_bin]
if keyfile: if keyfile:
args.extend(["--key-file", keyfile]) args.extend(["--key-file", keyfile])
@@ -689,27 +691,27 @@ class CryptHandler(Handler):
args.extend(["--allow-discards"]) args.extend(["--allow-discards"])
args.extend(["open", "--type", "luks", device, name]) args.extend(["open", "--type", "luks", device, name])
result = self._run_command(args, data=passphrase) rc, dummy, stderr = self._run_command(args, data=passphrase)
if result[RETURN_CODE] != 0: if rc != 0:
raise ValueError( raise ValueError(
f"Error while opening LUKS container on {device}: {result[STDERR]}" f"Error while opening LUKS container on {device}: {stderr}"
) )
def run_luks_close(self, name): def run_luks_close(self, name: str) -> None:
result = self._run_command([self._cryptsetup_bin, "close", name]) rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "close", name])
if result[RETURN_CODE] != 0: if rc != 0:
raise ValueError(f"Error while closing LUKS container {name}") raise ValueError(f"Error while closing LUKS container {name}")
def run_luks_remove(self, device): def run_luks_remove(self, device: str) -> None:
wipefs_bin = self._module.get_bin_path("wipefs", True) wipefs_bin = self._module.get_bin_path("wipefs", True)
name = self.get_container_name_by_device(device) name = self.get_container_name_by_device(device)
if name is not None: if name is not None:
self.run_luks_close(name) self.run_luks_close(name)
result = self._run_command([wipefs_bin, "--all", device]) rc, dummy, stderr = self._run_command([wipefs_bin, "--all", device])
if result[RETURN_CODE] != 0: if rc != 0:
raise ValueError( raise ValueError(
f"Error while wiping LUKS container signatures for {device}: {result[STDERR]}" f"Error while wiping LUKS container signatures for {device}: {stderr}"
) )
# For LUKS2, sometimes both `cryptsetup erase` and `wipefs` do **not** # For LUKS2, sometimes both `cryptsetup erase` and `wipefs` do **not**
@@ -724,14 +726,14 @@ class CryptHandler(Handler):
def run_luks_add_key( def run_luks_add_key(
self, self,
device, device: str,
keyfile, keyfile: str | None,
passphrase, passphrase: bytes | None,
new_keyfile, new_keyfile: str | None,
new_passphrase, new_passphrase: bytes | None,
new_keyslot, new_keyslot: int | None,
pbkdf, pbkdf: dict[str, t.Any] | None,
): ) -> None:
"""Add new key from a keyfile or passphrase to given 'device'; """Add new key from a keyfile or passphrase to given 'device';
authentication done using 'keyfile' or 'passphrase'. authentication done using 'keyfile' or 'passphrase'.
Raises ValueError when command fails. Raises ValueError when command fails.
@@ -746,36 +748,47 @@ class CryptHandler(Handler):
if keyfile: if keyfile:
args.extend(["--key-file", keyfile]) args.extend(["--key-file", keyfile])
else: elif passphrase is not None:
args.extend(["--key-file", "-", "--keyfile-size", str(len(passphrase))]) args.extend(["--key-file", "-", "--keyfile-size", str(len(passphrase))])
data.append(passphrase) data.append(passphrase)
else:
raise ValueError("Need passphrase or keyfile")
if new_keyfile: if new_keyfile:
args.append(new_keyfile) args.append(new_keyfile)
else: elif new_passphrase is not None:
args.append("-") args.append("-")
data.append(new_passphrase) data.append(new_passphrase)
else:
raise ValueError("Need new passphrase or new keyfile")
result = self._run_command(args, data=b"".join(data) or None) rc, dummy, stderr = self._run_command(args, data=b"".join(data) or None)
if result[RETURN_CODE] != 0: if rc != 0:
raise ValueError( raise ValueError(
f"Error while adding new LUKS keyslot to {device}: {result[STDERR]}" f"Error while adding new LUKS keyslot to {device}: {stderr}"
) )
def run_luks_remove_key( def run_luks_remove_key(
self, device, keyfile, passphrase, keyslot, force_remove_last_key=False self,
): device: str,
keyfile: str | None,
passphrase: bytes | None,
keyslot: int | None,
force_remove_last_key: bool = False,
) -> None:
"""Remove key from given device """Remove key from given device
Raises ValueError when command fails Raises ValueError when command fails
""" """
if not force_remove_last_key: if not force_remove_last_key:
result = self._run_command([self._cryptsetup_bin, "luksDump", device]) rc, stdout, dummy = self._run_command(
if result[RETURN_CODE] != 0: [self._cryptsetup_bin, "luksDump", device]
)
if rc != 0:
raise ValueError(f"Error while dumping LUKS header from {device}") raise ValueError(f"Error while dumping LUKS header from {device}")
keyslot_count = 0 keyslot_count = 0
keyslot_area = False keyslot_area = False
keyslot_re = re.compile(r"^Key Slot [0-9]+: ENABLED") keyslot_re = re.compile(r"^Key Slot [0-9]+: ENABLED")
for line in result[STDOUT].splitlines(): for line in stdout.splitlines():
if line.startswith("Keyslots:"): if line.startswith("Keyslots:"):
keyslot_area = True keyslot_area = True
elif line.startswith(" "): elif line.startswith(" "):
@@ -808,13 +821,17 @@ class CryptHandler(Handler):
# Since we supply -q no passphrase is needed # Since we supply -q no passphrase is needed
args = [self._cryptsetup_bin, "luksKillSlot", device, "-q", str(keyslot)] args = [self._cryptsetup_bin, "luksKillSlot", device, "-q", str(keyslot)]
passphrase = None passphrase = None
result = self._run_command(args, data=passphrase) rc, dummy, stderr = self._run_command(args, data=passphrase)
if result[RETURN_CODE] != 0: if rc != 0:
raise ValueError( raise ValueError(f"Error while removing LUKS key from {device}: {stderr}")
f"Error while removing LUKS key from {device}: {result[STDERR]}"
)
def luks_test_key(self, device, keyfile, passphrase, keyslot=None): def luks_test_key(
self,
device: str,
keyfile: str | None,
passphrase: bytes | None,
keyslot: int | None = None,
) -> bool:
"""Check whether the keyfile or passphrase works. """Check whether the keyfile or passphrase works.
Raises ValueError when command fails. Raises ValueError when command fails.
""" """
@@ -830,42 +847,37 @@ class CryptHandler(Handler):
if keyslot is not None: if keyslot is not None:
args.extend(["--key-slot", str(keyslot)]) args.extend(["--key-slot", str(keyslot)])
result = self._run_command(args, data=data) rc, stdout, stderr = self._run_command(args, data=data)
if result[RETURN_CODE] == 0: if rc == 0:
return True return True
for output in (STDOUT, STDERR): for output in (stdout, stderr):
if "No key available with this passphrase" in result[output]: if "No key available with this passphrase" in output:
return False return False
if "No usable keyslot is available." in result[output]: if "No usable keyslot is available." in output:
return False return False
# This check is necessary due to cryptsetup in version 2.0.3 not printing 'No usable keyslot is available' # This check is necessary due to cryptsetup in version 2.0.3 not printing 'No usable keyslot is available'
# when using the --key-slot parameter in combination with --test-passphrase # when using the --key-slot parameter in combination with --test-passphrase
if ( if rc == 1 and keyslot is not None and stdout == "" and stderr == "":
result[RETURN_CODE] == 1
and keyslot is not None
and result[STDOUT] == ""
and result[STDERR] == ""
):
return False return False
raise ValueError( raise ValueError(
f"Error while testing whether keyslot exists on {device}: {result[STDERR]}" f"Error while testing whether keyslot exists on {device}: {stderr}"
) )
class ConditionsHandler(Handler): class ConditionsHandler(Handler):
def __init__(self, module, crypthandler): def __init__(self, module: AnsibleModule, crypthandler: CryptHandler) -> None:
super(ConditionsHandler, self).__init__(module) super(ConditionsHandler, self).__init__(module)
self._crypthandler = crypthandler self._crypthandler = crypthandler
self.device = self.get_device_name() self.device = self.get_device_name()
def get_device_name(self): def get_device_name(self) -> str | None:
device = self._module.params.get("device") device: str | None = self._module.params.get("device")
label = self._module.params.get("label") label: str | None = self._module.params.get("label")
uuid = self._module.params.get("uuid") uuid: str | None = self._module.params.get("uuid")
name = self._module.params.get("name") name: str | None = self._module.params.get("name")
if device is None and label is not None: if device is None and label is not None:
device = self.get_device_by_label(label) device = self.get_device_by_label(label)
@@ -876,7 +888,7 @@ class ConditionsHandler(Handler):
return device return device
def luks_create(self): def luks_create(self) -> bool:
return ( return (
self.device is not None self.device is not None
and ( and (
@@ -887,7 +899,7 @@ class ConditionsHandler(Handler):
and not self._crypthandler.is_luks(self.device) and not self._crypthandler.is_luks(self.device)
) )
def opened_luks_name(self): def opened_luks_name(self, device: str) -> str | None:
"""If luks is already opened, return its name. """If luks is already opened, return its name.
If 'name' parameter is specified and differs If 'name' parameter is specified and differs
from obtained value, fail. from obtained value, fail.
@@ -897,7 +909,7 @@ class ConditionsHandler(Handler):
return None return None
# try to obtain luks name - it may be already opened # try to obtain luks name - it may be already opened
name = self._crypthandler.get_container_name_by_device(self.device) name = self._crypthandler.get_container_name_by_device(device)
if name is None: if name is None:
# container is not open # container is not open
@@ -917,7 +929,7 @@ class ConditionsHandler(Handler):
# container is opened and the names match # container is opened and the names match
return name return name
def luks_open(self): def luks_open(self) -> bool:
if ( if (
( (
self._module.params["keyfile"] is None self._module.params["keyfile"] is None
@@ -929,13 +941,13 @@ class ConditionsHandler(Handler):
# conditions for open not fulfilled # conditions for open not fulfilled
return False return False
name = self.opened_luks_name() name = self.opened_luks_name(self.device)
if name is None: if name is None:
return True return True
return False return False
def luks_close(self): def luks_close(self) -> bool:
if ( if (
self._module.params["name"] is None and self.device is None self._module.params["name"] is None and self.device is None
) or self._module.params["state"] != "closed": ) or self._module.params["state"] != "closed":
@@ -948,15 +960,17 @@ class ConditionsHandler(Handler):
luks_is_open = name is not None luks_is_open = name is not None
if self._module.params["name"] is not None: if self._module.params["name"] is not None:
self.device = self._crypthandler.get_container_device_by_name( device = self._crypthandler.get_container_device_by_name(
self._module.params["name"] self._module.params["name"]
) )
# successfully getting device based on name means that luks is open # successfully getting device based on name means that luks is open
luks_is_open = self.device is not None luks_is_open = device is not None
if device is not None:
self.device = device
return luks_is_open return luks_is_open
def luks_add_key(self): def luks_add_key(self) -> bool:
if ( if (
self.device is None self.device is None
or ( or (
@@ -995,7 +1009,7 @@ class ConditionsHandler(Handler):
return not key_present return not key_present
def luks_remove_key(self): def luks_remove_key(self) -> bool:
if self.device is None or ( if self.device is None or (
self._module.params["remove_keyfile"] is None self._module.params["remove_keyfile"] is None
and self._module.params["remove_passphrase"] is None and self._module.params["remove_passphrase"] is None
@@ -1037,14 +1051,16 @@ class ConditionsHandler(Handler):
self.get_passphrase_from_module_params("remove_passphrase"), self.get_passphrase_from_module_params("remove_passphrase"),
) )
def luks_remove(self): def luks_remove(self) -> bool:
return ( return (
self.device is not None self.device is not None
and self._module.params["state"] == "absent" and self._module.params["state"] == "absent"
and self._crypthandler.is_luks(self.device) and self._crypthandler.is_luks(self.device)
) )
def validate_keyslot(self, param, luks_type): def validate_keyslot(
self, param: str, luks_type: t.Literal["luks1", "luks2"] | None
) -> None:
if self._module.params[param] is not None: if self._module.params[param] is not None:
if luks_type is None and param == "keyslot": if luks_type is None and param == "keyslot":
if 8 <= self._module.params[param] <= 31: if 8 <= self._module.params[param] <= 31:
@@ -1066,7 +1082,7 @@ class ConditionsHandler(Handler):
) )
def run_module(): def run_module() -> t.NoReturn:
# available arguments/parameters that a user can pass # available arguments/parameters that a user can pass
module_args = dict( module_args = dict(
state=dict( state=dict(
@@ -1122,7 +1138,7 @@ def run_module():
] ]
# seed the result dict in the object # seed the result dict in the object
result = dict(changed=False, name=None) result: dict[str, t.Any] = {"changed": False, "name": None}
module = AnsibleModule( module = AnsibleModule(
argument_spec=module_args, argument_spec=module_args,
@@ -1142,19 +1158,26 @@ def run_module():
except Exception as e: except Exception as e:
module.fail_json(msg=str(e)) module.fail_json(msg=str(e))
crypt = CryptHandler(module)
conditions = ConditionsHandler(module, crypt)
# conditions not allowed to run # conditions not allowed to run
if module.params["label"] is not None and module.params["type"] == "luks1": if module.params["label"] is not None and module.params["type"] == "luks1":
module.fail_json(msg="You cannot combine type luks1 with the label option.") module.fail_json(msg="You cannot combine type luks1 with the label option.")
crypt = CryptHandler(module)
try:
conditions = ConditionsHandler(module, crypt)
except ValueError as exc:
module.fail_json(msg=str(exc))
if ( if (
module.params["keyslot"] is not None module.params["keyslot"] is not None
or module.params["new_keyslot"] is not None or module.params["new_keyslot"] is not None
or module.params["remove_keyslot"] is not None or module.params["remove_keyslot"] is not None
): ):
luks_type = crypt.get_luks_type(conditions.get_device_name()) luks_type = (
crypt.get_luks_type(conditions.device)
if conditions.device is not None
else None
)
if luks_type is None and module.params["type"] is not None: if luks_type is None and module.params["type"] is not None:
luks_type = module.params["type"] luks_type = module.params["type"]
for param in ["keyslot", "new_keyslot", "remove_keyslot"]: for param in ["keyslot", "new_keyslot", "remove_keyslot"]:
@@ -1175,6 +1198,7 @@ def run_module():
# luks create # luks create
if conditions.luks_create(): if conditions.luks_create():
assert conditions.device # ensured in conditions.luks_create()
if not module.check_mode: if not module.check_mode:
try: try:
crypt.run_luks_create( crypt.run_luks_create(
@@ -1196,11 +1220,13 @@ def run_module():
# luks open # luks open
name = conditions.opened_luks_name() if conditions.device is not None:
if name is not None: name = conditions.opened_luks_name(conditions.device)
result["name"] = name if name is not None:
result["name"] = name
if conditions.luks_open(): if conditions.luks_open():
assert conditions.device # ensured in conditions.luks_open()
name = module.params["name"] name = module.params["name"]
if name is None: if name is None:
try: try:
@@ -1237,6 +1263,8 @@ def run_module():
module.fail_json(msg=f"luks_device error: {e}") module.fail_json(msg=f"luks_device error: {e}")
else: else:
name = module.params["name"] name = module.params["name"]
if name is None:
module.fail_json(msg="Cannot determine name to close device")
if not module.check_mode: if not module.check_mode:
try: try:
crypt.run_luks_close(name) crypt.run_luks_close(name)
@@ -1249,6 +1277,7 @@ def run_module():
# luks add key # luks add key
if conditions.luks_add_key(): if conditions.luks_add_key():
assert conditions.device # ensured in conditions.luks_add_key()
if not module.check_mode: if not module.check_mode:
try: try:
crypt.run_luks_add_key( crypt.run_luks_add_key(
@@ -1268,6 +1297,7 @@ def run_module():
# luks remove key # luks remove key
if conditions.luks_remove_key(): if conditions.luks_remove_key():
assert conditions.device # ensured in conditions.luks_remove_key()
if not module.check_mode: if not module.check_mode:
try: try:
last_key = module.params["force_remove_last_key"] last_key = module.params["force_remove_last_key"]
@@ -1286,6 +1316,7 @@ def run_module():
# luks remove # luks remove
if conditions.luks_remove(): if conditions.luks_remove():
assert conditions.device # ensured in conditions.luks_remove()
if not module.check_mode: if not module.check_mode:
try: try:
crypt.run_luks_remove(conditions.device) crypt.run_luks_remove(conditions.device)
@@ -1299,7 +1330,7 @@ def run_module():
module.exit_json(**result) module.exit_json(**result)
def main(): def main() -> t.NoReturn:
run_module() run_module()

View File

@@ -284,6 +284,7 @@ info:
""" """
import os import os
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.common import ( from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.common import (
@@ -302,47 +303,58 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
class Certificate(OpensshModule): class Certificate(OpensshModule):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(Certificate, self).__init__(module) super(Certificate, self).__init__(module)
self.ssh_keygen = KeygenCommand(self.module) self.ssh_keygen = KeygenCommand(self.module)
self.identifier = self.module.params["identifier"] or "" self.identifier: str = self.module.params["identifier"] or ""
self.options = self.module.params["options"] or [] self.options: list[str] = self.module.params["options"] or []
self.path = self.module.params["path"] self.path: str = self.module.params["path"]
self.pkcs11_provider = self.module.params["pkcs11_provider"] self.pkcs11_provider: str | None = self.module.params["pkcs11_provider"]
self.principals = self.module.params["principals"] or [] self.principals: list[str] = self.module.params["principals"] or []
self.public_key = self.module.params["public_key"] self.public_key: str | None = self.module.params["public_key"]
self.regenerate = ( self.regenerate: t.Literal[
"never",
"fail",
"partial_idempotence",
"full_idempotence",
"always",
] = (
self.module.params["regenerate"] self.module.params["regenerate"]
if not self.module.params["force"] if not self.module.params["force"]
else "always" else "always"
) )
self.serial_number = self.module.params["serial_number"] self.serial_number: int | None = self.module.params["serial_number"]
self.signature_algorithm = self.module.params["signature_algorithm"] self.signature_algorithm: (
self.signing_key = self.module.params["signing_key"] t.Literal["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"] | None
self.state = self.module.params["state"] ) = self.module.params["signature_algorithm"]
self.type = self.module.params["type"] self.signing_key: str | None = self.module.params["signing_key"]
self.use_agent = self.module.params["use_agent"] self.state: t.Literal["absent", "present"] = self.module.params["state"]
self.valid_at = self.module.params["valid_at"] self.type: t.Literal["host", "user"] | None = self.module.params["type"]
self.ignore_timestamps = self.module.params["ignore_timestamps"] self.use_agent: bool = self.module.params["use_agent"]
self.valid_at: str | None = self.module.params["valid_at"]
self.ignore_timestamps: bool = self.module.params["ignore_timestamps"]
self._check_if_base_dir(self.path) self._check_if_base_dir(self.path)
if self.state == "present": if self.state == "present":
self._validate_parameters() self._validate_parameters()
self.data = None self.data: OpensshCertificate | None = None
self.original_data = None self.original_data: OpensshCertificate | None = None
if self._exists(): if self._exists():
self._load_certificate() self._load_certificate()
self.time_parameters = None self.time_parameters: OpensshCertificateTimeParameters | None = None
if self.state == "present": if self.state == "present":
self._set_time_parameters() self._set_time_parameters()
def _validate_parameters(self): def _validate_parameters(self) -> None:
for path in (self.public_key, self.signing_key): for path in (self.public_key, self.signing_key):
self._check_if_base_dir(path) if (
path is not None
): # should never be None, but the type checker doesn't know
self._check_if_base_dir(path)
if self.options and self.type == "host": if self.options and self.type == "host":
self.module.fail_json( self.module.fail_json(
@@ -352,7 +364,7 @@ class Certificate(OpensshModule):
if self.use_agent: if self.use_agent:
self._use_agent_available() self._use_agent_available()
def _use_agent_available(self): def _use_agent_available(self) -> None:
ssh_version = self._get_ssh_version() ssh_version = self._get_ssh_version()
if not ssh_version: if not ssh_version:
self.module.fail_json(msg="Failed to determine ssh version") self.module.fail_json(msg="Failed to determine ssh version")
@@ -362,10 +374,10 @@ class Certificate(OpensshModule):
+ f" Your version is: {ssh_version}" + f" Your version is: {ssh_version}"
) )
def _exists(self): def _exists(self) -> bool:
return os.path.exists(self.path) return os.path.exists(self.path)
def _load_certificate(self): def _load_certificate(self) -> None:
try: try:
self.original_data = OpensshCertificate.load(self.path) self.original_data = OpensshCertificate.load(self.path)
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
@@ -373,7 +385,7 @@ class Certificate(OpensshModule):
self.module.fail_json(msg=f"Unable to read existing certificate: {e}") self.module.fail_json(msg=f"Unable to read existing certificate: {e}")
self.module.warn(f"Unable to read existing certificate: {e}") self.module.warn(f"Unable to read existing certificate: {e}")
def _set_time_parameters(self): def _set_time_parameters(self) -> None:
try: try:
self.time_parameters = OpensshCertificateTimeParameters( self.time_parameters = OpensshCertificateTimeParameters(
valid_from=self.module.params["valid_from"], valid_from=self.module.params["valid_from"],
@@ -382,7 +394,7 @@ class Certificate(OpensshModule):
except ValueError as e: except ValueError as e:
self.module.fail_json(msg=str(e)) self.module.fail_json(msg=str(e))
def _execute(self): def _execute(self) -> None:
if self.state == "present": if self.state == "present":
if self._should_generate(): if self._should_generate():
self._generate() self._generate()
@@ -391,7 +403,7 @@ class Certificate(OpensshModule):
if self._exists(): if self._exists():
self._remove() self._remove()
def _should_generate(self): def _should_generate(self) -> bool:
if self.regenerate == "never": if self.regenerate == "never":
return self.original_data is None return self.original_data is None
elif self.regenerate == "fail": elif self.regenerate == "fail":
@@ -408,7 +420,13 @@ class Certificate(OpensshModule):
else: else:
return True return True
def _is_fully_valid(self): def _is_fully_valid(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
if self.public_key is None:
raise AssertionError("Contract violation public_key not provided")
if self.signing_key is None:
raise AssertionError("Contract violation signing_key not provided")
return self._is_partially_valid() and all( return self._is_partially_valid() and all(
[ [
self._compare_options() if self.original_data.type == "user" else True, self._compare_options() if self.original_data.type == "user" else True,
@@ -420,7 +438,9 @@ class Certificate(OpensshModule):
] ]
) )
def _is_partially_valid(self): def _is_partially_valid(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
return all( return all(
[ [
set(self.original_data.principals) == set(self.principals), set(self.original_data.principals) == set(self.principals),
@@ -439,7 +459,9 @@ class Certificate(OpensshModule):
] ]
) )
def _compare_time_parameters(self): def _compare_time_parameters(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
try: try:
original_time_parameters = OpensshCertificateTimeParameters( original_time_parameters = OpensshCertificateTimeParameters(
valid_from=self.original_data.valid_after, valid_from=self.original_data.valid_after,
@@ -458,7 +480,9 @@ class Certificate(OpensshModule):
] ]
) )
def _compare_options(self): def _compare_options(self) -> bool:
if self.original_data is None:
raise AssertionError("Contract violation original_data not provided")
try: try:
critical_options, extensions = parse_option_list(self.options) critical_options, extensions = parse_option_list(self.options)
except ValueError as e: except ValueError as e:
@@ -471,13 +495,13 @@ class Certificate(OpensshModule):
] ]
) )
def _get_key_fingerprint(self, path): def _get_key_fingerprint(self, path: str) -> str:
private_key_content = self.ssh_keygen.get_private_key(path, check_rc=True)[1] private_key_content = self.ssh_keygen.get_private_key(path, check_rc=True)[1]
return PrivateKey.from_string(private_key_content).fingerprint return PrivateKey.from_string(private_key_content).fingerprint
@OpensshModule.trigger_change @OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode @OpensshModule.skip_if_check_mode
def _generate(self): def _generate(self) -> None:
try: try:
temp_certificate = self._generate_temp_certificate() temp_certificate = self._generate_temp_certificate()
self._safe_secure_move([(temp_certificate, self.path)]) self._safe_secure_move([(temp_certificate, self.path)])
@@ -491,7 +515,14 @@ class Certificate(OpensshModule):
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
self.module.fail_json(msg=f"Unable to read new certificate: {e}") self.module.fail_json(msg=f"Unable to read new certificate: {e}")
def _generate_temp_certificate(self): def _generate_temp_certificate(self) -> str:
if self.public_key is None:
raise AssertionError("Contract violation public_key not provided")
if self.signing_key is None:
raise AssertionError("Contract violation signing_key not provided")
if self.time_parameters is None:
raise AssertionError("Contract violation time_parameters not provided")
key_copy = os.path.join(self.module.tmpdir, os.path.basename(self.public_key)) key_copy = os.path.join(self.module.tmpdir, os.path.basename(self.public_key))
try: try:
@@ -523,14 +554,14 @@ class Certificate(OpensshModule):
@OpensshModule.trigger_change @OpensshModule.trigger_change
@OpensshModule.skip_if_check_mode @OpensshModule.skip_if_check_mode
def _remove(self): def _remove(self) -> None:
try: try:
os.remove(self.path) os.remove(self.path)
except OSError as e: except OSError as e:
self.module.fail_json(msg=f"Unable to remove existing certificate: {e}") self.module.fail_json(msg=f"Unable to remove existing certificate: {e}")
@property @property
def _result(self): def _result(self) -> dict[str, t.Any]:
if self.state != "present": if self.state != "present":
return {} return {}
@@ -546,14 +577,14 @@ class Certificate(OpensshModule):
} }
@property @property
def diff(self): def diff(self) -> dict[str, t.Any]:
return { return {
"before": get_cert_dict(self.original_data), "before": get_cert_dict(self.original_data),
"after": get_cert_dict(self.data), "after": get_cert_dict(self.data),
} }
def format_cert_info(cert_info): def format_cert_info(cert_info: str) -> list[str]:
result = [] result = []
string = "" string = ""
@@ -579,7 +610,7 @@ def format_cert_info(cert_info):
return result return result
def get_cert_dict(data): def get_cert_dict(data: OpensshCertificate | None) -> dict[str, t.Any]:
if data is None: if data is None:
return {} return {}
@@ -590,7 +621,7 @@ def get_cert_dict(data):
return result return result
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
force=dict(type="bool", default=False), force=dict(type="bool", default=False),

View File

@@ -198,13 +198,15 @@ comment:
sample: test@comment sample: test@comment
""" """
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.keypair_backend import ( from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.keypair_backend import (
select_backend, select_backend,
) )
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(

View File

@@ -239,6 +239,7 @@ csr:
""" """
import os import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError, OpenSSLObjectError,
@@ -256,9 +257,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import (
CertificateSigningRequestBackend,
)
class CertificateSigningRequestModule(OpenSSLObject): class CertificateSigningRequestModule(OpenSSLObject):
def __init__(self, module, module_backend): def __init__(
self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend
) -> None:
super(CertificateSigningRequestModule, self).__init__( super(CertificateSigningRequestModule, self).__init__(
module.params["path"], module.params["path"],
module.params["state"], module.params["state"],
@@ -269,11 +279,11 @@ class CertificateSigningRequestModule(OpenSSLObject):
self.return_content = module.params["return_content"] self.return_content = module.params["return_content"]
self.backup = module.params["backup"] self.backup = module.params["backup"]
self.backup_file = None self.backup_file: str | None = None
self.module_backend.set_existing(load_file_if_exists(self.path, module)) self.module_backend.set_existing(load_file_if_exists(self.path, module))
def generate(self, module): def generate(self, module: AnsibleModule) -> None:
"""Generate the certificate signing request.""" """Generate the certificate signing request."""
if self.force or self.module_backend.needs_regeneration(): if self.force or self.module_backend.needs_regeneration():
if not self.check_mode: if not self.check_mode:
@@ -292,13 +302,13 @@ class CertificateSigningRequestModule(OpenSSLObject):
file_args, self.changed file_args, self.changed
) )
def remove(self, module): def remove(self, module: AnsibleModule) -> None:
self.module_backend.set_existing(None) self.module_backend.set_existing(None)
if self.backup and not self.check_mode: if self.backup and not self.check_mode:
self.backup_file = module.backup_local(self.path) self.backup_file = module.backup_local(self.path)
super(CertificateSigningRequestModule, self).remove(module) super(CertificateSigningRequestModule, self).remove(module)
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = self.module_backend.dump(include_csr=self.return_content) result = self.module_backend.dump(include_csr=self.return_content)
result.update( result.update(
@@ -312,7 +322,7 @@ class CertificateSigningRequestModule(OpenSSLObject):
return result return result
def main(): def main() -> t.NoReturn:
argument_spec = get_csr_argument_spec() argument_spec = get_csr_argument_spec()
argument_spec.argument_spec.update( argument_spec.argument_spec.update(
dict( dict(

View File

@@ -308,6 +308,7 @@ authority_cert_serial_number:
sample: 12345 sample: 12345
""" """
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -318,7 +319,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
) )
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
path=dict(type="path"), path=dict(type="path"),
@@ -335,11 +336,15 @@ def main():
supports_check_mode=True, supports_check_mode=True,
) )
if module.params["content"] is not None: content: str | None = module.params["content"]
data = module.params["content"].encode("utf-8") path: str | None = module.params["path"]
if content is not None:
data = content.encode("utf-8")
else: else:
if path is None:
module.fail_json(msg="One of content and path must be provided")
try: try:
with open(module.params["path"], "rb") as f: with open(path, "rb") as f:
data = f.read() data = f.read()
except (IOError, OSError) as e: except (IOError, OSError) as e:
module.fail_json(msg=f"Error while reading CSR file from disk: {e}") module.fail_json(msg=f"Error while reading CSR file from disk: {e}")

View File

@@ -127,6 +127,8 @@ csr:
type: str type: str
""" """
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError, OpenSSLObjectError,
) )
@@ -136,8 +138,17 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import (
CertificateSigningRequestBackend,
)
class CertificateSigningRequestModule: class CertificateSigningRequestModule:
def __init__(self, module, module_backend): def __init__(
self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend
) -> None:
self.check_mode = module.check_mode self.check_mode = module.check_mode
self.module = module self.module = module
self.module_backend = module_backend self.module_backend = module_backend
@@ -145,13 +156,13 @@ class CertificateSigningRequestModule:
if module.params["content"] is not None: if module.params["content"] is not None:
self.module_backend.set_existing(module.params["content"].encode("utf-8")) self.module_backend.set_existing(module.params["content"].encode("utf-8"))
def generate(self, module): def generate(self, module: AnsibleModule) -> None:
"""Generate the certificate signing request.""" """Generate the certificate signing request."""
if self.module_backend.needs_regeneration(): if self.module_backend.needs_regeneration():
self.module_backend.generate_csr() self.module_backend.generate_csr()
self.changed = True self.changed = True
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = self.module_backend.dump(include_csr=True) result = self.module_backend.dump(include_csr=True)
result.update( result.update(
@@ -162,7 +173,7 @@ class CertificateSigningRequestModule:
return result return result
def main(): def main() -> t.NoReturn:
argument_spec = get_csr_argument_spec() argument_spec = get_csr_argument_spec()
argument_spec.argument_spec.update( argument_spec.argument_spec.update(
dict( dict(

View File

@@ -132,6 +132,7 @@ import abc
import os import os
import re import re
import tempfile import tempfile
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
@@ -173,23 +174,23 @@ class DHParameterError(Exception):
class DHParameterBase: class DHParameterBase:
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
self.state = module.params["state"] self.state: t.Literal["absent", "present"] = module.params["state"]
self.path = module.params["path"] self.path: str = module.params["path"]
self.size = module.params["size"] self.size: int = module.params["size"]
self.force = module.params["force"] self.force: bool = module.params["force"]
self.changed = False self.changed = False
self.return_content = module.params["return_content"] self.return_content: bool = module.params["return_content"]
self.backup = module.params["backup"] self.backup: bool = module.params["backup"]
self.backup_file = None self.backup_file: str | None = None
@abc.abstractmethod @abc.abstractmethod
def _do_generate(self, module): def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params.""" """Actually generate the DH params."""
pass pass
def generate(self, module): def generate(self, module: AnsibleModule) -> None:
"""Generate DH params.""" """Generate DH params."""
changed = False changed = False
@@ -206,7 +207,7 @@ class DHParameterBase:
self.changed = changed self.changed = changed
def remove(self, module): def remove(self, module: AnsibleModule) -> None:
if self.backup: if self.backup:
self.backup_file = module.backup_local(self.path) self.backup_file = module.backup_local(self.path)
try: try:
@@ -215,28 +216,27 @@ class DHParameterBase:
except OSError as exc: except OSError as exc:
module.fail_json(msg=str(exc)) module.fail_json(msg=str(exc))
def check(self, module): def check(self, module: AnsibleModule) -> bool:
"""Ensure the resource is in its desired state.""" """Ensure the resource is in its desired state."""
if self.force: if self.force:
return False return False
return self._check_params_valid(module) and self._check_fs_attributes(module) return self._check_params_valid(module) and self._check_fs_attributes(module)
@abc.abstractmethod @abc.abstractmethod
def _check_params_valid(self, module): def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state""" """Check if the params are in the correct state"""
pass
def _check_fs_attributes(self, module): def _check_fs_attributes(self, module: AnsibleModule) -> bool:
"""Checks (and changes if not in check mode!) fs attributes""" """Checks (and changes if not in check mode!) fs attributes"""
file_args = module.load_file_common_arguments(module.params) file_args = module.load_file_common_arguments(module.params)
if module.check_file_absent_if_check_mode(file_args["path"]): if module.check_file_absent_if_check_mode(file_args["path"]):
return False return False
return not module.set_fs_attributes_if_different(file_args, False) return not module.set_fs_attributes_if_different(file_args, False)
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = { result: dict[str, t.Any] = {
"size": self.size, "size": self.size,
"filename": self.path, "filename": self.path,
"changed": self.changed, "changed": self.changed,
@@ -252,25 +252,24 @@ class DHParameterBase:
class DHParameterAbsent(DHParameterBase): class DHParameterAbsent(DHParameterBase):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(DHParameterAbsent, self).__init__(module) super(DHParameterAbsent, self).__init__(module)
def _do_generate(self, module): def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params.""" """Actually generate the DH params."""
pass
def _check_params_valid(self, module): def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state""" """Check if the params are in the correct state"""
pass return False
class DHParameterOpenSSL(DHParameterBase): class DHParameterOpenSSL(DHParameterBase):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(DHParameterOpenSSL, self).__init__(module) super(DHParameterOpenSSL, self).__init__(module)
self.openssl_bin = module.get_bin_path("openssl", True) self.openssl_bin = module.get_bin_path("openssl", True)
def _do_generate(self, module): def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params.""" """Actually generate the DH params."""
# create a tempfile # create a tempfile
fd, tmpsrc = tempfile.mkstemp() fd, tmpsrc = tempfile.mkstemp()
@@ -288,7 +287,7 @@ class DHParameterOpenSSL(DHParameterBase):
except Exception as e: except Exception as e:
module.fail_json(msg=f"Failed to write to file {self.path}: {str(e)}") module.fail_json(msg=f"Failed to write to file {self.path}: {str(e)}")
def _check_params_valid(self, module): def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state""" """Check if the params are in the correct state"""
command = [ command = [
self.openssl_bin, self.openssl_bin,
@@ -321,10 +320,10 @@ class DHParameterOpenSSL(DHParameterBase):
class DHParameterCryptography(DHParameterBase): class DHParameterCryptography(DHParameterBase):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(DHParameterCryptography, self).__init__(module) super(DHParameterCryptography, self).__init__(module)
def _do_generate(self, module): def _do_generate(self, module: AnsibleModule) -> None:
"""Actually generate the DH params.""" """Actually generate the DH params."""
# Generate parameters # Generate parameters
params = cryptography.hazmat.primitives.asymmetric.dh.generate_parameters( params = cryptography.hazmat.primitives.asymmetric.dh.generate_parameters(
@@ -341,7 +340,7 @@ class DHParameterCryptography(DHParameterBase):
self.backup_file = module.backup_local(self.path) self.backup_file = module.backup_local(self.path)
write_file(module, result) write_file(module, result)
def _check_params_valid(self, module): def _check_params_valid(self, module: AnsibleModule) -> bool:
"""Check if the params are in the correct state""" """Check if the params are in the correct state"""
# Load parameters # Load parameters
try: try:
@@ -357,7 +356,7 @@ class DHParameterCryptography(DHParameterBase):
return bits == self.size return bits == self.size
def main(): def main() -> t.NoReturn:
"""Main function""" """Main function"""
module = AnsibleModule( module = AnsibleModule(
@@ -383,6 +382,7 @@ def main():
msg=f"The directory '{base_dir}' does not exist or the file is not a directory", msg=f"The directory '{base_dir}' does not exist or the file is not a directory",
) )
dhparam: DHParameterOpenSSL | DHParameterCryptography | DHParameterAbsent
if module.params["state"] == "present": if module.params["state"] == "present":
backend = module.params["select_crypto_backend"] backend = module.params["select_crypto_backend"]
if backend == "auto": if backend == "auto":

View File

@@ -280,6 +280,7 @@ import itertools
import os import os
import stat import stat
import traceback import traceback
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible.module_utils.common.text.converters import to_bytes, to_native
@@ -296,7 +297,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
OpenSSLObject, OpenSSLObject,
load_certificate, load_certificate,
load_privatekey, load_certificate_issuer_privatekey,
) )
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import ( from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION, COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -320,6 +321,7 @@ except ImportError:
CRYPTOGRAPHY_COMPATIBILITY2022_ERR = None CRYPTOGRAPHY_COMPATIBILITY2022_ERR = None
try: try:
import cryptography.x509
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization.pkcs12 import PBES from cryptography.hazmat.primitives.serialization.pkcs12 import PBES
@@ -333,8 +335,22 @@ except Exception:
else: else:
CRYPTOGRAPHY_HAS_COMPATIBILITY2022 = True CRYPTOGRAPHY_HAS_COMPATIBILITY2022 = True
if t.TYPE_CHECKING:
from ..module_utils.crypto.cryptography_support import (
CertificateIssuerPrivateKeyTypes,
)
def load_certificate_set(filename): PKCS12 = tuple[
t.Union[CertificateIssuerPrivateKeyTypes, None],
t.Union[cryptography.x509.Certificate, None],
list[cryptography.x509.Certificate],
t.Union[bytes, None],
]
def load_certificate_set(
filename: str | os.PathLike,
) -> list[cryptography.x509.Certificate]:
""" """
Load list of concatenated PEM files, and return a list of parsed certificates. Load list of concatenated PEM files, and return a list of parsed certificates.
""" """
@@ -351,70 +367,80 @@ class PkcsError(OpenSSLObjectError):
class Pkcs(OpenSSLObject): class Pkcs(OpenSSLObject):
def __init__(self, module, iter_size_default=2048): path: str
def __init__(self, module: AnsibleModule, iter_size_default: int = 2048) -> None:
super(Pkcs, self).__init__( super(Pkcs, self).__init__(
module.params["path"], module.params["path"],
module.params["state"], module.params["state"],
module.params["force"], module.params["force"],
module.check_mode, module.check_mode,
) )
self.action = module.params["action"] self.action: t.Literal["export", "parse"] = module.params["action"]
self.other_certificates = module.params["other_certificates"] self.other_certificates: list[cryptography.x509.Certificate] = []
self.other_certificates_parse_all = module.params[ self.other_certificates_str: list[str] | None = module.params[
"other_certificates"
]
self.other_certificates_parse_all: bool = module.params[
"other_certificates_parse_all" "other_certificates_parse_all"
] ]
self.other_certificates_content = module.params["other_certificates_content"] self.other_certificates_content: list[str] | None = module.params[
self.certificate_path = module.params["certificate_path"] "other_certificates_content"
self.certificate_content = module.params["certificate_content"] ]
self.friendly_name = module.params["friendly_name"] self.certificate_path: str | None = module.params["certificate_path"]
self.iter_size = module.params["iter_size"] or iter_size_default certificate_content: str | None = module.params["certificate_content"]
self.maciter_size = module.params["maciter_size"] or 1 self.friendly_name: str | None = module.params["friendly_name"]
self.encryption_level = module.params["encryption_level"] self.iter_size: int = module.params["iter_size"] or iter_size_default
self.maciter_size: int = module.params["maciter_size"] or 1
self.encryption_level: t.Literal["auto", "compatibility2022"] = module.params[
"encryption_level"
]
self.passphrase = module.params["passphrase"] self.passphrase = module.params["passphrase"]
self.pkcs12 = None self.pkcs12: PKCS12 | None = None
self.privatekey_passphrase = module.params["privatekey_passphrase"] self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.privatekey_path = module.params["privatekey_path"] self.privatekey_path: str | None = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"] privatekey_content: str | None = module.params["privatekey_content"]
self.pkcs12_bytes = None self.pkcs12_bytes: bytes | None = None
self.return_content = module.params["return_content"] self.return_content: bool = module.params["return_content"]
self.src = module.params["src"] self.src: str | None = module.params["src"]
if module.params["mode"] is None: if module.params["mode"] is None:
module.params["mode"] = "0400" module.params["mode"] = "0400"
self.backup = module.params["backup"] self.backup: bool = module.params["backup"]
self.backup_file = None self.backup_file: str | None = None
self.certificate_content: bytes | None = None
if self.certificate_path is not None: if self.certificate_path is not None:
try: try:
with open(self.certificate_path, "rb") as fh: with open(self.certificate_path, "rb") as fh:
self.certificate_content = fh.read() self.certificate_content = fh.read()
except (IOError, OSError) as exc: except (IOError, OSError) as exc:
raise PkcsError(exc) raise PkcsError(exc)
elif self.certificate_content is not None: elif certificate_content is not None:
self.certificate_content = to_bytes(self.certificate_content) self.certificate_content = to_bytes(certificate_content)
self.privatekey_content: bytes | None = None
if self.privatekey_path is not None: if self.privatekey_path is not None:
try: try:
with open(self.privatekey_path, "rb") as fh: with open(self.privatekey_path, "rb") as fh:
self.privatekey_content = fh.read() self.privatekey_content = fh.read()
except (IOError, OSError) as exc: except (IOError, OSError) as exc:
raise PkcsError(exc) raise PkcsError(exc)
elif self.privatekey_content is not None: elif privatekey_content is not None:
self.privatekey_content = to_bytes(self.privatekey_content) self.privatekey_content = to_bytes(privatekey_content)
if self.other_certificates: if self.other_certificates_str:
if self.other_certificates_parse_all: if self.other_certificates_parse_all:
filenames = list(self.other_certificates)
self.other_certificates = [] self.other_certificates = []
for other_cert_bundle in filenames: for other_cert_bundle in self.other_certificates_str:
self.other_certificates.extend( self.other_certificates.extend(
load_certificate_set(other_cert_bundle) load_certificate_set(other_cert_bundle)
) )
else: else:
self.other_certificates = [ self.other_certificates = [
load_certificate(other_cert) load_certificate(other_cert)
for other_cert in self.other_certificates for other_cert in self.other_certificates_str
] ]
elif self.other_certificates_content: elif self.other_certificates_content:
certs = self.other_certificates_content certs = self.other_certificates_content
@@ -430,40 +456,42 @@ class Pkcs(OpenSSLObject):
] ]
@abc.abstractmethod @abc.abstractmethod
def generate_bytes(self, module): def generate_bytes(self, module: AnsibleModule) -> bytes:
"""Generate PKCS#12 file archive.""" """Generate PKCS#12 file archive."""
@abc.abstractmethod
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def parse_bytes(self, pkcs12_content): def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _dump_privatekey(self, pkcs12): def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _dump_certificate(self, pkcs12): def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _dump_other_certificates(self, pkcs12): def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
pass pass
@abc.abstractmethod def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
def _get_friendly_name(self, pkcs12):
pass
def check(self, module, perms_required=True):
"""Ensure the resource is in its desired state.""" """Ensure the resource is in its desired state."""
state_and_perms = super(Pkcs, self).check(module, perms_required) state_and_perms = super(Pkcs, self).check(module, perms_required)
def _check_pkey_passphrase(): def _check_pkey_passphrase() -> bool:
if self.privatekey_passphrase: if self.privatekey_passphrase:
try: try:
load_privatekey( load_certificate_issuer_privatekey(
None,
content=self.privatekey_content, content=self.privatekey_content,
passphrase=self.privatekey_passphrase, passphrase=self.privatekey_passphrase,
) )
@@ -476,6 +504,7 @@ class Pkcs(OpenSSLObject):
if os.path.exists(self.path) and module.params["action"] == "export": if os.path.exists(self.path) and module.params["action"] == "export":
self.generate_bytes(module) # ignore result self.generate_bytes(module) # ignore result
assert self.pkcs12 is not None
self.src = self.path self.src = self.path
try: try:
( (
@@ -524,7 +553,7 @@ class Pkcs(OpenSSLObject):
return False return False
elif ( elif (
module.params["action"] == "parse" module.params["action"] == "parse"
and os.path.exists(self.src) and os.path.exists(self.src or "")
and os.path.exists(self.path) and os.path.exists(self.path)
): ):
try: try:
@@ -548,10 +577,10 @@ class Pkcs(OpenSSLObject):
return _check_pkey_passphrase() return _check_pkey_passphrase()
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = { result: dict[str, t.Any] = {
"filename": self.path, "filename": self.path,
} }
if self.privatekey_path: if self.privatekey_path:
@@ -567,13 +596,20 @@ class Pkcs(OpenSSLObject):
return result return result
def remove(self, module): def remove(self, module: AnsibleModule) -> None:
if self.backup: if self.backup:
self.backup_file = module.backup_local(self.path) self.backup_file = module.backup_local(self.path)
super(Pkcs, self).remove(module) super(Pkcs, self).remove(module)
def parse(self): def parse(self) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
"""Read PKCS#12 file.""" """Read PKCS#12 file."""
if self.src is None:
raise AssertionError("Contract violation: src is None")
try: try:
with open(self.src, "rb") as pkcs12_fh: with open(self.src, "rb") as pkcs12_fh:
@@ -582,10 +618,13 @@ class Pkcs(OpenSSLObject):
except IOError as exc: except IOError as exc:
raise PkcsError(exc) raise PkcsError(exc)
def generate(self): def generate(self, module: AnsibleModule) -> None:
# Empty method because OpenSSLObject wants this
pass pass
def write(self, module, content, mode=None): def write(
self, module: AnsibleModule, content: bytes, mode: int | str | None = None
) -> None:
"""Write the PKCS#12 file.""" """Write the PKCS#12 file."""
if self.backup: if self.backup:
self.backup_file = module.backup_local(self.path) self.backup_file = module.backup_local(self.path)
@@ -595,7 +634,7 @@ class Pkcs(OpenSSLObject):
class PkcsCryptography(Pkcs): class PkcsCryptography(Pkcs):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(PkcsCryptography, self).__init__(module, iter_size_default=50000) super(PkcsCryptography, self).__init__(module, iter_size_default=50000)
if ( if (
self.encryption_level == "compatibility2022" self.encryption_level == "compatibility2022"
@@ -607,13 +646,12 @@ class PkcsCryptography(Pkcs):
exception=CRYPTOGRAPHY_COMPATIBILITY2022_ERR, exception=CRYPTOGRAPHY_COMPATIBILITY2022_ERR,
) )
def generate_bytes(self, module): def generate_bytes(self, module: AnsibleModule) -> bytes:
"""Generate PKCS#12 file archive.""" """Generate PKCS#12 file archive."""
pkey = None pkey = None
if self.privatekey_content: if self.privatekey_content:
try: try:
pkey = load_privatekey( pkey = load_certificate_issuer_privatekey(
None,
content=self.privatekey_content, content=self.privatekey_content,
passphrase=self.privatekey_passphrase, passphrase=self.privatekey_passphrase,
) )
@@ -631,6 +669,7 @@ class PkcsCryptography(Pkcs):
# Store fake object which can be used to retrieve the components back # Store fake object which can be used to retrieve the components back
self.pkcs12 = (pkey, cert, self.other_certificates, friendly_name) self.pkcs12 = (pkey, cert, self.other_certificates, friendly_name)
encryption: serialization.KeySerializationEncryption
if not self.passphrase: if not self.passphrase:
encryption = serialization.NoEncryption() encryption = serialization.NoEncryption()
elif self.encryption_level == "compatibility2022": elif self.encryption_level == "compatibility2022":
@@ -654,7 +693,12 @@ class PkcsCryptography(Pkcs):
encryption, encryption,
) )
def parse_bytes(self, pkcs12_content): def parse_bytes(self, pkcs12_content: bytes) -> tuple[
bytes | None,
bytes | None,
list[bytes],
bytes | None,
]:
try: try:
private_key, certificate, additional_certificates, friendly_name = ( private_key, certificate, additional_certificates, friendly_name = (
parse_pkcs12(pkcs12_content, self.passphrase) parse_pkcs12(pkcs12_content, self.passphrase)
@@ -683,11 +727,7 @@ class PkcsCryptography(Pkcs):
except ValueError as exc: except ValueError as exc:
raise PkcsError(exc) raise PkcsError(exc)
# The following methods will get self.pkcs12 passed, which is computed as: def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
#
# self.pkcs12 = (pkey, cert, self.other_certificates, self.friendly_name)
def _dump_privatekey(self, pkcs12):
return ( return (
pkcs12[0].private_bytes( pkcs12[0].private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
@@ -698,27 +738,27 @@ class PkcsCryptography(Pkcs):
else None else None
) )
def _dump_certificate(self, pkcs12): def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
return pkcs12[1].public_bytes(serialization.Encoding.PEM) if pkcs12[1] else None return pkcs12[1].public_bytes(serialization.Encoding.PEM) if pkcs12[1] else None
def _dump_other_certificates(self, pkcs12): def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
return [ return [
other_cert.public_bytes(serialization.Encoding.PEM) other_cert.public_bytes(serialization.Encoding.PEM)
for other_cert in pkcs12[2] for other_cert in pkcs12[2]
] ]
def _get_friendly_name(self, pkcs12): def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
return pkcs12[3] return pkcs12[3]
def select_backend(module): def select_backend(module: AnsibleModule) -> Pkcs:
assert_required_cryptography_version( assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )
return PkcsCryptography(module) return PkcsCryptography(module)
def main(): def main() -> t.NoReturn:
argument_spec = dict( argument_spec = dict(
action=dict(type="str", default="export", choices=["export", "parse"]), action=dict(type="str", default="export", choices=["export", "parse"]),
other_certificates=dict( other_certificates=dict(

View File

@@ -155,6 +155,7 @@ privatekey:
""" """
import os import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError, OpenSSLObjectError,
@@ -172,9 +173,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey import (
PrivateKeyBackend,
)
class PrivateKeyModule(OpenSSLObject): class PrivateKeyModule(OpenSSLObject):
def __init__(self, module, module_backend): def __init__(
self, module: AnsibleModule, module_backend: PrivateKeyBackend
) -> None:
super(PrivateKeyModule, self).__init__( super(PrivateKeyModule, self).__init__(
module.params["path"], module.params["path"],
module.params["state"], module.params["state"],
@@ -182,19 +192,19 @@ class PrivateKeyModule(OpenSSLObject):
module.check_mode, module.check_mode,
) )
self.module_backend = module_backend self.module_backend = module_backend
self.return_content = module.params["return_content"] self.return_content: bool = module.params["return_content"]
if self.force: if self.force:
module_backend.regenerate = "always" module_backend.regenerate = "always"
self.backup = module.params["backup"] self.backup: str | None = module.params["backup"]
self.backup_file = None self.backup_file: str | None = None
if module.params["mode"] is None: if module.params["mode"] is None:
module.params["mode"] = "0600" module.params["mode"] = "0600"
module_backend.set_existing(load_file_if_exists(self.path, module)) module_backend.set_existing(load_file_if_exists(self.path, module))
def generate(self, module): def generate(self, module: AnsibleModule) -> None:
"""Generate a keypair.""" """Generate a keypair."""
if self.module_backend.needs_regeneration(): if self.module_backend.needs_regeneration():
@@ -228,13 +238,13 @@ class PrivateKeyModule(OpenSSLObject):
file_args, self.changed file_args, self.changed
) )
def remove(self, module): def remove(self, module: AnsibleModule) -> None:
self.module_backend.set_existing(None) self.module_backend.set_existing(None)
if self.backup and not self.check_mode: if self.backup and not self.check_mode:
self.backup_file = module.backup_local(self.path) self.backup_file = module.backup_local(self.path)
super(PrivateKeyModule, self).remove(module) super(PrivateKeyModule, self).remove(module)
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = self.module_backend.dump(include_key=self.return_content) result = self.module_backend.dump(include_key=self.return_content)
@@ -246,7 +256,7 @@ class PrivateKeyModule(OpenSSLObject):
return result return result
def main(): def main() -> t.NoReturn:
argument_spec = get_privatekey_argument_spec() argument_spec = get_privatekey_argument_spec()
argument_spec.argument_spec.update( argument_spec.argument_spec.update(

View File

@@ -60,6 +60,7 @@ backup_file:
""" """
import os import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError, OpenSSLObjectError,
@@ -77,8 +78,17 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey_convert import (
PrivateKeyConvertBackend,
)
class PrivateKeyConvertModule(OpenSSLObject): class PrivateKeyConvertModule(OpenSSLObject):
def __init__(self, module, module_backend): def __init__(
self, module: AnsibleModule, module_backend: PrivateKeyConvertBackend
) -> None:
super(PrivateKeyConvertModule, self).__init__( super(PrivateKeyConvertModule, self).__init__(
module.params["dest_path"], module.params["dest_path"],
"present", "present",
@@ -87,8 +97,8 @@ class PrivateKeyConvertModule(OpenSSLObject):
) )
self.module_backend = module_backend self.module_backend = module_backend
self.backup = module.params["backup"] self.backup: bool = module.params["backup"]
self.backup_file = None self.backup_file: str | None = None
module.params["path"] = module.params["dest_path"] module.params["path"] = module.params["dest_path"]
if module.params["mode"] is None: if module.params["mode"] is None:
@@ -96,12 +106,14 @@ class PrivateKeyConvertModule(OpenSSLObject):
module_backend.set_existing_destination(load_file_if_exists(self.path, module)) module_backend.set_existing_destination(load_file_if_exists(self.path, module))
def generate(self, module): def generate(self, module: AnsibleModule) -> None:
"""Do conversion.""" """Do conversion."""
if self.module_backend.needs_conversion(): if self.module_backend.needs_conversion():
# Convert # Convert
privatekey_data = self.module_backend.get_private_key_data() privatekey_data = self.module_backend.get_private_key_data()
if privatekey_data is None:
raise AssertionError("Contract violation: privatekey_data is None")
if not self.check_mode: if not self.check_mode:
if self.backup: if self.backup:
self.backup_file = module.backup_local(self.path) self.backup_file = module.backup_local(self.path)
@@ -116,7 +128,7 @@ class PrivateKeyConvertModule(OpenSSLObject):
file_args, self.changed file_args, self.changed
) )
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = self.module_backend.dump() result = self.module_backend.dump()
@@ -127,7 +139,7 @@ class PrivateKeyConvertModule(OpenSSLObject):
return result return result
def main(): def main() -> t.NoReturn:
argument_spec = get_privatekey_argument_spec() argument_spec = get_privatekey_argument_spec()
argument_spec.argument_spec.update( argument_spec.argument_spec.update(

View File

@@ -200,6 +200,7 @@ private_data:
type: dict type: dict
""" """
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -212,7 +213,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
) )
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
path=dict(type="path"), path=dict(type="path"),
@@ -243,7 +244,7 @@ def main():
data = f.read() data = f.read()
except (IOError, OSError) as e: except (IOError, OSError) as e:
module.fail_json( module.fail_json(
msg=f"Error while reading private key file from disk: {e}", **result msg=f"Error while reading private key file from disk: {e}", **result # type: ignore
) )
result["can_load_key"] = True result["can_load_key"] = True
@@ -261,10 +262,10 @@ def main():
module.exit_json(**result) module.exit_json(**result)
except PrivateKeyParseError as exc: except PrivateKeyParseError as exc:
result.update(exc.result) result.update(exc.result)
module.fail_json(msg=exc.error_message, **result) module.fail_json(msg=exc.error_message, **result) # type: ignore
except PrivateKeyConsistencyError as exc: except PrivateKeyConsistencyError as exc:
result.update(exc.result) result.update(exc.result)
module.fail_json(msg=exc.error_message, **result) module.fail_json(msg=exc.error_message, **result) # type: ignore
except OpenSSLObjectError as exc: except OpenSSLObjectError as exc:
module.fail_json(msg=str(exc)) module.fail_json(msg=str(exc))

View File

@@ -186,6 +186,7 @@ publickey:
""" """
import os import os
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -218,6 +219,12 @@ try:
except ImportError: except ImportError:
pass pass
if t.TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
PublicKeyTypes,
)
class PublicKeyError(OpenSSLObjectError): class PublicKeyError(OpenSSLObjectError):
pass pass
@@ -225,7 +232,7 @@ class PublicKeyError(OpenSSLObjectError):
class PublicKey(OpenSSLObject): class PublicKey(OpenSSLObject):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(PublicKey, self).__init__( super(PublicKey, self).__init__(
module.params["path"], module.params["path"],
module.params["state"], module.params["state"],
@@ -233,27 +240,29 @@ class PublicKey(OpenSSLObject):
module.check_mode, module.check_mode,
) )
self.module = module self.module = module
self.format = module.params["format"] self.format: t.Literal["OpenSSH", "PEM"] = module.params["format"]
self.privatekey_path = module.params["privatekey_path"] self.privatekey_path: str | None = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"] privatekey_content: str | None = module.params["privatekey_content"]
if self.privatekey_content is not None: if privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8") self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"] else:
self.privatekey = None self.privatekey_content = None
self.publickey_bytes = None self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.return_content = module.params["return_content"] self.privatekey: PrivateKeyTypes | None = None
self.fingerprint = {} self.publickey_bytes: bytes | None = None
self.return_content: bool = module.params["return_content"]
self.fingerprint: dict[str, str] = {}
self.backup = module.params["backup"] self.backup: bool = module.params["backup"]
self.backup_file = None self.backup_file: str | None = None
self.diff_before = self._get_info(None) self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None) self.diff_after = self._get_info(None)
def _get_info(self, data): def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None: if data is None:
return dict() return {}
result = dict(can_parse_key=False) result = {"can_parse_key": False}
try: try:
result.update( result.update(
get_publickey_info( get_publickey_info(
@@ -267,7 +276,7 @@ class PublicKey(OpenSSLObject):
pass pass
return result return result
def _create_publickey(self, module): def _create_publickey(self, module: AnsibleModule) -> bytes:
self.privatekey = load_privatekey( self.privatekey = load_privatekey(
path=self.privatekey_path, path=self.privatekey_path,
content=self.privatekey_content, content=self.privatekey_content,
@@ -284,10 +293,12 @@ class PublicKey(OpenSSLObject):
crypto_serialization.PublicFormat.SubjectPublicKeyInfo, crypto_serialization.PublicFormat.SubjectPublicKeyInfo,
) )
def generate(self, module): def generate(self, module: AnsibleModule) -> None:
"""Generate the public key.""" """Generate the public key."""
if self.privatekey_content is None and not os.path.exists(self.privatekey_path): if self.privatekey_path is not None and not os.path.exists(
self.privatekey_path
):
raise PublicKeyError( raise PublicKeyError(
f"The private key {self.privatekey_path} does not exist" f"The private key {self.privatekey_path} does not exist"
) )
@@ -320,17 +331,18 @@ class PublicKey(OpenSSLObject):
elif module.set_fs_attributes_if_different(file_args, False): elif module.set_fs_attributes_if_different(file_args, False):
self.changed = True self.changed = True
def check(self, module, perms_required=True): def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state.""" """Ensure the resource is in its desired state."""
state_and_perms = super(PublicKey, self).check(module, perms_required) state_and_perms = super(PublicKey, self).check(module, perms_required)
def _check_privatekey(): def _check_privatekey() -> bool:
if self.privatekey_content is None and not os.path.exists( if self.privatekey_path is not None and not os.path.exists(
self.privatekey_path self.privatekey_path
): ):
return False return False
current_publickey: PublicKeyTypes
try: try:
with open(self.path, "rb") as public_key_fh: with open(self.path, "rb") as public_key_fh:
publickey_content = public_key_fh.read() publickey_content = public_key_fh.read()
@@ -369,15 +381,15 @@ class PublicKey(OpenSSLObject):
return _check_privatekey() return _check_privatekey()
def remove(self, module): def remove(self, module: AnsibleModule) -> None:
if self.backup: if self.backup:
self.backup_file = module.backup_local(self.path) self.backup_file = module.backup_local(self.path)
super(PublicKey, self).remove(module) super(PublicKey, self).remove(module)
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = { result: dict[str, t.Any] = {
"privatekey": self.privatekey_path, "privatekey": self.privatekey_path,
"filename": self.path, "filename": self.path,
"format": self.format, "format": self.format,
@@ -403,7 +415,7 @@ class PublicKey(OpenSSLObject):
return result return result
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(

View File

@@ -152,6 +152,7 @@ public_data:
returned: When RV(type=DSA) or RV(type=ECC) returned: When RV(type=DSA) or RV(type=ECC)
""" """
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -163,7 +164,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
) )
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
path=dict(type="path"), path=dict(type="path"),
@@ -191,7 +192,7 @@ def main():
data = f.read() data = f.read()
except (IOError, OSError) as e: except (IOError, OSError) as e:
module.fail_json( module.fail_json(
msg=f"Error while reading public key file from disk: {e}", **result msg=f"Error while reading public key file from disk: {e}", **result # type: ignore
) )
module_backend = select_backend(module, data) module_backend = select_backend(module, data)
@@ -201,7 +202,7 @@ def main():
module.exit_json(**result) module.exit_json(**result)
except PublicKeyParseError as exc: except PublicKeyParseError as exc:
result.update(exc.result) result.update(exc.result)
module.fail_json(msg=exc.error_message, **result) module.fail_json(msg=exc.error_message, **result) # type: ignore
except OpenSSLObjectError as exc: except OpenSSLObjectError as exc:
module.fail_json(msg=str(exc)) module.fail_json(msg=str(exc))

View File

@@ -99,6 +99,7 @@ signature:
import base64 import base64
import os import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import ( from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION, COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -132,7 +133,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im
class SignatureBase(OpenSSLObject): class SignatureBase(OpenSSLObject):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(SignatureBase, self).__init__( super(SignatureBase, self).__init__(
path=module.params["path"], path=module.params["path"],
state="present", state="present",
@@ -140,32 +141,35 @@ class SignatureBase(OpenSSLObject):
check_mode=module.check_mode, check_mode=module.check_mode,
) )
self.privatekey_path = module.params["privatekey_path"] self.module = module
self.privatekey_content = module.params["privatekey_content"] self.privatekey_path: str | None = module.params["privatekey_path"]
if self.privatekey_content is not None: privatekey_content: str | None = module.params["privatekey_content"]
self.privatekey_content = self.privatekey_content.encode("utf-8") if privatekey_content is not None:
self.privatekey_passphrase = module.params["privatekey_passphrase"] self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
def generate(self): def generate(self, module: AnsibleModule) -> None:
# Empty method because OpenSSLObject wants this # Empty method because OpenSSLObject wants this
pass pass
def dump(self): def dump(self) -> dict[str, t.Any]:
# Empty method because OpenSSLObject wants this # Empty method because OpenSSLObject wants this
pass return {}
# Implementation with using cryptography # Implementation with using cryptography
class SignatureCryptography(SignatureBase): class SignatureCryptography(SignatureBase):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(SignatureCryptography, self).__init__(module) super(SignatureCryptography, self).__init__(module)
def run(self): def run(self) -> dict[str, t.Any]:
_padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15() _padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
_hash = cryptography.hazmat.primitives.hashes.SHA256() _hash = cryptography.hazmat.primitives.hashes.SHA256()
result = dict() result: dict[str, t.Any] = {}
try: try:
with open(self.path, "rb") as f: with open(self.path, "rb") as f:
@@ -223,7 +227,7 @@ class SignatureCryptography(SignatureBase):
raise OpenSSLObjectError(e) raise OpenSSLObjectError(e)
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
privatekey_path=dict(type="path"), privatekey_path=dict(type="path"),

View File

@@ -88,6 +88,7 @@ valid:
import base64 import base64
import os import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import ( from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION, COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
@@ -121,7 +122,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im
class SignatureInfoBase(OpenSSLObject): class SignatureInfoBase(OpenSSLObject):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(SignatureInfoBase, self).__init__( super(SignatureInfoBase, self).__init__(
path=module.params["path"], path=module.params["path"],
state="present", state="present",
@@ -129,32 +130,35 @@ class SignatureInfoBase(OpenSSLObject):
check_mode=module.check_mode, check_mode=module.check_mode,
) )
self.signature = module.params["signature"] self.module = module
self.certificate_path = module.params["certificate_path"] self.signature: str = module.params["signature"]
self.certificate_content = module.params["certificate_content"] self.certificate_path: str | None = module.params["certificate_path"]
if self.certificate_content is not None: certificate_content: str | None = module.params["certificate_content"]
self.certificate_content = self.certificate_content.encode("utf-8") if certificate_content is not None:
self.certificate_content: bytes | None = certificate_content.encode("utf-8")
else:
self.certificate_content = None
def generate(self): def generate(self, module: AnsibleModule) -> None:
# Empty method because OpenSSLObject wants this # Empty method because OpenSSLObject wants this
pass pass
def dump(self): def dump(self) -> dict[str, t.Any]:
# Empty method because OpenSSLObject wants this # Empty method because OpenSSLObject wants this
pass return {}
# Implementation with using cryptography # Implementation with using cryptography
class SignatureInfoCryptography(SignatureInfoBase): class SignatureInfoCryptography(SignatureInfoBase):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(SignatureInfoCryptography, self).__init__(module) super(SignatureInfoCryptography, self).__init__(module)
def run(self): def run(self) -> dict[str, t.Any]:
_padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15() _padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
_hash = cryptography.hazmat.primitives.hashes.SHA256() _hash = cryptography.hazmat.primitives.hashes.SHA256()
result = dict() result: dict[str, t.Any] = {}
try: try:
with open(self.path, "rb") as f: with open(self.path, "rb") as f:
@@ -228,7 +232,7 @@ class SignatureInfoCryptography(SignatureInfoBase):
raise OpenSSLObjectError(e) raise OpenSSLObjectError(e)
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
certificate_path=dict(type="path"), certificate_path=dict(type="path"),

View File

@@ -224,6 +224,7 @@ certificate:
import os import os
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError, OpenSSLObjectError,
@@ -257,8 +258,15 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
CertificateBackend,
)
class CertificateAbsent(OpenSSLObject): class CertificateAbsent(OpenSSLObject):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(CertificateAbsent, self).__init__( super(CertificateAbsent, self).__init__(
module.params["path"], module.params["path"],
module.params["state"], module.params["state"],
@@ -266,19 +274,19 @@ class CertificateAbsent(OpenSSLObject):
module.check_mode, module.check_mode,
) )
self.module = module self.module = module
self.return_content = module.params["return_content"] self.return_content: bool = module.params["return_content"]
self.backup = module.params["backup"] self.backup: bool = module.params["backup"]
self.backup_file = None self.backup_file: str | None = None
def generate(self, module): def generate(self, module: AnsibleModule) -> None:
pass pass
def remove(self, module): def remove(self, module: AnsibleModule) -> None:
if self.backup: if self.backup:
self.backup_file = module.backup_local(self.path) self.backup_file = module.backup_local(self.path)
super(CertificateAbsent, self).remove(module) super(CertificateAbsent, self).remove(module)
def dump(self, check_mode=False): def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = { result = {
"changed": self.changed, "changed": self.changed,
"filename": self.path, "filename": self.path,
@@ -296,7 +304,7 @@ class CertificateAbsent(OpenSSLObject):
class GenericCertificate(OpenSSLObject): class GenericCertificate(OpenSSLObject):
"""Retrieve a certificate using the given module backend.""" """Retrieve a certificate using the given module backend."""
def __init__(self, module, module_backend): def __init__(self, module: AnsibleModule, module_backend: CertificateBackend):
super(GenericCertificate, self).__init__( super(GenericCertificate, self).__init__(
module.params["path"], module.params["path"],
module.params["state"], module.params["state"],
@@ -311,7 +319,7 @@ class GenericCertificate(OpenSSLObject):
self.module_backend = module_backend self.module_backend = module_backend
self.module_backend.set_existing(load_file_if_exists(self.path, module)) self.module_backend.set_existing(load_file_if_exists(self.path, module))
def generate(self, module): def generate(self, module: AnsibleModule) -> None:
if self.module_backend.needs_regeneration(): if self.module_backend.needs_regeneration():
if not self.check_mode: if not self.check_mode:
self.module_backend.generate_certificate() self.module_backend.generate_certificate()
@@ -329,14 +337,14 @@ class GenericCertificate(OpenSSLObject):
file_args, self.changed file_args, self.changed
) )
def check(self, module, perms_required=True): def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state.""" """Ensure the resource is in its desired state."""
return ( return (
super(GenericCertificate, self).check(module, perms_required) super(GenericCertificate, self).check(module, perms_required)
and not self.module_backend.needs_regeneration() and not self.module_backend.needs_regeneration()
) )
def dump(self, check_mode=False): def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = self.module_backend.dump(include_certificate=self.return_content) result = self.module_backend.dump(include_certificate=self.return_content)
result.update( result.update(
{ {
@@ -349,7 +357,7 @@ class GenericCertificate(OpenSSLObject):
return result return result
def main(): def main() -> t.NoReturn:
argument_spec = get_certificate_argument_spec() argument_spec = get_certificate_argument_spec()
add_acme_provider_to_argument_spec(argument_spec) add_acme_provider_to_argument_spec(argument_spec)
add_entrust_provider_to_argument_spec(argument_spec) add_entrust_provider_to_argument_spec(argument_spec)
@@ -363,13 +371,14 @@ def main():
return_content=dict(type="bool", default=False), return_content=dict(type="bool", default=False),
) )
) )
argument_spec.required_if.append(["state", "present", ["provider"]]) argument_spec.required_if.append(("state", "present", ["provider"]))
module = argument_spec.create_ansible_module( module = argument_spec.create_ansible_module(
add_file_common_args=True, add_file_common_args=True,
supports_check_mode=True, supports_check_mode=True,
) )
try: try:
certificate: GenericCertificate | CertificateAbsent
if module.params["state"] == "absent": if module.params["state"] == "absent":
certificate = CertificateAbsent(module) certificate = CertificateAbsent(module)
@@ -389,7 +398,13 @@ def main():
) )
provider = module.params["provider"] provider = module.params["provider"]
provider_map = { provider_map: dict[
str,
type[AcmeCertificateProvider]
| type[EntrustCertificateProvider]
| type[OwnCACertificateProvider]
| type[SelfSignedCertificateProvider],
] = {
"acme": AcmeCertificateProvider, "acme": AcmeCertificateProvider,
"entrust": EntrustCertificateProvider, "entrust": EntrustCertificateProvider,
"ownca": OwnCACertificateProvider, "ownca": OwnCACertificateProvider,

View File

@@ -106,6 +106,7 @@ backup_file:
import base64 import base64
import os import os
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.common.text.converters import to_bytes, to_text
@@ -142,8 +143,12 @@ except ImportError:
pass pass
def parse_certificate(input, strict=False): def parse_certificate(
input_format = "pem" if identify_pem_format(input) else "der" input: bytes, strict: bool = False
) -> tuple[bytes, t.Literal["pem", "der"], str | None]:
input_format: t.Literal["pem", "der"] = (
"pem" if identify_pem_format(input) else "der"
)
if input_format == "pem": if input_format == "pem":
pems = split_pem_list(to_text(input)) pems = split_pem_list(to_text(input))
if len(pems) > 1 and strict: if len(pems) > 1 and strict:
@@ -162,7 +167,7 @@ def parse_certificate(input, strict=False):
class X509CertificateConvertModule(OpenSSLObject): class X509CertificateConvertModule(OpenSSLObject):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(X509CertificateConvertModule, self).__init__( super(X509CertificateConvertModule, self).__init__(
module.params["dest_path"], module.params["dest_path"],
"present", "present",
@@ -170,9 +175,9 @@ class X509CertificateConvertModule(OpenSSLObject):
module.check_mode, module.check_mode,
) )
self.src_path = module.params["src_path"] self.src_path: str | None = module.params["src_path"]
self.src_content = module.params["src_content"] self.src_content: str | None = module.params["src_content"]
self.src_content_base64 = module.params["src_content_base64"] self.src_content_base64: bool = module.params["src_content_base64"]
if self.src_content is not None: if self.src_content is not None:
self.input = to_bytes(self.src_content) self.input = to_bytes(self.src_content)
if self.src_content_base64: if self.src_content_base64:
@@ -181,6 +186,8 @@ class X509CertificateConvertModule(OpenSSLObject):
except Exception as exc: except Exception as exc:
module.fail_json(msg=f"Cannot Base64 decode src_content: {exc}") module.fail_json(msg=f"Cannot Base64 decode src_content: {exc}")
else: else:
if self.src_path is None:
module.fail_json(msg="One of src_path and src_content must be provided")
try: try:
with open(self.src_path, "rb") as f: with open(self.src_path, "rb") as f:
self.input = f.read() self.input = f.read()
@@ -189,8 +196,8 @@ class X509CertificateConvertModule(OpenSSLObject):
msg=f"Failure while reading file {self.src_path}: {exc}" msg=f"Failure while reading file {self.src_path}: {exc}"
) )
self.format = module.params["format"] self.format: t.Literal["pem", "der"] = module.params["format"]
self.strict = module.params["strict"] self.strict: bool = module.params["strict"]
self.wanted_pem_type = "CERTIFICATE" self.wanted_pem_type = "CERTIFICATE"
try: try:
@@ -203,8 +210,8 @@ class X509CertificateConvertModule(OpenSSLObject):
if module.params["verify_cert_parsable"]: if module.params["verify_cert_parsable"]:
self.verify_cert_parsable(module) self.verify_cert_parsable(module)
self.backup = module.params["backup"] self.backup: bool = module.params["backup"]
self.backup_file = None self.backup_file: str | None = None
module.params["path"] = self.path module.params["path"] = self.path
@@ -221,7 +228,7 @@ class X509CertificateConvertModule(OpenSSLObject):
except Exception: except Exception:
pass pass
def verify_cert_parsable(self, module): def verify_cert_parsable(self, module: AnsibleModule) -> None:
assert_required_cryptography_version( assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
) )
@@ -230,7 +237,7 @@ class X509CertificateConvertModule(OpenSSLObject):
except Exception as exc: except Exception as exc:
module.fail_json(msg=f"Error while parsing certificate: {exc}") module.fail_json(msg=f"Error while parsing certificate: {exc}")
def needs_conversion(self): def needs_conversion(self) -> bool:
if self.dest_content is None or self.dest_content_format is None: if self.dest_content is None or self.dest_content_format is None:
return True return True
if self.dest_content_format != self.format: if self.dest_content_format != self.format:
@@ -241,7 +248,7 @@ class X509CertificateConvertModule(OpenSSLObject):
return True return True
return False return False
def get_dest_certificate(self): def get_dest_certificate(self) -> bytes:
if self.format == "der": if self.format == "der":
return self.input return self.input
data = to_bytes(base64.b64encode(self.input)) data = to_bytes(base64.b64encode(self.input))
@@ -250,7 +257,7 @@ class X509CertificateConvertModule(OpenSSLObject):
lines.append(to_bytes(f"{PEM_END_START}{self.wanted_pem_type}{PEM_END}\n")) lines.append(to_bytes(f"{PEM_END_START}{self.wanted_pem_type}{PEM_END}\n"))
return b"\n".join(lines) return b"\n".join(lines)
def generate(self, module): def generate(self, module: AnsibleModule) -> None:
"""Do conversion.""" """Do conversion."""
if self.needs_conversion(): if self.needs_conversion():
# Convert # Convert
@@ -269,18 +276,18 @@ class X509CertificateConvertModule(OpenSSLObject):
file_args, self.changed file_args, self.changed
) )
def dump(self): def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary.""" """Serialize the object into a dictionary."""
result = dict( result: dict[str, t.Any] = {
changed=self.changed, "changed": self.changed,
) }
if self.backup_file: if self.backup_file:
result["backup_file"] = self.backup_file result["backup_file"] = self.backup_file
return result return result
def main(): def main() -> t.NoReturn:
argument_spec = dict( argument_spec = dict(
src_path=dict(type="path"), src_path=dict(type="path"),
src_content=dict(type="str"), src_content=dict(type="str"),

View File

@@ -390,8 +390,10 @@ issuer_uri:
version_added: 2.9.0 version_added: 2.9.0
""" """
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError, OpenSSLObjectError,
) )
@@ -406,7 +408,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
) )
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
path=dict(type="path"), path=dict(type="path"),
@@ -424,18 +426,22 @@ def main():
supports_check_mode=True, supports_check_mode=True,
) )
if module.params["content"] is not None: content: str | None = module.params["content"]
data = module.params["content"].encode("utf-8") path: str | None = module.params["path"]
if content is not None:
data = content.encode("utf-8")
else: else:
if path is None:
module.fail_json(msg="One of path and content must be provided")
try: try:
with open(module.params["path"], "rb") as f: with open(path, "rb") as f:
data = f.read() data = f.read()
except (IOError, OSError) as e: except (IOError, OSError) as e:
module.fail_json(msg=f"Error while reading certificate file from disk: {e}") module.fail_json(msg=f"Error while reading certificate file from disk: {e}")
module_backend = select_backend(module, data) module_backend = select_backend(module, data)
valid_at = module.params["valid_at"] valid_at: dict[str, t.Any] = module.params["valid_at"]
if valid_at: if valid_at:
for k, v in valid_at.items(): for k, v in valid_at.items():
if not isinstance(v, (str, bytes)): if not isinstance(v, (str, bytes)):
@@ -443,13 +449,11 @@ def main():
msg=f"The value for valid_at.{k} must be of type string (got {type(v)})" msg=f"The value for valid_at.{k} must be of type string (got {type(v)})"
) )
valid_at[k] = get_relative_time_option( valid_at[k] = get_relative_time_option(
v, f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE to_text(v), f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE
) )
try: try:
result = module_backend.get_info( result = module_backend.get_info(der_support_enabled=content is None)
der_support_enabled=module.params["content"] is None
)
not_before = module_backend.get_not_before() not_before = module_backend.get_not_before()
not_after = module_backend.get_not_after() not_after = module_backend.get_not_after()

View File

@@ -118,6 +118,8 @@ certificate:
type: str type: str
""" """
import typing as t
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
OpenSSLObjectError, OpenSSLObjectError,
) )
@@ -139,23 +141,31 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
) )
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
CertificateBackend,
)
class GenericCertificate: class GenericCertificate:
"""Retrieve a certificate using the given module backend.""" """Retrieve a certificate using the given module backend."""
def __init__(self, module, module_backend): def __init__(self, module: AnsibleModule, module_backend: CertificateBackend):
self.check_mode = module.check_mode self.check_mode = module.check_mode
self.module = module self.module = module
self.module_backend = module_backend self.module_backend = module_backend
self.changed = False self.changed = False
if module.params["content"] is not None: content: str | None = module.params["content"]
self.module_backend.set_existing(module.params["content"].encode("utf-8")) if content is not None:
self.module_backend.set_existing(content.encode("utf-8"))
def generate(self, module): def generate(self, module: AnsibleModule) -> None:
if self.module_backend.needs_regeneration(): if self.module_backend.needs_regeneration():
self.module_backend.generate_certificate() self.module_backend.generate_certificate()
self.changed = True self.changed = True
def dump(self, check_mode=False): def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = self.module_backend.dump(include_certificate=True) result = self.module_backend.dump(include_certificate=True)
result.update( result.update(
{ {
@@ -165,7 +175,7 @@ class GenericCertificate:
return result return result
def main(): def main() -> t.NoReturn:
argument_spec = get_certificate_argument_spec() argument_spec = get_certificate_argument_spec()
argument_spec.argument_spec["provider"]["required"] = True argument_spec.argument_spec["provider"]["required"] = True
add_entrust_provider_to_argument_spec(argument_spec) add_entrust_provider_to_argument_spec(argument_spec)
@@ -182,7 +192,12 @@ def main():
try: try:
provider = module.params["provider"] provider = module.params["provider"]
provider_map = { provider_map: dict[
str,
type[EntrustCertificateProvider]
| type[OwnCACertificateProvider]
| type[SelfSignedCertificateProvider],
] = {
"entrust": EntrustCertificateProvider, "entrust": EntrustCertificateProvider,
"ownca": OwnCACertificateProvider, "ownca": OwnCACertificateProvider,
"selfsigned": SelfSignedCertificateProvider, "selfsigned": SelfSignedCertificateProvider,

View File

@@ -425,6 +425,7 @@ crl:
import base64 import base64
import os import os
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
@@ -463,7 +464,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
OpenSSLObject, OpenSSLObject,
load_certificate, load_certificate,
load_privatekey, load_certificate_issuer_privatekey,
parse_name_field, parse_name_field,
parse_ordered_name_field, parse_ordered_name_field,
select_message_digest, select_message_digest,
@@ -495,6 +496,9 @@ try:
except ImportError: except ImportError:
pass pass
if t.TYPE_CHECKING:
import datetime
class CRLError(OpenSSLObjectError): class CRLError(OpenSSLObjectError):
pass pass
@@ -502,7 +506,7 @@ class CRLError(OpenSSLObjectError):
class CRL(OpenSSLObject): class CRL(OpenSSLObject):
def __init__(self, module): def __init__(self, module: AnsibleModule) -> None:
super(CRL, self).__init__( super(CRL, self).__init__(
module.params["path"], module.params["path"],
module.params["state"], module.params["state"],
@@ -510,53 +514,69 @@ class CRL(OpenSSLObject):
module.check_mode, module.check_mode,
) )
self.format = module.params["format"] self.format: t.Literal["pem", "der"] = module.params["format"]
self.update = module.params["crl_mode"] == "update" self.update: bool = module.params["crl_mode"] == "update"
self.ignore_timestamps = module.params["ignore_timestamps"] self.ignore_timestamps: bool = module.params["ignore_timestamps"]
self.return_content = module.params["return_content"] self.return_content: bool = module.params["return_content"]
self.name_encoding = module.params["name_encoding"] self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params[
self.serial_numbers_format = module.params["serial_numbers"] "name_encoding"
self.crl_content = None ]
self.serial_numbers_format: t.Literal["integer", "hex-octets"] = module.params[
"serial_numbers"
]
self.crl_content: bytes | None = None
self.privatekey_path = module.params["privatekey_path"] self.privatekey_path: str | None = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"] privatekey_content: str | None = module.params["privatekey_content"]
if self.privatekey_content is not None: if privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode("utf-8") self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"] else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
try: try:
if module.params["issuer_ordered"]: issuer_ordered: list[dict[str, t.Any]] | None = module.params[
"issuer_ordered"
]
issuer: dict[str, list[str | bytes] | bytes | str] | None = module.params[
"issuer"
]
if issuer_ordered:
self.issuer_ordered = True self.issuer_ordered = True
self.issuer = parse_ordered_name_field( self.issuer = parse_ordered_name_field(issuer_ordered, "issuer_ordered")
module.params["issuer_ordered"], "issuer_ordered"
)
else: else:
self.issuer_ordered = False self.issuer_ordered = False
self.issuer = parse_name_field(module.params["issuer"], "issuer") self.issuer = (
parse_name_field(issuer, "issuer") if issuer is not None else []
)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
module.fail_json(msg=str(exc)) module.fail_json(msg=str(exc))
self.last_update = get_relative_time_option( self.last_update: datetime.datetime = get_relative_time_option(
module.params["last_update"], module.params["last_update"],
"last_update", "last_update",
with_timezone=CRYPTOGRAPHY_TIMEZONE, with_timezone=CRYPTOGRAPHY_TIMEZONE,
) )
self.next_update = get_relative_time_option( self.next_update: datetime.datetime | None = get_relative_time_option(
module.params["next_update"], module.params["next_update"],
"next_update", "next_update",
with_timezone=CRYPTOGRAPHY_TIMEZONE, with_timezone=CRYPTOGRAPHY_TIMEZONE,
) )
self.digest = select_message_digest(module.params["digest"]) digest = select_message_digest(module.params["digest"])
if self.digest is None: if digest is None:
raise CRLError(f'The digest "{module.params["digest"]}" is not supported') raise CRLError(f'The digest "{module.params["digest"]}" is not supported')
self.digest = digest
self.module = module self.module = module
self.revoked_certificates = [] self.revoked_certificates = []
for i, rc in enumerate(module.params["revoked_certificates"]): revoked_certificates: list[dict[str, t.Any]] = module.params[
result = { "revoked_certificates"
]
for i, rc in enumerate(revoked_certificates):
result: dict[str, t.Any] = {
"serial_number": None, "serial_number": None,
"revocation_date": None, "revocation_date": None,
"issuer": None, "issuer": None,
@@ -567,21 +587,25 @@ class CRL(OpenSSLObject):
"invalidity_date_critical": False, "invalidity_date_critical": False,
} }
path_prefix = f"revoked_certificates[{i}]." path_prefix = f"revoked_certificates[{i}]."
if rc["path"] is not None or rc["content"] is not None: path: str | None = rc["path"]
content_str: str | None = rc["content"]
if path is not None or content_str is not None:
# Load certificate from file or content # Load certificate from file or content
try: try:
if rc["content"] is not None: content: bytes | None = None
rc["content"] = rc["content"].encode("utf-8") if content_str is not None:
cert = load_certificate(rc["path"], content=rc["content"]) content = content_str.encode("utf-8")
rc["content"] = content
cert = load_certificate(path, content=content)
result["serial_number"] = cert.serial_number result["serial_number"] = cert.serial_number
except OpenSSLObjectError as e: except OpenSSLObjectError as e:
if rc["content"] is not None: if content_str is not None:
module.fail_json( module.fail_json(
msg=f"Cannot parse certificate from {path_prefix}content: {e}" msg=f"Cannot parse certificate from {path_prefix}content: {e}"
) )
else: else:
module.fail_json( module.fail_json(
msg=f'Cannot read certificate "{rc["path"]}" from {path_prefix}path: {e}' msg=f'Cannot read certificate "{path}" from {path_prefix}path: {e}'
) )
else: else:
# Specify serial_number (and potentially issuer) directly # Specify serial_number (and potentially issuer) directly
@@ -611,11 +635,11 @@ class CRL(OpenSSLObject):
result["invalidity_date_critical"] = rc["invalidity_date_critical"] result["invalidity_date_critical"] = rc["invalidity_date_critical"]
self.revoked_certificates.append(result) self.revoked_certificates.append(result)
self.backup = module.params["backup"] self.backup: bool = module.params["backup"]
self.backup_file = None self.backup_file: str | None = None
try: try:
self.privatekey = load_privatekey( self.privatekey = load_certificate_issuer_privatekey(
path=self.privatekey_path, path=self.privatekey_path,
content=self.privatekey_content, content=self.privatekey_content,
passphrase=self.privatekey_passphrase, passphrase=self.privatekey_passphrase,
@@ -643,7 +667,7 @@ class CRL(OpenSSLObject):
self.diff_after = self.diff_before = self._get_info(data) self.diff_after = self.diff_before = self._get_info(data)
def _parse_serial_number(self, value, index): def _parse_serial_number(self, value: t.Any, index: int) -> int:
if self.serial_numbers_format == "integer": if self.serial_numbers_format == "integer":
try: try:
return check_type_int(value) return check_type_int(value)
@@ -662,22 +686,42 @@ class CRL(OpenSSLObject):
f"Unexpected value {self.serial_numbers_format} of serial_numbers" f"Unexpected value {self.serial_numbers_format} of serial_numbers"
) )
def _get_info(self, data): def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None: if data is None:
return dict() return {}
try: try:
result = get_crl_info(self.module, data) result = get_crl_info(self.module, data)
result["can_parse_crl"] = True result["can_parse_crl"] = True
return result return result
except Exception: except Exception:
return dict(can_parse_crl=False) return {"can_parse_crl": False}
def remove(self): def remove(self, module: AnsibleModule) -> None:
if self.backup: if self.backup:
self.backup_file = self.module.backup_local(self.path) self.backup_file = self.module.backup_local(self.path)
super(CRL, self).remove(self.module) super(CRL, self).remove(self.module)
def _compress_entry(self, entry): def _compress_entry(self, entry: dict[str, t.Any]) -> (
tuple[
int | None,
tuple[str, ...] | None,
bool,
int | None,
bool,
datetime.datetime | None,
bool,
]
| tuple[
int | None,
datetime.datetime | None,
tuple[str, ...] | None,
bool,
int | None,
bool,
datetime.datetime | None,
bool,
]
):
issuer = None issuer = None
if entry["issuer"] is not None: if entry["issuer"] is not None:
# Normalize to IDNA. If this is used-provided, it was already converted to # Normalize to IDNA. If this is used-provided, it was already converted to
@@ -713,7 +757,12 @@ class CRL(OpenSSLObject):
entry["invalidity_date_critical"], entry["invalidity_date_critical"],
) )
def check(self, module, perms_required=True, ignore_conversion=True): def check(
self,
module: AnsibleModule,
perms_required: bool = True,
ignore_conversion: bool = True,
) -> bool:
"""Ensure the resource is in its desired state.""" """Ensure the resource is in its desired state."""
state_and_perms = super(CRL, self).check(self.module, perms_required) state_and_perms = super(CRL, self).check(self.module, perms_required)
@@ -743,10 +792,13 @@ class CRL(OpenSSLObject):
] ]
is_issuer = [(sub.oid, sub.value) for sub in self.crl.issuer] is_issuer = [(sub.oid, sub.value) for sub in self.crl.issuer]
if not self.issuer_ordered: if not self.issuer_ordered:
want_issuer = set(want_issuer) want_issuer_set = set(want_issuer)
is_issuer = set(is_issuer) is_issuer_set = set(is_issuer)
if want_issuer != is_issuer: if want_issuer_set != is_issuer_set:
return False return False
else:
if want_issuer != is_issuer:
return False
old_entries = [ old_entries = [
self._compress_entry(cryptography_decode_revoked_certificate(cert)) self._compress_entry(cryptography_decode_revoked_certificate(cert))
@@ -769,7 +821,7 @@ class CRL(OpenSSLObject):
return True return True
def _generate_crl(self): def _generate_crl(self) -> bytes:
crl = CertificateRevocationListBuilder() crl = CertificateRevocationListBuilder()
try: try:
@@ -787,7 +839,8 @@ class CRL(OpenSSLObject):
raise CRLError(e) raise CRLError(e)
crl = set_last_update(crl, self.last_update) crl = set_last_update(crl, self.last_update)
crl = set_next_update(crl, self.next_update) if self.next_update is not None:
crl = set_next_update(crl, self.next_update)
if self.update and self.crl: if self.update and self.crl:
new_entries = set( new_entries = set(
@@ -799,22 +852,26 @@ class CRL(OpenSSLObject):
) )
if decoded_entry not in new_entries: if decoded_entry not in new_entries:
crl = crl.add_revoked_certificate(entry) crl = crl.add_revoked_certificate(entry)
for entry in self.revoked_certificates: for revoked_entry in self.revoked_certificates:
revoked_cert = RevokedCertificateBuilder() revoked_cert = RevokedCertificateBuilder()
revoked_cert = revoked_cert.serial_number(entry["serial_number"]) revoked_cert = revoked_cert.serial_number(revoked_entry["serial_number"])
revoked_cert = set_revocation_date(revoked_cert, entry["revocation_date"]) revoked_cert = set_revocation_date(
if entry["issuer"] is not None: revoked_cert, revoked_entry["revocation_date"]
)
if revoked_entry["issuer"] is not None:
revoked_cert = revoked_cert.add_extension( revoked_cert = revoked_cert.add_extension(
x509.CertificateIssuer(entry["issuer"]), entry["issuer_critical"] x509.CertificateIssuer(revoked_entry["issuer"]),
revoked_entry["issuer_critical"],
) )
if entry["reason"] is not None: if revoked_entry["reason"] is not None:
revoked_cert = revoked_cert.add_extension( revoked_cert = revoked_cert.add_extension(
x509.CRLReason(entry["reason"]), entry["reason_critical"] x509.CRLReason(revoked_entry["reason"]),
revoked_entry["reason_critical"],
) )
if entry["invalidity_date"] is not None: if revoked_entry["invalidity_date"] is not None:
revoked_cert = revoked_cert.add_extension( revoked_cert = revoked_cert.add_extension(
x509.InvalidityDate(entry["invalidity_date"]), x509.InvalidityDate(revoked_entry["invalidity_date"]),
entry["invalidity_date_critical"], revoked_entry["invalidity_date_critical"],
) )
crl = crl.add_revoked_certificate(revoked_cert.build()) crl = crl.add_revoked_certificate(revoked_cert.build())
@@ -827,7 +884,7 @@ class CRL(OpenSSLObject):
else: else:
return self.crl.public_bytes(Encoding.DER) return self.crl.public_bytes(Encoding.DER)
def generate(self): def generate(self, module: AnsibleModule) -> None:
result = None result = None
if ( if (
not self.check(self.module, perms_required=False, ignore_conversion=True) not self.check(self.module, perms_required=False, ignore_conversion=True)
@@ -861,7 +918,7 @@ class CRL(OpenSSLObject):
elif self.module.set_fs_attributes_if_different(file_args, False): elif self.module.set_fs_attributes_if_different(file_args, False):
self.changed = True self.changed = True
def dump(self, check_mode=False): def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
result = { result = {
"changed": self.changed, "changed": self.changed,
"filename": self.path, "filename": self.path,
@@ -879,37 +936,53 @@ class CRL(OpenSSLObject):
if check_mode: if check_mode:
result["last_update"] = self.last_update.strftime(TIMESTAMP_FORMAT) result["last_update"] = self.last_update.strftime(TIMESTAMP_FORMAT)
result["next_update"] = self.next_update.strftime(TIMESTAMP_FORMAT) result["next_update"] = (
self.next_update.strftime(TIMESTAMP_FORMAT)
if self.next_update is not None
else None
)
# result['digest'] = cryptography_oid_to_name(self.crl.signature_algorithm_oid) # result['digest'] = cryptography_oid_to_name(self.crl.signature_algorithm_oid)
result["digest"] = self.module.params["digest"] result["digest"] = self.module.params["digest"]
result["issuer_ordered"] = self.issuer result["issuer_ordered"] = self.issuer
result["issuer"] = {} issuer: dict[str, str | bytes] = {}
result["issuer"] = issuer
for k, v in self.issuer: for k, v in self.issuer:
result["issuer"][k] = v issuer[k] = v
result["revoked_certificates"] = [] revoked_certificates: list[dict[str, t.Any]] = []
result["revoked_certificates"] = revoked_certificates
for entry in self.revoked_certificates: for entry in self.revoked_certificates:
result["revoked_certificates"].append( revoked_certificates.append(
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding) cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
) )
elif self.crl: elif self.crl:
result["last_update"] = get_last_update(self.crl).strftime(TIMESTAMP_FORMAT) result["last_update"] = get_last_update(self.crl).strftime(TIMESTAMP_FORMAT)
result["next_update"] = get_next_update(self.crl).strftime(TIMESTAMP_FORMAT) next_update = get_next_update(self.crl)
result["next_update"] = (
next_update.strftime(TIMESTAMP_FORMAT)
if next_update is not None
else None
)
result["digest"] = cryptography_oid_to_name( result["digest"] = cryptography_oid_to_name(
cryptography_get_signature_algorithm_oid_from_crl(self.crl) cryptography_get_signature_algorithm_oid_from_crl(self.crl)
) )
issuer = [] issuer_list: list[list[str]] = []
for attribute in self.crl.issuer: for attribute in self.crl.issuer:
issuer.append( issuer_list.append(
[cryptography_oid_to_name(attribute.oid), attribute.value] [
cryptography_oid_to_name(attribute.oid),
to_text(attribute.value),
]
) )
result["issuer_ordered"] = issuer result["issuer_ordered"] = issuer_list
result["issuer"] = {} issuer = {}
for k, v in issuer: result["issuer"] = issuer
result["issuer"][k] = v for k, v in issuer_list:
result["revoked_certificates"] = [] issuer[k] = v
revoked_certificates = []
result["revoked_certificates"] = revoked_certificates
for cert in self.crl: for cert in self.crl:
entry = cryptography_decode_revoked_certificate(cert) entry = cryptography_decode_revoked_certificate(cert)
result["revoked_certificates"].append( revoked_certificates.append(
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding) cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
) )
@@ -923,7 +996,7 @@ class CRL(OpenSSLObject):
return result return result
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
state=dict(type="str", default="present", choices=["present", "absent"]), state=dict(type="str", default="present", choices=["present", "absent"]),
@@ -1015,14 +1088,14 @@ def main():
) )
module.exit_json(**result) module.exit_json(**result)
crl.generate() crl.generate(module)
else: else:
if module.check_mode: if module.check_mode:
result = crl.dump(check_mode=True) result = crl.dump(check_mode=True)
result["changed"] = os.path.exists(module.params["path"]) result["changed"] = os.path.exists(module.params["path"])
module.exit_json(**result) module.exit_json(**result)
crl.remove() crl.remove(module)
result = crl.dump() result = crl.dump()
module.exit_json(**result) module.exit_json(**result)

View File

@@ -100,7 +100,9 @@ last_update:
type: str type: str
sample: '20190413202428Z' sample: '20190413202428Z'
next_update: next_update:
description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME. description:
- The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
- Will be C(none) if no such timestamp is present.
returned: success returned: success
type: str type: str
sample: '20190413202428Z' sample: '20190413202428Z'
@@ -172,6 +174,7 @@ revoked_certificates:
import base64 import base64
import binascii import binascii
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
@@ -185,7 +188,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
) )
def main(): def main() -> t.NoReturn:
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
path=dict(type="path"), path=dict(type="path"),
@@ -200,25 +203,30 @@ def main():
supports_check_mode=True, supports_check_mode=True,
) )
if module.params["content"] is None: content: str | None = module.params["content"]
path: str | None = module.params["path"]
if content is None:
if path is None:
module.fail_json(msg="One of content and path must be provided")
try: try:
with open(module.params["path"], "rb") as f: with open(path, "rb") as f:
data = f.read() data = f.read()
except (IOError, OSError) as e: except (IOError, OSError) as e:
module.fail_json(msg=f"Error while reading CRL file from disk: {e}") module.fail_json(msg=f"Error while reading CRL file from disk: {e}")
else: else:
data = module.params["content"].encode("utf-8") data = content.encode("utf-8")
if not identify_pem_format(data): if not identify_pem_format(data):
try: try:
data = base64.b64decode(module.params["content"]) data = base64.b64decode(content)
except (binascii.Error, TypeError) as e: except (binascii.Error, TypeError) as e:
module.fail_json(msg=f"Error while Base64 decoding content: {e}") module.fail_json(msg=f"Error while Base64 decoding content: {e}")
list_revoked_certificates: bool = module.params["list_revoked_certificates"]
try: try:
result = get_crl_info( result = get_crl_info(
module, module,
data, data,
list_revoked_certificates=module.params["list_revoked_certificates"], list_revoked_certificates=list_revoked_certificates,
) )
module.exit_json(**result) module.exit_json(**result)
except OpenSSLObjectError as e: except OpenSSLObjectError as e:

View File

@@ -15,20 +15,24 @@ from __future__ import annotations
import abc import abc
import copy import copy
import traceback import traceback
import typing as t
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.module_utils.basic import SEQUENCETYPE, remove_values from ansible.module_utils.basic import SEQUENCETYPE, remove_values
from ansible.module_utils.common._collections_compat import Mapping from ansible.module_utils.common._collections_compat import Mapping
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
from ansible.module_utils.common.validation import (
safe_eval,
)
from ansible.module_utils.errors import UnsupportedError from ansible.module_utils.errors import UnsupportedError
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
ArgumentSpec,
)
class _ModuleExitException(Exception): class _ModuleExitException(Exception):
def __init__(self, result): def __init__(self, result: dict[str, t.Any]) -> None:
super(_ModuleExitException, self).__init__() super(_ModuleExitException, self).__init__()
self.result = result self.result = result
@@ -36,20 +40,21 @@ class _ModuleExitException(Exception):
class AnsibleActionModule: class AnsibleActionModule:
def __init__( def __init__(
self, self,
action_plugin, action_plugin: ActionModuleBase,
argument_spec, argument_spec: dict[str, t.Any],
bypass_checks=False, *,
mutually_exclusive=None, bypass_checks: bool = False,
required_together=None, supports_check_mode: bool = False,
required_one_of=None, mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
supports_check_mode=False, required_together: list[list[str] | tuple[str, ...]] | None = None,
required_if=None, required_one_of: list[list[str] | tuple[str, ...]] | None = None,
required_by=None, required_if: list[tuple[str, t.Any, list[str] | tuple[str, ...]]] | None = None,
): required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
) -> None:
# Internal data # Internal data
self.__action_plugin = action_plugin self.__action_plugin = action_plugin
self.__warnings = [] self.__warnings: list[str] = []
self.__deprecations = [] self.__deprecations: list[dict[str, str | None]] = []
# AnsibleModule data # AnsibleModule data
self._name = self.__action_plugin._task.action self._name = self.__action_plugin._task.action
@@ -67,10 +72,6 @@ class AnsibleActionModule:
self._diff = self.__action_plugin._play_context.diff self._diff = self.__action_plugin._play_context.diff
self._verbosity = self.__action_plugin._display.verbosity self._verbosity = self.__action_plugin._display.verbosity
self.aliases = {}
self._legal_inputs = []
self._options_context = list()
self.params = copy.deepcopy(self.__action_plugin._task.args) self.params = copy.deepcopy(self.__action_plugin._task.args)
self.no_log_values = set() self.no_log_values = set()
self._validator = ArgumentSpecValidator( self._validator = ArgumentSpecValidator(
@@ -122,38 +123,41 @@ class AnsibleActionModule:
self.fail_json(msg=msg) self.fail_json(msg=msg)
def safe_eval(self, value, locals=None, include_exceptions=False): def warn(self, warning: str) -> None:
return safe_eval(value, locals, include_exceptions)
def warn(self, warning):
# Copied from ansible.module_utils.common.warnings: # Copied from ansible.module_utils.common.warnings:
if isinstance(warning, (str, bytes)): if isinstance(warning, str):
self.__warnings.append(warning) self.__warnings.append(warning)
else: else:
raise TypeError(f"warn requires a string not a {type(warning)}") raise TypeError(f"warn requires a string not a {type(warning)}")
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
if version is not None and date is not None: if version is not None and date is not None:
raise AssertionError( raise AssertionError(
"implementation error -- version and date must not both be set" "implementation error -- version and date must not both be set"
) )
# Copied from ansible.module_utils.common.warnings: # Copied from ansible.module_utils.common.warnings:
if isinstance(msg, (str, bytes)): if not isinstance(msg, str):
# For compatibility, we accept that neither version nor date is set,
# and treat that the same as if version would haven been set
if date is not None:
self.__deprecations.append(
{"msg": msg, "date": date, "collection_name": collection_name}
)
else:
self.__deprecations.append(
{"msg": msg, "version": version, "collection_name": collection_name}
)
else:
raise TypeError(f"deprecate requires a string not a {type(msg)}") raise TypeError(f"deprecate requires a string not a {type(msg)}")
def _return_formatted(self, kwargs): # For compatibility, we accept that neither version nor date is set,
# and treat that the same as if version would haven been set
if date is not None:
self.__deprecations.append(
{"msg": msg, "date": date, "collection_name": collection_name}
)
else:
self.__deprecations.append(
{"msg": msg, "version": version, "collection_name": collection_name}
)
def _return_formatted(self, kwargs: dict[str, t.Any]) -> t.NoReturn:
if "invocation" not in kwargs: if "invocation" not in kwargs:
kwargs["invocation"] = {"module_args": self.params} kwargs["invocation"] = {"module_args": self.params}
@@ -194,13 +198,13 @@ class AnsibleActionModule:
kwargs = remove_values(kwargs, self.no_log_values) kwargs = remove_values(kwargs, self.no_log_values)
raise _ModuleExitException(kwargs) raise _ModuleExitException(kwargs)
def exit_json(self, **kwargs): def exit_json(self, **kwargs) -> t.NoReturn:
result = dict(kwargs) result = dict(kwargs)
if "failed" not in result: if "failed" not in result:
result["failed"] = False result["failed"] = False
self._return_formatted(result) self._return_formatted(result)
def fail_json(self, msg, **kwargs): def fail_json(self, msg: str, **kwargs) -> t.NoReturn:
result = dict(kwargs) result = dict(kwargs)
result["failed"] = True result["failed"] = True
result["msg"] = msg result["msg"] = msg
@@ -209,16 +213,15 @@ class AnsibleActionModule:
class ActionModuleBase(ActionBase, metaclass=abc.ABCMeta): class ActionModuleBase(ActionBase, metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def setup_module(self): def setup_module(self) -> tuple[ArgumentSpec, dict[str, t.Any]]:
"""Return pair (ArgumentSpec, kwargs).""" """Return pair (ArgumentSpec, kwargs)."""
pass
@abc.abstractmethod @abc.abstractmethod
def run_module(self, module): def run_module(self, module: AnsibleActionModule) -> None:
"""Run module code""" """Run module code"""
module.fail_json(msg="Not implemented.") module.fail_json(msg="Not implemented.")
def run(self, tmp=None, task_vars=None): def run(self, tmp=None, task_vars=None) -> dict[str, t.Any]:
if task_vars is None: if task_vars is None:
task_vars = dict() task_vars = dict()

Some files were not shown because too many files have changed in this diff Show More