mirror of
https://github.com/ansible-collections/community.crypto.git
synced 2026-03-26 21:33:25 +00:00
Add type hints and type checking (#885)
* Enable basic type checking. * Fix first errors. * Add changelog fragment. * Add types to module_utils and plugin_utils (without module backends). * Add typing hints for acme_* modules. * Add typing to X.509 certificate modules, and add more helpers. * Add typing to remaining module backends. * Add typing for action, filter, and lookup plugins. * Bump ansible-core 2.19 beta requirement for typing. * Add more typing definitions. * Add typing to some unit tests.
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -19,25 +20,41 @@ from ansible_collections.community.crypto.plugins.plugin_utils.action_module imp
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
ArgumentSpec,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey import (
|
||||
PrivateKeyBackend,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.plugin_utils.action_module import (
|
||||
AnsibleActionModule,
|
||||
)
|
||||
|
||||
|
||||
class PrivateKeyModule:
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(
|
||||
self, module: AnsibleActionModule, module_backend: PrivateKeyBackend
|
||||
) -> None:
|
||||
self.module = module
|
||||
self.module_backend = module_backend
|
||||
self.check_mode = module.check_mode
|
||||
self.changed = False
|
||||
self.return_current_key = module.params["return_current_key"]
|
||||
self.return_current_key: bool = module.params["return_current_key"]
|
||||
|
||||
if module.params["content"] is not None:
|
||||
if module.params["content_base64"]:
|
||||
content: str | None = module.params["content"]
|
||||
content_base64: bool = module.params["content_base64"]
|
||||
if content is not None:
|
||||
if content_base64:
|
||||
try:
|
||||
data = base64.b64decode(module.params["content"])
|
||||
data = base64.b64decode(content)
|
||||
except Exception as e:
|
||||
module.fail_json(msg=f"Cannot decode Base64 encoded data: {e}")
|
||||
else:
|
||||
data = to_bytes(module.params["content"])
|
||||
data = to_bytes(content)
|
||||
module_backend.set_existing(data)
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleActionModule) -> None:
|
||||
"""Generate a keypair."""
|
||||
|
||||
if self.module_backend.needs_regeneration():
|
||||
@@ -53,7 +70,7 @@ class PrivateKeyModule:
|
||||
self.privatekey_bytes = privatekey_data
|
||||
self.changed = True
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = self.module_backend.dump(
|
||||
include_key=self.changed or self.return_current_key
|
||||
@@ -64,7 +81,7 @@ class PrivateKeyModule:
|
||||
|
||||
class ActionModule(ActionModuleBase):
|
||||
@staticmethod
|
||||
def setup_module():
|
||||
def setup_module() -> tuple[ArgumentSpec, dict[str, t.Any]]:
|
||||
argument_spec = get_privatekey_argument_spec()
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
@@ -78,7 +95,7 @@ class ActionModule(ActionModuleBase):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def run_module(module):
|
||||
def run_module(module: AnsibleActionModule) -> None:
|
||||
module_backend = select_backend(module=module)
|
||||
|
||||
try:
|
||||
|
||||
@@ -39,6 +39,8 @@ _value:
|
||||
type: string
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import (
|
||||
@@ -50,7 +52,7 @@ from ansible_collections.community.crypto.plugins.plugin_utils.gnupg import (
|
||||
)
|
||||
|
||||
|
||||
def gpg_fingerprint(input):
|
||||
def gpg_fingerprint(input: str | bytes) -> str:
|
||||
if not isinstance(input, (str, bytes)):
|
||||
raise AnsibleFilterError(
|
||||
f"The input for the community.crypto.gpg_fingerprint filter must be a string; got {type(input)} instead"
|
||||
@@ -65,7 +67,7 @@ def gpg_fingerprint(input):
|
||||
class FilterModule:
|
||||
"""Ansible jinja2 filters"""
|
||||
|
||||
def filters(self):
|
||||
def filters(self) -> dict[str, t.Callable]:
|
||||
return {
|
||||
"gpg_fingerprint": gpg_fingerprint,
|
||||
}
|
||||
|
||||
@@ -274,6 +274,8 @@ _value:
|
||||
sample: 12345
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -287,7 +289,9 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
|
||||
)
|
||||
|
||||
|
||||
def openssl_csr_info_filter(data, name_encoding="ignore"):
|
||||
def openssl_csr_info_filter(
|
||||
data: str | bytes, name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore"
|
||||
) -> dict[str, t.Any]:
|
||||
"""Extract information from X.509 PEM certificate."""
|
||||
if not isinstance(data, (str, bytes)):
|
||||
raise AnsibleFilterError(
|
||||
@@ -313,7 +317,7 @@ def openssl_csr_info_filter(data, name_encoding="ignore"):
|
||||
class FilterModule:
|
||||
"""Ansible jinja2 filters"""
|
||||
|
||||
def filters(self):
|
||||
def filters(self) -> dict[str, t.Callable]:
|
||||
return {
|
||||
"openssl_csr_info": openssl_csr_info_filter,
|
||||
}
|
||||
|
||||
@@ -146,8 +146,10 @@ _value:
|
||||
type: dict
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
)
|
||||
@@ -161,8 +163,10 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
|
||||
|
||||
|
||||
def openssl_privatekey_info_filter(
|
||||
data, passphrase=None, return_private_key_data=False
|
||||
):
|
||||
data: str | bytes,
|
||||
passphrase: str | bytes | None = None,
|
||||
return_private_key_data: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
"""Extract information from X.509 PEM certificate."""
|
||||
if not isinstance(data, (str, bytes)):
|
||||
raise AnsibleFilterError(
|
||||
@@ -182,7 +186,7 @@ def openssl_privatekey_info_filter(
|
||||
result = get_privatekey_info(
|
||||
module,
|
||||
content=to_bytes(data),
|
||||
passphrase=passphrase,
|
||||
passphrase=to_text(passphrase) if passphrase is not None else None,
|
||||
return_private_key_data=return_private_key_data,
|
||||
)
|
||||
result.pop("can_parse_key", None)
|
||||
@@ -197,7 +201,7 @@ def openssl_privatekey_info_filter(
|
||||
class FilterModule:
|
||||
"""Ansible jinja2 filters"""
|
||||
|
||||
def filters(self):
|
||||
def filters(self) -> dict[str, t.Callable]:
|
||||
return {
|
||||
"openssl_privatekey_info": openssl_privatekey_info_filter,
|
||||
}
|
||||
|
||||
@@ -123,6 +123,8 @@ _value:
|
||||
returned: When RV(_value.type=DSA) or RV(_value.type=ECC)
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -137,7 +139,7 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
|
||||
)
|
||||
|
||||
|
||||
def openssl_publickey_info_filter(data):
|
||||
def openssl_publickey_info_filter(data: str | bytes) -> dict[str, t.Any]:
|
||||
"""Extract information from OpenSSL PEM public key."""
|
||||
if not isinstance(data, (str, bytes)):
|
||||
raise AnsibleFilterError(
|
||||
@@ -156,7 +158,7 @@ def openssl_publickey_info_filter(data):
|
||||
class FilterModule:
|
||||
"""Ansible jinja2 filters"""
|
||||
|
||||
def filters(self):
|
||||
def filters(self) -> dict[str, t.Callable]:
|
||||
return {
|
||||
"openssl_publickey_info": openssl_publickey_info_filter,
|
||||
}
|
||||
|
||||
@@ -39,6 +39,8 @@ _value:
|
||||
type: int
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.module_utils.common.text.converters import to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.serial import (
|
||||
@@ -46,7 +48,7 @@ from ansible_collections.community.crypto.plugins.module_utils.serial import (
|
||||
)
|
||||
|
||||
|
||||
def parse_serial_filter(input):
|
||||
def parse_serial_filter(input: str | bytes) -> int:
|
||||
if not isinstance(input, (str, bytes)):
|
||||
raise AnsibleFilterError(
|
||||
f"The input for the community.crypto.parse_serial filter must be a string; got {type(input)} instead"
|
||||
@@ -60,7 +62,7 @@ def parse_serial_filter(input):
|
||||
class FilterModule:
|
||||
"""Ansible jinja2 filters"""
|
||||
|
||||
def filters(self):
|
||||
def filters(self) -> dict[str, t.Callable]:
|
||||
return {
|
||||
"parse_serial": parse_serial_filter,
|
||||
}
|
||||
|
||||
@@ -38,6 +38,8 @@ _value:
|
||||
elements: string
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.module_utils.common.text.converters import to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import (
|
||||
@@ -45,21 +47,20 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
|
||||
)
|
||||
|
||||
|
||||
def split_pem_filter(data):
|
||||
def split_pem_filter(data: str | bytes) -> list[str]:
|
||||
"""Split PEM file."""
|
||||
if not isinstance(data, (str, bytes)):
|
||||
raise AnsibleFilterError(
|
||||
f"The community.crypto.split_pem input must be a text type, not {type(data)}"
|
||||
)
|
||||
|
||||
data = to_text(data)
|
||||
return split_pem_list(data)
|
||||
return split_pem_list(to_text(data))
|
||||
|
||||
|
||||
class FilterModule:
|
||||
"""Ansible jinja2 filters"""
|
||||
|
||||
def filters(self):
|
||||
def filters(self) -> dict[str, t.Callable]:
|
||||
return {
|
||||
"split_pem": split_pem_filter,
|
||||
}
|
||||
|
||||
@@ -39,11 +39,13 @@ _value:
|
||||
type: string
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible_collections.community.crypto.plugins.module_utils.serial import to_serial
|
||||
|
||||
|
||||
def to_serial_filter(input):
|
||||
def to_serial_filter(input: int) -> str:
|
||||
if not isinstance(input, int):
|
||||
raise AnsibleFilterError(
|
||||
f"The input for the community.crypto.to_serial filter must be an integer; got {type(input)} instead"
|
||||
@@ -61,7 +63,7 @@ def to_serial_filter(input):
|
||||
class FilterModule:
|
||||
"""Ansible jinja2 filters"""
|
||||
|
||||
def filters(self):
|
||||
def filters(self) -> dict[str, t.Callable]:
|
||||
return {
|
||||
"to_serial": to_serial_filter,
|
||||
}
|
||||
|
||||
@@ -308,6 +308,8 @@ _value:
|
||||
type: str
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -321,7 +323,9 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
|
||||
)
|
||||
|
||||
|
||||
def x509_certificate_info_filter(data, name_encoding="ignore"):
|
||||
def x509_certificate_info_filter(
|
||||
data: str | bytes, name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore"
|
||||
) -> dict[str, t.Any]:
|
||||
"""Extract information from X.509 PEM certificate."""
|
||||
if not isinstance(data, (str, bytes)):
|
||||
raise AnsibleFilterError(
|
||||
@@ -347,7 +351,7 @@ def x509_certificate_info_filter(data, name_encoding="ignore"):
|
||||
class FilterModule:
|
||||
"""Ansible jinja2 filters"""
|
||||
|
||||
def filters(self):
|
||||
def filters(self) -> dict[str, t.Callable]:
|
||||
return {
|
||||
"x509_certificate_info": x509_certificate_info_filter,
|
||||
}
|
||||
|
||||
@@ -155,6 +155,7 @@ _value:
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
@@ -172,7 +173,11 @@ from ansible_collections.community.crypto.plugins.plugin_utils.filter_module imp
|
||||
)
|
||||
|
||||
|
||||
def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates=True):
|
||||
def x509_crl_info_filter(
|
||||
data: str | bytes,
|
||||
name_encoding: t.Literal["ignore", "idna", "unicode"] = "ignore",
|
||||
list_revoked_certificates: bool = True,
|
||||
) -> dict[str, t.Any]:
|
||||
"""Extract information from X.509 PEM certificate."""
|
||||
if not isinstance(data, (str, bytes)):
|
||||
raise AnsibleFilterError(
|
||||
@@ -192,17 +197,19 @@ def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates
|
||||
f'The name_encoding option must be one of the values "ignore", "idna", or "unicode", not "{name_encoding}"'
|
||||
)
|
||||
|
||||
data = to_bytes(data)
|
||||
if not identify_pem_format(data):
|
||||
data_bytes = to_bytes(data)
|
||||
if not identify_pem_format(data_bytes):
|
||||
try:
|
||||
data = base64.b64decode(to_native(data))
|
||||
data_bytes = base64.b64decode(to_native(data_bytes))
|
||||
except (binascii.Error, TypeError, ValueError, UnicodeEncodeError):
|
||||
pass
|
||||
|
||||
module = FilterModuleMock({"name_encoding": name_encoding})
|
||||
try:
|
||||
return get_crl_info(
|
||||
module, content=data, list_revoked_certificates=list_revoked_certificates
|
||||
module,
|
||||
content=data_bytes,
|
||||
list_revoked_certificates=list_revoked_certificates,
|
||||
)
|
||||
except OpenSSLObjectError as exc:
|
||||
raise AnsibleFilterError(str(exc))
|
||||
@@ -211,7 +218,7 @@ def x509_crl_info_filter(data, name_encoding="ignore", list_revoked_certificates
|
||||
class FilterModule:
|
||||
"""Ansible jinja2 filters"""
|
||||
|
||||
def filters(self):
|
||||
def filters(self) -> dict[str, t.Callable]:
|
||||
return {
|
||||
"x509_crl_info": x509_crl_info_filter,
|
||||
}
|
||||
|
||||
@@ -42,7 +42,11 @@ _value:
|
||||
elements: string
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleLookupError
|
||||
from ansible.module_utils.common.text.converters import to_native
|
||||
from ansible.plugins.lookup import LookupBase
|
||||
from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import (
|
||||
GPGError,
|
||||
@@ -54,14 +58,20 @@ from ansible_collections.community.crypto.plugins.plugin_utils.gnupg import (
|
||||
|
||||
|
||||
class LookupModule(LookupBase):
|
||||
def run(self, terms, variables=None, **kwargs):
|
||||
def run(self, terms: list[t.Any], variables=None, **kwargs) -> list[str]:
|
||||
self.set_options(direct=kwargs)
|
||||
if self._loader is None:
|
||||
raise AssertionError("Contract violation: self._loader is None")
|
||||
|
||||
try:
|
||||
gpg = PluginGPGRunner(cwd=self._loader.get_basedir())
|
||||
result = []
|
||||
for path in terms:
|
||||
result.append(get_fingerprint_from_file(gpg, path))
|
||||
for i, path in enumerate(terms):
|
||||
if not isinstance(path, (str, bytes, os.PathLike)):
|
||||
raise AnsibleLookupError(
|
||||
f"Lookup parameter #{i} should be string or a path object, but got {type(path)}"
|
||||
)
|
||||
result.append(get_fingerprint_from_file(gpg, to_native(path)))
|
||||
return result
|
||||
except GPGError as exc:
|
||||
raise AnsibleLookupError(str(exc))
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common._collections_compat import Mapping
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
ACMEProtocolException,
|
||||
@@ -12,26 +14,29 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .acme import ACMEClient
|
||||
|
||||
|
||||
class ACMEAccount:
|
||||
"""
|
||||
ACME account object. Allows to create new accounts, check for existence of accounts,
|
||||
retrieve account data.
|
||||
"""
|
||||
|
||||
def __init__(self, client):
|
||||
def __init__(self, client: ACMEClient) -> None:
|
||||
# Set to true to enable logging of all signed requests
|
||||
self._debug = False
|
||||
self._debug: bool = False
|
||||
|
||||
self.client = client
|
||||
|
||||
def _new_reg(
|
||||
self,
|
||||
contact=None,
|
||||
agreement=None,
|
||||
terms_agreed=False,
|
||||
allow_creation=True,
|
||||
external_account_binding=None,
|
||||
):
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: bool = True,
|
||||
external_account_binding: dict[str, t.Any] | None = None,
|
||||
) -> tuple[bool, dict[str, t.Any] | None]:
|
||||
"""
|
||||
Registers a new ACME account. Returns a pair ``(created, data)``.
|
||||
Here, ``created`` is ``True`` if the account was created and
|
||||
@@ -63,7 +68,7 @@ class ACMEAccount:
|
||||
return created, data
|
||||
# An account does not yet exist. Try to create one next.
|
||||
|
||||
new_reg = {"contact": contact}
|
||||
new_reg: dict[str, t.Any] = {"contact": contact}
|
||||
if not allow_creation:
|
||||
# https://tools.ietf.org/html/rfc8555#section-7.3.1
|
||||
new_reg["onlyReturnExisting"] = True
|
||||
@@ -99,7 +104,7 @@ class ACMEAccount:
|
||||
self.client.module,
|
||||
msg="Invalid account creation reply from ACME server",
|
||||
info=info,
|
||||
content=result,
|
||||
content_json=result,
|
||||
)
|
||||
|
||||
if info["status"] == 201:
|
||||
@@ -152,7 +157,7 @@ class ACMEAccount:
|
||||
content_json=result,
|
||||
)
|
||||
|
||||
def get_account_data(self):
|
||||
def get_account_data(self) -> dict[str, t.Any] | None:
|
||||
"""
|
||||
Retrieve account information. Can only be called when the account
|
||||
URI is already known (such as after calling setup_account).
|
||||
@@ -161,7 +166,7 @@ class ACMEAccount:
|
||||
if self.client.account_uri is None:
|
||||
raise ModuleFailException("Account URI unknown")
|
||||
# try POST-as-GET first (draft-15 or newer)
|
||||
data = None
|
||||
data: dict[str, t.Any] | None = None
|
||||
result, info = self.client.send_signed_request(
|
||||
self.client.account_uri, data, fail_on_error=False
|
||||
)
|
||||
@@ -180,7 +185,7 @@ class ACMEAccount:
|
||||
self.client.module,
|
||||
msg="Invalid account data retrieved from ACME server",
|
||||
info=info,
|
||||
content=result,
|
||||
content_json=result,
|
||||
)
|
||||
if (
|
||||
info["status"] in (400, 403)
|
||||
@@ -203,15 +208,34 @@ class ACMEAccount:
|
||||
)
|
||||
return result
|
||||
|
||||
@t.overload
|
||||
def setup_account(
|
||||
self,
|
||||
contact=None,
|
||||
agreement=None,
|
||||
terms_agreed=False,
|
||||
allow_creation=True,
|
||||
remove_account_uri_if_not_exists=False,
|
||||
external_account_binding=None,
|
||||
):
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: t.Literal[True] = True,
|
||||
remove_account_uri_if_not_exists: bool = False,
|
||||
external_account_binding: dict[str, t.Any] | None = None,
|
||||
) -> tuple[bool, dict[str, t.Any]]: ...
|
||||
|
||||
@t.overload
|
||||
def setup_account(
|
||||
self,
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: bool = True,
|
||||
remove_account_uri_if_not_exists: bool = False,
|
||||
external_account_binding: dict[str, t.Any] | None = None,
|
||||
) -> tuple[bool, dict[str, t.Any] | None]: ...
|
||||
|
||||
def setup_account(
|
||||
self,
|
||||
contact: list[str] | None = None,
|
||||
terms_agreed: bool = False,
|
||||
allow_creation: bool = True,
|
||||
remove_account_uri_if_not_exists: bool = False,
|
||||
external_account_binding: dict[str, t.Any] | None = None,
|
||||
) -> tuple[bool, dict[str, t.Any] | None]:
|
||||
"""
|
||||
Detect or create an account on the ACME server. For ACME v1,
|
||||
as the only way (without knowing an account URI) to test if an
|
||||
@@ -253,7 +277,6 @@ class ACMEAccount:
|
||||
else:
|
||||
created, account_data = self._new_reg(
|
||||
contact,
|
||||
agreement=agreement,
|
||||
terms_agreed=terms_agreed,
|
||||
allow_creation=allow_creation and not self.client.module.check_mode,
|
||||
external_account_binding=external_account_binding,
|
||||
@@ -267,7 +290,9 @@ class ACMEAccount:
|
||||
account_data = {"contact": contact or []}
|
||||
return created, account_data
|
||||
|
||||
def update_account(self, account_data, contact=None):
|
||||
def update_account(
|
||||
self, account_data: dict[str, t.Any], contact: list[str] | None = None
|
||||
) -> tuple[bool, dict[str, t.Any]]:
|
||||
"""
|
||||
Update an account on the ACME server. Check mode is fully respected.
|
||||
|
||||
@@ -280,8 +305,11 @@ class ACMEAccount:
|
||||
|
||||
https://tools.ietf.org/html/rfc8555#section-7.3.2
|
||||
"""
|
||||
if self.client.account_uri is None:
|
||||
raise ModuleFailException("Cannot update account without account URI")
|
||||
|
||||
# Create request
|
||||
update_request = {}
|
||||
update_request: dict[str, t.Any] = {}
|
||||
if contact is not None and account_data.get("contact", []) != contact:
|
||||
update_request["contact"] = list(contact)
|
||||
|
||||
@@ -302,7 +330,7 @@ class ACMEAccount:
|
||||
self.client.module,
|
||||
msg="Invalid account updating reply from ACME server",
|
||||
info=info,
|
||||
content=account_data,
|
||||
content_json=account_data,
|
||||
)
|
||||
|
||||
return True, account_data
|
||||
|
||||
@@ -10,6 +10,7 @@ import datetime
|
||||
import json
|
||||
import locale
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import missing_required_lib
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
@@ -41,13 +42,24 @@ from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import os
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .account import ACMEAccount
|
||||
from .backends import CertificateInformation, CryptoBackend
|
||||
|
||||
|
||||
# -1 usually means connection problems
|
||||
RETRY_STATUS_CODES = (-1, 408, 429, 503)
|
||||
|
||||
RETRY_COUNT = 10
|
||||
|
||||
|
||||
def _decode_retry(module, response, info, retry_count):
|
||||
def _decode_retry(
|
||||
module: AnsibleModule, response: t.Any, info: dict[str, t.Any], retry_count: int
|
||||
) -> bool:
|
||||
if info["status"] not in RETRY_STATUS_CODES:
|
||||
return False
|
||||
|
||||
@@ -61,7 +73,8 @@ def _decode_retry(module, response, info, retry_count):
|
||||
|
||||
# 429 and 503 should have a Retry-After header (https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
|
||||
try:
|
||||
retry_after = min(max(1, int(info.get("retry-after"))), 60)
|
||||
# TODO: use utils.parse_retry_after()
|
||||
retry_after = min(max(1, int(info.get("retry-after", "10"))), 60)
|
||||
except (TypeError, ValueError):
|
||||
retry_after = 10
|
||||
module.log(
|
||||
@@ -73,13 +86,13 @@ def _decode_retry(module, response, info, retry_count):
|
||||
|
||||
|
||||
def _assert_fetch_url_success(
|
||||
module,
|
||||
response,
|
||||
info,
|
||||
allow_redirect=False,
|
||||
allow_client_error=True,
|
||||
allow_server_error=True,
|
||||
):
|
||||
module: AnsibleModule,
|
||||
response: t.Any,
|
||||
info: dict[str, t.Any],
|
||||
allow_redirect: bool = False,
|
||||
allow_client_error: bool = True,
|
||||
allow_server_error: bool = True,
|
||||
) -> None:
|
||||
if info["status"] < 0:
|
||||
raise NetworkException(msg=f"Failure downloading {info['url']}, {info['msg']}")
|
||||
|
||||
@@ -91,7 +104,9 @@ def _assert_fetch_url_success(
|
||||
raise ACMEProtocolException(module, info=info, response=response)
|
||||
|
||||
|
||||
def _is_failed(info, expected_status_codes=None):
|
||||
def _is_failed(
|
||||
info: dict[str, t.Any], expected_status_codes: t.Iterable[int] | None = None
|
||||
) -> bool:
|
||||
if info["status"] < 200 or info["status"] >= 400:
|
||||
return True
|
||||
if (
|
||||
@@ -111,12 +126,12 @@ class ACMEDirectory:
|
||||
https://tools.ietf.org/html/rfc8555#section-7.1.1
|
||||
"""
|
||||
|
||||
def __init__(self, module, account):
|
||||
def __init__(self, module: AnsibleModule, client: ACMEClient) -> None:
|
||||
self.module = module
|
||||
self.directory_root = module.params["acme_directory"]
|
||||
self.version = module.params["acme_version"]
|
||||
|
||||
self.directory, dummy = account.get_request(self.directory_root, get_only=True)
|
||||
self.directory, dummy = client.get_request(self.directory_root, get_only=True)
|
||||
|
||||
self.request_timeout = module.params["request_timeout"]
|
||||
|
||||
@@ -131,16 +146,16 @@ class ACMEDirectory:
|
||||
if "meta" not in self.directory:
|
||||
self.directory["meta"] = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
return self.directory[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self.directory
|
||||
|
||||
def get(self, key, default_value=None):
|
||||
def get(self, key: str, default_value: t.Any = None) -> t.Any:
|
||||
return self.directory.get(key, default_value)
|
||||
|
||||
def get_nonce(self, resource=None):
|
||||
def get_nonce(self, resource: str | None = None) -> str:
|
||||
url = self.directory["newNonce"]
|
||||
if resource is not None:
|
||||
url = resource
|
||||
@@ -170,7 +185,7 @@ class ACMEDirectory:
|
||||
)
|
||||
retry_count += 1
|
||||
|
||||
def has_renewal_info_endpoint(self):
|
||||
def has_renewal_info_endpoint(self) -> bool:
|
||||
return "renewalInfo" in self.directory
|
||||
|
||||
|
||||
@@ -180,7 +195,7 @@ class ACMEClient:
|
||||
ACME server.
|
||||
"""
|
||||
|
||||
def __init__(self, module, backend):
|
||||
def __init__(self, module: AnsibleModule, backend: CryptoBackend) -> None:
|
||||
# Set to true to enable logging of all signed requests
|
||||
self._debug = False
|
||||
|
||||
@@ -221,16 +236,22 @@ class ACMEClient:
|
||||
|
||||
self.directory = ACMEDirectory(module, self)
|
||||
|
||||
def set_account_uri(self, uri):
|
||||
def set_account_uri(self, uri: str) -> None:
|
||||
"""
|
||||
Set account URI. For ACME v2, it needs to be used to sending signed
|
||||
requests.
|
||||
"""
|
||||
self.account_uri = uri
|
||||
self.account_jws_header.pop("jwk")
|
||||
self.account_jws_header["kid"] = self.account_uri
|
||||
if self.account_jws_header:
|
||||
self.account_jws_header.pop("jwk", None)
|
||||
self.account_jws_header["kid"] = self.account_uri
|
||||
|
||||
def parse_key(self, key_file=None, key_content=None, passphrase=None):
|
||||
def parse_key(
|
||||
self,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
|
||||
In case of an error, raises KeyParsingError.
|
||||
@@ -239,7 +260,13 @@ class ACMEClient:
|
||||
raise AssertionError("One of key_file and key_content must be specified!")
|
||||
return self.backend.parse_key(key_file, key_content, passphrase=passphrase)
|
||||
|
||||
def sign_request(self, protected, payload, key_data, encode_payload=True):
|
||||
def sign_request(
|
||||
self,
|
||||
protected: dict[str, t.Any],
|
||||
payload: str | dict[str, t.Any] | None,
|
||||
key_data: dict[str, t.Any],
|
||||
encode_payload: bool = True,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Signs an ACME request.
|
||||
"""
|
||||
@@ -260,7 +287,7 @@ class ACMEClient:
|
||||
|
||||
return self.backend.sign(payload64, protected64, key_data)
|
||||
|
||||
def _log(self, msg, data=None):
|
||||
def _log(self, msg: str, data: t.Any = None) -> None:
|
||||
"""
|
||||
Write arguments to acme.log when logging is enabled.
|
||||
"""
|
||||
@@ -275,18 +302,49 @@ class ACMEClient:
|
||||
)
|
||||
)
|
||||
|
||||
@t.overload
|
||||
def send_signed_request(
|
||||
self,
|
||||
url,
|
||||
payload,
|
||||
key_data=None,
|
||||
jws_header=None,
|
||||
parse_json_result=True,
|
||||
encode_payload=True,
|
||||
fail_on_error=True,
|
||||
error_msg=None,
|
||||
expected_status_codes=None,
|
||||
):
|
||||
url: str,
|
||||
payload: str | dict[str, t.Any] | None,
|
||||
*,
|
||||
key_data: dict[str, t.Any] | None = None,
|
||||
jws_header: dict[str, t.Any] | None = None,
|
||||
parse_json_result: t.Literal[True] = True,
|
||||
encode_payload: bool = True,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ...
|
||||
|
||||
@t.overload
|
||||
def send_signed_request(
|
||||
self,
|
||||
url: str,
|
||||
payload: str | dict[str, t.Any] | None,
|
||||
*,
|
||||
key_data: dict[str, t.Any] | None = None,
|
||||
jws_header: dict[str, t.Any] | None = None,
|
||||
parse_json_result: t.Literal[False],
|
||||
encode_payload: bool = True,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[bytes, dict[str, t.Any]]: ...
|
||||
|
||||
def send_signed_request(
|
||||
self,
|
||||
url: str,
|
||||
payload: str | dict[str, t.Any] | None,
|
||||
*,
|
||||
key_data: dict[str, t.Any] | None = None,
|
||||
jws_header: dict[str, t.Any] | None = None,
|
||||
parse_json_result: bool = True,
|
||||
encode_payload: bool = True,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]:
|
||||
"""
|
||||
Sends a JWS signed HTTP POST request to the ACME server and returns
|
||||
the response as dictionary (if parse_json_result is True) or in raw form
|
||||
@@ -297,7 +355,11 @@ class ACMEClient:
|
||||
(https://tools.ietf.org/html/rfc8555#section-6.3)
|
||||
"""
|
||||
key_data = key_data or self.account_key_data
|
||||
if key_data is None:
|
||||
raise ModuleFailException("Missing key data")
|
||||
jws_header = jws_header or self.account_jws_header
|
||||
if jws_header is None:
|
||||
raise ModuleFailException("Missing JWS header")
|
||||
failed_tries = 0
|
||||
while True:
|
||||
protected = copy.deepcopy(jws_header)
|
||||
@@ -382,16 +444,43 @@ class ACMEClient:
|
||||
)
|
||||
return result, info
|
||||
|
||||
@t.overload
|
||||
def get_request(
|
||||
self,
|
||||
uri,
|
||||
parse_json_result=True,
|
||||
headers=None,
|
||||
get_only=False,
|
||||
fail_on_error=True,
|
||||
error_msg=None,
|
||||
expected_status_codes=None,
|
||||
):
|
||||
uri: str,
|
||||
*,
|
||||
parse_json_result: t.Literal[True] = True,
|
||||
headers: dict[str, str] | None = None,
|
||||
get_only: bool = False,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ...
|
||||
|
||||
@t.overload
|
||||
def get_request(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
parse_json_result: t.Literal[False],
|
||||
headers: dict[str, str] | None = None,
|
||||
get_only: bool = False,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[bytes, dict[str, t.Any]]: ...
|
||||
|
||||
def get_request(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
parse_json_result: bool = True,
|
||||
headers: dict[str, str] | None = None,
|
||||
get_only: bool = False,
|
||||
fail_on_error: bool = True,
|
||||
error_msg: str | None = None,
|
||||
expected_status_codes: t.Iterable[int] | None = None,
|
||||
) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]:
|
||||
"""
|
||||
Perform a GET-like request. Will try POST-as-GET for ACMEv2, with fallback
|
||||
to GET if server replies with a status code of 405.
|
||||
@@ -436,6 +525,7 @@ class ACMEClient:
|
||||
|
||||
# Process result
|
||||
parsed_json_result = False
|
||||
result: dict[str, t.Any] | bytes
|
||||
if parse_json_result:
|
||||
result = {}
|
||||
if content:
|
||||
@@ -445,7 +535,7 @@ class ACMEClient:
|
||||
parsed_json_result = True
|
||||
except ValueError:
|
||||
raise NetworkException(
|
||||
f"Failed to parse the ACME response: {uri} {content}"
|
||||
f"Failed to parse the ACME response: {uri} {content!r}"
|
||||
)
|
||||
else:
|
||||
result = content
|
||||
@@ -460,19 +550,21 @@ class ACMEClient:
|
||||
msg=error_msg,
|
||||
info=info,
|
||||
content=content,
|
||||
content_json=result if parsed_json_result else None,
|
||||
content_json=(
|
||||
t.cast(dict[str, t.Any], result) if parsed_json_result else None
|
||||
),
|
||||
)
|
||||
return result, info
|
||||
|
||||
def get_renewal_info(
|
||||
self,
|
||||
cert_id=None,
|
||||
cert_info=None,
|
||||
cert_filename=None,
|
||||
cert_content=None,
|
||||
include_retry_after=False,
|
||||
retry_after_relative_with_timezone=True,
|
||||
):
|
||||
cert_id: str | None = None,
|
||||
cert_info: CertificateInformation | None = None,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
include_retry_after: bool = False,
|
||||
retry_after_relative_with_timezone: bool = True,
|
||||
) -> dict[str, t.Any]:
|
||||
if not self.directory.has_renewal_info_endpoint():
|
||||
raise ModuleFailException(
|
||||
"The ACME endpoint does not support ACME Renewal Information retrieval"
|
||||
@@ -504,10 +596,10 @@ class ACMEClient:
|
||||
|
||||
|
||||
def create_default_argspec(
|
||||
with_account=True,
|
||||
require_account_key=True,
|
||||
with_certificate=False,
|
||||
):
|
||||
with_account: bool = True,
|
||||
require_account_key: bool = True,
|
||||
with_certificate: bool = False,
|
||||
) -> ArgumentSpec:
|
||||
"""
|
||||
Provides default argument spec for the options documented in the acme doc fragment.
|
||||
"""
|
||||
@@ -544,7 +636,7 @@ def create_default_argspec(
|
||||
return result
|
||||
|
||||
|
||||
def create_backend(module, needs_acme_v2=True):
|
||||
def create_backend(module: AnsibleModule, needs_acme_v2: bool = True) -> CryptoBackend:
|
||||
backend = module.params["select_crypto_backend"]
|
||||
|
||||
# Backend autodetect
|
||||
@@ -552,6 +644,7 @@ def create_backend(module, needs_acme_v2=True):
|
||||
backend = "cryptography" if HAS_CURRENT_CRYPTOGRAPHY else "openssl"
|
||||
|
||||
# Create backend object
|
||||
module_backend: CryptoBackend
|
||||
if backend == "cryptography":
|
||||
if CRYPTOGRAPHY_ERROR is not None:
|
||||
# Either we could not import cryptography at all, or there was an unexpected error
|
||||
|
||||
@@ -9,6 +9,7 @@ import base64
|
||||
import binascii
|
||||
import os
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
@@ -75,10 +76,19 @@ else:
|
||||
CRYPTOGRAPHY_MINIMAL_VERSION
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .certificates import CertificateChain, Criterium
|
||||
|
||||
|
||||
class CryptographyChainMatcher(ChainMatcher):
|
||||
@staticmethod
|
||||
def _parse_key_identifier(key_identifier, name, criterium_idx, module):
|
||||
def _parse_key_identifier(
|
||||
key_identifier: str | None, name: str, criterium_idx: int, module: AnsibleModule
|
||||
) -> bytes | None:
|
||||
if key_identifier:
|
||||
try:
|
||||
return binascii.unhexlify(key_identifier.replace(":", ""))
|
||||
@@ -94,11 +104,11 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
)
|
||||
return None
|
||||
|
||||
def __init__(self, criterium, module):
|
||||
def __init__(self, criterium: Criterium, module: AnsibleModule) -> None:
|
||||
self.criterium = criterium
|
||||
self.test_certificates = criterium.test_certificates
|
||||
self.subject = []
|
||||
self.issuer = []
|
||||
self.subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = []
|
||||
self.issuer: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]] = []
|
||||
if criterium.subject:
|
||||
self.subject = [
|
||||
(cryptography_name_to_oid(k), to_native(v))
|
||||
@@ -121,8 +131,13 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
criterium.index,
|
||||
module,
|
||||
)
|
||||
self.module = module
|
||||
|
||||
def _match_subject(self, x509_subject, match_subject):
|
||||
def _match_subject(
|
||||
self,
|
||||
x509_subject: cryptography.x509.Name,
|
||||
match_subject: list[tuple[cryptography.x509.oid.ObjectIdentifier, str]],
|
||||
) -> bool:
|
||||
for oid, value in match_subject:
|
||||
found = False
|
||||
for attribute in x509_subject:
|
||||
@@ -133,7 +148,7 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
return False
|
||||
return True
|
||||
|
||||
def match(self, certificate):
|
||||
def match(self, certificate: CertificateChain) -> bool:
|
||||
"""
|
||||
Check whether an alternate chain matches the specified criterium.
|
||||
"""
|
||||
@@ -152,19 +167,22 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
matches = False
|
||||
if self.subject_key_identifier:
|
||||
try:
|
||||
ext = x509.extensions.get_extension_for_class(
|
||||
ext_ski = x509.extensions.get_extension_for_class(
|
||||
cryptography.x509.SubjectKeyIdentifier
|
||||
)
|
||||
if self.subject_key_identifier != ext.value.digest:
|
||||
if self.subject_key_identifier != ext_ski.value.digest:
|
||||
matches = False
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
matches = False
|
||||
if self.authority_key_identifier:
|
||||
try:
|
||||
ext = x509.extensions.get_extension_for_class(
|
||||
ext_aki = x509.extensions.get_extension_for_class(
|
||||
cryptography.x509.AuthorityKeyIdentifier
|
||||
)
|
||||
if self.authority_key_identifier != ext.value.key_identifier:
|
||||
if (
|
||||
self.authority_key_identifier
|
||||
!= ext_aki.value.key_identifier
|
||||
):
|
||||
matches = False
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
matches = False
|
||||
@@ -176,59 +194,68 @@ class CryptographyChainMatcher(ChainMatcher):
|
||||
|
||||
|
||||
class CryptographyBackend(CryptoBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CryptographyBackend, self).__init__(
|
||||
module, with_timezone=CRYPTOGRAPHY_TIMEZONE
|
||||
)
|
||||
|
||||
def parse_key(self, key_file=None, key_content=None, passphrase=None):
|
||||
def parse_key(
|
||||
self,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
|
||||
Raises KeyParsingError in case of errors.
|
||||
"""
|
||||
# If key_content is not given, read key_file
|
||||
if key_content is None:
|
||||
key_content = read_file(key_file)
|
||||
if key_file is None:
|
||||
raise KeyParsingError(
|
||||
"one of key_file and key_content must be specified"
|
||||
)
|
||||
b_key_content = read_file(key_file)
|
||||
else:
|
||||
key_content = to_bytes(key_content)
|
||||
b_key_content = to_bytes(key_content)
|
||||
# Parse key
|
||||
try:
|
||||
key = cryptography.hazmat.primitives.serialization.load_pem_private_key(
|
||||
key_content,
|
||||
b_key_content,
|
||||
password=to_bytes(passphrase) if passphrase is not None else None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise KeyParsingError(f"error while loading key: {e}")
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
|
||||
pk = key.public_key().public_numbers()
|
||||
rsa_pk = key.public_key().public_numbers()
|
||||
return {
|
||||
"key_obj": key,
|
||||
"type": "rsa",
|
||||
"alg": "RS256",
|
||||
"jwk": {
|
||||
"kty": "RSA",
|
||||
"e": nopad_b64(convert_int_to_bytes(pk.e)),
|
||||
"n": nopad_b64(convert_int_to_bytes(pk.n)),
|
||||
"e": nopad_b64(convert_int_to_bytes(rsa_pk.e)),
|
||||
"n": nopad_b64(convert_int_to_bytes(rsa_pk.n)),
|
||||
},
|
||||
"hash": "sha256",
|
||||
}
|
||||
elif isinstance(
|
||||
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
|
||||
):
|
||||
pk = key.public_key().public_numbers()
|
||||
if pk.curve.name == "secp256r1":
|
||||
ec_pk = key.public_key().public_numbers()
|
||||
if ec_pk.curve.name == "secp256r1":
|
||||
bits = 256
|
||||
alg = "ES256"
|
||||
hashalg = "sha256"
|
||||
point_size = 32
|
||||
curve = "P-256"
|
||||
elif pk.curve.name == "secp384r1":
|
||||
elif ec_pk.curve.name == "secp384r1":
|
||||
bits = 384
|
||||
alg = "ES384"
|
||||
hashalg = "sha384"
|
||||
point_size = 48
|
||||
curve = "P-384"
|
||||
elif pk.curve.name == "secp521r1":
|
||||
elif ec_pk.curve.name == "secp521r1":
|
||||
# Not yet supported on Let's Encrypt side, see
|
||||
# https://github.com/letsencrypt/boulder/issues/2217
|
||||
bits = 521
|
||||
@@ -237,7 +264,7 @@ class CryptographyBackend(CryptoBackend):
|
||||
point_size = 66
|
||||
curve = "P-521"
|
||||
else:
|
||||
raise KeyParsingError(f"unknown elliptic curve: {pk.curve.name}")
|
||||
raise KeyParsingError(f"unknown elliptic curve: {ec_pk.curve.name}")
|
||||
num_bytes = (bits + 7) // 8
|
||||
return {
|
||||
"key_obj": key,
|
||||
@@ -246,8 +273,8 @@ class CryptographyBackend(CryptoBackend):
|
||||
"jwk": {
|
||||
"kty": "EC",
|
||||
"crv": curve,
|
||||
"x": nopad_b64(convert_int_to_bytes(pk.x, count=num_bytes)),
|
||||
"y": nopad_b64(convert_int_to_bytes(pk.y, count=num_bytes)),
|
||||
"x": nopad_b64(convert_int_to_bytes(ec_pk.x, count=num_bytes)),
|
||||
"y": nopad_b64(convert_int_to_bytes(ec_pk.y, count=num_bytes)),
|
||||
},
|
||||
"hash": hashalg,
|
||||
"point_size": point_size,
|
||||
@@ -255,8 +282,11 @@ class CryptographyBackend(CryptoBackend):
|
||||
else:
|
||||
raise KeyParsingError(f'unknown key type "{type(key)}"')
|
||||
|
||||
def sign(self, payload64, protected64, key_data):
|
||||
def sign(
|
||||
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
) -> dict[str, t.Any]:
|
||||
sign_payload = f"{protected64}.{payload64}".encode("utf8")
|
||||
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
|
||||
if "mac_obj" in key_data:
|
||||
mac = key_data["mac_obj"]()
|
||||
mac.update(sign_payload)
|
||||
@@ -292,8 +322,9 @@ class CryptographyBackend(CryptoBackend):
|
||||
"signature": nopad_b64(signature),
|
||||
}
|
||||
|
||||
def create_mac_key(self, alg, key):
|
||||
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
|
||||
"""Create a MAC key."""
|
||||
hashalg: type[cryptography.hazmat.primitives.hashes.HashAlgorithm]
|
||||
if alg == "HS256":
|
||||
hashalg = cryptography.hazmat.primitives.hashes.SHA256
|
||||
hashbytes = 32
|
||||
@@ -324,7 +355,11 @@ class CryptographyBackend(CryptoBackend):
|
||||
},
|
||||
}
|
||||
|
||||
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_ordered_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Return a list of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -334,15 +369,19 @@ class CryptographyBackend(CryptoBackend):
|
||||
as the first element in the result.
|
||||
"""
|
||||
if csr_content is None:
|
||||
csr_content = read_file(csr_filename)
|
||||
if csr_filename is None:
|
||||
raise BackendException(
|
||||
"One of csr_content and csr_filename has to be provided"
|
||||
)
|
||||
b_csr_content = read_file(csr_filename)
|
||||
else:
|
||||
csr_content = to_bytes(csr_content)
|
||||
csr = cryptography.x509.load_pem_x509_csr(csr_content)
|
||||
b_csr_content = to_bytes(csr_content)
|
||||
csr = cryptography.x509.load_pem_x509_csr(b_csr_content)
|
||||
|
||||
identifiers = set()
|
||||
result = []
|
||||
|
||||
def add_identifier(identifier):
|
||||
def add_identifier(identifier: tuple[str, str]) -> None:
|
||||
if identifier in identifiers:
|
||||
return
|
||||
identifiers.add(identifier)
|
||||
@@ -350,7 +389,7 @@ class CryptographyBackend(CryptoBackend):
|
||||
|
||||
for sub in csr.subject:
|
||||
if sub.oid == cryptography.x509.oid.NameOID.COMMON_NAME:
|
||||
add_identifier(("dns", sub.value))
|
||||
add_identifier(("dns", t.cast(str, sub.value)))
|
||||
for extension in csr.extensions:
|
||||
if (
|
||||
extension.oid
|
||||
@@ -367,7 +406,11 @@ class CryptographyBackend(CryptoBackend):
|
||||
)
|
||||
return result
|
||||
|
||||
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | bytes | None = None,
|
||||
) -> set[tuple[str, str]]:
|
||||
"""
|
||||
Return a set of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -379,7 +422,12 @@ class CryptographyBackend(CryptoBackend):
|
||||
)
|
||||
)
|
||||
|
||||
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
|
||||
def get_cert_days(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Return the days the certificate in cert_filename remains valid and -1
|
||||
if the file was not found. If cert_filename contains more than one
|
||||
@@ -398,10 +446,10 @@ class CryptographyBackend(CryptoBackend):
|
||||
return -1
|
||||
|
||||
# Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf.
|
||||
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
|
||||
b_cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
|
||||
|
||||
try:
|
||||
cert = cryptography.x509.load_pem_x509_certificate(cert_content)
|
||||
cert = cryptography.x509.load_pem_x509_certificate(b_cert_content)
|
||||
except Exception as e:
|
||||
if cert_filename is None:
|
||||
raise BackendException(f"Cannot parse certificate: {e}")
|
||||
@@ -413,13 +461,17 @@ class CryptographyBackend(CryptoBackend):
|
||||
now = add_or_remove_timezone(now, with_timezone=CRYPTOGRAPHY_TIMEZONE)
|
||||
return (get_not_valid_after(cert) - now).days
|
||||
|
||||
def create_chain_matcher(self, criterium):
|
||||
def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
|
||||
"""
|
||||
Given a Criterium object, creates a ChainMatcher object.
|
||||
"""
|
||||
return CryptographyChainMatcher(criterium, self.module)
|
||||
|
||||
def get_cert_information(self, cert_filename=None, cert_content=None):
|
||||
def get_cert_information(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
) -> CertificateInformation:
|
||||
"""
|
||||
Return some information on a X.509 certificate as a CertificateInformation object.
|
||||
"""
|
||||
@@ -429,10 +481,10 @@ class CryptographyBackend(CryptoBackend):
|
||||
cert_content = to_bytes(cert_content)
|
||||
|
||||
# Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf.
|
||||
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
|
||||
b_cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
|
||||
|
||||
try:
|
||||
cert = cryptography.x509.load_pem_x509_certificate(cert_content)
|
||||
cert = cryptography.x509.load_pem_x509_certificate(b_cert_content)
|
||||
except Exception as e:
|
||||
if cert_filename is None:
|
||||
raise BackendException(f"Cannot parse certificate: {e}")
|
||||
@@ -440,19 +492,19 @@ class CryptographyBackend(CryptoBackend):
|
||||
|
||||
ski = None
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(
|
||||
ext_ski = cert.extensions.get_extension_for_class(
|
||||
cryptography.x509.SubjectKeyIdentifier
|
||||
)
|
||||
ski = ext.value.digest
|
||||
ski = ext_ski.value.digest
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
pass
|
||||
|
||||
aki = None
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(
|
||||
ext_aki = cert.extensions.get_extension_for_class(
|
||||
cryptography.x509.AuthorityKeyIdentifier
|
||||
)
|
||||
aki = ext.value.key_identifier
|
||||
aki = ext_aki.value.key_identifier
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
pass
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
@@ -34,12 +35,23 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .certificates import Criterium
|
||||
|
||||
|
||||
_OPENSSL_ENVIRONMENT_UPDATE = dict(LANG="C", LC_ALL="C", LC_MESSAGES="C", LC_CTYPE="C")
|
||||
|
||||
|
||||
def _extract_date(out_text, name, cert_filename_suffix=""):
|
||||
def _extract_date(
|
||||
out_text: str, name: str, cert_filename_suffix: str = ""
|
||||
) -> datetime.datetime:
|
||||
matcher = re.search(rf"\s+{name}\s*:\s+(.*)", out_text)
|
||||
if matcher is None:
|
||||
raise BackendException(f"No '{name}' date found{cert_filename_suffix}")
|
||||
date_str = matcher.group(1)
|
||||
try:
|
||||
date_str = re.search(rf"\s+{name}\s*:\s+(.*)", out_text).group(1)
|
||||
# For some reason Python's strptime() does not return any timezone information,
|
||||
# even though the information is there and a supported timezone for all supported
|
||||
# Python implementations (GMT). So we have to modify the datetime object by
|
||||
@@ -47,19 +59,40 @@ def _extract_date(out_text, name, cert_filename_suffix=""):
|
||||
return ensure_utc_timezone(
|
||||
datetime.datetime.strptime(date_str, "%b %d %H:%M:%S %Y %Z")
|
||||
)
|
||||
except AttributeError:
|
||||
raise BackendException(f"No '{name}' date found{cert_filename_suffix}")
|
||||
except ValueError as exc:
|
||||
raise BackendException(
|
||||
f"Failed to parse '{name}' date{cert_filename_suffix}: {exc}"
|
||||
)
|
||||
|
||||
|
||||
def _decode_octets(octets_text):
|
||||
def _decode_octets(octets_text: str) -> bytes:
|
||||
return binascii.unhexlify(re.sub(r"(\s|:)", "", octets_text).encode("utf-8"))
|
||||
|
||||
|
||||
def _extract_octets(out_text, name, required=True, potential_prefixes=None):
|
||||
@t.overload
|
||||
def _extract_octets(
|
||||
out_text: str,
|
||||
name: str,
|
||||
required: t.Literal[False],
|
||||
potential_prefixes: t.Iterable[str] | None = None,
|
||||
) -> bytes | None: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def _extract_octets(
|
||||
out_text: str,
|
||||
name: str,
|
||||
required: t.Literal[True],
|
||||
potential_prefixes: t.Iterable[str] | None = None,
|
||||
) -> bytes: ...
|
||||
|
||||
|
||||
def _extract_octets(
|
||||
out_text: str,
|
||||
name: str,
|
||||
required: bool = True,
|
||||
potential_prefixes: t.Iterable[str] | None = None,
|
||||
) -> bytes | None:
|
||||
part = (
|
||||
f"(?:{'|'.join(re.escape(pp) for pp in potential_prefixes)})"
|
||||
if potential_prefixes
|
||||
@@ -75,13 +108,20 @@ def _extract_octets(out_text, name, required=True, potential_prefixes=None):
|
||||
|
||||
|
||||
class OpenSSLCLIBackend(CryptoBackend):
|
||||
def __init__(self, module, openssl_binary=None):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, openssl_binary: str | None = None
|
||||
) -> None:
|
||||
super(OpenSSLCLIBackend, self).__init__(module, with_timezone=True)
|
||||
if openssl_binary is None:
|
||||
openssl_binary = module.get_bin_path("openssl", True)
|
||||
self.openssl_binary = openssl_binary
|
||||
|
||||
def parse_key(self, key_file=None, key_content=None, passphrase=None):
|
||||
def parse_key(
|
||||
self,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
|
||||
Raises KeyParsingError in case of errors.
|
||||
@@ -90,6 +130,10 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
raise KeyParsingError("openssl backend does not support key passphrases")
|
||||
# If key_file is not given, but key_content, write that to a temporary file
|
||||
if key_file is None:
|
||||
if key_content is None:
|
||||
raise KeyParsingError(
|
||||
"one of key_file and key_content must be specified"
|
||||
)
|
||||
fd, tmpsrc = tempfile.mkstemp()
|
||||
self.module.add_cleanup_file(tmpsrc) # Ansible will delete the file on exit
|
||||
f = os.fdopen(fd, "wb")
|
||||
@@ -108,8 +152,8 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
f.close()
|
||||
# Parse key
|
||||
account_key_type = None
|
||||
with open(key_file, "rt") as f:
|
||||
for line in f:
|
||||
with open(key_file, "rt") as fi:
|
||||
for line in fi:
|
||||
m = re.match(
|
||||
r"^\s*-{5,}BEGIN\s+(EC|RSA)\s+PRIVATE\s+KEY-{5,}\s*$", line
|
||||
)
|
||||
@@ -129,38 +173,44 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
self.openssl_binary,
|
||||
account_key_type,
|
||||
"-in",
|
||||
key_file,
|
||||
str(key_file),
|
||||
"-noout",
|
||||
"-text",
|
||||
]
|
||||
rc, out, err = self.module.run_command(
|
||||
rc, out, stderr = self.module.run_command(
|
||||
openssl_keydump_cmd,
|
||||
check_rc=False,
|
||||
environ_update=_OPENSSL_ENVIRONMENT_UPDATE,
|
||||
)
|
||||
if rc != 0:
|
||||
raise BackendException(
|
||||
f"Error while running {' '.join(openssl_keydump_cmd)}: {err}"
|
||||
f"Error while running {' '.join(openssl_keydump_cmd)}: {stderr}"
|
||||
)
|
||||
|
||||
out_text = to_text(out, errors="surrogate_or_strict")
|
||||
|
||||
if account_key_type == "rsa":
|
||||
pub_hex = re.search(
|
||||
matcher = re.search(
|
||||
r"modulus:\n\s+00:([a-f0-9\:\s]+?)\npublicExponent",
|
||||
out_text,
|
||||
re.MULTILINE | re.DOTALL,
|
||||
).group(1)
|
||||
)
|
||||
if matcher is None:
|
||||
raise KeyParsingError("cannot parse RSA key: modulus not found")
|
||||
pub_hex = matcher.group(1)
|
||||
|
||||
pub_exp = re.search(
|
||||
matcher = re.search(
|
||||
r"\npublicExponent: ([0-9]+)", out_text, re.MULTILINE | re.DOTALL
|
||||
).group(1)
|
||||
)
|
||||
if matcher is None:
|
||||
raise KeyParsingError("cannot parse RSA key: public exponent not found")
|
||||
pub_exp = matcher.group(1)
|
||||
pub_exp = f"{int(pub_exp):x}"
|
||||
if len(pub_exp) % 2:
|
||||
pub_exp = f"0{pub_exp}"
|
||||
|
||||
return {
|
||||
"key_file": key_file,
|
||||
"key_file": str(key_file),
|
||||
"type": "rsa",
|
||||
"alg": "RS256",
|
||||
"jwk": {
|
||||
@@ -223,8 +273,13 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
"hash": hashalg,
|
||||
"point_size": point_size,
|
||||
}
|
||||
raise KeyParsingError(
|
||||
f"Internal error: unexpected account_key_type = {account_key_type!r}"
|
||||
)
|
||||
|
||||
def sign(self, payload64, protected64, key_data):
|
||||
def sign(
|
||||
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
) -> dict[str, t.Any]:
|
||||
sign_payload = f"{protected64}.{payload64}".encode("utf8")
|
||||
if key_data["type"] == "hmac":
|
||||
hex_key = (
|
||||
@@ -284,7 +339,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
"signature": nopad_b64(to_bytes(out)),
|
||||
}
|
||||
|
||||
def create_mac_key(self, alg, key):
|
||||
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
|
||||
"""Create a MAC key."""
|
||||
if alg == "HS256":
|
||||
hashalg = "sha256"
|
||||
@@ -315,14 +370,18 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_ip(ip):
|
||||
def _normalize_ip(ip: str) -> str:
|
||||
try:
|
||||
return ipaddress.ip_address(to_text(ip)).compressed
|
||||
return ipaddress.ip_address(ip).compressed
|
||||
except ValueError:
|
||||
# We do not want to error out on something IPAddress() cannot parse
|
||||
return ip
|
||||
|
||||
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_ordered_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Return a list of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -335,13 +394,13 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
data = None
|
||||
if csr_content is not None:
|
||||
filename = "/dev/stdin"
|
||||
data = csr_content.encode("utf-8")
|
||||
data = to_bytes(csr_content)
|
||||
|
||||
openssl_csr_cmd = [
|
||||
self.openssl_binary,
|
||||
"req",
|
||||
"-in",
|
||||
filename,
|
||||
str(filename),
|
||||
"-noout",
|
||||
"-text",
|
||||
]
|
||||
@@ -360,7 +419,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
identifiers = set()
|
||||
result = []
|
||||
|
||||
def add_identifier(identifier):
|
||||
def add_identifier(identifier: tuple[str, str]) -> None:
|
||||
if identifier in identifiers:
|
||||
return
|
||||
identifiers.add(identifier)
|
||||
@@ -389,7 +448,11 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
raise BackendException(f'Found unsupported SAN identifier "{san}"')
|
||||
return result
|
||||
|
||||
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> set[tuple[str, str]]:
|
||||
"""
|
||||
Return a set of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -401,7 +464,12 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
)
|
||||
)
|
||||
|
||||
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
|
||||
def get_cert_days(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Return the days the certificate in cert_filename remains valid and -1
|
||||
if the file was not found. If cert_filename contains more than one
|
||||
@@ -413,7 +481,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
data = None
|
||||
if cert_content is not None:
|
||||
filename = "/dev/stdin"
|
||||
data = cert_content.encode("utf-8")
|
||||
data = to_bytes(cert_content)
|
||||
cert_filename_suffix = ""
|
||||
elif cert_filename is not None:
|
||||
if not os.path.exists(cert_filename):
|
||||
@@ -426,7 +494,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
self.openssl_binary,
|
||||
"x509",
|
||||
"-in",
|
||||
filename,
|
||||
str(filename),
|
||||
"-noout",
|
||||
"-text",
|
||||
]
|
||||
@@ -452,7 +520,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
now = ensure_utc_timezone(now)
|
||||
return (not_after - now).days
|
||||
|
||||
def create_chain_matcher(self, criterium):
|
||||
def create_chain_matcher(self, criterium: Criterium) -> t.NoReturn:
|
||||
"""
|
||||
Given a Criterium object, creates a ChainMatcher object.
|
||||
"""
|
||||
@@ -460,7 +528,11 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
'Alternate chain matching can only be used with the "cryptography" backend.'
|
||||
)
|
||||
|
||||
def get_cert_information(self, cert_filename=None, cert_content=None):
|
||||
def get_cert_information(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
) -> CertificateInformation:
|
||||
"""
|
||||
Return some information on a X.509 certificate as a CertificateInformation object.
|
||||
"""
|
||||
@@ -477,7 +549,7 @@ class OpenSSLCLIBackend(CryptoBackend):
|
||||
self.openssl_binary,
|
||||
"x509",
|
||||
"-in",
|
||||
filename,
|
||||
str(filename),
|
||||
"-noout",
|
||||
"-text",
|
||||
]
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import datetime
|
||||
import re
|
||||
from collections import namedtuple
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
BackendException,
|
||||
@@ -27,16 +27,20 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
CertificateInformation = namedtuple(
|
||||
"CertificateInformation",
|
||||
(
|
||||
"not_valid_after",
|
||||
"not_valid_before",
|
||||
"serial_number",
|
||||
"subject_key_identifier",
|
||||
"authority_key_identifier",
|
||||
),
|
||||
)
|
||||
if t.TYPE_CHECKING:
|
||||
import os
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .certificates import ChainMatcher, Criterium
|
||||
|
||||
|
||||
class CertificateInformation(t.NamedTuple):
|
||||
not_valid_after: datetime.datetime
|
||||
not_valid_before: datetime.datetime
|
||||
serial_number: int
|
||||
subject_key_identifier: bytes | None
|
||||
authority_key_identifier: bytes | None
|
||||
|
||||
|
||||
_FRACTIONAL_MATCHER = re.compile(
|
||||
@@ -44,7 +48,7 @@ _FRACTIONAL_MATCHER = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def _reduce_fractional_digits(timestamp_str):
|
||||
def _reduce_fractional_digits(timestamp_str: str) -> str:
|
||||
"""
|
||||
Given a RFC 3339 timestamp that includes too many digits for the fractional seconds part, reduces these to at most 6.
|
||||
"""
|
||||
@@ -60,7 +64,7 @@ def _reduce_fractional_digits(timestamp_str):
|
||||
return f"{timestamp}{fractional}{timezone}"
|
||||
|
||||
|
||||
def _parse_acme_timestamp(timestamp_str, with_timezone):
|
||||
def _parse_acme_timestamp(timestamp_str: str, with_timezone: bool) -> datetime.datetime:
|
||||
"""
|
||||
Parses a RFC 3339 timestamp.
|
||||
"""
|
||||
@@ -86,34 +90,42 @@ def _parse_acme_timestamp(timestamp_str, with_timezone):
|
||||
|
||||
|
||||
class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module, with_timezone=False):
|
||||
def __init__(self, module: AnsibleModule, with_timezone: bool = False) -> None:
|
||||
self.module = module
|
||||
self._with_timezone = with_timezone
|
||||
|
||||
def get_now(self):
|
||||
def get_now(self) -> datetime.datetime:
|
||||
return get_now_datetime(with_timezone=self._with_timezone)
|
||||
|
||||
def parse_acme_timestamp(self, timestamp_str):
|
||||
def parse_acme_timestamp(self, timestamp_str: str) -> datetime.datetime:
|
||||
# RFC 3339 (https://www.rfc-editor.org/info/rfc3339)
|
||||
return _parse_acme_timestamp(timestamp_str, with_timezone=self._with_timezone)
|
||||
|
||||
def parse_module_parameter(self, value, name):
|
||||
def parse_module_parameter(self, value: str, name: str) -> datetime.datetime:
|
||||
try:
|
||||
return get_relative_time_option(
|
||||
result = get_relative_time_option(
|
||||
value, name, with_timezone=self._with_timezone
|
||||
)
|
||||
if result is None:
|
||||
raise BackendException(f"Invalid value for {name}: {value!r}")
|
||||
return result
|
||||
except OpenSSLObjectError as exc:
|
||||
raise BackendException(str(exc))
|
||||
|
||||
def interpolate_timestamp(self, timestamp_start, timestamp_end, percentage):
|
||||
def interpolate_timestamp(
|
||||
self,
|
||||
timestamp_start: datetime.datetime,
|
||||
timestamp_end: datetime.datetime,
|
||||
percentage: float,
|
||||
) -> datetime.datetime:
|
||||
start = get_epoch_seconds(timestamp_start)
|
||||
end = get_epoch_seconds(timestamp_end)
|
||||
return from_epoch_seconds(
|
||||
start + percentage * (end - start), with_timezone=self._with_timezone
|
||||
)
|
||||
|
||||
def get_utc_datetime(self, *args, **kwargs):
|
||||
kwargs_ext = dict(kwargs)
|
||||
def get_utc_datetime(self, *args, **kwargs) -> datetime.datetime:
|
||||
kwargs_ext: dict[str, t.Any] = dict(kwargs)
|
||||
if self._with_timezone and ("tzinfo" not in kwargs_ext and len(args) < 8):
|
||||
kwargs_ext["tzinfo"] = UTC
|
||||
result = datetime.datetime(*args, **kwargs_ext)
|
||||
@@ -122,22 +134,33 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
return result
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_key(self, key_file=None, key_content=None, passphrase=None):
|
||||
def parse_key(
|
||||
self,
|
||||
key_file: str | os.PathLike | None = None,
|
||||
key_content: str | None = None,
|
||||
passphrase: str | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
|
||||
Raises KeyParsingError in case of errors.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sign(self, payload64, protected64, key_data):
|
||||
def sign(
|
||||
self, payload64: str, protected64: str, key_data: dict[str, t.Any]
|
||||
) -> dict[str, t.Any]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_mac_key(self, alg, key):
|
||||
def create_mac_key(self, alg: str, key: str) -> dict[str, t.Any]:
|
||||
"""Create a MAC key."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_ordered_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Return a list of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -148,7 +171,11 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
|
||||
def get_csr_identifiers(
|
||||
self,
|
||||
csr_filename: str | os.PathLike | None = None,
|
||||
csr_content: str | bytes | None = None,
|
||||
) -> set[tuple[str, str]]:
|
||||
"""
|
||||
Return a set of requested identifiers (CN and SANs) for the CSR.
|
||||
Each identifier is a pair (type, identifier), where type is either
|
||||
@@ -156,7 +183,12 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
|
||||
def get_cert_days(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Return the days the certificate in cert_filename remains valid and -1
|
||||
if the file was not found. If cert_filename contains more than one
|
||||
@@ -166,13 +198,17 @@ class CryptoBackend(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_chain_matcher(self, criterium):
|
||||
def create_chain_matcher(self, criterium: Criterium) -> ChainMatcher:
|
||||
"""
|
||||
Given a Criterium object, creates a ChainMatcher object.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_cert_information(self, cert_filename=None, cert_content=None):
|
||||
def get_cert_information(
|
||||
self,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
) -> CertificateInformation:
|
||||
"""
|
||||
Return some information on a X.509 certificate as a CertificateInformation object.
|
||||
"""
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
@@ -30,6 +31,14 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .backends import CryptoBackend
|
||||
from .certificates import ChainMatcher
|
||||
from .challenges import Challenge
|
||||
|
||||
|
||||
class ACMECertificateClient:
|
||||
"""
|
||||
ACME v2 client class. Uses an ACME account object and a CSR to
|
||||
@@ -37,7 +46,13 @@ class ACMECertificateClient:
|
||||
certificates.
|
||||
"""
|
||||
|
||||
def __init__(self, module, backend, client=None, account=None):
|
||||
def __init__(
|
||||
self,
|
||||
module: AnsibleModule,
|
||||
backend: CryptoBackend,
|
||||
client: ACMEClient | None = None,
|
||||
account: ACMEAccount | None = None,
|
||||
) -> None:
|
||||
self.module = module
|
||||
self.version = module.params["acme_version"]
|
||||
self.csr = module.params.get("csr")
|
||||
@@ -66,13 +81,17 @@ class ACMECertificateClient:
|
||||
|
||||
# Extract list of identifiers from CSR
|
||||
if self.csr is not None or self.csr_content is not None:
|
||||
self.identifiers = self.client.backend.get_ordered_csr_identifiers(
|
||||
csr_filename=self.csr, csr_content=self.csr_content
|
||||
self.identifiers: list[tuple[str, str]] | None = (
|
||||
self.client.backend.get_ordered_csr_identifiers(
|
||||
csr_filename=self.csr, csr_content=self.csr_content
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.identifiers = None
|
||||
|
||||
def parse_select_chain(self, select_chain):
|
||||
def parse_select_chain(
|
||||
self, select_chain: list[dict[str, t.Any]] | None
|
||||
) -> list[ChainMatcher]:
|
||||
select_chain_matcher = []
|
||||
if select_chain:
|
||||
for criterium_idx, criterium in enumerate(select_chain):
|
||||
@@ -88,14 +107,16 @@ class ACMECertificateClient:
|
||||
)
|
||||
return select_chain_matcher
|
||||
|
||||
def load_order(self):
|
||||
def load_order(self) -> Order:
|
||||
if not self.order_uri:
|
||||
raise ModuleFailException("The order URI has not been provided")
|
||||
order = Order.from_url(self.client, self.order_uri)
|
||||
order.load_authorizations(self.client)
|
||||
return order
|
||||
|
||||
def create_order(self, replaces_cert_id=None, profile=None):
|
||||
def create_order(
|
||||
self, replaces_cert_id: str | None = None, profile: str | None = None
|
||||
) -> Order:
|
||||
"""
|
||||
Create a new order.
|
||||
"""
|
||||
@@ -114,31 +135,31 @@ class ACMECertificateClient:
|
||||
order.load_authorizations(self.client)
|
||||
return order
|
||||
|
||||
def get_challenges_data(self, order):
|
||||
def get_challenges_data(
|
||||
self, order: Order
|
||||
) -> tuple[list[dict[str, t.Any]], dict[str, list[str]]]:
|
||||
"""
|
||||
Get challenge details.
|
||||
|
||||
Return a tuple of generic challenge details, and specialized DNS challenge details.
|
||||
"""
|
||||
# Get general challenge data
|
||||
data = []
|
||||
data: list[dict[str, t.Any]] = []
|
||||
data_dns: dict[str, list[str]] = {}
|
||||
dns_challenge_type = "dns-01"
|
||||
for authz in order.authorizations.values():
|
||||
# Skip valid authentications: their challenges are already valid
|
||||
# and do not need to be returned
|
||||
if authz.status == "valid":
|
||||
continue
|
||||
challenge_data = authz.get_challenge_data(self.client)
|
||||
data.append(
|
||||
dict(
|
||||
identifier=authz.identifier,
|
||||
identifier_type=authz.identifier_type,
|
||||
challenges=authz.get_challenge_data(self.client),
|
||||
challenges=challenge_data,
|
||||
)
|
||||
)
|
||||
# Get DNS challenge data
|
||||
data_dns = {}
|
||||
dns_challenge_type = "dns-01"
|
||||
for entry in data:
|
||||
dns_challenge = entry["challenges"].get(dns_challenge_type)
|
||||
dns_challenge = challenge_data.get(dns_challenge_type)
|
||||
if dns_challenge:
|
||||
values = data_dns.get(dns_challenge["record"])
|
||||
if values is None:
|
||||
@@ -147,7 +168,7 @@ class ACMECertificateClient:
|
||||
values.append(dns_challenge["resource_value"])
|
||||
return data, data_dns
|
||||
|
||||
def check_that_authorizations_can_be_used(self, order):
|
||||
def check_that_authorizations_can_be_used(self, order: Order) -> None:
|
||||
bad_authzs = []
|
||||
for authz in order.authorizations.values():
|
||||
if authz.status not in ("valid", "pending"):
|
||||
@@ -155,27 +176,32 @@ class ACMECertificateClient:
|
||||
f"{authz.combined_identifier} (status={authz.status!r})"
|
||||
)
|
||||
if bad_authzs:
|
||||
bad_authzs = ", ".join(sorted(bad_authzs))
|
||||
bad_authzs_str = ", ".join(sorted(bad_authzs))
|
||||
raise ModuleFailException(
|
||||
"Some of the authorizations for the order are in a bad state, so the order"
|
||||
f" can no longer be satisfied: {bad_authzs}",
|
||||
f" can no longer be satisfied: {bad_authzs_str}",
|
||||
)
|
||||
|
||||
def collect_invalid_authzs(self, order):
|
||||
def collect_invalid_authzs(self, order: Order) -> list[Authorization]:
|
||||
return [
|
||||
authz
|
||||
for authz in order.authorizations.values()
|
||||
if authz.status == "invalid"
|
||||
]
|
||||
|
||||
def collect_pending_authzs(self, order):
|
||||
def collect_pending_authzs(self, order: Order) -> list[Authorization]:
|
||||
return [
|
||||
authz
|
||||
for authz in order.authorizations.values()
|
||||
if authz.status == "pending"
|
||||
]
|
||||
|
||||
def call_validate(self, pending_authzs, get_challenge, wait=True):
|
||||
def call_validate(
|
||||
self,
|
||||
pending_authzs: list[Authorization],
|
||||
get_challenge: t.Callable[[Authorization], str],
|
||||
wait: bool = True,
|
||||
) -> list[tuple[Authorization, str, Challenge | None]]:
|
||||
authzs_with_challenges_to_wait_for = []
|
||||
for authz in pending_authzs:
|
||||
challenge_type = get_challenge(authz)
|
||||
@@ -185,10 +211,12 @@ class ACMECertificateClient:
|
||||
)
|
||||
return authzs_with_challenges_to_wait_for
|
||||
|
||||
def wait_for_validation(self, authzs_to_wait_for):
|
||||
def wait_for_validation(self, authzs_to_wait_for: list[Authorization]) -> None:
|
||||
wait_for_validation(authzs_to_wait_for, self.client)
|
||||
|
||||
def _download_alternate_chains(self, cert):
|
||||
def _download_alternate_chains(
|
||||
self, cert: CertificateChain
|
||||
) -> list[CertificateChain]:
|
||||
alternate_chains = []
|
||||
for alternate in cert.alternates:
|
||||
try:
|
||||
@@ -206,13 +234,30 @@ class ACMECertificateClient:
|
||||
)
|
||||
return alternate_chains
|
||||
|
||||
def download_certificate(self, order, download_all_chains=True):
|
||||
@t.overload
|
||||
def download_certificate(
|
||||
self, order: Order, *, download_all_chains: t.Literal[True] = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain]]: ...
|
||||
|
||||
@t.overload
|
||||
def download_certificate(
|
||||
self, order: Order, *, download_all_chains: t.Literal[False]
|
||||
) -> tuple[CertificateChain, None]: ...
|
||||
|
||||
@t.overload
|
||||
def download_certificate(
|
||||
self, order: Order, *, download_all_chains: bool = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
|
||||
|
||||
def download_certificate(
|
||||
self, order: Order, *, download_all_chains: bool = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]:
|
||||
"""
|
||||
Download certificate from a valid oder.
|
||||
"""
|
||||
if order.status != "valid":
|
||||
raise ModuleFailException(
|
||||
f"The order must be valid, but has state {order.state!r}!"
|
||||
f"The order must be valid, but has state {order.status!r}!"
|
||||
)
|
||||
|
||||
if not order.certificate_uri:
|
||||
@@ -232,7 +277,24 @@ class ACMECertificateClient:
|
||||
|
||||
return cert, alternate_chains
|
||||
|
||||
def get_certificate(self, order, download_all_chains=True):
|
||||
@t.overload
|
||||
def get_certificate(
|
||||
self, order: Order, *, download_all_chains: t.Literal[True] = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
|
||||
|
||||
@t.overload
|
||||
def get_certificate(
|
||||
self, order: Order, *, download_all_chains: t.Literal[False]
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
|
||||
|
||||
@t.overload
|
||||
def get_certificate(
|
||||
self, order: Order, *, download_all_chains: bool = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]: ...
|
||||
|
||||
def get_certificate(
|
||||
self, order: Order, *, download_all_chains: bool = True
|
||||
) -> tuple[CertificateChain, list[CertificateChain] | None]:
|
||||
"""
|
||||
Request a new certificate and downloads it, and optionally all certificate chains.
|
||||
First verifies whether all authorizations are valid; if not, aborts with an error.
|
||||
@@ -250,7 +312,11 @@ class ACMECertificateClient:
|
||||
|
||||
return self.download_certificate(order, download_all_chains=download_all_chains)
|
||||
|
||||
def find_matching_chain(self, chains, select_chain_matcher):
|
||||
def find_matching_chain(
|
||||
self,
|
||||
chains: list[CertificateChain],
|
||||
select_chain_matcher: t.Iterable[ChainMatcher],
|
||||
) -> CertificateChain | None:
|
||||
for criterium_idx, matcher in enumerate(select_chain_matcher):
|
||||
for chain in chains:
|
||||
if matcher.match(chain):
|
||||
@@ -261,9 +327,15 @@ class ACMECertificateClient:
|
||||
return None
|
||||
|
||||
def write_cert_chain(
|
||||
self, cert, cert_dest=None, fullchain_dest=None, chain_dest=None
|
||||
):
|
||||
self,
|
||||
cert: CertificateChain,
|
||||
cert_dest: str | os.PathLike | None = None,
|
||||
fullchain_dest: str | os.PathLike | None = None,
|
||||
chain_dest: str | os.PathLike | None = None,
|
||||
) -> bool:
|
||||
changed = False
|
||||
if cert.cert is None:
|
||||
raise ValueError("Certificate is not present")
|
||||
|
||||
if cert_dest and write_file(self.module, cert_dest, cert.cert.encode("utf8")):
|
||||
changed = True
|
||||
@@ -282,7 +354,7 @@ class ACMECertificateClient:
|
||||
|
||||
return changed
|
||||
|
||||
def deactivate_authzs(self, order):
|
||||
def deactivate_authzs(self, order: Order) -> None:
|
||||
"""
|
||||
Deactivates all valid authz's. Does not raise exceptions.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
ModuleFailException,
|
||||
@@ -19,20 +20,29 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .acme import ACMEClient
|
||||
|
||||
|
||||
_CertificateChain = t.TypeVar("_CertificateChain", bound="CertificateChain")
|
||||
|
||||
|
||||
class CertificateChain:
|
||||
"""
|
||||
Download and parse the certificate chain.
|
||||
https://tools.ietf.org/html/rfc8555#section-7.4.2
|
||||
"""
|
||||
|
||||
def __init__(self, url):
|
||||
def __init__(self, url: str):
|
||||
self.url = url
|
||||
self.cert = None
|
||||
self.chain = []
|
||||
self.alternates = []
|
||||
self.cert: str | None = None
|
||||
self.chain: list[str] = []
|
||||
self.alternates: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def download(cls, client, url):
|
||||
def download(
|
||||
cls: t.Type[_CertificateChain], client: ACMEClient, url: str
|
||||
) -> _CertificateChain:
|
||||
content, info = client.get_request(
|
||||
url,
|
||||
parse_json_result=False,
|
||||
@@ -43,7 +53,7 @@ class CertificateChain:
|
||||
"application/pem-certificate-chain"
|
||||
):
|
||||
raise ModuleFailException(
|
||||
f"Cannot download certificate chain from {url}, as content type is not application/pem-certificate-chain: {content} (headers: {info})"
|
||||
f"Cannot download certificate chain from {url}, as content type is not application/pem-certificate-chain: {content!r} (headers: {info})"
|
||||
)
|
||||
|
||||
result = cls(url)
|
||||
@@ -60,12 +70,12 @@ class CertificateChain:
|
||||
|
||||
if result.cert is None:
|
||||
raise ModuleFailException(
|
||||
f"Failed to parse certificate chain download from {url}: {content} (headers: {info})"
|
||||
f"Failed to parse certificate chain download from {url}: {content!r} (headers: {info})"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _process_links(self, client, link, relation):
|
||||
def _process_links(self, client: ACMEClient, link: str, relation: str) -> None:
|
||||
if relation == "up":
|
||||
# Process link-up headers if there was no chain in reply
|
||||
if not self.chain:
|
||||
@@ -77,7 +87,9 @@ class CertificateChain:
|
||||
elif relation == "alternate":
|
||||
self.alternates.append(link)
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self) -> dict[str, bytes]:
|
||||
if self.cert is None:
|
||||
raise ValueError("Has no certificate")
|
||||
cert = self.cert.encode("utf8")
|
||||
chain = ("\n".join(self.chain)).encode("utf8")
|
||||
return {
|
||||
@@ -88,18 +100,22 @@ class CertificateChain:
|
||||
|
||||
|
||||
class Criterium:
|
||||
def __init__(self, criterium, index=None):
|
||||
def __init__(self, criterium: dict[str, t.Any], index: int):
|
||||
self.index = index
|
||||
self.test_certificates = criterium["test_certificates"]
|
||||
self.subject = criterium["subject"]
|
||||
self.issuer = criterium["issuer"]
|
||||
self.subject_key_identifier = criterium["subject_key_identifier"]
|
||||
self.authority_key_identifier = criterium["authority_key_identifier"]
|
||||
self.test_certificates: t.Literal["first", "last", "all"] = criterium[
|
||||
"test_certificates"
|
||||
]
|
||||
self.subject: dict[str, t.Any] | None = criterium["subject"]
|
||||
self.issuer: dict[str, t.Any] | None = criterium["issuer"]
|
||||
self.subject_key_identifier: str | None = criterium["subject_key_identifier"]
|
||||
self.authority_key_identifier: str | None = criterium[
|
||||
"authority_key_identifier"
|
||||
]
|
||||
|
||||
|
||||
class ChainMatcher(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def match(self, certificate):
|
||||
def match(self, certificate: CertificateChain) -> bool:
|
||||
"""
|
||||
Check whether a certificate chain (CertificateChain instance) matches.
|
||||
"""
|
||||
|
||||
@@ -11,6 +11,7 @@ import ipaddress
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
@@ -23,7 +24,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
def create_key_authorization(client, token):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from .acme import ACMEClient
|
||||
|
||||
|
||||
def create_key_authorization(client: ACMEClient, token: str) -> str:
|
||||
"""
|
||||
Returns the key authorization for the given token
|
||||
https://tools.ietf.org/html/rfc8555#section-8.1
|
||||
@@ -35,41 +42,49 @@ def create_key_authorization(client, token):
|
||||
return f"{token}.{thumbprint}"
|
||||
|
||||
|
||||
def combine_identifier(identifier_type, identifier):
|
||||
def combine_identifier(identifier_type: str, identifier: str) -> str:
|
||||
return f"{identifier_type}:{identifier}"
|
||||
|
||||
|
||||
def normalize_combined_identifier(identifier):
|
||||
def normalize_combined_identifier(identifier: str) -> str:
|
||||
identifier_type, identifier = split_identifier(identifier)
|
||||
# Normalize DNS names and IPs
|
||||
identifier = identifier.lower()
|
||||
return combine_identifier(identifier_type, identifier)
|
||||
|
||||
|
||||
def split_identifier(identifier):
|
||||
def split_identifier(identifier: str) -> tuple[str, str]:
|
||||
parts = identifier.split(":", 1)
|
||||
if len(parts) != 2:
|
||||
raise ModuleFailException(
|
||||
f'Identifier "{identifier}" is not of the form <type>:<identifier>'
|
||||
)
|
||||
return parts
|
||||
return parts[0], parts[1]
|
||||
|
||||
|
||||
_Challenge = t.TypeVar("_Challenge", bound="Challenge")
|
||||
|
||||
|
||||
class Challenge:
|
||||
def __init__(self, data, url):
|
||||
def __init__(self, data: dict[str, t.Any], url: str) -> None:
|
||||
self.data = data
|
||||
|
||||
self.type = data["type"]
|
||||
self.type: str = data["type"]
|
||||
self.url = url
|
||||
self.status = data["status"]
|
||||
self.token = data.get("token")
|
||||
self.status: str = data["status"]
|
||||
self.token: str | None = data.get("token")
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, client, data, url=None):
|
||||
def from_json(
|
||||
cls: t.Type[_Challenge],
|
||||
client: ACMEClient,
|
||||
data: dict[str, t.Any],
|
||||
url: str | None = None,
|
||||
) -> _Challenge:
|
||||
return cls(data, url or data["url"])
|
||||
|
||||
def call_validate(self, client):
|
||||
challenge_response = {}
|
||||
def call_validate(self, client: ACMEClient) -> None:
|
||||
challenge_response: dict[str, t.Any] = {}
|
||||
client.send_signed_request(
|
||||
self.url,
|
||||
challenge_response,
|
||||
@@ -77,10 +92,15 @@ class Challenge:
|
||||
expected_status_codes=[200, 202],
|
||||
)
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self) -> dict[str, t.Any]:
|
||||
return self.data.copy()
|
||||
|
||||
def get_validation_data(self, client, identifier_type, identifier):
|
||||
def get_validation_data(
|
||||
self, client: ACMEClient, identifier_type: str, identifier: str
|
||||
) -> dict[str, t.Any] | None:
|
||||
if self.token is None:
|
||||
return None
|
||||
|
||||
token = re.sub(r"[^A-Za-z0-9_\-]", "_", self.token)
|
||||
key_authorization = create_key_authorization(client, token)
|
||||
|
||||
@@ -113,21 +133,33 @@ class Challenge:
|
||||
resource += "."
|
||||
else:
|
||||
resource = identifier
|
||||
value = base64.b64encode(
|
||||
b_value = base64.b64encode(
|
||||
hashlib.sha256(to_bytes(key_authorization)).digest()
|
||||
)
|
||||
return {
|
||||
"resource": resource,
|
||||
"resource_original": combine_identifier(identifier_type, identifier),
|
||||
"resource_value": value,
|
||||
"resource_value": b_value,
|
||||
}
|
||||
|
||||
# Unknown challenge type: ignore
|
||||
return None
|
||||
|
||||
|
||||
_Authorization = t.TypeVar("_Authorization", bound="Authorization")
|
||||
|
||||
|
||||
class Authorization:
|
||||
def _setup(self, client, data):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
|
||||
self.data: dict[str, t.Any] | None = None
|
||||
self.challenges: list[Challenge] = []
|
||||
self.status: str | None = None
|
||||
self.identifier_type: str | None = None
|
||||
self.identifier: str | None = None
|
||||
|
||||
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
|
||||
data["uri"] = self.url
|
||||
self.data = data
|
||||
# While 'challenges' is a required field, apparently not every CA cares
|
||||
@@ -145,29 +177,32 @@ class Authorization:
|
||||
if data.get("wildcard", False):
|
||||
self.identifier = f"*.{self.identifier}"
|
||||
|
||||
def __init__(self, url):
|
||||
self.url = url
|
||||
|
||||
self.data = None
|
||||
self.challenges = []
|
||||
self.status = None
|
||||
self.identifier_type = None
|
||||
self.identifier = None
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, client, data, url):
|
||||
def from_json(
|
||||
cls: t.Type[_Authorization],
|
||||
client: ACMEClient,
|
||||
data: dict[str, t.Any],
|
||||
url: str,
|
||||
) -> _Authorization:
|
||||
result = cls(url)
|
||||
result._setup(client, data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, client, url):
|
||||
def from_url(
|
||||
cls: t.Type[_Authorization], client: ACMEClient, url: str
|
||||
) -> _Authorization:
|
||||
result = cls(url)
|
||||
result.refresh(client)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def create(cls, client, identifier_type, identifier):
|
||||
def create(
|
||||
cls: t.Type[_Authorization],
|
||||
client: ACMEClient,
|
||||
identifier_type: str,
|
||||
identifier: str,
|
||||
) -> _Authorization:
|
||||
"""
|
||||
Create a new authorization for the given identifier.
|
||||
Return the authorization object of the new authorization
|
||||
@@ -194,23 +229,29 @@ class Authorization:
|
||||
return cls.from_json(client, result, info["location"])
|
||||
|
||||
@property
|
||||
def combined_identifier(self):
|
||||
def combined_identifier(self) -> str:
|
||||
if self.identifier_type is None or self.identifier is None:
|
||||
raise ValueError("Data not present")
|
||||
return combine_identifier(self.identifier_type, self.identifier)
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self) -> dict[str, t.Any]:
|
||||
if self.data is None:
|
||||
raise ValueError("Data not present")
|
||||
return self.data.copy()
|
||||
|
||||
def refresh(self, client):
|
||||
def refresh(self, client: ACMEClient) -> bool:
|
||||
result, dummy = client.get_request(self.url)
|
||||
changed = self.data != result
|
||||
self._setup(client, result)
|
||||
return changed
|
||||
|
||||
def get_challenge_data(self, client):
|
||||
def get_challenge_data(self, client: ACMEClient) -> dict[str, t.Any]:
|
||||
"""
|
||||
Returns a dict with the data for all proposed (and supported) challenges
|
||||
of the given authorization.
|
||||
"""
|
||||
if self.identifier_type is None or self.identifier is None:
|
||||
raise ValueError("Data not present")
|
||||
data = {}
|
||||
for challenge in self.challenges:
|
||||
validation_data = challenge.get_validation_data(
|
||||
@@ -220,7 +261,7 @@ class Authorization:
|
||||
data[challenge.type] = validation_data
|
||||
return data
|
||||
|
||||
def raise_error(self, error_msg, module=None):
|
||||
def raise_error(self, error_msg: str, module: AnsibleModule) -> t.NoReturn:
|
||||
"""
|
||||
Aborts with a specific error for a challenge.
|
||||
"""
|
||||
@@ -246,13 +287,13 @@ class Authorization:
|
||||
),
|
||||
)
|
||||
|
||||
def find_challenge(self, challenge_type):
|
||||
def find_challenge(self, challenge_type: str) -> Challenge | None:
|
||||
for challenge in self.challenges:
|
||||
if challenge_type == challenge.type:
|
||||
return challenge
|
||||
return None
|
||||
|
||||
def wait_for_validation(self, client, callenge_type):
|
||||
def wait_for_validation(self, client: ACMEClient, callenge_type: str) -> bool:
|
||||
while True:
|
||||
self.refresh(client)
|
||||
if self.status in ["valid", "invalid", "revoked"]:
|
||||
@@ -264,7 +305,9 @@ class Authorization:
|
||||
|
||||
return self.status == "valid"
|
||||
|
||||
def call_validate(self, client, challenge_type, wait=True):
|
||||
def call_validate(
|
||||
self, client: ACMEClient, challenge_type: str, wait: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the authorization provided in the auth dict. Returns True
|
||||
when the validation was successful and False when it was not.
|
||||
@@ -281,7 +324,7 @@ class Authorization:
|
||||
return self.status == "valid"
|
||||
return self.wait_for_validation(client, challenge_type)
|
||||
|
||||
def can_deactivate(self):
|
||||
def can_deactivate(self) -> bool:
|
||||
"""
|
||||
Deactivates this authorization.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
@@ -289,14 +332,14 @@ class Authorization:
|
||||
"""
|
||||
return self.status in ("valid", "pending")
|
||||
|
||||
def deactivate(self, client):
|
||||
def deactivate(self, client: ACMEClient) -> bool | None:
|
||||
"""
|
||||
Deactivates this authorization.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
https://tools.ietf.org/html/rfc8555#section-7.5.2
|
||||
"""
|
||||
if not self.can_deactivate():
|
||||
return
|
||||
return None
|
||||
authz_deactivate = {"status": "deactivated"}
|
||||
result, info = client.send_signed_request(
|
||||
self.url, authz_deactivate, fail_on_error=False
|
||||
@@ -307,7 +350,9 @@ class Authorization:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def deactivate_url(cls, client, url):
|
||||
def deactivate_url(
|
||||
cls: t.Type[_Authorization], client: ACMEClient, url: str
|
||||
) -> _Authorization:
|
||||
"""
|
||||
Deactivates this authorization.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
@@ -322,7 +367,7 @@ class Authorization:
|
||||
return authz
|
||||
|
||||
|
||||
def wait_for_validation(authzs, client):
|
||||
def wait_for_validation(authzs: t.Iterable[Authorization], client: ACMEClient) -> None:
|
||||
"""
|
||||
Wait until a list of authz is valid. Fail if at least one of them is invalid or revoked.
|
||||
"""
|
||||
|
||||
@@ -5,19 +5,24 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from http.client import responses as http_responses
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_text
|
||||
|
||||
|
||||
def format_http_status(status_code):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def format_http_status(status_code: int) -> str:
|
||||
expl = http_responses.get(status_code)
|
||||
if not expl:
|
||||
return str(status_code)
|
||||
return f"{status_code} {expl}"
|
||||
|
||||
|
||||
def format_error_problem(problem, subproblem_prefix=""):
|
||||
def format_error_problem(problem: dict[str, t.Any], subproblem_prefix: str = "") -> str:
|
||||
error_type = problem.get(
|
||||
"type", "about:blank"
|
||||
) # https://www.rfc-editor.org/rfc/rfc7807#section-3.1
|
||||
@@ -32,8 +37,10 @@ def format_error_problem(problem, subproblem_prefix=""):
|
||||
msg = f"{msg} Subproblems:"
|
||||
for index, problem in enumerate(subproblems):
|
||||
index_str = f"{subproblem_prefix}{index}"
|
||||
problem = format_error_problem(problem, subproblem_prefix=f"{index_str}.")
|
||||
msg = f"{msg}\n({index_str}) {problem}"
|
||||
problem_str = format_error_problem(
|
||||
problem, subproblem_prefix=f"{index_str}."
|
||||
)
|
||||
msg = f"{msg}\n({index_str}) {problem_str}"
|
||||
return msg
|
||||
|
||||
|
||||
@@ -42,25 +49,25 @@ class ModuleFailException(Exception):
|
||||
If raised, module.fail_json() will be called with the given parameters after cleanup.
|
||||
"""
|
||||
|
||||
def __init__(self, msg, **args):
|
||||
def __init__(self, msg: str, **args: t.Any) -> None:
|
||||
super(ModuleFailException, self).__init__(self, msg)
|
||||
self.msg = msg
|
||||
self.module_fail_args = args
|
||||
|
||||
def do_fail(self, module, **arguments):
|
||||
def do_fail(self, module: AnsibleModule, **arguments) -> t.NoReturn:
|
||||
module.fail_json(msg=self.msg, other=self.module_fail_args, **arguments)
|
||||
|
||||
|
||||
class ACMEProtocolException(ModuleFailException):
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
msg=None,
|
||||
info=None,
|
||||
module: AnsibleModule,
|
||||
msg: str | None = None,
|
||||
info: dict[str, t.Any] | None = None,
|
||||
response=None,
|
||||
content=None,
|
||||
content_json=None,
|
||||
extras=None,
|
||||
content: bytes | None = None,
|
||||
content_json: dict[str, t.Any] | None = None,
|
||||
extras: dict[str, t.Any] | None = None,
|
||||
):
|
||||
# Try to get hold of content, if response is given and content is not provided
|
||||
if content is None and content_json is None and response is not None:
|
||||
@@ -71,7 +78,8 @@ class ACMEProtocolException(ModuleFailException):
|
||||
raise TypeError
|
||||
content = response.read()
|
||||
except (AttributeError, TypeError):
|
||||
content = info.pop("body", None)
|
||||
if info is not None:
|
||||
content = info.pop("body", None)
|
||||
|
||||
# Make sure that content_json is None or a dictionary
|
||||
if content_json is not None and not isinstance(content_json, dict):
|
||||
@@ -139,8 +147,8 @@ class ACMEProtocolException(ModuleFailException):
|
||||
add_msg = f" The raw result: {to_text(content)}"
|
||||
|
||||
super(ACMEProtocolException, self).__init__(f"{msg}.{add_msg}", **extras)
|
||||
self.problem = {}
|
||||
self.subproblems = []
|
||||
self.problem: dict[str, t.Any] = {}
|
||||
self.subproblems: list[dict[str, t.Any]] = []
|
||||
self.error_code = error_code
|
||||
self.error_type = error_type
|
||||
for k, v in extras.items():
|
||||
|
||||
@@ -10,22 +10,27 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
ModuleFailException,
|
||||
)
|
||||
|
||||
|
||||
def read_file(fn, mode="b"):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def read_file(fn: str | os.PathLike) -> bytes:
|
||||
try:
|
||||
with open(fn, "r" + mode) as f:
|
||||
with open(fn, "rb") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
raise ModuleFailException(f'Error while reading file "{fn}": {e}')
|
||||
|
||||
|
||||
# This function was adapted from an earlier version of https://github.com/ansible/ansible/blob/devel/lib/ansible/modules/uri.py
|
||||
def write_file(module, dest, content):
|
||||
def write_file(module: AnsibleModule, dest: str | os.PathLike, content: bytes) -> bool:
|
||||
"""
|
||||
Write content to destination file dest, only if the content
|
||||
has changed.
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
|
||||
Authorization,
|
||||
@@ -13,14 +14,35 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.challenges i
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
ACMEProtocolException,
|
||||
ModuleFailException,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.utils import (
|
||||
nopad_b64,
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .acme import ACMEClient
|
||||
|
||||
|
||||
_Order = t.TypeVar("_Order", bound="Order")
|
||||
|
||||
|
||||
class Order:
|
||||
def _setup(self, client, data):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
|
||||
self.data: dict[str, t.Any] | None = None
|
||||
|
||||
self.status = None
|
||||
self.identifiers: list[tuple[str, str]] = []
|
||||
self.replaces_cert_id = None
|
||||
self.finalize_uri = None
|
||||
self.certificate_uri = None
|
||||
self.authorization_uris: list[str] = []
|
||||
self.authorizations: dict[str, Authorization] = {}
|
||||
|
||||
def _setup(self, client: ACMEClient, data: dict[str, t.Any]) -> None:
|
||||
self.data = data
|
||||
|
||||
self.status = data["status"]
|
||||
@@ -33,33 +55,28 @@ class Order:
|
||||
self.authorization_uris = data["authorizations"]
|
||||
self.authorizations = {}
|
||||
|
||||
def __init__(self, url):
|
||||
self.url = url
|
||||
|
||||
self.data = None
|
||||
|
||||
self.status = None
|
||||
self.identifiers = []
|
||||
self.replaces_cert_id = None
|
||||
self.finalize_uri = None
|
||||
self.certificate_uri = None
|
||||
self.authorization_uris = []
|
||||
self.authorizations = {}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, client, data, url):
|
||||
def from_json(
|
||||
cls: t.Type[_Order], client: ACMEClient, data: dict[str, t.Any], url: str
|
||||
) -> _Order:
|
||||
result = cls(url)
|
||||
result._setup(client, data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, client, url):
|
||||
def from_url(cls: t.Type[_Order], client: ACMEClient, url: str) -> _Order:
|
||||
result = cls(url)
|
||||
result.refresh(client)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def create(cls, client, identifiers, replaces_cert_id=None, profile=None):
|
||||
def create(
|
||||
cls: t.Type[_Order],
|
||||
client: ACMEClient,
|
||||
identifiers: list[tuple[str, str]],
|
||||
replaces_cert_id: str | None = None,
|
||||
profile: str | None = None,
|
||||
) -> _Order:
|
||||
"""
|
||||
Start a new certificate order (ACME v2 protocol).
|
||||
https://tools.ietf.org/html/rfc8555#section-7.4
|
||||
@@ -72,7 +89,7 @@ class Order:
|
||||
"value": identifier,
|
||||
}
|
||||
)
|
||||
new_order = {"identifiers": acme_identifiers}
|
||||
new_order: dict[str, t.Any] = {"identifiers": acme_identifiers}
|
||||
if replaces_cert_id is not None:
|
||||
new_order["replaces"] = replaces_cert_id
|
||||
if profile is not None:
|
||||
@@ -87,15 +104,17 @@ class Order:
|
||||
|
||||
@classmethod
|
||||
def create_with_error_handling(
|
||||
cls,
|
||||
client,
|
||||
identifiers,
|
||||
error_strategy="auto",
|
||||
error_max_retries=3,
|
||||
replaces_cert_id=None,
|
||||
profile=None,
|
||||
message_callback=None,
|
||||
):
|
||||
cls: t.Type[_Order],
|
||||
client: ACMEClient,
|
||||
identifiers: list[tuple[str, str]],
|
||||
error_strategy: t.Literal[
|
||||
"auto", "fail", "always", "retry_without_replaces_cert_id"
|
||||
] = "auto",
|
||||
error_max_retries: int = 3,
|
||||
replaces_cert_id: str | None = None,
|
||||
profile: str | None = None,
|
||||
message_callback: t.Callable[[str], None] | None = None,
|
||||
) -> _Order:
|
||||
"""
|
||||
error_strategy can be one of the following strings:
|
||||
|
||||
@@ -140,20 +159,20 @@ class Order:
|
||||
|
||||
raise
|
||||
|
||||
def refresh(self, client):
|
||||
def refresh(self, client: ACMEClient) -> bool:
|
||||
result, dummy = client.get_request(self.url)
|
||||
changed = self.data != result
|
||||
self._setup(client, result)
|
||||
return changed
|
||||
|
||||
def load_authorizations(self, client):
|
||||
def load_authorizations(self, client: ACMEClient) -> None:
|
||||
for auth_uri in self.authorization_uris:
|
||||
authz = Authorization.from_url(client, auth_uri)
|
||||
self.authorizations[
|
||||
normalize_combined_identifier(authz.combined_identifier)
|
||||
] = authz
|
||||
|
||||
def wait_for_finalization(self, client):
|
||||
def wait_for_finalization(self, client: ACMEClient) -> None:
|
||||
while True:
|
||||
self.refresh(client)
|
||||
if self.status in ["valid", "invalid", "pending", "ready"]:
|
||||
@@ -167,12 +186,14 @@ class Order:
|
||||
content_json=self.data,
|
||||
)
|
||||
|
||||
def finalize(self, client, csr_der, wait=True):
|
||||
def finalize(self, client: ACMEClient, csr_der: bytes, wait: bool = True) -> None:
|
||||
"""
|
||||
Create a new certificate based on the csr.
|
||||
Return the certificate object as dict
|
||||
https://tools.ietf.org/html/rfc8555#section-7.4
|
||||
"""
|
||||
if self.finalize_uri is None:
|
||||
raise ModuleFailException("finalize_uri must be set")
|
||||
new_cert = {
|
||||
"csr": nopad_b64(csr_der),
|
||||
}
|
||||
|
||||
@@ -7,9 +7,11 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import traceback
|
||||
import typing as t
|
||||
from urllib.parse import unquote
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.errors import (
|
||||
@@ -23,11 +25,15 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
def nopad_b64(data):
|
||||
if t.TYPE_CHECKING:
|
||||
from .backends import CertificateInformation, CryptoBackend
|
||||
|
||||
|
||||
def nopad_b64(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).decode("utf8").replace("=", "")
|
||||
|
||||
|
||||
def der_to_pem(der_cert):
|
||||
def der_to_pem(der_cert: bytes) -> str:
|
||||
"""
|
||||
Convert the DER format certificate in der_cert to a PEM format certificate and return it.
|
||||
"""
|
||||
@@ -35,7 +41,9 @@ def der_to_pem(der_cert):
|
||||
return f"-----BEGIN CERTIFICATE-----\n{content}\n-----END CERTIFICATE-----\n"
|
||||
|
||||
|
||||
def pem_to_der(pem_filename=None, pem_content=None):
|
||||
def pem_to_der(
|
||||
pem_filename: str | os.PathLike | None = None, pem_content: str | None = None
|
||||
) -> bytes:
|
||||
"""
|
||||
Load PEM file, or use PEM file's content, and convert to DER.
|
||||
|
||||
@@ -70,7 +78,9 @@ def pem_to_der(pem_filename=None, pem_content=None):
|
||||
return base64.b64decode("".join(certificate_lines))
|
||||
|
||||
|
||||
def process_links(info, callback):
|
||||
def process_links(
|
||||
info: dict[str, t.Any], callback: t.Callable[[str, str], None]
|
||||
) -> None:
|
||||
"""
|
||||
Process link header, calls callback for every link header with the URL and relation as options.
|
||||
|
||||
@@ -82,7 +92,11 @@ def process_links(info, callback):
|
||||
callback(unquote(url), relation)
|
||||
|
||||
|
||||
def parse_retry_after(value, relative_with_timezone=True, now=None):
|
||||
def parse_retry_after(
|
||||
value: str,
|
||||
relative_with_timezone: bool = True,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> datetime.datetime:
|
||||
"""
|
||||
Parse the value of a Retry-After header and return a timestamp.
|
||||
|
||||
@@ -106,12 +120,12 @@ def parse_retry_after(value, relative_with_timezone=True, now=None):
|
||||
|
||||
|
||||
def compute_cert_id(
|
||||
backend,
|
||||
cert_info=None,
|
||||
cert_filename=None,
|
||||
cert_content=None,
|
||||
none_if_required_information_is_missing=False,
|
||||
):
|
||||
backend: CryptoBackend,
|
||||
cert_info: CertificateInformation | None = None,
|
||||
cert_filename: str | os.PathLike | None = None,
|
||||
cert_content: str | bytes | None = None,
|
||||
none_if_required_information_is_missing: bool = False,
|
||||
) -> str | None:
|
||||
# Obtain certificate info if not provided
|
||||
if cert_info is None:
|
||||
cert_info = backend.get_cert_information(
|
||||
|
||||
@@ -4,10 +4,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def _ensure_list(value):
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
def _ensure_list(value: list[_T] | tuple[_T] | None) -> list[_T]:
|
||||
if value is None:
|
||||
return []
|
||||
return list(value)
|
||||
@@ -16,13 +21,19 @@ def _ensure_list(value):
|
||||
class ArgumentSpec:
|
||||
def __init__(
|
||||
self,
|
||||
argument_spec=None,
|
||||
mutually_exclusive=None,
|
||||
required_together=None,
|
||||
required_one_of=None,
|
||||
required_if=None,
|
||||
required_by=None,
|
||||
):
|
||||
argument_spec: dict[str, t.Any] | None = None,
|
||||
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_together: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_if: (
|
||||
list[
|
||||
tuple[str, t.Any, list[str] | tuple[str, ...]]
|
||||
| tuple[str, t.Any, list[str] | tuple[str, ...], bool]
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
|
||||
) -> None:
|
||||
self.argument_spec = argument_spec or {}
|
||||
self.mutually_exclusive = _ensure_list(mutually_exclusive)
|
||||
self.required_together = _ensure_list(required_together)
|
||||
@@ -30,17 +41,23 @@ class ArgumentSpec:
|
||||
self.required_if = _ensure_list(required_if)
|
||||
self.required_by = required_by or {}
|
||||
|
||||
def update_argspec(self, **kwargs):
|
||||
def update_argspec(self, **kwargs) -> t.Self:
|
||||
self.argument_spec.update(kwargs)
|
||||
return self
|
||||
|
||||
def update(
|
||||
self,
|
||||
mutually_exclusive=None,
|
||||
required_together=None,
|
||||
required_one_of=None,
|
||||
required_if=None,
|
||||
required_by=None,
|
||||
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_together: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_if: (
|
||||
list[
|
||||
tuple[str, t.Any, list[str] | tuple[str, ...]]
|
||||
| tuple[str, t.Any, list[str] | tuple[str, ...], bool]
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
|
||||
):
|
||||
if mutually_exclusive:
|
||||
self.mutually_exclusive.extend(mutually_exclusive)
|
||||
@@ -57,7 +74,7 @@ class ArgumentSpec:
|
||||
self.required_by[k] = v
|
||||
return self
|
||||
|
||||
def merge(self, other):
|
||||
def merge(self, other: t.Self) -> t.Self:
|
||||
self.update_argspec(**other.argument_spec)
|
||||
self.update(
|
||||
mutually_exclusive=other.mutually_exclusive,
|
||||
@@ -68,8 +85,22 @@ class ArgumentSpec:
|
||||
)
|
||||
return self
|
||||
|
||||
def create_ansible_module_helper(self, clazz, args, **kwargs):
|
||||
return clazz(
|
||||
def create_ansible_module_helper(
|
||||
self, clazz: type[_T], args: tuple, **kwargs: t.Any
|
||||
) -> _T:
|
||||
for forbidden_name in (
|
||||
"argument_spec",
|
||||
"mutually_exclusive",
|
||||
"required_together",
|
||||
"required_one_of",
|
||||
"required_if",
|
||||
"required_by",
|
||||
):
|
||||
if forbidden_name in kwargs:
|
||||
raise ValueError(
|
||||
f"You must not provide a {forbidden_name} keyword parameter to create_ansible_module_helper()"
|
||||
)
|
||||
instance = clazz( # type: ignore
|
||||
*args,
|
||||
argument_spec=self.argument_spec,
|
||||
mutually_exclusive=self.mutually_exclusive,
|
||||
@@ -79,8 +110,9 @@ class ArgumentSpec:
|
||||
required_by=self.required_by,
|
||||
**kwargs,
|
||||
)
|
||||
return instance
|
||||
|
||||
def create_ansible_module(self, **kwargs):
|
||||
def create_ansible_module(self, **kwargs: t.Any) -> AnsibleModule:
|
||||
return self.create_ansible_module_helper(AnsibleModule, (), **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import re
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
@@ -32,7 +33,7 @@ ASN1_STRING_REGEX = re.compile(
|
||||
)
|
||||
|
||||
|
||||
class TagClass:
|
||||
class TagClass(enum.Enum):
|
||||
universal = 0
|
||||
application = 1
|
||||
context_specific = 2
|
||||
@@ -40,11 +41,11 @@ class TagClass:
|
||||
|
||||
|
||||
# Universal tag numbers that can be encoded.
|
||||
class TagNumber:
|
||||
class TagNumber(enum.Enum):
|
||||
utf8_string = 12
|
||||
|
||||
|
||||
def _pack_octet_integer(value):
|
||||
def _pack_octet_integer(value: int) -> bytes:
|
||||
"""Packs an integer value into 1 or multiple octets."""
|
||||
# NOTE: This is *NOT* the same as packing an ASN.1 INTEGER like value.
|
||||
octets = bytearray()
|
||||
@@ -66,7 +67,7 @@ def _pack_octet_integer(value):
|
||||
return bytes(octets)
|
||||
|
||||
|
||||
def serialize_asn1_string_as_der(value):
|
||||
def serialize_asn1_string_as_der(value: str) -> bytes:
|
||||
"""Deserializes an ASN.1 string to a DER encoded byte string."""
|
||||
asn1_match = ASN1_STRING_REGEX.match(value)
|
||||
if not asn1_match:
|
||||
@@ -92,7 +93,7 @@ def serialize_asn1_string_as_der(value):
|
||||
b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value)
|
||||
|
||||
if tag_type:
|
||||
tag_class = {
|
||||
tag_class_enum = {
|
||||
"U": TagClass.universal,
|
||||
"A": TagClass.application,
|
||||
"P": TagClass.private,
|
||||
@@ -100,13 +101,15 @@ def serialize_asn1_string_as_der(value):
|
||||
}[tag_class]
|
||||
|
||||
# When adding support for more types this should be looked into further. For now it works with UTF8Strings.
|
||||
constructed = tag_type == "EXPLICIT" and tag_class != TagClass.universal
|
||||
b_value = pack_asn1(tag_class, constructed, int(tag_number), b_value)
|
||||
constructed = tag_type == "EXPLICIT" and tag_class_enum != TagClass.universal
|
||||
b_value = pack_asn1(tag_class_enum, constructed, int(tag_number), b_value)
|
||||
|
||||
return b_value
|
||||
|
||||
|
||||
def pack_asn1(tag_class, constructed, tag_number, b_data):
|
||||
def pack_asn1(
|
||||
tag_class: TagClass, constructed: bool, tag_number: TagNumber | int, b_data: bytes
|
||||
) -> bytes:
|
||||
"""Pack the value into an ASN.1 data structure.
|
||||
|
||||
The structure for an ASN.1 element is
|
||||
@@ -115,16 +118,15 @@ def pack_asn1(tag_class, constructed, tag_number, b_data):
|
||||
"""
|
||||
b_asn1_data = bytearray()
|
||||
|
||||
if tag_class < 0 or tag_class > 3:
|
||||
raise ValueError(f"tag_class must be between 0 and 3 not {tag_class}")
|
||||
|
||||
# Bit 8 and 7 denotes the class.
|
||||
identifier_octets = tag_class << 6
|
||||
identifier_octets = tag_class.value << 6
|
||||
# Bit 6 denotes whether the value is primitive or constructed.
|
||||
identifier_octets |= (1 if constructed else 0) << 5
|
||||
|
||||
# Bits 5-1 contain the tag number, if it cannot be encoded in these 5 bits
|
||||
# then they are set and another octet(s) is used to denote the tag number.
|
||||
if isinstance(tag_number, TagNumber):
|
||||
tag_number = tag_number.value
|
||||
if tag_number < 31:
|
||||
identifier_octets |= tag_number
|
||||
b_asn1_data.append(identifier_octets)
|
||||
|
||||
@@ -34,7 +34,7 @@ from __future__ import annotations
|
||||
# cryptography versions!
|
||||
|
||||
|
||||
def obj2txt(openssl_lib, openssl_ffi, obj):
|
||||
def obj2txt(openssl_lib, openssl_ffi, obj) -> str:
|
||||
# Set to 80 on the recommendation of
|
||||
# https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values
|
||||
#
|
||||
|
||||
@@ -7,9 +7,9 @@ from __future__ import annotations
|
||||
from ._objects_data import OID_MAP
|
||||
|
||||
|
||||
OID_LOOKUP = dict()
|
||||
NORMALIZE_NAMES = dict()
|
||||
NORMALIZE_NAMES_SHORT = dict()
|
||||
OID_LOOKUP: dict[str, str] = dict()
|
||||
NORMALIZE_NAMES: dict[str, str] = dict()
|
||||
NORMALIZE_NAMES_SHORT: dict[str, str] = dict()
|
||||
|
||||
for dotted, names in OID_MAP.items():
|
||||
for name in names:
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
LooseVersion as _LooseVersion,
|
||||
)
|
||||
@@ -21,6 +23,10 @@ from .basic import HAS_CRYPTOGRAPHY
|
||||
from .cryptography_support import CRYPTOGRAPHY_TIMEZONE, cryptography_decode_name
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
|
||||
# TODO: once cryptography has a _utc variant of InvalidityDate.invalidity_date, set this
|
||||
# to True and adjust get_invalidity_date() accordingly.
|
||||
# (https://github.com/pyca/cryptography/issues/10818)
|
||||
@@ -55,7 +61,9 @@ else:
|
||||
REVOCATION_REASON_MAP_INVERSE = dict()
|
||||
|
||||
|
||||
def cryptography_decode_revoked_certificate(cert):
|
||||
def cryptography_decode_revoked_certificate(
|
||||
cert: x509.RevokedCertificate,
|
||||
) -> dict[str, t.Any]:
|
||||
result = {
|
||||
"serial_number": cert.serial_number,
|
||||
"revocation_date": get_revocation_date(cert),
|
||||
@@ -67,27 +75,30 @@ def cryptography_decode_revoked_certificate(cert):
|
||||
"invalidity_date_critical": False,
|
||||
}
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(x509.CertificateIssuer)
|
||||
result["issuer"] = list(ext.value)
|
||||
result["issuer_critical"] = ext.critical
|
||||
ext_ci = cert.extensions.get_extension_for_class(x509.CertificateIssuer)
|
||||
result["issuer"] = list(ext_ci.value)
|
||||
result["issuer_critical"] = ext_ci.critical
|
||||
except x509.ExtensionNotFound:
|
||||
pass
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(x509.CRLReason)
|
||||
result["reason"] = ext.value.reason
|
||||
result["reason_critical"] = ext.critical
|
||||
ext_cr = cert.extensions.get_extension_for_class(x509.CRLReason)
|
||||
result["reason"] = ext_cr.value.reason
|
||||
result["reason_critical"] = ext_cr.critical
|
||||
except x509.ExtensionNotFound:
|
||||
pass
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(x509.InvalidityDate)
|
||||
result["invalidity_date"] = get_invalidity_date(ext.value)
|
||||
result["invalidity_date_critical"] = ext.critical
|
||||
ext_id = cert.extensions.get_extension_for_class(x509.InvalidityDate)
|
||||
result["invalidity_date"] = get_invalidity_date(ext_id.value)
|
||||
result["invalidity_date_critical"] = ext_id.critical
|
||||
except x509.ExtensionNotFound:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def cryptography_dump_revoked(entry, idn_rewrite="ignore"):
|
||||
def cryptography_dump_revoked(
|
||||
entry: dict[str, t.Any],
|
||||
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
|
||||
) -> dict[str, t.Any]:
|
||||
return {
|
||||
"serial_number": entry["serial_number"],
|
||||
"revocation_date": entry["revocation_date"].strftime(TIMESTAMP_FORMAT),
|
||||
@@ -115,48 +126,56 @@ def cryptography_dump_revoked(entry, idn_rewrite="ignore"):
|
||||
}
|
||||
|
||||
|
||||
def cryptography_get_signature_algorithm_oid_from_crl(crl):
|
||||
def cryptography_get_signature_algorithm_oid_from_crl(
|
||||
crl: x509.CertificateRevocationList,
|
||||
) -> x509.oid.ObjectIdentifier:
|
||||
try:
|
||||
return crl.signature_algorithm_oid
|
||||
except AttributeError:
|
||||
# Older cryptography versions do not have signature_algorithm_oid yet
|
||||
dotted = obj2txt(
|
||||
crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm
|
||||
crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm # type: ignore
|
||||
)
|
||||
return x509.oid.ObjectIdentifier(dotted)
|
||||
|
||||
|
||||
def get_next_update(obj):
|
||||
def get_next_update(obj: x509.CertificateRevocationList) -> datetime.datetime | None:
|
||||
if CRYPTOGRAPHY_TIMEZONE:
|
||||
return obj.next_update_utc
|
||||
return obj.next_update
|
||||
|
||||
|
||||
def get_last_update(obj):
|
||||
def get_last_update(obj: x509.CertificateRevocationList) -> datetime.datetime:
|
||||
if CRYPTOGRAPHY_TIMEZONE:
|
||||
return obj.last_update_utc
|
||||
return obj.last_update
|
||||
|
||||
|
||||
def get_revocation_date(obj):
|
||||
def get_revocation_date(obj: x509.RevokedCertificate) -> datetime.datetime:
|
||||
if CRYPTOGRAPHY_TIMEZONE:
|
||||
return obj.revocation_date_utc
|
||||
return obj.revocation_date
|
||||
|
||||
|
||||
def get_invalidity_date(obj):
|
||||
def get_invalidity_date(obj: x509.InvalidityDate) -> datetime.datetime:
|
||||
if CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE:
|
||||
return obj.invalidity_date_utc
|
||||
return obj.invalidity_date
|
||||
|
||||
|
||||
def set_next_update(builder, value):
|
||||
def set_next_update(
|
||||
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
|
||||
) -> x509.CertificateRevocationListBuilder:
|
||||
return builder.next_update(value)
|
||||
|
||||
|
||||
def set_last_update(builder, value):
|
||||
def set_last_update(
|
||||
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
|
||||
) -> x509.CertificateRevocationListBuilder:
|
||||
return builder.last_update(value)
|
||||
|
||||
|
||||
def set_revocation_date(builder, value):
|
||||
def set_revocation_date(
|
||||
builder: x509.RevokedCertificateBuilder, value: datetime.datetime
|
||||
) -> x509.RevokedCertificateBuilder:
|
||||
return builder.revocation_date(value)
|
||||
|
||||
@@ -9,6 +9,7 @@ import binascii
|
||||
import ipaddress
|
||||
import re
|
||||
import traceback
|
||||
import typing as t
|
||||
from urllib.parse import (
|
||||
ParseResult,
|
||||
urlparse,
|
||||
@@ -40,6 +41,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import cryptography.hazmat.primitives.asymmetric.dh
|
||||
import cryptography.hazmat.primitives.asymmetric.ed448
|
||||
import cryptography.hazmat.primitives.asymmetric.ed25519
|
||||
import cryptography.hazmat.primitives.asymmetric.rsa
|
||||
@@ -55,7 +57,7 @@ try:
|
||||
)
|
||||
except ImportError:
|
||||
# Error handled in the calling module.
|
||||
_load_pkcs12 = None
|
||||
_load_pkcs12 = None # type: ignore
|
||||
|
||||
try:
|
||||
import idna
|
||||
@@ -74,6 +76,50 @@ from .basic import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric.dh import DHPrivateKey, DHPublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.dsa import (
|
||||
DSAPrivateKey,
|
||||
DSAPublicKey,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
EllipticCurvePrivateKey,
|
||||
EllipticCurvePublicKey,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
||||
RSAPrivateKey,
|
||||
RSAPublicKey,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
CertificateIssuerPublicKeyTypes,
|
||||
CertificatePublicKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
PublicKeyTypes,
|
||||
)
|
||||
from cryptography.hazmat.primitives.serialization.pkcs12 import (
|
||||
PKCS12KeyAndCertificates,
|
||||
)
|
||||
|
||||
CertificatePrivateKeyTypes = (
|
||||
CertificateIssuerPrivateKeyTypes
|
||||
| cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey
|
||||
| cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
|
||||
)
|
||||
PublicKeyTypesWOEdwards = (
|
||||
DHPublicKey | DSAPublicKey | EllipticCurvePublicKey | RSAPublicKey
|
||||
)
|
||||
PrivateKeyTypesWOEdwards = (
|
||||
DHPrivateKey | DSAPrivateKey | EllipticCurvePrivateKey | RSAPrivateKey
|
||||
)
|
||||
else:
|
||||
PublicKeyTypesWOEdwards = None
|
||||
PrivateKeyTypesWOEdwards = None
|
||||
|
||||
|
||||
CRYPTOGRAPHY_TIMEZONE = False
|
||||
_CRYPTOGRAPHY_36_0_OR_NEWER = False
|
||||
if _HAS_CRYPTOGRAPHY:
|
||||
@@ -88,7 +134,9 @@ if _HAS_CRYPTOGRAPHY:
|
||||
DOTTED_OID = re.compile(r"^\d+(?:\.\d+)+$")
|
||||
|
||||
|
||||
def cryptography_get_extensions_from_cert(cert):
|
||||
def cryptography_get_extensions_from_cert(
|
||||
cert: x509.Certificate,
|
||||
) -> dict[str, dict[str, bool | str]]:
|
||||
result = dict()
|
||||
|
||||
if _CRYPTOGRAPHY_36_0_OR_NEWER:
|
||||
@@ -105,7 +153,7 @@ def cryptography_get_extensions_from_cert(cert):
|
||||
|
||||
backend = default_backend()
|
||||
|
||||
x509_obj = cert._x509
|
||||
x509_obj = cert._x509 # type: ignore
|
||||
# With cryptography 35.0.0, we can no longer use obj2txt. Unfortunately it still does
|
||||
# not allow to get the raw value of an extension, so we have to use this ugly hack:
|
||||
exts = list(cert.extensions)
|
||||
@@ -135,7 +183,9 @@ def cryptography_get_extensions_from_cert(cert):
|
||||
return result
|
||||
|
||||
|
||||
def cryptography_get_extensions_from_csr(csr):
|
||||
def cryptography_get_extensions_from_csr(
|
||||
csr: x509.CertificateSigningRequest,
|
||||
) -> dict[str, dict[str, bool | str]]:
|
||||
result = dict()
|
||||
|
||||
if _CRYPTOGRAPHY_36_0_OR_NEWER:
|
||||
@@ -153,7 +203,7 @@ def cryptography_get_extensions_from_csr(csr):
|
||||
|
||||
backend = default_backend()
|
||||
|
||||
extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req)
|
||||
extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req) # type: ignore
|
||||
extensions = backend._ffi.gc(
|
||||
extensions,
|
||||
lambda ext: backend._lib.sk_X509_EXTENSION_pop_free(
|
||||
@@ -175,7 +225,7 @@ def cryptography_get_extensions_from_csr(csr):
|
||||
crit = backend._lib.X509_EXTENSION_get_critical(ext)
|
||||
data = backend._lib.X509_EXTENSION_get_data(ext)
|
||||
backend.openssl_assert(data != backend._ffi.NULL)
|
||||
der = backend._ffi.buffer(data.data, data.length)[:]
|
||||
der: bytes = backend._ffi.buffer(data.data, data.length)[:] # type: ignore
|
||||
entry = dict(
|
||||
critical=(crit == 1),
|
||||
value=base64.b64encode(der).decode("ascii"),
|
||||
@@ -193,7 +243,7 @@ def cryptography_get_extensions_from_csr(csr):
|
||||
return result
|
||||
|
||||
|
||||
def cryptography_name_to_oid(name):
|
||||
def cryptography_name_to_oid(name: str) -> x509.oid.ObjectIdentifier:
|
||||
dotted = OID_LOOKUP.get(name)
|
||||
if dotted is None:
|
||||
if DOTTED_OID.match(name):
|
||||
@@ -202,7 +252,9 @@ def cryptography_name_to_oid(name):
|
||||
return x509.oid.ObjectIdentifier(dotted)
|
||||
|
||||
|
||||
def cryptography_oid_to_name(oid, short=False):
|
||||
def cryptography_oid_to_name(
|
||||
oid: x509.oid.ObjectIdentifier, short: bool = False
|
||||
) -> str:
|
||||
dotted_string = oid.dotted_string
|
||||
names = OID_MAP.get(dotted_string)
|
||||
if names:
|
||||
@@ -217,15 +269,22 @@ def cryptography_oid_to_name(oid, short=False):
|
||||
return NORMALIZE_NAMES.get(name, name)
|
||||
|
||||
|
||||
def _get_hex(bytesstr):
|
||||
def _get_hex(bytesstr: bytes) -> str:
|
||||
if bytesstr is None:
|
||||
return bytesstr
|
||||
data = binascii.hexlify(bytesstr)
|
||||
data = to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2)))
|
||||
return data
|
||||
return to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2)))
|
||||
|
||||
|
||||
def _parse_hex(bytesstr):
|
||||
@t.overload
|
||||
def _parse_hex(bytesstr: bytes | str) -> bytes: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def _parse_hex(bytesstr: bytes | str | None) -> bytes | None: ...
|
||||
|
||||
|
||||
def _parse_hex(bytesstr: bytes | str | None) -> bytes | None:
|
||||
if bytesstr is None:
|
||||
return bytesstr
|
||||
data = "".join(
|
||||
@@ -234,19 +293,20 @@ def _parse_hex(bytesstr):
|
||||
for p in to_text(bytesstr).split(":")
|
||||
]
|
||||
)
|
||||
data = binascii.unhexlify(data)
|
||||
return data
|
||||
return binascii.unhexlify(data)
|
||||
|
||||
|
||||
DN_COMPONENT_START_RE = re.compile(b"^ *([a-zA-z0-9.]+) *= *")
|
||||
DN_HEX_LETTER = b"0123456789abcdef"
|
||||
|
||||
|
||||
def _int_to_byte(value):
|
||||
def _int_to_byte(value: int) -> bytes:
|
||||
return bytes((value,))
|
||||
|
||||
|
||||
def _parse_dn_component(name, sep=b",", decode_remainder=True):
|
||||
def _parse_dn_component(
|
||||
name: bytes, sep: bytes = b",", decode_remainder: bool = True
|
||||
) -> tuple[x509.NameAttribute, bytes]:
|
||||
m = DN_COMPONENT_START_RE.match(name)
|
||||
if not m:
|
||||
raise OpenSSLObjectError(f'cannot start part in "{to_text(name)}"')
|
||||
@@ -305,7 +365,7 @@ def _parse_dn_component(name, sep=b",", decode_remainder=True):
|
||||
return x509.NameAttribute(oid, to_text(b"".join(decoded_name))), name[idx:]
|
||||
|
||||
|
||||
def _parse_dn(name):
|
||||
def _parse_dn(name: bytes) -> list[x509.NameAttribute]:
|
||||
"""
|
||||
Parse a Distinguished Name.
|
||||
|
||||
@@ -323,31 +383,33 @@ def _parse_dn(name):
|
||||
attribute, name = _parse_dn_component(name, sep=sep)
|
||||
except OpenSSLObjectError as e:
|
||||
raise OpenSSLObjectError(
|
||||
f'Error while parsing distinguished name "{to_text(original_name)}": {e}'
|
||||
f"Error while parsing distinguished name {to_text(original_name)!r}: {e}"
|
||||
)
|
||||
result.append(attribute)
|
||||
if name:
|
||||
if name[0:1] != sep or len(name) < 2:
|
||||
raise OpenSSLObjectError(
|
||||
f'Error while parsing distinguished name "{to_text(original_name)}": unexpected end of string'
|
||||
f"Error while parsing distinguished name {to_text(original_name)!r}: unexpected end of string"
|
||||
)
|
||||
name = name[1:]
|
||||
return result
|
||||
|
||||
|
||||
def cryptography_parse_relative_distinguished_name(rdn):
|
||||
def cryptography_parse_relative_distinguished_name(
|
||||
rdn: list[str | bytes],
|
||||
) -> cryptography.x509.RelativeDistinguishedName:
|
||||
names = []
|
||||
for part in rdn:
|
||||
try:
|
||||
names.append(_parse_dn_component(to_bytes(part), decode_remainder=False)[0])
|
||||
except OpenSSLObjectError as e:
|
||||
raise OpenSSLObjectError(
|
||||
f'Error while parsing relative distinguished name "{part}": {e}'
|
||||
f"Error while parsing relative distinguished name {to_text(part)!r}: {e}"
|
||||
)
|
||||
return cryptography.x509.RelativeDistinguishedName(names)
|
||||
|
||||
|
||||
def _is_ascii(value):
|
||||
def _is_ascii(value: str) -> bool:
|
||||
"""Check whether the Unicode string `value` contains only ASCII characters."""
|
||||
try:
|
||||
value.encode("ascii")
|
||||
@@ -356,7 +418,7 @@ def _is_ascii(value):
|
||||
return False
|
||||
|
||||
|
||||
def _adjust_idn(value, idn_rewrite):
|
||||
def _adjust_idn(value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]) -> str:
|
||||
if idn_rewrite == "ignore" or not value:
|
||||
return value
|
||||
if idn_rewrite == "idna" and _is_ascii(value):
|
||||
@@ -399,16 +461,20 @@ def _adjust_idn(value, idn_rewrite):
|
||||
return ".".join(parts)
|
||||
|
||||
|
||||
def _adjust_idn_email(value, idn_rewrite):
|
||||
def _adjust_idn_email(
|
||||
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
|
||||
) -> str:
|
||||
idx = value.find("@")
|
||||
if idx < 0:
|
||||
return value
|
||||
return f"{value[:idx]}@{_adjust_idn(value[idx + 1:], idn_rewrite)}"
|
||||
|
||||
|
||||
def _adjust_idn_url(value, idn_rewrite):
|
||||
def _adjust_idn_url(
|
||||
value: str, idn_rewrite: t.Literal["ignore", "idna", "unicode"]
|
||||
) -> str:
|
||||
url = urlparse(value)
|
||||
host = _adjust_idn(url.hostname, idn_rewrite)
|
||||
host = _adjust_idn(url.hostname, idn_rewrite) if url.hostname else None
|
||||
if url.username is not None and url.password is not None:
|
||||
host = f"{url.username}:{url.password}@{host}"
|
||||
elif url.username is not None:
|
||||
@@ -418,7 +484,7 @@ def _adjust_idn_url(value, idn_rewrite):
|
||||
return urlunparse(
|
||||
ParseResult(
|
||||
scheme=url.scheme,
|
||||
netloc=host,
|
||||
netloc=host or "",
|
||||
path=url.path,
|
||||
params=url.params,
|
||||
query=url.query,
|
||||
@@ -427,7 +493,9 @@ def _adjust_idn_url(value, idn_rewrite):
|
||||
)
|
||||
|
||||
|
||||
def cryptography_get_name(name, what="Subject Alternative Name"):
|
||||
def cryptography_get_name(
|
||||
name: str, what: str = "Subject Alternative Name"
|
||||
) -> x509.GeneralName:
|
||||
"""
|
||||
Given a name string, returns a cryptography x509.GeneralName object.
|
||||
Raises an OpenSSLObjectError if the name is unknown or cannot be parsed.
|
||||
@@ -490,7 +558,7 @@ def cryptography_get_name(name, what="Subject Alternative Name"):
|
||||
)
|
||||
|
||||
|
||||
def _dn_escape_value(value):
|
||||
def _dn_escape_value(value: str) -> str:
|
||||
"""
|
||||
Escape Distinguished Name's attribute value.
|
||||
"""
|
||||
@@ -505,7 +573,10 @@ def _dn_escape_value(value):
|
||||
return value
|
||||
|
||||
|
||||
def cryptography_decode_name(name, idn_rewrite="ignore"):
|
||||
def cryptography_decode_name(
|
||||
name: x509.GeneralName,
|
||||
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
|
||||
) -> str:
|
||||
"""
|
||||
Given a cryptography x509.GeneralName object, returns a string.
|
||||
Raises an OpenSSLObjectError if the name is not supported.
|
||||
@@ -529,7 +600,7 @@ def cryptography_decode_name(name, idn_rewrite="ignore"):
|
||||
# list needs to be reversed, and joined by commas
|
||||
return "dirName:" + ",".join(
|
||||
[
|
||||
f"{to_text(cryptography_oid_to_name(attribute.oid, short=True))}={_dn_escape_value(attribute.value)}"
|
||||
f"{to_text(cryptography_oid_to_name(attribute.oid, short=True))}={_dn_escape_value(to_text(attribute.value))}"
|
||||
for attribute in reversed(list(name.value))
|
||||
]
|
||||
)
|
||||
@@ -540,7 +611,7 @@ def cryptography_decode_name(name, idn_rewrite="ignore"):
|
||||
raise OpenSSLObjectError(f'Cannot decode name "{name}"')
|
||||
|
||||
|
||||
def _cryptography_get_keyusage(usage):
|
||||
def _cryptography_get_keyusage(usage: str) -> str:
|
||||
"""
|
||||
Given a key usage identifier string, returns the parameter name used by cryptography's x509.KeyUsage().
|
||||
Raises an OpenSSLObjectError if the identifier is unknown.
|
||||
@@ -566,7 +637,7 @@ def _cryptography_get_keyusage(usage):
|
||||
raise OpenSSLObjectError(f'Unknown key usage "{usage}"')
|
||||
|
||||
|
||||
def cryptography_parse_key_usage_params(usages):
|
||||
def cryptography_parse_key_usage_params(usages: t.Iterable[str]) -> dict[str, bool]:
|
||||
"""
|
||||
Given a list of key usage identifier strings, returns the parameters for cryptography's x509.KeyUsage().
|
||||
Raises an OpenSSLObjectError if an identifier is unknown.
|
||||
@@ -587,13 +658,15 @@ def cryptography_parse_key_usage_params(usages):
|
||||
return params
|
||||
|
||||
|
||||
def cryptography_get_basic_constraints(constraints):
|
||||
def cryptography_get_basic_constraints(
|
||||
constraints: t.Iterable[str] | None,
|
||||
) -> tuple[bool, int | None]:
|
||||
"""
|
||||
Given a list of constraints, returns a tuple (ca, path_length).
|
||||
Raises an OpenSSLObjectError if a constraint is unknown or cannot be parsed.
|
||||
"""
|
||||
ca = False
|
||||
path_length = None
|
||||
path_length: int | None = None
|
||||
if constraints:
|
||||
for constraint in constraints:
|
||||
if constraint.startswith("CA:"):
|
||||
@@ -618,7 +691,9 @@ def cryptography_get_basic_constraints(constraints):
|
||||
return ca, path_length
|
||||
|
||||
|
||||
def cryptography_key_needs_digest_for_signing(key):
|
||||
def cryptography_key_needs_digest_for_signing(
|
||||
key: CertificateIssuerPrivateKeyTypes,
|
||||
) -> bool:
|
||||
"""Tests whether the given private key requires a digest algorithm for signing.
|
||||
|
||||
Ed25519 and Ed448 keys do not; they need None to be passed as the digest algorithm.
|
||||
@@ -632,19 +707,27 @@ def cryptography_key_needs_digest_for_signing(key):
|
||||
return True
|
||||
|
||||
|
||||
def _compare_public_keys(key1, key2, clazz):
|
||||
def _compare_public_keys(
|
||||
key1: PublicKeyTypes, key2: PublicKeyTypes, clazz: type[PublicKeyTypes]
|
||||
) -> bool | None:
|
||||
a = isinstance(key1, clazz)
|
||||
b = isinstance(key2, clazz)
|
||||
if not (a or b):
|
||||
return None
|
||||
if not a or not b:
|
||||
return False
|
||||
a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
||||
b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
||||
return a == b
|
||||
a_bytes = key1.public_bytes(
|
||||
serialization.Encoding.Raw, serialization.PublicFormat.Raw
|
||||
)
|
||||
b_bytes = key2.public_bytes(
|
||||
serialization.Encoding.Raw, serialization.PublicFormat.Raw
|
||||
)
|
||||
return a_bytes == b_bytes
|
||||
|
||||
|
||||
def cryptography_compare_public_keys(key1, key2):
|
||||
def cryptography_compare_public_keys(
|
||||
key1: PublicKeyTypes, key2: PublicKeyTypes
|
||||
) -> bool:
|
||||
"""Tests whether two public keys are the same.
|
||||
|
||||
Needs special logic for Ed25519 and Ed448 keys, since they do not have public_numbers().
|
||||
@@ -654,6 +737,13 @@ def cryptography_compare_public_keys(key1, key2):
|
||||
key2,
|
||||
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
res = _compare_public_keys(
|
||||
key1,
|
||||
key2,
|
||||
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
res = _compare_public_keys(
|
||||
@@ -661,10 +751,20 @@ def cryptography_compare_public_keys(key1, key2):
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
return key1.public_numbers() == key2.public_numbers()
|
||||
res = _compare_public_keys(
|
||||
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
return (
|
||||
t.cast(PublicKeyTypesWOEdwards, key1).public_numbers()
|
||||
== t.cast(PublicKeyTypesWOEdwards, key2).public_numbers()
|
||||
)
|
||||
|
||||
|
||||
def _compare_private_keys(key1, key2, clazz):
|
||||
def _compare_private_keys(
|
||||
key1: PrivateKeyTypes, key2: PrivateKeyTypes, clazz: type[PrivateKeyTypes]
|
||||
) -> bool | None:
|
||||
a = isinstance(key1, clazz)
|
||||
b = isinstance(key2, clazz)
|
||||
if not (a or b):
|
||||
@@ -672,20 +772,22 @@ def _compare_private_keys(key1, key2, clazz):
|
||||
if not a or not b:
|
||||
return False
|
||||
encryption_algorithm = cryptography.hazmat.primitives.serialization.NoEncryption()
|
||||
a = key1.private_bytes(
|
||||
a_bytes = key1.private_bytes(
|
||||
serialization.Encoding.Raw,
|
||||
serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
||||
b = key2.private_bytes(
|
||||
b_bytes = key2.private_bytes(
|
||||
serialization.Encoding.Raw,
|
||||
serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
||||
return a == b
|
||||
return a_bytes == b_bytes
|
||||
|
||||
|
||||
def cryptography_compare_private_keys(key1, key2):
|
||||
def cryptography_compare_private_keys(
|
||||
key1: PrivateKeyTypes, key2: PrivateKeyTypes
|
||||
) -> bool:
|
||||
"""Tests whether two private keys are the same.
|
||||
|
||||
Needs special logic for Ed25519, X25519, and Ed448 keys, since they do not have private_numbers().
|
||||
@@ -714,25 +816,39 @@ def cryptography_compare_private_keys(key1, key2):
|
||||
)
|
||||
if res is not None:
|
||||
return res
|
||||
return key1.private_numbers() == key2.private_numbers()
|
||||
return (
|
||||
t.cast(PrivateKeyTypesWOEdwards, key1).private_numbers()
|
||||
== t.cast(PrivateKeyTypesWOEdwards, key2).private_numbers()
|
||||
)
|
||||
|
||||
|
||||
def parse_pkcs12(pkcs12_bytes, passphrase=None):
|
||||
def parse_pkcs12(pkcs12_bytes: bytes, passphrase: bytes | str | None = None) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
bytes | None,
|
||||
]:
|
||||
"""Returns a tuple (private_key, certificate, additional_certificates, friendly_name)."""
|
||||
passphrase_bytes = None
|
||||
if passphrase is not None:
|
||||
passphrase = to_bytes(passphrase)
|
||||
passphrase_bytes = to_bytes(passphrase)
|
||||
|
||||
# Main code for cryptography 36.0.0 and forward
|
||||
if _load_pkcs12 is not None:
|
||||
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase)
|
||||
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase_bytes)
|
||||
|
||||
if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"):
|
||||
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase)
|
||||
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase_bytes)
|
||||
|
||||
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase)
|
||||
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase_bytes)
|
||||
|
||||
|
||||
def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None):
|
||||
def _parse_pkcs12_36_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
bytes | None,
|
||||
]:
|
||||
# Requires cryptography 36.0.0 or newer
|
||||
pkcs12 = _load_pkcs12(pkcs12_bytes, passphrase)
|
||||
additional_certificates = [cert.certificate for cert in pkcs12.additional_certs]
|
||||
@@ -745,7 +861,12 @@ def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None):
|
||||
return private_key, certificate, additional_certificates, friendly_name
|
||||
|
||||
|
||||
def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None):
|
||||
def _parse_pkcs12_35_0_0(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
bytes | None,
|
||||
]:
|
||||
# Backwards compatibility code for cryptography 35.x
|
||||
private_key, certificate, additional_certificates = _load_key_and_certificates(
|
||||
pkcs12_bytes, passphrase
|
||||
@@ -787,7 +908,12 @@ def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None):
|
||||
return private_key, certificate, additional_certificates, friendly_name
|
||||
|
||||
|
||||
def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None):
|
||||
def _parse_pkcs12_legacy(pkcs12_bytes: bytes, passphrase: bytes | None = None) -> tuple[
|
||||
PrivateKeyTypes | None,
|
||||
x509.Certificate | None,
|
||||
list[x509.Certificate],
|
||||
bytes | None,
|
||||
]:
|
||||
# Backwards compatibility code for cryptography < 35.0.0
|
||||
private_key, certificate, additional_certificates = _load_key_and_certificates(
|
||||
pkcs12_bytes, passphrase
|
||||
@@ -796,14 +922,19 @@ def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None):
|
||||
friendly_name = None
|
||||
if certificate:
|
||||
# See https://github.com/pyca/cryptography/issues/5760#issuecomment-842687238
|
||||
backend = certificate._backend
|
||||
maybe_name = backend._lib.X509_alias_get0(certificate._x509, backend._ffi.NULL)
|
||||
backend = certificate._backend # type: ignore
|
||||
maybe_name = backend._lib.X509_alias_get0(certificate._x509, backend._ffi.NULL) # type: ignore
|
||||
if maybe_name != backend._ffi.NULL:
|
||||
friendly_name = backend._ffi.string(maybe_name)
|
||||
return private_key, certificate, additional_certificates, friendly_name
|
||||
|
||||
|
||||
def cryptography_verify_signature(signature, data, hash_algorithm, signer_public_key):
|
||||
def cryptography_verify_signature(
|
||||
signature: bytes,
|
||||
data: bytes,
|
||||
hash_algorithm: hashes.HashAlgorithm | None,
|
||||
signer_public_key: PublicKeyTypes,
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether the given signature of the given data was signed by the given public key object.
|
||||
"""
|
||||
@@ -812,6 +943,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
|
||||
signer_public_key,
|
||||
cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey,
|
||||
):
|
||||
if hash_algorithm is None:
|
||||
raise OpenSSLObjectError("Need hash_algorithm for RSA keys")
|
||||
signer_public_key.verify(
|
||||
signature, data, padding.PKCS1v15(), hash_algorithm
|
||||
)
|
||||
@@ -820,6 +953,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
|
||||
signer_public_key,
|
||||
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey,
|
||||
):
|
||||
if hash_algorithm is None:
|
||||
raise OpenSSLObjectError("Need hash_algorithm for ECC keys")
|
||||
signer_public_key.verify(
|
||||
signature,
|
||||
data,
|
||||
@@ -830,6 +965,8 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
|
||||
signer_public_key,
|
||||
cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey,
|
||||
):
|
||||
if hash_algorithm is None:
|
||||
raise OpenSSLObjectError("Need hash_algorithm for DSA keys")
|
||||
signer_public_key.verify(signature, data, hash_algorithm)
|
||||
return True
|
||||
if isinstance(
|
||||
@@ -851,7 +988,9 @@ def cryptography_verify_signature(signature, data, hash_algorithm, signer_public
|
||||
return False
|
||||
|
||||
|
||||
def cryptography_verify_certificate_signature(certificate, signer_public_key):
|
||||
def cryptography_verify_certificate_signature(
|
||||
certificate: x509.Certificate, signer_public_key: PublicKeyTypes
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether the given X509 certificate object was signed by the given public key object.
|
||||
"""
|
||||
@@ -863,21 +1002,65 @@ def cryptography_verify_certificate_signature(certificate, signer_public_key):
|
||||
)
|
||||
|
||||
|
||||
def get_not_valid_after(obj):
|
||||
def get_not_valid_after(obj: x509.Certificate) -> datetime.datetime:
|
||||
if CRYPTOGRAPHY_TIMEZONE:
|
||||
return obj.not_valid_after_utc
|
||||
return obj.not_valid_after
|
||||
|
||||
|
||||
def get_not_valid_before(obj):
|
||||
def get_not_valid_before(obj: x509.Certificate) -> datetime.datetime:
|
||||
if CRYPTOGRAPHY_TIMEZONE:
|
||||
return obj.not_valid_before_utc
|
||||
return obj.not_valid_before
|
||||
|
||||
|
||||
def set_not_valid_after(builder, value):
|
||||
def set_not_valid_after(
|
||||
builder: x509.CertificateBuilder, value: datetime.datetime
|
||||
) -> x509.CertificateBuilder:
|
||||
return builder.not_valid_after(value)
|
||||
|
||||
|
||||
def set_not_valid_before(builder, value):
|
||||
def set_not_valid_before(
|
||||
builder: x509.CertificateBuilder, value: datetime.datetime
|
||||
) -> x509.CertificateBuilder:
|
||||
return builder.not_valid_before(value)
|
||||
|
||||
|
||||
def is_potential_certificate_private_key(
|
||||
key: PrivateKeyTypes,
|
||||
) -> t.TypeGuard[CertificatePrivateKeyTypes]:
|
||||
return not isinstance(
|
||||
key, cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey
|
||||
)
|
||||
|
||||
|
||||
def is_potential_certificate_issuer_private_key(
|
||||
key: PrivateKeyTypes,
|
||||
) -> t.TypeGuard[CertificateIssuerPrivateKeyTypes]:
|
||||
return not isinstance(
|
||||
key,
|
||||
(
|
||||
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
|
||||
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
|
||||
cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def is_potential_certificate_public_key(
|
||||
key: PublicKeyTypes,
|
||||
) -> t.TypeGuard[CertificatePublicKeyTypes]:
|
||||
return not isinstance(key, DHPublicKey)
|
||||
|
||||
|
||||
def is_potential_certificate_issuer_public_key(
|
||||
key: PublicKeyTypes,
|
||||
) -> t.TypeGuard[CertificateIssuerPublicKeyTypes]:
|
||||
return not isinstance(
|
||||
key,
|
||||
(
|
||||
cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey,
|
||||
cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey,
|
||||
cryptography.hazmat.primitives.asymmetric.dh.DHPublicKey,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def binary_exp_mod(f, e, m):
|
||||
def binary_exp_mod(f: int, e: int, m: int) -> int:
|
||||
"""Computes f^e mod m in O(log e) multiplications modulo m."""
|
||||
# Compute len_e = floor(log_2(e))
|
||||
len_e = -1
|
||||
@@ -22,14 +22,14 @@ def binary_exp_mod(f, e, m):
|
||||
return result
|
||||
|
||||
|
||||
def simple_gcd(a, b):
|
||||
def simple_gcd(a: int, b: int) -> int:
|
||||
"""Compute GCD of its two inputs."""
|
||||
while b != 0:
|
||||
a, b = b, a % b
|
||||
return a
|
||||
|
||||
|
||||
def quick_is_not_prime(n):
|
||||
def quick_is_not_prime(n: int) -> bool:
|
||||
"""Does some quick checks to see if we can poke a hole into the primality of n.
|
||||
|
||||
A result of `False` does **not** mean that the number is prime; it just means
|
||||
@@ -97,7 +97,7 @@ def quick_is_not_prime(n):
|
||||
return False
|
||||
|
||||
|
||||
def count_bytes(no):
|
||||
def count_bytes(no: int) -> int:
|
||||
"""
|
||||
Given an integer, compute the number of bytes necessary to store its absolute value.
|
||||
"""
|
||||
@@ -107,7 +107,7 @@ def count_bytes(no):
|
||||
return (no.bit_length() + 7) // 8
|
||||
|
||||
|
||||
def count_bits(no):
|
||||
def count_bits(no: int) -> int:
|
||||
"""
|
||||
Given an integer, compute the number of bits necessary to store its absolute value.
|
||||
"""
|
||||
@@ -117,19 +117,7 @@ def count_bits(no):
|
||||
return no.bit_length()
|
||||
|
||||
|
||||
def _convert_int_to_bytes(count, no):
|
||||
return no.to_bytes(count, byteorder="big")
|
||||
|
||||
|
||||
def _convert_bytes_to_int(data):
|
||||
return int.from_bytes(data, byteorder="big", signed=False)
|
||||
|
||||
|
||||
def _to_hex(no):
|
||||
return f"{no:x}"
|
||||
|
||||
|
||||
def convert_int_to_bytes(no, count=None):
|
||||
def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
|
||||
"""
|
||||
Convert the absolute value of an integer to a byte string in network byte order.
|
||||
|
||||
@@ -142,10 +130,10 @@ def convert_int_to_bytes(no, count=None):
|
||||
no = abs(no)
|
||||
if count is None:
|
||||
count = count_bytes(no)
|
||||
return _convert_int_to_bytes(count, no)
|
||||
return no.to_bytes(count, byteorder="big")
|
||||
|
||||
|
||||
def convert_int_to_hex(no, digits=None):
|
||||
def convert_int_to_hex(no: int, digits: int | None = None) -> str:
|
||||
"""
|
||||
Convert the absolute value of an integer to a string of hexadecimal digits.
|
||||
|
||||
@@ -154,14 +142,14 @@ def convert_int_to_hex(no, digits=None):
|
||||
the string will be longer.
|
||||
"""
|
||||
no = abs(no)
|
||||
value = _to_hex(no)
|
||||
value = f"{no:x}"
|
||||
if digits is not None and len(value) < digits:
|
||||
value = "0" * (digits - len(value)) + value
|
||||
return value
|
||||
|
||||
|
||||
def convert_bytes_to_int(data):
|
||||
def convert_bytes_to_int(data: bytes) -> int:
|
||||
"""
|
||||
Convert a byte string to an unsigned integer in network byte order.
|
||||
"""
|
||||
return _convert_bytes_to_int(data)
|
||||
return int.from_bytes(data, byteorder="big", signed=False)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
ArgumentSpec,
|
||||
@@ -24,8 +25,8 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
|
||||
load_certificate,
|
||||
load_certificate_privatekey,
|
||||
load_certificate_request,
|
||||
load_privatekey,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
|
||||
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
|
||||
@@ -33,6 +34,17 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ..cryptography_support import CertificatePrivateKeyTypes
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -47,41 +59,45 @@ class CertificateError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
|
||||
self.force = module.params["force"]
|
||||
self.ignore_timestamps = module.params["ignore_timestamps"]
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
if self.privatekey_content is not None:
|
||||
self.privatekey_content = self.privatekey_content.encode("utf-8")
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.csr_path = module.params["csr_path"]
|
||||
self.csr_content = module.params["csr_content"]
|
||||
if self.csr_content is not None:
|
||||
self.csr_content = self.csr_content.encode("utf-8")
|
||||
self.force: bool = module.params["force"]
|
||||
self.ignore_timestamps: bool = module.params["ignore_timestamps"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
if privatekey_content is not None:
|
||||
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
|
||||
else:
|
||||
self.privatekey_content = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
self.csr_path: str | None = module.params["csr_path"]
|
||||
csr_content = module.params["csr_content"]
|
||||
if csr_content is not None:
|
||||
self.csr_content: bytes | None = csr_content.encode("utf-8")
|
||||
else:
|
||||
self.csr_content = None
|
||||
|
||||
# The following are default values which make sure check() works as
|
||||
# before if providers do not explicitly change these properties.
|
||||
self.create_subject_key_identifier = "never_create"
|
||||
self.create_authority_key_identifier = False
|
||||
self.create_subject_key_identifier: str = "never_create"
|
||||
self.create_authority_key_identifier: bool = False
|
||||
|
||||
self.privatekey = None
|
||||
self.csr = None
|
||||
self.cert = None
|
||||
self.existing_certificate = None
|
||||
self.existing_certificate_bytes = None
|
||||
self.privatekey: CertificatePrivateKeyTypes | None = None
|
||||
self.csr: x509.CertificateSigningRequest | None = None
|
||||
self.cert: x509.Certificate | None = None
|
||||
self.existing_certificate: x509.Certificate | None = None
|
||||
self.existing_certificate_bytes: bytes | None = None
|
||||
|
||||
self.check_csr_subject = True
|
||||
self.check_csr_extensions = True
|
||||
self.check_csr_subject: bool = True
|
||||
self.check_csr_extensions: bool = True
|
||||
|
||||
self.diff_before = self._get_info(None)
|
||||
self.diff_after = self._get_info(None)
|
||||
|
||||
def _get_info(self, data):
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return dict()
|
||||
return {}
|
||||
try:
|
||||
result = get_certificate_info(
|
||||
self.module, data, prefer_one_fingerprint=True
|
||||
@@ -92,34 +108,34 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
return dict(can_parse_certificate=False)
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_certificate(self):
|
||||
def generate_certificate(self) -> None:
|
||||
"""(Re-)Generate certificate."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_certificate_data(self):
|
||||
def get_certificate_data(self) -> bytes:
|
||||
"""Return bytes for self.cert."""
|
||||
pass
|
||||
|
||||
def set_existing(self, certificate_bytes):
|
||||
def set_existing(self, certificate_bytes: bytes | None) -> None:
|
||||
"""Set existing certificate bytes. None indicates that the key does not exist."""
|
||||
self.existing_certificate_bytes = certificate_bytes
|
||||
self.diff_after = self.diff_before = self._get_info(
|
||||
self.existing_certificate_bytes
|
||||
)
|
||||
|
||||
def has_existing(self):
|
||||
def has_existing(self) -> bool:
|
||||
"""Query whether an existing certificate is/has been there."""
|
||||
return self.existing_certificate_bytes is not None
|
||||
|
||||
def _ensure_private_key_loaded(self):
|
||||
def _ensure_private_key_loaded(self) -> None:
|
||||
"""Load the provided private key into self.privatekey."""
|
||||
if self.privatekey is not None:
|
||||
return
|
||||
if self.privatekey_path is None and self.privatekey_content is None:
|
||||
return
|
||||
try:
|
||||
self.privatekey = load_privatekey(
|
||||
self.privatekey = load_certificate_privatekey(
|
||||
path=self.privatekey_path,
|
||||
content=self.privatekey_content,
|
||||
passphrase=self.privatekey_passphrase,
|
||||
@@ -127,7 +143,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
except OpenSSLBadPassphraseError as exc:
|
||||
raise CertificateError(exc)
|
||||
|
||||
def _ensure_csr_loaded(self):
|
||||
def _ensure_csr_loaded(self) -> None:
|
||||
"""Load the CSR into self.csr."""
|
||||
if self.csr is not None:
|
||||
return
|
||||
@@ -138,7 +154,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
content=self.csr_content,
|
||||
)
|
||||
|
||||
def _ensure_existing_certificate_loaded(self):
|
||||
def _ensure_existing_certificate_loaded(self) -> None:
|
||||
"""Load the existing certificate into self.existing_certificate."""
|
||||
if self.existing_certificate is not None:
|
||||
return
|
||||
@@ -149,14 +165,28 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
content=self.existing_certificate_bytes,
|
||||
)
|
||||
|
||||
def _check_privatekey(self):
|
||||
def _check_privatekey(self) -> bool:
|
||||
"""Check whether provided parameters match, assuming self.existing_certificate and self.privatekey have been populated."""
|
||||
if self.existing_certificate is None:
|
||||
raise AssertionError(
|
||||
"Contract violation: existing_certificate has not been populated"
|
||||
)
|
||||
if self.privatekey is None:
|
||||
raise AssertionError(
|
||||
"Contract violation: privatekey has not been populated"
|
||||
)
|
||||
return cryptography_compare_public_keys(
|
||||
self.existing_certificate.public_key(), self.privatekey.public_key()
|
||||
)
|
||||
|
||||
def _check_csr(self):
|
||||
def _check_csr(self) -> bool:
|
||||
"""Check whether provided parameters match, assuming self.existing_certificate and self.csr have been populated."""
|
||||
if self.existing_certificate is None:
|
||||
raise AssertionError(
|
||||
"Contract violation: existing_certificate has not been populated"
|
||||
)
|
||||
if self.csr is None:
|
||||
raise AssertionError("Contract violation: csr has not been populated")
|
||||
# Verify that CSR is signed by certificate's private key
|
||||
if not self.csr.is_signature_valid:
|
||||
return False
|
||||
@@ -214,8 +244,14 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_subject_key_identifier(self):
|
||||
"""Check whether Subject Key Identifier matches, assuming self.existing_certificate has been populated."""
|
||||
def _check_subject_key_identifier(self) -> bool:
|
||||
"""Check whether Subject Key Identifier matches, assuming self.existing_certificate and self.csr have been populated."""
|
||||
if self.existing_certificate is None:
|
||||
raise AssertionError(
|
||||
"Contract violation: existing_certificate has not been populated"
|
||||
)
|
||||
if self.csr is None:
|
||||
raise AssertionError("Contract violation: csr has not been populated")
|
||||
# Get hold of certificate's SKI
|
||||
try:
|
||||
ext = self.existing_certificate.extensions.get_extension_for_class(
|
||||
@@ -247,7 +283,11 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
return False
|
||||
return True
|
||||
|
||||
def needs_regeneration(self, not_before=None, not_after=None):
|
||||
def needs_regeneration(
|
||||
self,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
"""Check whether a regeneration is necessary."""
|
||||
if self.force or self.existing_certificate_bytes is None:
|
||||
return True
|
||||
@@ -256,6 +296,7 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
self._ensure_existing_certificate_loaded()
|
||||
except Exception:
|
||||
return True
|
||||
assert self.existing_certificate is not None
|
||||
|
||||
# Check whether private key matches
|
||||
self._ensure_private_key_loaded()
|
||||
@@ -285,9 +326,12 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
return True
|
||||
return False
|
||||
|
||||
def dump(self, include_certificate):
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = {"privatekey": self.privatekey_path, "csr": self.csr_path}
|
||||
result: dict[str, t.Any] = {
|
||||
"privatekey": self.privatekey_path,
|
||||
"csr": self.csr_path,
|
||||
}
|
||||
# Get hold of certificate bytes
|
||||
certificate_bytes = self.existing_certificate_bytes
|
||||
if self.cert is not None:
|
||||
@@ -299,35 +343,33 @@ class CertificateBackend(metaclass=abc.ABCMeta):
|
||||
certificate_bytes.decode("utf-8") if certificate_bytes else None
|
||||
)
|
||||
|
||||
result["diff"] = dict(
|
||||
before=self.diff_before,
|
||||
after=self.diff_after,
|
||||
)
|
||||
result["diff"] = {
|
||||
"before": self.diff_before,
|
||||
"after": self.diff_after,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class CertificateProvider(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def validate_module_args(self, module):
|
||||
def validate_module_args(self, module: AnsibleModule) -> None:
|
||||
"""Check module arguments"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def needs_version_two_certs(self, module):
|
||||
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
|
||||
"""Whether the provider needs to create a version 2 certificate."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_backend(self, module):
|
||||
def create_backend(self, module: AnsibleModule) -> CertificateBackend:
|
||||
"""Create an implementation for a backend.
|
||||
|
||||
Return value must be instance of CertificateBackend.
|
||||
"""
|
||||
|
||||
|
||||
def select_backend(module, provider):
|
||||
"""
|
||||
:type module: AnsibleModule
|
||||
:type provider: CertificateProvider
|
||||
"""
|
||||
def select_backend(
|
||||
module: AnsibleModule, provider: CertificateProvider
|
||||
) -> CertificateBackend:
|
||||
provider.validate_module_args(module)
|
||||
|
||||
assert_required_cryptography_version(
|
||||
@@ -343,7 +385,7 @@ def select_backend(module, provider):
|
||||
return provider.create_backend(module)
|
||||
|
||||
|
||||
def get_certificate_argument_spec():
|
||||
def get_certificate_argument_spec() -> ArgumentSpec:
|
||||
return ArgumentSpec(
|
||||
argument_spec=dict(
|
||||
provider=dict(
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
|
||||
@@ -17,22 +18,30 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
class AcmeCertificateBackend(CertificateBackend):
|
||||
def __init__(self, module):
|
||||
super(AcmeCertificateBackend, self).__init__(module)
|
||||
self.accountkey_path = module.params["acme_accountkey_path"]
|
||||
self.challenge_path = module.params["acme_challenge_path"]
|
||||
self.use_chain = module.params["acme_chain"]
|
||||
self.acme_directory = module.params["acme_directory"]
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
if self.csr_content is None and self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for ownca provider"
|
||||
)
|
||||
if self.csr_content is None and not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
from ...argspec import ArgumentSpec
|
||||
|
||||
|
||||
class AcmeCertificateBackend(CertificateBackend):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(AcmeCertificateBackend, self).__init__(module)
|
||||
self.accountkey_path: str = module.params["acme_accountkey_path"]
|
||||
self.challenge_path: str = module.params["acme_challenge_path"]
|
||||
self.use_chain: bool = module.params["acme_chain"]
|
||||
self.acme_directory: str = module.params["acme_directory"]
|
||||
self.cert_bytes: bytes | None = None
|
||||
|
||||
if self.csr_content is None:
|
||||
if self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for ownca provider"
|
||||
)
|
||||
if not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
|
||||
if not os.path.exists(self.accountkey_path):
|
||||
raise CertificateError(
|
||||
@@ -46,7 +55,7 @@ class AcmeCertificateBackend(CertificateBackend):
|
||||
|
||||
self.acme_tiny_path = self.module.get_bin_path("acme-tiny", required=True)
|
||||
|
||||
def generate_certificate(self):
|
||||
def generate_certificate(self) -> None:
|
||||
"""(Re-)Generate certificate."""
|
||||
|
||||
command = [self.acme_tiny_path]
|
||||
@@ -77,22 +86,26 @@ class AcmeCertificateBackend(CertificateBackend):
|
||||
command.extend(["--directory-url", self.acme_directory])
|
||||
|
||||
try:
|
||||
self.cert = to_bytes(self.module.run_command(command, check_rc=True)[1])
|
||||
self.cert_bytes = to_bytes(
|
||||
self.module.run_command(command, check_rc=True)[1]
|
||||
)
|
||||
except OSError as exc:
|
||||
raise CertificateError(exc)
|
||||
|
||||
def get_certificate_data(self):
|
||||
def get_certificate_data(self) -> bytes:
|
||||
"""Return bytes for self.cert."""
|
||||
return self.cert
|
||||
if self.cert_bytes is None:
|
||||
raise AssertionError("Contract violation: cert_bytes is None")
|
||||
return self.cert_bytes
|
||||
|
||||
def dump(self, include_certificate):
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
result = super(AcmeCertificateBackend, self).dump(include_certificate)
|
||||
result["accountkey"] = self.accountkey_path
|
||||
return result
|
||||
|
||||
|
||||
class AcmeCertificateProvider(CertificateProvider):
|
||||
def validate_module_args(self, module):
|
||||
def validate_module_args(self, module: AnsibleModule) -> None:
|
||||
if module.params["acme_accountkey_path"] is None:
|
||||
module.fail_json(
|
||||
msg="The acme_accountkey_path option must be specified for the acme provider."
|
||||
@@ -102,14 +115,14 @@ class AcmeCertificateProvider(CertificateProvider):
|
||||
msg="The acme_challenge_path option must be specified for the acme provider."
|
||||
)
|
||||
|
||||
def needs_version_two_certs(self, module):
|
||||
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
|
||||
return False
|
||||
|
||||
def create_backend(self, module):
|
||||
def create_backend(self, module: AnsibleModule) -> AcmeCertificateBackend:
|
||||
return AcmeCertificateBackend(module)
|
||||
|
||||
|
||||
def add_acme_provider_to_argument_spec(argument_spec):
|
||||
def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
argument_spec.argument_spec["provider"]["choices"].append("acme")
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
|
||||
@@ -32,6 +33,12 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from ...argspec import ArgumentSpec
|
||||
|
||||
|
||||
try:
|
||||
from cryptography.x509.oid import NameOID
|
||||
except ImportError:
|
||||
@@ -39,7 +46,7 @@ except ImportError:
|
||||
|
||||
|
||||
class EntrustCertificateBackend(CertificateBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(EntrustCertificateBackend, self).__init__(module)
|
||||
self.trackingId = None
|
||||
self.notAfter = get_relative_time_option(
|
||||
@@ -48,16 +55,19 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
|
||||
if self.csr_content is None and self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for entrust provider"
|
||||
)
|
||||
if self.csr_content is None and not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
if self.csr_content is None:
|
||||
if self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for entrust provider"
|
||||
)
|
||||
if not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
|
||||
self._ensure_csr_loaded()
|
||||
if self.csr is None:
|
||||
raise CertificateError("CSR not provided")
|
||||
|
||||
# ECS API defaults to using the validated organization tied to the account.
|
||||
# We want to always force behavior of trying to use the organization provided in the CSR.
|
||||
@@ -93,9 +103,9 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
],
|
||||
)
|
||||
except SessionConfigurationException as e:
|
||||
module.fail_json(msg=f"Failed to initialize Entrust Provider: {e.message}")
|
||||
module.fail_json(msg=f"Failed to initialize Entrust Provider: {e}")
|
||||
|
||||
def generate_certificate(self):
|
||||
def generate_certificate(self) -> None:
|
||||
"""(Re-)Generate certificate."""
|
||||
body = {}
|
||||
|
||||
@@ -104,6 +114,7 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
# csr_content contains bytes
|
||||
body["csr"] = to_native(self.csr_content)
|
||||
else:
|
||||
assert self.csr_path is not None
|
||||
with open(self.csr_path, "r") as csr_file:
|
||||
body["csr"] = csr_file.read()
|
||||
|
||||
@@ -138,11 +149,15 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
content=self.cert_bytes,
|
||||
)
|
||||
|
||||
def get_certificate_data(self):
|
||||
def get_certificate_data(self) -> bytes:
|
||||
"""Return bytes for self.cert."""
|
||||
return self.cert_bytes
|
||||
|
||||
def needs_regeneration(self):
|
||||
def needs_regeneration(
|
||||
self,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
parent_check = super(EntrustCertificateBackend, self).needs_regeneration()
|
||||
|
||||
try:
|
||||
@@ -167,12 +182,12 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
|
||||
return parent_check
|
||||
|
||||
def _get_cert_details(self):
|
||||
cert_details = {}
|
||||
def _get_cert_details(self) -> dict[str, t.Any]:
|
||||
cert_details: dict[str, t.Any] = {}
|
||||
try:
|
||||
self._ensure_existing_certificate_loaded()
|
||||
except Exception:
|
||||
return
|
||||
return cert_details
|
||||
if self.existing_certificate:
|
||||
serial_number = f"{self.existing_certificate.serial_number:X}"
|
||||
expiry = get_not_valid_after(self.existing_certificate)
|
||||
@@ -203,17 +218,17 @@ class EntrustCertificateBackend(CertificateBackend):
|
||||
|
||||
|
||||
class EntrustCertificateProvider(CertificateProvider):
|
||||
def validate_module_args(self, module):
|
||||
def validate_module_args(self, module: AnsibleModule) -> None:
|
||||
pass
|
||||
|
||||
def needs_version_two_certs(self, module):
|
||||
def needs_version_two_certs(self, module: AnsibleModule) -> t.Literal[False]:
|
||||
return False
|
||||
|
||||
def create_backend(self, module):
|
||||
def create_backend(self, module: AnsibleModule) -> EntrustCertificateBackend:
|
||||
return EntrustCertificateBackend(module)
|
||||
|
||||
|
||||
def add_entrust_provider_to_argument_spec(argument_spec):
|
||||
def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
argument_spec.argument_spec["provider"]["choices"].append("entrust")
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
@@ -248,7 +263,7 @@ def add_entrust_provider_to_argument_spec(argument_spec):
|
||||
)
|
||||
)
|
||||
argument_spec.required_if.append(
|
||||
[
|
||||
(
|
||||
"provider",
|
||||
"entrust",
|
||||
[
|
||||
@@ -260,5 +275,5 @@ def add_entrust_provider_to_argument_spec(argument_spec):
|
||||
"entrust_api_client_cert_path",
|
||||
"entrust_api_client_cert_key_path",
|
||||
],
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import binascii
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
|
||||
@@ -34,6 +35,19 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
from ....plugin_utils.filter_module import FilterModuleMock
|
||||
from ...argspec import ArgumentSpec
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -48,93 +62,97 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
|
||||
|
||||
|
||||
class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module, content):
|
||||
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
|
||||
# content must be a bytes string
|
||||
self.module = module
|
||||
self.content = content
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_der_bytes(self):
|
||||
def _get_der_bytes(self) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_signature_algorithm(self):
|
||||
def _get_signature_algorithm(self) -> str:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_ordered(self):
|
||||
def _get_subject_ordered(self) -> list[list[str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_issuer_ordered(self):
|
||||
def _get_issuer_ordered(self) -> list[list[str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_version(self):
|
||||
def _get_version(self) -> int | str:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_key_usage(self):
|
||||
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extended_key_usage(self):
|
||||
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_basic_constraints(self):
|
||||
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_ocsp_must_staple(self):
|
||||
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_alt_name(self):
|
||||
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_not_before(self):
|
||||
def get_not_before(self) -> datetime.datetime:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_not_after(self):
|
||||
def get_not_after(self) -> datetime.datetime:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key_pem(self):
|
||||
def _get_public_key_pem(self) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key_object(self):
|
||||
def _get_public_key_object(self) -> PublicKeyTypes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_key_identifier(self):
|
||||
def _get_subject_key_identifier(self) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_authority_key_identifier(self):
|
||||
def _get_authority_key_identifier(
|
||||
self,
|
||||
) -> tuple[bytes | None, list[str] | None, int | None]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_serial_number(self):
|
||||
def _get_serial_number(self) -> int:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_all_extensions(self):
|
||||
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_ocsp_uri(self):
|
||||
def _get_ocsp_uri(self) -> str | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_issuer_uri(self):
|
||||
def _get_issuer_uri(self) -> str | None:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint=False, der_support_enabled=False):
|
||||
result = dict()
|
||||
def get_info(
|
||||
self, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
|
||||
) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
self.cert = load_certificate(
|
||||
None,
|
||||
content=self.content,
|
||||
@@ -194,16 +212,20 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
self._get_der_bytes(), prefer_one=prefer_one_fingerprint
|
||||
)
|
||||
|
||||
ski = self._get_subject_key_identifier()
|
||||
if ski is not None:
|
||||
ski = binascii.hexlify(ski).decode("ascii")
|
||||
ski_bytes = self._get_subject_key_identifier()
|
||||
if ski_bytes is not None:
|
||||
ski = binascii.hexlify(ski_bytes).decode("ascii")
|
||||
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
|
||||
else:
|
||||
ski = None
|
||||
result["subject_key_identifier"] = ski
|
||||
|
||||
aki, aci, acsn = self._get_authority_key_identifier()
|
||||
if aki is not None:
|
||||
aki = binascii.hexlify(aki).decode("ascii")
|
||||
aki_bytes, aci, acsn = self._get_authority_key_identifier()
|
||||
if aki_bytes is not None:
|
||||
aki = binascii.hexlify(aki_bytes).decode("ascii")
|
||||
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
|
||||
else:
|
||||
aki = None
|
||||
result["authority_key_identifier"] = aki
|
||||
result["authority_cert_issuer"] = aci
|
||||
result["authority_cert_serial_number"] = acsn
|
||||
@@ -219,36 +241,40 @@ class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
"""Validate the supplied cert, using the cryptography backend"""
|
||||
|
||||
def __init__(self, module, content):
|
||||
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
|
||||
super(CertificateInfoRetrievalCryptography, self).__init__(module, content)
|
||||
self.name_encoding = module.params.get("name_encoding", "ignore")
|
||||
|
||||
def _get_der_bytes(self):
|
||||
def _get_der_bytes(self) -> bytes:
|
||||
return self.cert.public_bytes(serialization.Encoding.DER)
|
||||
|
||||
def _get_signature_algorithm(self):
|
||||
def _get_signature_algorithm(self) -> str:
|
||||
return cryptography_oid_to_name(self.cert.signature_algorithm_oid)
|
||||
|
||||
def _get_subject_ordered(self):
|
||||
result = []
|
||||
def _get_subject_ordered(self) -> list[list[str]]:
|
||||
result: list[list[str]] = []
|
||||
for attribute in self.cert.subject:
|
||||
result.append([cryptography_oid_to_name(attribute.oid), attribute.value])
|
||||
result.append(
|
||||
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_issuer_ordered(self):
|
||||
def _get_issuer_ordered(self) -> list[list[str]]:
|
||||
result = []
|
||||
for attribute in self.cert.issuer:
|
||||
result.append([cryptography_oid_to_name(attribute.oid), attribute.value])
|
||||
result.append(
|
||||
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_version(self):
|
||||
def _get_version(self) -> int | str:
|
||||
if self.cert.version == x509.Version.v1:
|
||||
return 1
|
||||
if self.cert.version == x509.Version.v3:
|
||||
return 3
|
||||
return "unknown"
|
||||
|
||||
def _get_key_usage(self):
|
||||
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
current_key_ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.KeyUsage
|
||||
@@ -297,7 +323,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_extended_key_usage(self):
|
||||
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.ExtendedKeyUsage
|
||||
@@ -311,7 +337,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_basic_constraints(self):
|
||||
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.BasicConstraints
|
||||
@@ -324,7 +350,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_ocsp_must_staple(self):
|
||||
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
|
||||
try:
|
||||
tlsfeature_ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.TLSFeature
|
||||
@@ -336,7 +362,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_subject_alt_name(self):
|
||||
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
san_ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.SubjectAlternativeName
|
||||
@@ -349,22 +375,22 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def get_not_before(self):
|
||||
def get_not_before(self) -> datetime.datetime:
|
||||
return get_not_valid_before(self.cert)
|
||||
|
||||
def get_not_after(self):
|
||||
def get_not_after(self) -> datetime.datetime:
|
||||
return get_not_valid_after(self.cert)
|
||||
|
||||
def _get_public_key_pem(self):
|
||||
def _get_public_key_pem(self) -> bytes:
|
||||
return self.cert.public_key().public_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def _get_public_key_object(self):
|
||||
def _get_public_key_object(self) -> PublicKeyTypes:
|
||||
return self.cert.public_key()
|
||||
|
||||
def _get_subject_key_identifier(self):
|
||||
def _get_subject_key_identifier(self) -> bytes | None:
|
||||
try:
|
||||
ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.SubjectKeyIdentifier
|
||||
@@ -373,7 +399,9 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None
|
||||
|
||||
def _get_authority_key_identifier(self):
|
||||
def _get_authority_key_identifier(
|
||||
self,
|
||||
) -> tuple[bytes | None, list[str] | None, int | None]:
|
||||
try:
|
||||
ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.AuthorityKeyIdentifier
|
||||
@@ -392,13 +420,13 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, None, None
|
||||
|
||||
def _get_serial_number(self):
|
||||
def _get_serial_number(self) -> int:
|
||||
return self.cert.serial_number
|
||||
|
||||
def _get_all_extensions(self):
|
||||
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
|
||||
return cryptography_get_extensions_from_cert(self.cert)
|
||||
|
||||
def _get_ocsp_uri(self):
|
||||
def _get_ocsp_uri(self) -> str | None:
|
||||
try:
|
||||
ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.AuthorityInformationAccess
|
||||
@@ -411,7 +439,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
pass
|
||||
return None
|
||||
|
||||
def _get_issuer_uri(self):
|
||||
def _get_issuer_uri(self) -> str | None:
|
||||
try:
|
||||
ext = self.cert.extensions.get_extension_for_class(
|
||||
x509.AuthorityInformationAccess
|
||||
@@ -428,12 +456,16 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
|
||||
return None
|
||||
|
||||
|
||||
def get_certificate_info(module, content, prefer_one_fingerprint=False):
|
||||
def get_certificate_info(
|
||||
module: GeneralAnsibleModule, content: bytes, prefer_one_fingerprint: bool = False
|
||||
) -> dict[str, t.Any]:
|
||||
info = CertificateInfoRetrievalCryptography(module, content)
|
||||
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
|
||||
|
||||
|
||||
def select_backend(module, content):
|
||||
def select_backend(
|
||||
module: GeneralAnsibleModule, content: bytes
|
||||
) -> CertificateInfoRetrieval:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
from random import randrange
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -18,6 +19,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
|
||||
cryptography_verify_certificate_signature,
|
||||
get_not_valid_after,
|
||||
get_not_valid_before,
|
||||
is_potential_certificate_issuer_public_key,
|
||||
set_not_valid_after,
|
||||
set_not_valid_before,
|
||||
)
|
||||
@@ -28,7 +30,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
|
||||
load_certificate,
|
||||
load_privatekey,
|
||||
load_certificate_issuer_privatekey,
|
||||
select_message_digest,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
@@ -36,6 +38,17 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ...argspec import ArgumentSpec
|
||||
|
||||
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography import x509
|
||||
@@ -45,13 +58,13 @@ except ImportError:
|
||||
|
||||
|
||||
class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(OwnCACertificateBackendCryptography, self).__init__(module)
|
||||
|
||||
self.create_subject_key_identifier = module.params[
|
||||
"ownca_create_subject_key_identifier"
|
||||
]
|
||||
self.create_authority_key_identifier = module.params[
|
||||
self.create_subject_key_identifier: t.Literal[
|
||||
"create_if_not_provided", "always_create", "never_create"
|
||||
] = module.params["ownca_create_subject_key_identifier"]
|
||||
self.create_authority_key_identifier: bool = module.params[
|
||||
"ownca_create_authority_key_identifier"
|
||||
]
|
||||
self.notBefore = get_relative_time_option(
|
||||
@@ -65,31 +78,40 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
self.digest = select_message_digest(module.params["ownca_digest"])
|
||||
self.version = module.params["ownca_version"]
|
||||
self.version: int = module.params["ownca_version"]
|
||||
self.serial_number = x509.random_serial_number()
|
||||
self.ca_cert_path = module.params["ownca_path"]
|
||||
self.ca_cert_content = module.params["ownca_content"]
|
||||
if self.ca_cert_content is not None:
|
||||
self.ca_cert_content = self.ca_cert_content.encode("utf-8")
|
||||
self.ca_privatekey_path = module.params["ownca_privatekey_path"]
|
||||
self.ca_privatekey_content = module.params["ownca_privatekey_content"]
|
||||
if self.ca_privatekey_content is not None:
|
||||
self.ca_privatekey_content = self.ca_privatekey_content.encode("utf-8")
|
||||
self.ca_privatekey_passphrase = module.params["ownca_privatekey_passphrase"]
|
||||
self.ca_cert_path: str | None = module.params["ownca_path"]
|
||||
ca_cert_content: str | None = module.params["ownca_content"]
|
||||
if ca_cert_content is not None:
|
||||
self.ca_cert_content: bytes | None = ca_cert_content.encode("utf-8")
|
||||
else:
|
||||
self.ca_cert_content = None
|
||||
self.ca_privatekey_path: str | None = module.params["ownca_privatekey_path"]
|
||||
ca_privatekey_content: str | None = module.params["ownca_privatekey_content"]
|
||||
if ca_privatekey_content is not None:
|
||||
self.ca_privatekey_content: bytes | None = ca_privatekey_content.encode(
|
||||
"utf-8"
|
||||
)
|
||||
else:
|
||||
self.ca_privatekey_content = None
|
||||
self.ca_privatekey_passphrase: str | None = module.params[
|
||||
"ownca_privatekey_passphrase"
|
||||
]
|
||||
|
||||
if self.csr_content is None and self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for ownca provider"
|
||||
)
|
||||
if self.csr_content is None and not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
if self.ca_cert_content is None and not os.path.exists(self.ca_cert_path):
|
||||
if self.csr_content is None:
|
||||
if self.csr_path is None:
|
||||
raise CertificateError(
|
||||
"csr_path or csr_content is required for ownca provider"
|
||||
)
|
||||
if not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
if self.ca_cert_path is not None and not os.path.exists(self.ca_cert_path):
|
||||
raise CertificateError(
|
||||
f"The CA certificate file {self.ca_cert_path} does not exist"
|
||||
)
|
||||
if self.ca_privatekey_content is None and not os.path.exists(
|
||||
if self.ca_privatekey_path is not None and not os.path.exists(
|
||||
self.ca_privatekey_path
|
||||
):
|
||||
raise CertificateError(
|
||||
@@ -101,8 +123,12 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
path=self.ca_cert_path,
|
||||
content=self.ca_cert_content,
|
||||
)
|
||||
if not is_potential_certificate_issuer_public_key(self.ca_cert.public_key()):
|
||||
raise CertificateError(
|
||||
"CA certificate's public key cannot be used to sign certificates"
|
||||
)
|
||||
try:
|
||||
self.ca_private_key = load_privatekey(
|
||||
self.ca_private_key = load_certificate_issuer_privatekey(
|
||||
path=self.ca_privatekey_path,
|
||||
content=self.ca_privatekey_content,
|
||||
passphrase=self.ca_privatekey_passphrase,
|
||||
@@ -125,8 +151,10 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
else:
|
||||
self.digest = None
|
||||
|
||||
def generate_certificate(self):
|
||||
def generate_certificate(self) -> None:
|
||||
"""(Re-)Generate certificate."""
|
||||
if self.csr is None:
|
||||
raise AssertionError("Contract violation: csr has not been populated")
|
||||
cert_builder = x509.CertificateBuilder()
|
||||
cert_builder = cert_builder.subject_name(self.csr.subject)
|
||||
cert_builder = cert_builder.issuer_name(self.ca_cert.subject)
|
||||
@@ -166,10 +194,10 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
critical=False,
|
||||
)
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
public_key = self.ca_cert.public_key()
|
||||
assert is_potential_certificate_issuer_public_key(public_key)
|
||||
cert_builder = cert_builder.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(
|
||||
self.ca_cert.public_key()
|
||||
),
|
||||
x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key),
|
||||
critical=False,
|
||||
)
|
||||
|
||||
@@ -180,17 +208,24 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
self.cert = certificate
|
||||
|
||||
def get_certificate_data(self):
|
||||
def get_certificate_data(self) -> bytes:
|
||||
"""Return bytes for self.cert."""
|
||||
if self.cert is None:
|
||||
raise AssertionError("Contract violation: cert has not been populated")
|
||||
return self.cert.public_bytes(Encoding.PEM)
|
||||
|
||||
def needs_regeneration(self):
|
||||
def needs_regeneration(
|
||||
self,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
if super(OwnCACertificateBackendCryptography, self).needs_regeneration(
|
||||
not_before=self.notBefore, not_after=self.notAfter
|
||||
):
|
||||
return True
|
||||
|
||||
self._ensure_existing_certificate_loaded()
|
||||
assert self.existing_certificate is not None
|
||||
|
||||
# Check whether certificate is signed by CA certificate
|
||||
if not cryptography_verify_certificate_signature(
|
||||
@@ -205,31 +240,33 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
# Check AuthorityKeyIdentifier
|
||||
if self.create_authority_key_identifier:
|
||||
try:
|
||||
ext = self.ca_cert.extensions.get_extension_for_class(
|
||||
ext_ski = self.ca_cert.extensions.get_extension_for_class(
|
||||
x509.SubjectKeyIdentifier
|
||||
)
|
||||
expected_ext = (
|
||||
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
|
||||
ext.value
|
||||
ext_ski.value
|
||||
)
|
||||
)
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
public_key = self.ca_cert.public_key()
|
||||
assert is_potential_certificate_issuer_public_key(public_key)
|
||||
expected_ext = x509.AuthorityKeyIdentifier.from_issuer_public_key(
|
||||
self.ca_cert.public_key()
|
||||
public_key
|
||||
)
|
||||
|
||||
try:
|
||||
ext = self.existing_certificate.extensions.get_extension_for_class(
|
||||
ext_aki = self.existing_certificate.extensions.get_extension_for_class(
|
||||
x509.AuthorityKeyIdentifier
|
||||
)
|
||||
if ext.value != expected_ext:
|
||||
if ext_aki.value != expected_ext:
|
||||
return True
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def dump(self, include_certificate):
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
result = super(OwnCACertificateBackendCryptography, self).dump(
|
||||
include_certificate
|
||||
)
|
||||
@@ -251,6 +288,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
else:
|
||||
if self.cert is None:
|
||||
self.cert = self.existing_certificate
|
||||
assert self.cert is not None
|
||||
result.update(
|
||||
{
|
||||
"notBefore": get_not_valid_before(self.cert).strftime(
|
||||
@@ -266,7 +304,7 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
|
||||
return result
|
||||
|
||||
|
||||
def generate_serial_number():
|
||||
def generate_serial_number() -> int:
|
||||
"""Generate a serial number for a certificate"""
|
||||
while True:
|
||||
result = randrange(0, 1 << 160)
|
||||
@@ -275,7 +313,7 @@ def generate_serial_number():
|
||||
|
||||
|
||||
class OwnCACertificateProvider(CertificateProvider):
|
||||
def validate_module_args(self, module):
|
||||
def validate_module_args(self, module: AnsibleModule) -> None:
|
||||
if (
|
||||
module.params["ownca_path"] is None
|
||||
and module.params["ownca_content"] is None
|
||||
@@ -291,14 +329,16 @@ class OwnCACertificateProvider(CertificateProvider):
|
||||
msg="One of ownca_privatekey_path and ownca_privatekey_content must be specified for the ownca provider."
|
||||
)
|
||||
|
||||
def needs_version_two_certs(self, module):
|
||||
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
|
||||
return module.params["ownca_version"] == 2
|
||||
|
||||
def create_backend(self, module):
|
||||
def create_backend(
|
||||
self, module: AnsibleModule
|
||||
) -> OwnCACertificateBackendCryptography:
|
||||
return OwnCACertificateBackendCryptography(module)
|
||||
|
||||
|
||||
def add_ownca_provider_to_argument_spec(argument_spec):
|
||||
def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
argument_spec.argument_spec["provider"]["choices"].append("ownca")
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
from random import randrange
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
|
||||
@@ -14,6 +15,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
|
||||
cryptography_verify_certificate_signature,
|
||||
get_not_valid_after,
|
||||
get_not_valid_before,
|
||||
is_potential_certificate_issuer_private_key,
|
||||
set_not_valid_after,
|
||||
set_not_valid_before,
|
||||
)
|
||||
@@ -30,6 +32,17 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ...argspec import ArgumentSpec
|
||||
|
||||
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography import x509
|
||||
@@ -39,12 +52,14 @@ except ImportError:
|
||||
|
||||
|
||||
class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
def __init__(self, module):
|
||||
privatekey: CertificateIssuerPrivateKeyTypes
|
||||
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SelfSignedCertificateBackendCryptography, self).__init__(module)
|
||||
|
||||
self.create_subject_key_identifier = module.params[
|
||||
"selfsigned_create_subject_key_identifier"
|
||||
]
|
||||
self.create_subject_key_identifier: t.Literal[
|
||||
"create_if_not_provided", "always_create", "never_create"
|
||||
] = module.params["selfsigned_create_subject_key_identifier"]
|
||||
self.notBefore = get_relative_time_option(
|
||||
module.params["selfsigned_not_before"],
|
||||
"selfsigned_not_before",
|
||||
@@ -56,14 +71,16 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
self.digest = select_message_digest(module.params["selfsigned_digest"])
|
||||
self.version = module.params["selfsigned_version"]
|
||||
self.version: int = module.params["selfsigned_version"]
|
||||
self.serial_number = x509.random_serial_number()
|
||||
|
||||
if self.csr_path is not None and not os.path.exists(self.csr_path):
|
||||
raise CertificateError(
|
||||
f"The certificate signing request file {self.csr_path} does not exist"
|
||||
)
|
||||
if self.privatekey_content is None and not os.path.exists(self.privatekey_path):
|
||||
if self.privatekey_path is not None and not os.path.exists(
|
||||
self.privatekey_path
|
||||
):
|
||||
raise CertificateError(
|
||||
f"The private key file {self.privatekey_path} does not exist"
|
||||
)
|
||||
@@ -71,20 +88,10 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
self._module = module
|
||||
|
||||
self._ensure_private_key_loaded()
|
||||
|
||||
self._ensure_csr_loaded()
|
||||
if self.csr is None:
|
||||
# Create empty CSR on the fly
|
||||
csr = cryptography.x509.CertificateSigningRequestBuilder()
|
||||
csr = csr.subject_name(cryptography.x509.Name([]))
|
||||
digest = None
|
||||
if cryptography_key_needs_digest_for_signing(self.privatekey):
|
||||
digest = self.digest
|
||||
if digest is None:
|
||||
self.module.fail_json(
|
||||
msg=f'Unsupported digest "{module.params["selfsigned_digest"]}"'
|
||||
)
|
||||
self.csr = csr.sign(self.privatekey, digest)
|
||||
if self.privatekey is None:
|
||||
raise CertificateError("Private key has not been provided")
|
||||
if not is_potential_certificate_issuer_private_key(self.privatekey):
|
||||
raise CertificateError("Private key cannot be used to sign certificates")
|
||||
|
||||
if cryptography_key_needs_digest_for_signing(self.privatekey):
|
||||
if self.digest is None:
|
||||
@@ -94,8 +101,21 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
else:
|
||||
self.digest = None
|
||||
|
||||
def generate_certificate(self):
|
||||
self._ensure_csr_loaded()
|
||||
if self.csr is None:
|
||||
# Create empty CSR on the fly
|
||||
csr = cryptography.x509.CertificateSigningRequestBuilder()
|
||||
csr = csr.subject_name(cryptography.x509.Name([]))
|
||||
self.csr = csr.sign(self.privatekey, self.digest)
|
||||
|
||||
def generate_certificate(self) -> None:
|
||||
"""(Re-)Generate certificate."""
|
||||
if self.csr is None:
|
||||
raise AssertionError("Contract violation: csr has not been populated")
|
||||
if self.privatekey is None:
|
||||
raise AssertionError(
|
||||
"Contract violation: privatekey has not been populated"
|
||||
)
|
||||
try:
|
||||
cert_builder = x509.CertificateBuilder()
|
||||
cert_builder = cert_builder.subject_name(self.csr.subject)
|
||||
@@ -130,17 +150,26 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
self.cert = certificate
|
||||
|
||||
def get_certificate_data(self):
|
||||
def get_certificate_data(self) -> bytes:
|
||||
"""Return bytes for self.cert."""
|
||||
if self.cert is None:
|
||||
raise AssertionError("Contract violation: cert has not been populated")
|
||||
return self.cert.public_bytes(Encoding.PEM)
|
||||
|
||||
def needs_regeneration(self):
|
||||
def needs_regeneration(
|
||||
self,
|
||||
not_before: datetime.datetime | None = None,
|
||||
not_after: datetime.datetime | None = None,
|
||||
) -> bool:
|
||||
assert self.privatekey is not None
|
||||
|
||||
if super(SelfSignedCertificateBackendCryptography, self).needs_regeneration(
|
||||
not_before=self.notBefore, not_after=self.notAfter
|
||||
):
|
||||
return True
|
||||
|
||||
self._ensure_existing_certificate_loaded()
|
||||
assert self.existing_certificate is not None
|
||||
|
||||
# Check whether certificate is signed by private key
|
||||
if not cryptography_verify_certificate_signature(
|
||||
@@ -150,7 +179,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
|
||||
return False
|
||||
|
||||
def dump(self, include_certificate):
|
||||
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
|
||||
result = super(SelfSignedCertificateBackendCryptography, self).dump(
|
||||
include_certificate
|
||||
)
|
||||
@@ -166,6 +195,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
else:
|
||||
if self.cert is None:
|
||||
self.cert = self.existing_certificate
|
||||
assert self.cert is not None
|
||||
result.update(
|
||||
{
|
||||
"notBefore": get_not_valid_before(self.cert).strftime(
|
||||
@@ -181,7 +211,7 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
|
||||
return result
|
||||
|
||||
|
||||
def generate_serial_number():
|
||||
def generate_serial_number() -> int:
|
||||
"""Generate a serial number for a certificate"""
|
||||
while True:
|
||||
result = randrange(0, 1 << 160)
|
||||
@@ -190,7 +220,7 @@ def generate_serial_number():
|
||||
|
||||
|
||||
class SelfSignedCertificateProvider(CertificateProvider):
|
||||
def validate_module_args(self, module):
|
||||
def validate_module_args(self, module: AnsibleModule) -> None:
|
||||
if (
|
||||
module.params["privatekey_path"] is None
|
||||
and module.params["privatekey_content"] is None
|
||||
@@ -199,14 +229,16 @@ class SelfSignedCertificateProvider(CertificateProvider):
|
||||
msg="One of privatekey_path and privatekey_content must be specified for the selfsigned provider."
|
||||
)
|
||||
|
||||
def needs_version_two_certs(self, module):
|
||||
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
|
||||
return module.params["selfsigned_version"] == 2
|
||||
|
||||
def create_backend(self, module):
|
||||
def create_backend(
|
||||
self, module: AnsibleModule
|
||||
) -> SelfSignedCertificateBackendCryptography:
|
||||
return SelfSignedCertificateBackendCryptography(module)
|
||||
|
||||
|
||||
def add_selfsigned_provider_to_argument_spec(argument_spec):
|
||||
def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
|
||||
argument_spec.argument_spec["provider"]["choices"].append("selfsigned")
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_crl import (
|
||||
TIMESTAMP_FORMAT,
|
||||
cryptography_decode_revoked_certificate,
|
||||
@@ -22,6 +24,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
from ....plugin_utils.filter_module import FilterModuleMock
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
# crypto_utils
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
@@ -33,14 +47,19 @@ except ImportError:
|
||||
|
||||
|
||||
class CRLInfoRetrieval:
|
||||
def __init__(self, module, content, list_revoked_certificates=True):
|
||||
def __init__(
|
||||
self,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
list_revoked_certificates: bool = True,
|
||||
) -> None:
|
||||
# content must be a bytes string
|
||||
self.module = module
|
||||
self.content = content
|
||||
self.list_revoked_certificates = list_revoked_certificates
|
||||
self.name_encoding = module.params.get("name_encoding", "ignore")
|
||||
|
||||
def get_info(self):
|
||||
def get_info(self) -> dict[str, t.Any]:
|
||||
self.crl_pem = identify_pem_format(self.content)
|
||||
try:
|
||||
if self.crl_pem:
|
||||
@@ -50,7 +69,7 @@ class CRLInfoRetrieval:
|
||||
except ValueError as e:
|
||||
self.module.fail_json(msg=f"Error while decoding CRL: {e}")
|
||||
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"changed": False,
|
||||
"format": "pem" if self.crl_pem else "der",
|
||||
"last_update": None,
|
||||
@@ -61,7 +80,11 @@ class CRLInfoRetrieval:
|
||||
}
|
||||
|
||||
result["last_update"] = self.crl.last_update.strftime(TIMESTAMP_FORMAT)
|
||||
result["next_update"] = self.crl.next_update.strftime(TIMESTAMP_FORMAT)
|
||||
result["next_update"] = (
|
||||
self.crl.next_update.strftime(TIMESTAMP_FORMAT)
|
||||
if self.crl.next_update
|
||||
else None
|
||||
)
|
||||
result["digest"] = cryptography_oid_to_name(
|
||||
cryptography_get_signature_algorithm_oid_from_crl(self.crl)
|
||||
)
|
||||
@@ -83,7 +106,9 @@ class CRLInfoRetrieval:
|
||||
return result
|
||||
|
||||
|
||||
def get_crl_info(module, content, list_revoked_certificates=True):
|
||||
def get_crl_info(
|
||||
module: GeneralAnsibleModule, content: bytes, list_revoked_certificates: bool = True
|
||||
) -> dict[str, t.Any]:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import binascii
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
@@ -26,13 +27,14 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptograp
|
||||
cryptography_name_to_oid,
|
||||
cryptography_parse_key_usage_params,
|
||||
cryptography_parse_relative_distinguished_name,
|
||||
is_potential_certificate_issuer_public_key,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr_info import (
|
||||
get_csr_info,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
|
||||
load_certificate_issuer_privatekey,
|
||||
load_certificate_request,
|
||||
load_privatekey,
|
||||
parse_name_field,
|
||||
parse_ordered_name_field,
|
||||
select_message_digest,
|
||||
@@ -43,6 +45,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ..cryptography_support import CertificatePrivateKeyTypes
|
||||
|
||||
_ET = t.TypeVar("_ET", bound="cryptography.x509.ExtensionType")
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -69,49 +83,58 @@ class CertificateSigningRequestError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.digest = module.params["digest"]
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
if self.privatekey_content is not None:
|
||||
self.privatekey_content = self.privatekey_content.encode("utf-8")
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.version = module.params["version"]
|
||||
self.subjectAltName = module.params["subject_alt_name"]
|
||||
self.subjectAltName_critical = module.params["subject_alt_name_critical"]
|
||||
self.keyUsage = module.params["key_usage"]
|
||||
self.keyUsage_critical = module.params["key_usage_critical"]
|
||||
self.extendedKeyUsage = module.params["extended_key_usage"]
|
||||
self.extendedKeyUsage_critical = module.params["extended_key_usage_critical"]
|
||||
self.basicConstraints = module.params["basic_constraints"]
|
||||
self.basicConstraints_critical = module.params["basic_constraints_critical"]
|
||||
self.ocspMustStaple = module.params["ocsp_must_staple"]
|
||||
self.ocspMustStaple_critical = module.params["ocsp_must_staple_critical"]
|
||||
self.name_constraints_permitted = (
|
||||
self.digest: str = module.params["digest"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
if privatekey_content is not None:
|
||||
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
|
||||
else:
|
||||
self.privatekey_content = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
self.version: t.Literal[1] = module.params["version"]
|
||||
self.subjectAltName: list[str] | None = module.params["subject_alt_name"]
|
||||
self.subjectAltName_critical: bool = module.params["subject_alt_name_critical"]
|
||||
self.keyUsage: list[str] | None = module.params["key_usage"]
|
||||
self.keyUsage_critical: bool = module.params["key_usage_critical"]
|
||||
self.extendedKeyUsage: list[str] | None = module.params["extended_key_usage"]
|
||||
self.extendedKeyUsage_critical: bool = module.params[
|
||||
"extended_key_usage_critical"
|
||||
]
|
||||
self.basicConstraints: list[str] | None = module.params["basic_constraints"]
|
||||
self.basicConstraints_critical: bool = module.params[
|
||||
"basic_constraints_critical"
|
||||
]
|
||||
self.ocspMustStaple: bool = module.params["ocsp_must_staple"]
|
||||
self.ocspMustStaple_critical: bool = module.params["ocsp_must_staple_critical"]
|
||||
self.name_constraints_permitted: list[str] = (
|
||||
module.params["name_constraints_permitted"] or []
|
||||
)
|
||||
self.name_constraints_excluded = (
|
||||
self.name_constraints_excluded: list[str] = (
|
||||
module.params["name_constraints_excluded"] or []
|
||||
)
|
||||
self.name_constraints_critical = module.params["name_constraints_critical"]
|
||||
self.create_subject_key_identifier = module.params[
|
||||
self.name_constraints_critical: bool = module.params[
|
||||
"name_constraints_critical"
|
||||
]
|
||||
self.create_subject_key_identifier: bool = module.params[
|
||||
"create_subject_key_identifier"
|
||||
]
|
||||
self.subject_key_identifier = module.params["subject_key_identifier"]
|
||||
self.authority_key_identifier = module.params["authority_key_identifier"]
|
||||
self.authority_cert_issuer = module.params["authority_cert_issuer"]
|
||||
self.authority_cert_serial_number = module.params[
|
||||
subject_key_identifier: str | None = module.params["subject_key_identifier"]
|
||||
authority_key_identifier: str | None = module.params["authority_key_identifier"]
|
||||
self.authority_cert_issuer: list[str] | None = module.params[
|
||||
"authority_cert_issuer"
|
||||
]
|
||||
self.authority_cert_serial_number: int = module.params[
|
||||
"authority_cert_serial_number"
|
||||
]
|
||||
self.crl_distribution_points = module.params["crl_distribution_points"]
|
||||
self.csr = None
|
||||
self.privatekey = None
|
||||
self.crl_distribution_points: (
|
||||
list[cryptography.x509.DistributionPoint] | None
|
||||
) = None
|
||||
self.csr: cryptography.x509.CertificateSigningRequest | None = None
|
||||
self.privatekey: CertificateIssuerPrivateKeyTypes | None = None
|
||||
|
||||
if (
|
||||
self.create_subject_key_identifier
|
||||
and self.subject_key_identifier is not None
|
||||
):
|
||||
if self.create_subject_key_identifier and subject_key_identifier is not None:
|
||||
module.fail_json(
|
||||
msg="subject_key_identifier cannot be specified if create_subject_key_identifier is true"
|
||||
)
|
||||
@@ -153,35 +176,37 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
self.using_common_name_for_san = True
|
||||
break
|
||||
|
||||
if self.subject_key_identifier is not None:
|
||||
self.subject_key_identifier: bytes | None = None
|
||||
if subject_key_identifier is not None:
|
||||
try:
|
||||
self.subject_key_identifier = binascii.unhexlify(
|
||||
self.subject_key_identifier.replace(":", "")
|
||||
subject_key_identifier.replace(":", "")
|
||||
)
|
||||
except Exception as e:
|
||||
raise CertificateSigningRequestError(
|
||||
f"Cannot parse subject_key_identifier: {e}"
|
||||
)
|
||||
|
||||
if self.authority_key_identifier is not None:
|
||||
self.authority_key_identifier: bytes | None = None
|
||||
if authority_key_identifier is not None:
|
||||
try:
|
||||
self.authority_key_identifier = binascii.unhexlify(
|
||||
self.authority_key_identifier.replace(":", "")
|
||||
authority_key_identifier.replace(":", "")
|
||||
)
|
||||
except Exception as e:
|
||||
raise CertificateSigningRequestError(
|
||||
f"Cannot parse authority_key_identifier: {e}"
|
||||
)
|
||||
|
||||
self.existing_csr = None
|
||||
self.existing_csr_bytes = None
|
||||
self.existing_csr: cryptography.x509.CertificateSigningRequest | None = None
|
||||
self.existing_csr_bytes: bytes | None = None
|
||||
|
||||
self.diff_before = self._get_info(None)
|
||||
self.diff_after = self._get_info(None)
|
||||
|
||||
def _get_info(self, data):
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return dict()
|
||||
return {}
|
||||
try:
|
||||
result = get_csr_info(
|
||||
self.module,
|
||||
@@ -195,30 +220,28 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
return dict(can_parse_csr=False)
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_csr(self):
|
||||
def generate_csr(self) -> None:
|
||||
"""(Re-)Generate CSR."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_csr_data(self):
|
||||
def get_csr_data(self) -> bytes:
|
||||
"""Return bytes for self.csr."""
|
||||
pass
|
||||
|
||||
def set_existing(self, csr_bytes):
|
||||
def set_existing(self, csr_bytes: bytes | None) -> None:
|
||||
"""Set existing CSR bytes. None indicates that the CSR does not exist."""
|
||||
self.existing_csr_bytes = csr_bytes
|
||||
self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes)
|
||||
|
||||
def has_existing(self):
|
||||
def has_existing(self) -> bool:
|
||||
"""Query whether an existing CSR is/has been there."""
|
||||
return self.existing_csr_bytes is not None
|
||||
|
||||
def _ensure_private_key_loaded(self):
|
||||
def _ensure_private_key_loaded(self) -> None:
|
||||
"""Load the provided private key into self.privatekey."""
|
||||
if self.privatekey is not None:
|
||||
return
|
||||
try:
|
||||
self.privatekey = load_privatekey(
|
||||
self.privatekey = load_certificate_issuer_privatekey(
|
||||
path=self.privatekey_path,
|
||||
content=self.privatekey_content,
|
||||
passphrase=self.privatekey_passphrase,
|
||||
@@ -227,11 +250,10 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
raise CertificateSigningRequestError(exc)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_csr(self):
|
||||
def _check_csr(self) -> bool:
|
||||
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
|
||||
pass
|
||||
|
||||
def needs_regeneration(self):
|
||||
def needs_regeneration(self) -> bool:
|
||||
"""Check whether a regeneration is necessary."""
|
||||
if self.existing_csr_bytes is None:
|
||||
return True
|
||||
@@ -245,9 +267,9 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
self._ensure_private_key_loaded()
|
||||
return not self._check_csr()
|
||||
|
||||
def dump(self, include_csr):
|
||||
def dump(self, include_csr: bool) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"privatekey": self.privatekey_path,
|
||||
"subject": self.subject,
|
||||
"subjectAltName": self.subjectAltName,
|
||||
@@ -274,44 +296,49 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
|
||||
return result
|
||||
|
||||
|
||||
def parse_crl_distribution_points(module, crl_distribution_points):
|
||||
def parse_crl_distribution_points(
|
||||
module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
|
||||
) -> list[cryptography.x509.DistributionPoint]:
|
||||
result = []
|
||||
for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
|
||||
try:
|
||||
params = dict(
|
||||
full_name=None,
|
||||
relative_name=None,
|
||||
crl_issuer=None,
|
||||
reasons=None,
|
||||
)
|
||||
full_name = None
|
||||
relative_name = None
|
||||
crl_issuer = None
|
||||
reasons = None
|
||||
if parse_crl_distribution_point["full_name"] is not None:
|
||||
if not parse_crl_distribution_point["full_name"]:
|
||||
raise OpenSSLObjectError("full_name must not be empty")
|
||||
params["full_name"] = [
|
||||
full_name = [
|
||||
cryptography_get_name(name, "full name")
|
||||
for name in parse_crl_distribution_point["full_name"]
|
||||
]
|
||||
if parse_crl_distribution_point["relative_name"] is not None:
|
||||
if not parse_crl_distribution_point["relative_name"]:
|
||||
raise OpenSSLObjectError("relative_name must not be empty")
|
||||
params["relative_name"] = (
|
||||
cryptography_parse_relative_distinguished_name(
|
||||
parse_crl_distribution_point["relative_name"]
|
||||
)
|
||||
relative_name = cryptography_parse_relative_distinguished_name(
|
||||
parse_crl_distribution_point["relative_name"]
|
||||
)
|
||||
if parse_crl_distribution_point["crl_issuer"] is not None:
|
||||
if not parse_crl_distribution_point["crl_issuer"]:
|
||||
raise OpenSSLObjectError("crl_issuer must not be empty")
|
||||
params["crl_issuer"] = [
|
||||
crl_issuer = [
|
||||
cryptography_get_name(name, "CRL issuer")
|
||||
for name in parse_crl_distribution_point["crl_issuer"]
|
||||
]
|
||||
if parse_crl_distribution_point["reasons"] is not None:
|
||||
reasons = []
|
||||
reasons_list = []
|
||||
for reason in parse_crl_distribution_point["reasons"]:
|
||||
reasons.append(REVOCATION_REASON_MAP[reason])
|
||||
params["reasons"] = frozenset(reasons)
|
||||
result.append(cryptography.x509.DistributionPoint(**params))
|
||||
reasons_list.append(REVOCATION_REASON_MAP[reason])
|
||||
reasons = frozenset(reasons_list)
|
||||
result.append(
|
||||
cryptography.x509.DistributionPoint(
|
||||
full_name=full_name,
|
||||
relative_name=relative_name,
|
||||
crl_issuer=crl_issuer,
|
||||
reasons=reasons,
|
||||
)
|
||||
)
|
||||
except (OpenSSLObjectError, ValueError) as e:
|
||||
raise OpenSSLObjectError(
|
||||
f"Error while parsing CRL distribution point #{index}: {e}"
|
||||
@@ -321,21 +348,25 @@ def parse_crl_distribution_points(module, crl_distribution_points):
|
||||
|
||||
# Implementation with using cryptography
|
||||
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CertificateSigningRequestCryptographyBackend, self).__init__(module)
|
||||
if self.version != 1:
|
||||
module.warn(
|
||||
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
|
||||
)
|
||||
|
||||
if self.crl_distribution_points:
|
||||
crl_distribution_points: list[dict[str, t.Any]] | None = module.params[
|
||||
"crl_distribution_points"
|
||||
]
|
||||
if crl_distribution_points:
|
||||
self.crl_distribution_points = parse_crl_distribution_points(
|
||||
module, self.crl_distribution_points
|
||||
module, crl_distribution_points
|
||||
)
|
||||
|
||||
def generate_csr(self):
|
||||
def generate_csr(self) -> None:
|
||||
"""(Re-)Generate CSR."""
|
||||
self._ensure_private_key_loaded()
|
||||
assert self.privatekey is not None
|
||||
|
||||
csr = cryptography.x509.CertificateSigningRequestBuilder()
|
||||
try:
|
||||
@@ -412,6 +443,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
raise OpenSSLObjectError(f"Error while parsing name constraint: {e}")
|
||||
|
||||
if self.create_subject_key_identifier:
|
||||
if not is_potential_certificate_issuer_public_key(
|
||||
self.privatekey.public_key()
|
||||
):
|
||||
raise OpenSSLObjectError(
|
||||
"Private key can not be used to create subject key identifier"
|
||||
)
|
||||
csr = csr.add_extension(
|
||||
cryptography.x509.SubjectKeyIdentifier.from_public_key(
|
||||
self.privatekey.public_key()
|
||||
@@ -450,7 +487,10 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
critical=False,
|
||||
)
|
||||
|
||||
digest = None
|
||||
# csr.sign() does not accept some digests we theoretically could have in digest.
|
||||
# For that reason we use type t.Any here. csr.sign() will complain if
|
||||
# the digest is not acceptable.
|
||||
digest: t.Any | None = None
|
||||
if cryptography_key_needs_digest_for_signing(self.privatekey):
|
||||
digest = select_message_digest(self.digest)
|
||||
if digest is None:
|
||||
@@ -482,16 +522,22 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
+ "This is probably caused by an invalid Subject Alternative DNS Name."
|
||||
)
|
||||
|
||||
def get_csr_data(self):
|
||||
def get_csr_data(self) -> bytes:
|
||||
"""Return bytes for self.csr."""
|
||||
if self.csr is None:
|
||||
raise AssertionError("Violated contract: csr is not populated")
|
||||
return self.csr.public_bytes(
|
||||
cryptography.hazmat.primitives.serialization.Encoding.PEM
|
||||
)
|
||||
|
||||
def _check_csr(self):
|
||||
def _check_csr(self) -> bool:
|
||||
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
|
||||
if self.existing_csr is None:
|
||||
raise AssertionError("Violated contract: existing_csr is not populated")
|
||||
if self.privatekey is None:
|
||||
raise AssertionError("Violated contract: privatekey is not populated")
|
||||
|
||||
def _check_subject(csr):
|
||||
def _check_subject(csr: cryptography.x509.CertificateSigningRequest) -> bool:
|
||||
subject = [
|
||||
(cryptography_name_to_oid(entry[0]), to_text(entry[1]))
|
||||
for entry in self.subject
|
||||
@@ -502,12 +548,14 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else:
|
||||
return set(subject) == set(current_subject)
|
||||
|
||||
def _find_extension(extensions, exttype):
|
||||
def _find_extension(
|
||||
extensions: cryptography.x509.Extensions, exttype: type[_ET]
|
||||
) -> cryptography.x509.Extension[_ET] | None:
|
||||
return next(
|
||||
(ext for ext in extensions if isinstance(ext.value, exttype)), None
|
||||
)
|
||||
|
||||
def _check_subjectAltName(extensions):
|
||||
def _check_subjectAltName(extensions: cryptography.x509.Extensions) -> bool:
|
||||
current_altnames_ext = _find_extension(
|
||||
extensions, cryptography.x509.SubjectAlternativeName
|
||||
)
|
||||
@@ -526,12 +574,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
)
|
||||
if set(altnames) != set(current_altnames):
|
||||
return False
|
||||
if altnames:
|
||||
if altnames and current_altnames_ext:
|
||||
if current_altnames_ext.critical != self.subjectAltName_critical:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_keyUsage(extensions):
|
||||
def _check_keyUsage(extensions: cryptography.x509.Extensions) -> bool:
|
||||
current_keyusage_ext = _find_extension(
|
||||
extensions, cryptography.x509.KeyUsage
|
||||
)
|
||||
@@ -547,7 +595,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_extenededKeyUsage(extensions):
|
||||
def _check_extenededKeyUsage(extensions: cryptography.x509.Extensions) -> bool:
|
||||
current_usages_ext = _find_extension(
|
||||
extensions, cryptography.x509.ExtendedKeyUsage
|
||||
)
|
||||
@@ -566,12 +614,12 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
)
|
||||
if set(current_usages) != set(usages):
|
||||
return False
|
||||
if usages:
|
||||
if usages and current_usages_ext:
|
||||
if current_usages_ext.critical != self.extendedKeyUsage_critical:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_basicConstraints(extensions):
|
||||
def _check_basicConstraints(extensions: cryptography.x509.Extensions) -> bool:
|
||||
bc_ext = _find_extension(extensions, cryptography.x509.BasicConstraints)
|
||||
current_ca = bc_ext.value.ca if bc_ext else False
|
||||
current_path_length = bc_ext.value.path_length if bc_ext else None
|
||||
@@ -591,7 +639,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else:
|
||||
return bc_ext is None
|
||||
|
||||
def _check_ocspMustStaple(extensions):
|
||||
def _check_ocspMustStaple(extensions: cryptography.x509.Extensions) -> bool:
|
||||
tlsfeature_ext = _find_extension(extensions, cryptography.x509.TLSFeature)
|
||||
if self.ocspMustStaple:
|
||||
if (
|
||||
@@ -606,7 +654,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else:
|
||||
return tlsfeature_ext is None
|
||||
|
||||
def _check_nameConstraints(extensions):
|
||||
def _check_nameConstraints(extensions: cryptography.x509.Extensions) -> bool:
|
||||
current_nc_ext = _find_extension(
|
||||
extensions, cryptography.x509.NameConstraints
|
||||
)
|
||||
@@ -638,12 +686,14 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
current_nc_excl
|
||||
):
|
||||
return False
|
||||
if nc_perm or nc_excl:
|
||||
if (nc_perm or nc_excl) and current_nc_ext:
|
||||
if current_nc_ext.critical != self.name_constraints_critical:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_subject_key_identifier(extensions):
|
||||
def _check_subject_key_identifier(
|
||||
extensions: cryptography.x509.Extensions,
|
||||
) -> bool:
|
||||
ext = _find_extension(extensions, cryptography.x509.SubjectKeyIdentifier)
|
||||
if (
|
||||
self.create_subject_key_identifier
|
||||
@@ -652,6 +702,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
if not ext or ext.critical:
|
||||
return False
|
||||
if self.create_subject_key_identifier:
|
||||
assert self.privatekey is not None
|
||||
digest = cryptography.x509.SubjectKeyIdentifier.from_public_key(
|
||||
self.privatekey.public_key()
|
||||
).digest
|
||||
@@ -661,7 +712,9 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else:
|
||||
return ext is None
|
||||
|
||||
def _check_authority_key_identifier(extensions):
|
||||
def _check_authority_key_identifier(
|
||||
extensions: cryptography.x509.Extensions,
|
||||
) -> bool:
|
||||
ext = _find_extension(extensions, cryptography.x509.AuthorityKeyIdentifier)
|
||||
if (
|
||||
self.authority_key_identifier is not None
|
||||
@@ -688,7 +741,9 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
else:
|
||||
return ext is None
|
||||
|
||||
def _check_crl_distribution_points(extensions):
|
||||
def _check_crl_distribution_points(
|
||||
extensions: cryptography.x509.Extensions,
|
||||
) -> bool:
|
||||
ext = _find_extension(extensions, cryptography.x509.CRLDistributionPoints)
|
||||
if self.crl_distribution_points is None:
|
||||
return ext is None
|
||||
@@ -696,7 +751,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
return False
|
||||
return list(ext.value) == self.crl_distribution_points
|
||||
|
||||
def _check_extensions(csr):
|
||||
def _check_extensions(csr: cryptography.x509.CertificateSigningRequest) -> bool:
|
||||
extensions = csr.extensions
|
||||
return (
|
||||
_check_subjectAltName(extensions)
|
||||
@@ -710,7 +765,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
and _check_crl_distribution_points(extensions)
|
||||
)
|
||||
|
||||
def _check_signature(csr):
|
||||
def _check_signature(csr: cryptography.x509.CertificateSigningRequest) -> bool:
|
||||
if not csr.is_signature_valid:
|
||||
return False
|
||||
# To check whether public key of CSR belongs to private key,
|
||||
@@ -719,6 +774,7 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
cryptography.hazmat.primitives.serialization.Encoding.PEM,
|
||||
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
assert self.privatekey is not None
|
||||
key_b = self.privatekey.public_key().public_bytes(
|
||||
cryptography.hazmat.primitives.serialization.Encoding.PEM,
|
||||
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
@@ -732,14 +788,16 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
|
||||
)
|
||||
|
||||
|
||||
def select_backend(module):
|
||||
def select_backend(
|
||||
module: AnsibleModule,
|
||||
) -> CertificateSigningRequestCryptographyBackend:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return CertificateSigningRequestCryptographyBackend(module)
|
||||
|
||||
|
||||
def get_csr_argument_spec():
|
||||
def get_csr_argument_spec() -> ArgumentSpec:
|
||||
return ArgumentSpec(
|
||||
argument_spec=dict(
|
||||
digest=dict(type="str", default="sha256"),
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import binascii
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
|
||||
@@ -27,6 +28,19 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificatePublicKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
from ....plugin_utils.filter_module import FilterModuleMock
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -41,66 +55,69 @@ TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
|
||||
|
||||
|
||||
class CSRInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module, content, validate_signature):
|
||||
# content must be a bytes string
|
||||
def __init__(
|
||||
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
|
||||
) -> None:
|
||||
self.module = module
|
||||
self.content = content
|
||||
self.validate_signature = validate_signature
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_ordered(self):
|
||||
def _get_subject_ordered(self) -> list[list[str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_key_usage(self):
|
||||
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extended_key_usage(self):
|
||||
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_basic_constraints(self):
|
||||
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_ocsp_must_staple(self):
|
||||
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_alt_name(self):
|
||||
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_name_constraints(self):
|
||||
def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key_pem(self):
|
||||
def _get_public_key_pem(self) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key_object(self):
|
||||
def _get_public_key_object(self) -> CertificatePublicKeyTypes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_subject_key_identifier(self):
|
||||
def _get_subject_key_identifier(self) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_authority_key_identifier(self):
|
||||
def _get_authority_key_identifier(
|
||||
self,
|
||||
) -> tuple[bytes | None, list[str] | None, int | None]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_all_extensions(self):
|
||||
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _is_signature_valid(self):
|
||||
def _is_signature_valid(self) -> bool:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint=False):
|
||||
result = dict()
|
||||
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
self.csr = load_certificate_request(
|
||||
None,
|
||||
content=self.content,
|
||||
@@ -145,15 +162,17 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
}
|
||||
)
|
||||
|
||||
ski = self._get_subject_key_identifier()
|
||||
if ski is not None:
|
||||
ski = binascii.hexlify(ski).decode("ascii")
|
||||
ski_bytes = self._get_subject_key_identifier()
|
||||
ski = None
|
||||
if ski_bytes is not None:
|
||||
ski = binascii.hexlify(ski_bytes).decode("ascii")
|
||||
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
|
||||
result["subject_key_identifier"] = ski
|
||||
|
||||
aki, aci, acsn = self._get_authority_key_identifier()
|
||||
if aki is not None:
|
||||
aki = binascii.hexlify(aki).decode("ascii")
|
||||
aki_bytes, aci, acsn = self._get_authority_key_identifier()
|
||||
aki = None
|
||||
if aki_bytes is not None:
|
||||
aki = binascii.hexlify(aki_bytes).decode("ascii")
|
||||
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
|
||||
result["authority_key_identifier"] = aki
|
||||
result["authority_cert_issuer"] = aci
|
||||
@@ -170,19 +189,25 @@ class CSRInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
"""Validate the supplied CSR, using the cryptography backend"""
|
||||
|
||||
def __init__(self, module, content, validate_signature):
|
||||
def __init__(
|
||||
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
|
||||
) -> None:
|
||||
super(CSRInfoRetrievalCryptography, self).__init__(
|
||||
module, content, validate_signature
|
||||
)
|
||||
self.name_encoding = module.params.get("name_encoding", "ignore")
|
||||
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get(
|
||||
"name_encoding", "ignore"
|
||||
)
|
||||
|
||||
def _get_subject_ordered(self):
|
||||
result = []
|
||||
def _get_subject_ordered(self) -> list[list[str]]:
|
||||
result: list[list[str]] = []
|
||||
for attribute in self.csr.subject:
|
||||
result.append([cryptography_oid_to_name(attribute.oid), attribute.value])
|
||||
result.append(
|
||||
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_key_usage(self):
|
||||
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
current_key_ext = self.csr.extensions.get_extension_for_class(x509.KeyUsage)
|
||||
current_key_usage = current_key_ext.value
|
||||
@@ -229,7 +254,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_extended_key_usage(self):
|
||||
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
|
||||
x509.ExtendedKeyUsage
|
||||
@@ -243,7 +268,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_basic_constraints(self):
|
||||
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
|
||||
x509.BasicConstraints
|
||||
@@ -255,7 +280,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_ocsp_must_staple(self):
|
||||
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
|
||||
try:
|
||||
# This only works with cryptography >= 2.1
|
||||
tlsfeature_ext = self.csr.extensions.get_extension_for_class(
|
||||
@@ -268,7 +293,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_subject_alt_name(self):
|
||||
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
|
||||
try:
|
||||
san_ext = self.csr.extensions.get_extension_for_class(
|
||||
x509.SubjectAlternativeName
|
||||
@@ -281,7 +306,7 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, False
|
||||
|
||||
def _get_name_constraints(self):
|
||||
def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
|
||||
try:
|
||||
nc_ext = self.csr.extensions.get_extension_for_class(x509.NameConstraints)
|
||||
permitted = [
|
||||
@@ -296,23 +321,25 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, None, False
|
||||
|
||||
def _get_public_key_pem(self):
|
||||
def _get_public_key_pem(self) -> bytes:
|
||||
return self.csr.public_key().public_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def _get_public_key_object(self):
|
||||
def _get_public_key_object(self) -> CertificatePublicKeyTypes:
|
||||
return self.csr.public_key()
|
||||
|
||||
def _get_subject_key_identifier(self):
|
||||
def _get_subject_key_identifier(self) -> bytes | None:
|
||||
try:
|
||||
ext = self.csr.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
|
||||
return ext.value.digest
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None
|
||||
|
||||
def _get_authority_key_identifier(self):
|
||||
def _get_authority_key_identifier(
|
||||
self,
|
||||
) -> tuple[bytes | None, list[str] | None, int | None]:
|
||||
try:
|
||||
ext = self.csr.extensions.get_extension_for_class(
|
||||
x509.AuthorityKeyIdentifier
|
||||
@@ -331,23 +358,28 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
return None, None, None
|
||||
|
||||
def _get_all_extensions(self):
|
||||
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
|
||||
return cryptography_get_extensions_from_csr(self.csr)
|
||||
|
||||
def _is_signature_valid(self):
|
||||
def _is_signature_valid(self) -> bool:
|
||||
return self.csr.is_signature_valid
|
||||
|
||||
|
||||
def get_csr_info(
|
||||
module, content, validate_signature=True, prefer_one_fingerprint=False
|
||||
):
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
validate_signature: bool = True,
|
||||
prefer_one_fingerprint: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
info = CSRInfoRetrievalCryptography(
|
||||
module, content, validate_signature=validate_signature
|
||||
)
|
||||
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
|
||||
|
||||
|
||||
def select_backend(module, content, validate_signature=True):
|
||||
def select_backend(
|
||||
module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
|
||||
) -> CSRInfoRetrieval:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import base64
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
@@ -33,6 +34,17 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule]
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -64,29 +76,37 @@ class PrivateKeyError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: GeneralAnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.type = module.params["type"]
|
||||
self.size = module.params["size"]
|
||||
self.curve = module.params["curve"]
|
||||
self.passphrase = module.params["passphrase"]
|
||||
self.cipher = module.params["cipher"]
|
||||
self.format = module.params["format"]
|
||||
self.format_mismatch = module.params.get("format_mismatch", "regenerate")
|
||||
self.regenerate = module.params.get("regenerate", "full_idempotence")
|
||||
self.type: t.Literal[
|
||||
"DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"
|
||||
] = module.params["type"]
|
||||
self.size: int = module.params["size"]
|
||||
self.curve: str | None = module.params["curve"]
|
||||
self.passphrase: str | None = module.params["passphrase"]
|
||||
self.cipher: str = module.params["cipher"]
|
||||
self.format: t.Literal["pkcs1", "pkcs8", "raw", "auto", "auto_ignore"] = (
|
||||
module.params["format"]
|
||||
)
|
||||
self.format_mismatch: t.Literal["regenerate", "convert"] = module.params.get(
|
||||
"format_mismatch", "regenerate"
|
||||
)
|
||||
self.regenerate: t.Literal[
|
||||
"never", "fail", "partial_idempotence", "full_idempotence", "always"
|
||||
] = module.params.get("regenerate", "full_idempotence")
|
||||
|
||||
self.private_key = None
|
||||
self.private_key: PrivateKeyTypes | None = None
|
||||
|
||||
self.existing_private_key = None
|
||||
self.existing_private_key_bytes = None
|
||||
self.existing_private_key: PrivateKeyTypes | None = None
|
||||
self.existing_private_key_bytes: bytes | None = None
|
||||
|
||||
self.diff_before = self._get_info(None)
|
||||
self.diff_after = self._get_info(None)
|
||||
|
||||
def _get_info(self, data):
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return dict()
|
||||
result = dict(can_parse_key=False)
|
||||
return {}
|
||||
result: dict[str, t.Any] = {"can_parse_key": False}
|
||||
try:
|
||||
result.update(
|
||||
get_privatekey_info(
|
||||
@@ -106,11 +126,11 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
return result
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_private_key(self):
|
||||
def generate_private_key(self) -> None:
|
||||
"""(Re-)Generate private key."""
|
||||
pass
|
||||
|
||||
def convert_private_key(self):
|
||||
def convert_private_key(self) -> None:
|
||||
"""Convert existing private key (self.existing_private_key) to new private key (self.private_key).
|
||||
|
||||
This is effectively a copy without active conversion. The conversion is done
|
||||
@@ -121,42 +141,37 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
self.private_key = self.existing_private_key
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_private_key_data(self):
|
||||
def get_private_key_data(self) -> bytes:
|
||||
"""Return bytes for self.private_key."""
|
||||
pass
|
||||
|
||||
def set_existing(self, privatekey_bytes):
|
||||
def set_existing(self, privatekey_bytes: bytes | None) -> None:
|
||||
"""Set existing private key bytes. None indicates that the key does not exist."""
|
||||
self.existing_private_key_bytes = privatekey_bytes
|
||||
self.diff_after = self.diff_before = self._get_info(
|
||||
self.existing_private_key_bytes
|
||||
)
|
||||
|
||||
def has_existing(self):
|
||||
def has_existing(self) -> bool:
|
||||
"""Query whether an existing private key is/has been there."""
|
||||
return self.existing_private_key_bytes is not None
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_passphrase(self):
|
||||
def _check_passphrase(self) -> bool:
|
||||
"""Check whether provided passphrase matches, assuming self.existing_private_key_bytes has been populated."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _ensure_existing_private_key_loaded(self):
|
||||
def _ensure_existing_private_key_loaded(self) -> None:
|
||||
"""Make sure that self.existing_private_key is populated from self.existing_private_key_bytes."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_size_and_type(self):
|
||||
def _check_size_and_type(self) -> bool:
|
||||
"""Check whether provided size and type matches, assuming self.existing_private_key has been populated."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_format(self):
|
||||
def _check_format(self) -> bool:
|
||||
"""Check whether the key file format, assuming self.existing_private_key and self.existing_private_key_bytes has been populated."""
|
||||
pass
|
||||
|
||||
def needs_regeneration(self):
|
||||
def needs_regeneration(self) -> bool:
|
||||
"""Check whether a regeneration is necessary."""
|
||||
if self.regenerate == "always":
|
||||
return True
|
||||
@@ -194,7 +209,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
)
|
||||
return False
|
||||
|
||||
def needs_conversion(self):
|
||||
def needs_conversion(self) -> bool:
|
||||
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
|
||||
# During conversion step, convert if format does not match and format_mismatch == 'convert'
|
||||
self._ensure_existing_private_key_loaded()
|
||||
@@ -204,7 +219,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
and not self._check_format()
|
||||
)
|
||||
|
||||
def _get_fingerprint(self):
|
||||
def _get_fingerprint(self) -> dict[str, str] | None:
|
||||
if self.private_key:
|
||||
return get_fingerprint_of_privatekey(self.private_key)
|
||||
try:
|
||||
@@ -214,8 +229,9 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
if self.existing_private_key:
|
||||
return get_fingerprint_of_privatekey(self.existing_private_key)
|
||||
return None
|
||||
|
||||
def dump(self, include_key):
|
||||
def dump(self, include_key: bool) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
if not self.private_key:
|
||||
@@ -224,7 +240,7 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
except Exception:
|
||||
# Ignore errors
|
||||
pass
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"type": self.type,
|
||||
"size": self.size,
|
||||
"fingerprint": self._get_fingerprint(),
|
||||
@@ -253,38 +269,57 @@ class PrivateKeyBackend(metaclass=abc.ABCMeta):
|
||||
return result
|
||||
|
||||
|
||||
# Implementation with using cryptography
|
||||
class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
class _Curve:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
ectype: str,
|
||||
deprecated: bool,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.ectype = ectype
|
||||
self.deprecated = deprecated
|
||||
|
||||
def _get_ec_class(self, ectype):
|
||||
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(ectype)
|
||||
def _get_ec_class(
|
||||
self, module: GeneralAnsibleModule
|
||||
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
|
||||
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore
|
||||
if ecclass is None:
|
||||
self.module.fail_json(
|
||||
msg=f"Your cryptography version does not support {ectype}"
|
||||
module.fail_json(
|
||||
msg=f"Your cryptography version does not support {self.ectype}"
|
||||
)
|
||||
return ecclass
|
||||
|
||||
def _add_curve(self, name, ectype, deprecated=False):
|
||||
def create(size):
|
||||
ecclass = self._get_ec_class(ectype)
|
||||
return ecclass()
|
||||
def create(
|
||||
self, size: int, module: GeneralAnsibleModule
|
||||
) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
|
||||
ecclass = self._get_ec_class(module)
|
||||
return ecclass()
|
||||
|
||||
def verify(privatekey):
|
||||
ecclass = self._get_ec_class(ectype)
|
||||
return isinstance(
|
||||
privatekey.private_numbers().public_numbers.curve, ecclass
|
||||
)
|
||||
def verify(
|
||||
self,
|
||||
privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
|
||||
module: GeneralAnsibleModule,
|
||||
) -> bool:
|
||||
ecclass = self._get_ec_class(module)
|
||||
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
|
||||
|
||||
self.curves[name] = {
|
||||
"create": create,
|
||||
"verify": verify,
|
||||
"deprecated": deprecated,
|
||||
}
|
||||
|
||||
def __init__(self, module):
|
||||
# Implementation with using cryptography
|
||||
class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
|
||||
def _add_curve(
|
||||
self,
|
||||
name: str,
|
||||
ectype: str,
|
||||
deprecated: bool = False,
|
||||
) -> None:
|
||||
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
|
||||
|
||||
def __init__(self, module: GeneralAnsibleModule) -> None:
|
||||
super(PrivateKeyCryptographyBackend, self).__init__(module=module)
|
||||
|
||||
self.curves = dict()
|
||||
self.curves: dict[str, _Curve] = {}
|
||||
self._add_curve("secp224r1", "SECP224R1")
|
||||
self._add_curve("secp256k1", "SECP256K1")
|
||||
self._add_curve("secp256r1", "SECP256R1")
|
||||
@@ -305,15 +340,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
self._add_curve("brainpoolP384r1", "BrainpoolP384R1", deprecated=True)
|
||||
self._add_curve("brainpoolP512r1", "BrainpoolP512R1", deprecated=True)
|
||||
|
||||
def _get_wanted_format(self):
|
||||
def _get_wanted_format(self) -> t.Literal["pkcs1", "pkcs8", "raw"]:
|
||||
if self.format not in ("auto", "auto_ignore"):
|
||||
return self.format
|
||||
return self.format # type: ignore
|
||||
if self.type in ("X25519", "X448", "Ed25519", "Ed448"):
|
||||
return "pkcs8"
|
||||
else:
|
||||
return "pkcs1"
|
||||
|
||||
def generate_private_key(self):
|
||||
def generate_private_key(self) -> None:
|
||||
"""(Re-)Generate private key."""
|
||||
try:
|
||||
if self.type == "RSA":
|
||||
@@ -346,13 +381,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate()
|
||||
)
|
||||
if self.type == "ECC" and self.curve in self.curves:
|
||||
if self.curves[self.curve]["deprecated"]:
|
||||
if self.curves[self.curve].deprecated:
|
||||
self.module.warn(
|
||||
f"Elliptic curves of type {self.curve} should not be used for new keys!"
|
||||
)
|
||||
self.private_key = (
|
||||
cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
|
||||
curve=self.curves[self.curve]["create"](self.size),
|
||||
curve=self.curves[self.curve].create(
|
||||
size=self.size, module=self.module
|
||||
),
|
||||
)
|
||||
)
|
||||
except cryptography.exceptions.UnsupportedAlgorithm:
|
||||
@@ -360,22 +397,24 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
msg=f"Cryptography backend does not support the algorithm required for {self.type}"
|
||||
)
|
||||
|
||||
def get_private_key_data(self):
|
||||
def get_private_key_data(self) -> bytes:
|
||||
"""Return bytes for self.private_key"""
|
||||
if self.private_key is None:
|
||||
raise AssertionError("private_key not set")
|
||||
# Select export format and encoding
|
||||
try:
|
||||
export_format = self._get_wanted_format()
|
||||
export_format_txt = self._get_wanted_format()
|
||||
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
|
||||
if export_format == "pkcs1":
|
||||
if export_format_txt == "pkcs1":
|
||||
# "TraditionalOpenSSL" format is PKCS1
|
||||
export_format = (
|
||||
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
|
||||
)
|
||||
elif export_format == "pkcs8":
|
||||
elif export_format_txt == "pkcs8":
|
||||
export_format = (
|
||||
cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
|
||||
)
|
||||
elif export_format == "raw":
|
||||
elif export_format_txt == "raw":
|
||||
export_format = (
|
||||
cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
|
||||
)
|
||||
@@ -388,9 +427,9 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
)
|
||||
|
||||
# Select key encryption
|
||||
encryption_algorithm = (
|
||||
cryptography.hazmat.primitives.serialization.NoEncryption()
|
||||
)
|
||||
encryption_algorithm: (
|
||||
cryptography.hazmat.primitives.serialization.KeySerializationEncryption
|
||||
) = cryptography.hazmat.primitives.serialization.NoEncryption()
|
||||
if self.cipher and self.passphrase:
|
||||
if self.cipher == "auto":
|
||||
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
|
||||
@@ -418,8 +457,10 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
exception=traceback.format_exc(),
|
||||
)
|
||||
|
||||
def _load_privatekey(self):
|
||||
def _load_privatekey(self) -> PrivateKeyTypes:
|
||||
data = self.existing_private_key_bytes
|
||||
if data is None:
|
||||
raise AssertionError("existing_private_key_bytes not set")
|
||||
try:
|
||||
# Interpret bytes depending on format.
|
||||
format = identify_private_key_format(data)
|
||||
@@ -460,11 +501,13 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
except Exception as e:
|
||||
raise PrivateKeyError(e)
|
||||
|
||||
def _ensure_existing_private_key_loaded(self):
|
||||
def _ensure_existing_private_key_loaded(self) -> None:
|
||||
if self.existing_private_key is None and self.has_existing():
|
||||
self.existing_private_key = self._load_privatekey()
|
||||
|
||||
def _check_passphrase(self):
|
||||
def _check_passphrase(self) -> bool:
|
||||
if self.existing_private_key_bytes is None:
|
||||
raise AssertionError("existing_private_key_bytes not set")
|
||||
try:
|
||||
format = identify_private_key_format(self.existing_private_key_bytes)
|
||||
if format == "raw":
|
||||
@@ -475,7 +518,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
# provided.
|
||||
return self.passphrase is None
|
||||
else:
|
||||
return (
|
||||
return bool(
|
||||
cryptography.hazmat.primitives.serialization.load_pem_private_key(
|
||||
self.existing_private_key_bytes,
|
||||
None if self.passphrase is None else to_bytes(self.passphrase),
|
||||
@@ -484,7 +527,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _check_size_and_type(self):
|
||||
def _check_size_and_type(self) -> bool:
|
||||
if isinstance(
|
||||
self.existing_private_key,
|
||||
cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey,
|
||||
@@ -527,11 +570,15 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
return False
|
||||
if self.curve not in self.curves:
|
||||
return False
|
||||
return self.curves[self.curve]["verify"](self.existing_private_key)
|
||||
return self.curves[self.curve].verify(
|
||||
self.existing_private_key, module=self.module
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
def _check_format(self):
|
||||
def _check_format(self) -> bool:
|
||||
if self.existing_private_key_bytes is None:
|
||||
raise AssertionError("existing_private_key_bytes not set")
|
||||
if self.format == "auto_ignore":
|
||||
return True
|
||||
try:
|
||||
@@ -541,14 +588,14 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
|
||||
return False
|
||||
|
||||
|
||||
def select_backend(module):
|
||||
def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return PrivateKeyCryptographyBackend(module)
|
||||
|
||||
|
||||
def get_privatekey_argument_spec():
|
||||
def get_privatekey_argument_spec() -> ArgumentSpec:
|
||||
return ArgumentSpec(
|
||||
argument_spec=dict(
|
||||
size=dict(type="int", default=4096),
|
||||
@@ -607,6 +654,6 @@ def get_privatekey_argument_spec():
|
||||
),
|
||||
),
|
||||
required_if=[
|
||||
["type", "ECC", ["curve"]],
|
||||
("type", "ECC", ["curve"]),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
@@ -27,6 +28,13 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
from ansible_collections.community.crypto.plugins.module_utils.io import load_file
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -58,42 +66,48 @@ class PrivateKeyError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.src_path = module.params["src_path"]
|
||||
self.src_content = module.params["src_content"]
|
||||
self.src_passphrase = module.params["src_passphrase"]
|
||||
self.format = module.params["format"]
|
||||
self.dest_passphrase = module.params["dest_passphrase"]
|
||||
self.src_path: str | None = module.params["src_path"]
|
||||
self.src_content: str | None = module.params["src_content"]
|
||||
self.src_passphrase: str | None = module.params["src_passphrase"]
|
||||
self.format: t.Literal["pkcs1", "pkcs8", "raw"] = module.params["format"]
|
||||
self.dest_passphrase: str | None = module.params["dest_passphrase"]
|
||||
|
||||
self.src_private_key = None
|
||||
self.src_private_key: PrivateKeyTypes | None = None
|
||||
if self.src_path is not None:
|
||||
self.src_private_key_bytes = load_file(self.src_path, module)
|
||||
else:
|
||||
if self.src_content is None:
|
||||
raise AssertionError("src_content is None")
|
||||
self.src_private_key_bytes = self.src_content.encode("utf-8")
|
||||
|
||||
self.dest_private_key = None
|
||||
self.dest_private_key_bytes = None
|
||||
self.dest_private_key: PrivateKeyTypes | None = None
|
||||
self.dest_private_key_bytes: bytes | None = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_private_key_data(self):
|
||||
def get_private_key_data(self) -> bytes:
|
||||
"""Return bytes for self.src_private_key in output format."""
|
||||
pass
|
||||
|
||||
def set_existing_destination(self, privatekey_bytes):
|
||||
def set_existing_destination(self, privatekey_bytes: bytes | None) -> None:
|
||||
"""Set existing private key bytes. None indicates that the key does not exist."""
|
||||
self.dest_private_key_bytes = privatekey_bytes
|
||||
|
||||
def has_existing_destination(self):
|
||||
def has_existing_destination(self) -> bool:
|
||||
"""Query whether an existing private key is/has been there."""
|
||||
return self.dest_private_key_bytes is not None
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_private_key(self, data, passphrase, current_hint=None):
|
||||
def _load_private_key(
|
||||
self,
|
||||
data: bytes,
|
||||
passphrase: str | None,
|
||||
current_hint: PrivateKeyTypes | None = None,
|
||||
) -> tuple[str, PrivateKeyTypes]:
|
||||
"""Check whether data can be loaded as a private key with the provided passphrase. Return tuple (type, private_key)."""
|
||||
pass
|
||||
|
||||
def needs_conversion(self):
|
||||
def needs_conversion(self) -> bool:
|
||||
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
|
||||
dummy, self.src_private_key = self._load_private_key(
|
||||
self.src_private_key_bytes, self.src_passphrase
|
||||
@@ -101,6 +115,7 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
|
||||
if not self.has_existing_destination():
|
||||
return True
|
||||
assert self.dest_private_key_bytes is not None
|
||||
|
||||
try:
|
||||
format, self.dest_private_key = self._load_private_key(
|
||||
@@ -115,18 +130,20 @@ class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
|
||||
self.dest_private_key, self.src_private_key
|
||||
)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
return {}
|
||||
|
||||
|
||||
# Implementation with using cryptography
|
||||
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module)
|
||||
|
||||
def get_private_key_data(self):
|
||||
def get_private_key_data(self) -> bytes:
|
||||
"""Return bytes for self.src_private_key in output format"""
|
||||
if self.src_private_key is None:
|
||||
raise AssertionError("src_private_key not set")
|
||||
# Select export format and encoding
|
||||
try:
|
||||
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
|
||||
@@ -152,9 +169,9 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
|
||||
)
|
||||
|
||||
# Select key encryption
|
||||
encryption_algorithm = (
|
||||
cryptography.hazmat.primitives.serialization.NoEncryption()
|
||||
)
|
||||
encryption_algorithm: (
|
||||
cryptography.hazmat.primitives.serialization.KeySerializationEncryption
|
||||
) = cryptography.hazmat.primitives.serialization.NoEncryption()
|
||||
if self.dest_passphrase:
|
||||
encryption_algorithm = (
|
||||
cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
|
||||
@@ -179,7 +196,12 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
|
||||
exception=traceback.format_exc(),
|
||||
)
|
||||
|
||||
def _load_private_key(self, data, passphrase, current_hint=None):
|
||||
def _load_private_key(
|
||||
self,
|
||||
data: bytes,
|
||||
passphrase: str | None,
|
||||
current_hint: PrivateKeyTypes | None = None,
|
||||
) -> tuple[str, PrivateKeyTypes]:
|
||||
try:
|
||||
# Interpret bytes depending on format.
|
||||
format = identify_private_key_format(data)
|
||||
@@ -247,14 +269,14 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
|
||||
raise PrivateKeyError(e)
|
||||
|
||||
|
||||
def select_backend(module):
|
||||
def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return PrivateKeyConvertCryptographyBackend(module)
|
||||
|
||||
|
||||
def get_privatekey_argument_spec():
|
||||
def get_privatekey_argument_spec() -> ArgumentSpec:
|
||||
return ArgumentSpec(
|
||||
argument_spec=dict(
|
||||
src_path=dict(type="path"),
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -29,6 +30,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
from ....plugin_utils.filter_module import FilterModuleMock
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -40,38 +53,49 @@ except ImportError:
|
||||
SIGNATURE_TEST_DATA = b"1234"
|
||||
|
||||
|
||||
def _get_cryptography_private_key_info(key, need_private_key_data=False):
|
||||
def _get_cryptography_private_key_info(
|
||||
key: PrivateKeyTypes, need_private_key_data: bool = False
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
key_type, key_public_data = _get_cryptography_public_key_info(key.public_key())
|
||||
key_private_data = dict()
|
||||
key_private_data: dict[str, t.Any] = {}
|
||||
if need_private_key_data:
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
|
||||
private_numbers = key.private_numbers()
|
||||
key_private_data["p"] = private_numbers.p
|
||||
key_private_data["q"] = private_numbers.q
|
||||
key_private_data["exponent"] = private_numbers.d
|
||||
rsa_private_numbers = key.private_numbers()
|
||||
key_private_data["p"] = rsa_private_numbers.p
|
||||
key_private_data["q"] = rsa_private_numbers.q
|
||||
key_private_data["exponent"] = rsa_private_numbers.d
|
||||
elif isinstance(
|
||||
key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey
|
||||
):
|
||||
private_numbers = key.private_numbers()
|
||||
key_private_data["x"] = private_numbers.x
|
||||
dsa_private_numbers = key.private_numbers()
|
||||
key_private_data["x"] = dsa_private_numbers.x
|
||||
elif isinstance(
|
||||
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
|
||||
):
|
||||
private_numbers = key.private_numbers()
|
||||
key_private_data["multiplier"] = private_numbers.private_value
|
||||
ecc_private_numbers = key.private_numbers()
|
||||
key_private_data["multiplier"] = ecc_private_numbers.private_value
|
||||
return key_type, key_public_data, key_private_data
|
||||
|
||||
|
||||
def _check_dsa_consistency(key_public_data, key_private_data):
|
||||
def _check_dsa_consistency(
|
||||
key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
) -> bool | None:
|
||||
# Get parameters
|
||||
p = key_public_data.get("p")
|
||||
q = key_public_data.get("q")
|
||||
g = key_public_data.get("g")
|
||||
y = key_public_data.get("y")
|
||||
x = key_private_data.get("x")
|
||||
for v in (p, q, g, y, x):
|
||||
if v is None:
|
||||
return None
|
||||
p: int | None = key_public_data.get("p")
|
||||
if p is None:
|
||||
return None
|
||||
q: int | None = key_public_data.get("q")
|
||||
if q is None:
|
||||
return None
|
||||
g: int | None = key_public_data.get("g")
|
||||
if g is None:
|
||||
return None
|
||||
y: int | None = key_public_data.get("y")
|
||||
if y is None:
|
||||
return None
|
||||
x: int | None = key_private_data.get("x")
|
||||
if x is None:
|
||||
return None
|
||||
# Make sure that g is not 0, 1 or -1 in Z/pZ
|
||||
if g < 2 or g >= p - 1:
|
||||
return False
|
||||
@@ -94,13 +118,16 @@ def _check_dsa_consistency(key_public_data, key_private_data):
|
||||
|
||||
|
||||
def _is_cryptography_key_consistent(
|
||||
key, key_public_data, key_private_data, warn_func=None
|
||||
):
|
||||
key: PrivateKeyTypes,
|
||||
key_public_data: dict[str, t.Any],
|
||||
key_private_data: dict[str, t.Any],
|
||||
warn_func: t.Callable[[str], None] | None = None,
|
||||
) -> bool | None:
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
|
||||
# key._backend was removed in cryptography 42.0.0
|
||||
backend = getattr(key, "_backend", None)
|
||||
if backend is not None:
|
||||
return bool(backend._lib.RSA_check_key(key._rsa_cdata))
|
||||
return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
|
||||
result = _check_dsa_consistency(key_public_data, key_private_data)
|
||||
if result is not None:
|
||||
@@ -145,9 +172,9 @@ def _is_cryptography_key_consistent(
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey):
|
||||
has_simple_sign_function = True
|
||||
if has_simple_sign_function:
|
||||
signature = key.sign(SIGNATURE_TEST_DATA)
|
||||
signature = key.sign(SIGNATURE_TEST_DATA) # type: ignore
|
||||
try:
|
||||
key.public_key().verify(signature, SIGNATURE_TEST_DATA)
|
||||
key.public_key().verify(signature, SIGNATURE_TEST_DATA) # type: ignore
|
||||
return True
|
||||
except cryptography.exceptions.InvalidSignature:
|
||||
return False
|
||||
@@ -158,14 +185,14 @@ def _is_cryptography_key_consistent(
|
||||
|
||||
|
||||
class PrivateKeyConsistencyError(OpenSSLObjectError):
|
||||
def __init__(self, msg, result):
|
||||
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
|
||||
super(PrivateKeyConsistencyError, self).__init__(msg)
|
||||
self.error_message = msg
|
||||
self.result = result
|
||||
|
||||
|
||||
class PrivateKeyParseError(OpenSSLObjectError):
|
||||
def __init__(self, msg, result):
|
||||
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
|
||||
super(PrivateKeyParseError, self).__init__(msg)
|
||||
self.error_message = msg
|
||||
self.result = result
|
||||
@@ -174,13 +201,12 @@ class PrivateKeyParseError(OpenSSLObjectError):
|
||||
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
content,
|
||||
passphrase=None,
|
||||
return_private_key_data=False,
|
||||
check_consistency=False,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
passphrase: str | None = None,
|
||||
return_private_key_data: bool = False,
|
||||
check_consistency: bool = False,
|
||||
):
|
||||
# content must be a bytes string
|
||||
self.module = module
|
||||
self.content = content
|
||||
self.passphrase = passphrase
|
||||
@@ -188,22 +214,26 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
self.check_consistency = check_consistency
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key(self, binary):
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_key_info(self, need_private_key_data=False):
|
||||
def _get_key_info(
|
||||
self, need_private_key_data: bool = False
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _is_key_consistent(self, key_public_data, key_private_data):
|
||||
def _is_key_consistent(
|
||||
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
) -> bool | None:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint=False):
|
||||
result = dict(
|
||||
can_parse_key=False,
|
||||
key_is_consistent=None,
|
||||
)
|
||||
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {
|
||||
"can_parse_key": False,
|
||||
"key_is_consistent": None,
|
||||
}
|
||||
priv_key_detail = self.content
|
||||
try:
|
||||
self.key = load_privatekey(
|
||||
@@ -252,35 +282,39 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
|
||||
"""Validate the supplied private key, using the cryptography backend"""
|
||||
|
||||
def __init__(self, module, content, **kwargs):
|
||||
def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None:
|
||||
super(PrivateKeyInfoRetrievalCryptography, self).__init__(
|
||||
module, content, **kwargs
|
||||
)
|
||||
|
||||
def _get_public_key(self, binary):
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
return self.key.public_key().public_bytes(
|
||||
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def _get_key_info(self, need_private_key_data=False):
|
||||
def _get_key_info(
|
||||
self, need_private_key_data: bool = False
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return _get_cryptography_private_key_info(
|
||||
self.key, need_private_key_data=need_private_key_data
|
||||
)
|
||||
|
||||
def _is_key_consistent(self, key_public_data, key_private_data):
|
||||
def _is_key_consistent(
|
||||
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
|
||||
) -> bool | None:
|
||||
return _is_cryptography_key_consistent(
|
||||
self.key, key_public_data, key_private_data, warn_func=self.module.warn
|
||||
)
|
||||
|
||||
|
||||
def get_privatekey_info(
|
||||
module,
|
||||
content,
|
||||
passphrase=None,
|
||||
return_private_key_data=False,
|
||||
prefer_one_fingerprint=False,
|
||||
):
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
passphrase: str | None = None,
|
||||
return_private_key_data: bool = False,
|
||||
prefer_one_fingerprint: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
info = PrivateKeyInfoRetrievalCryptography(
|
||||
module,
|
||||
content,
|
||||
@@ -291,12 +325,12 @@ def get_privatekey_info(
|
||||
|
||||
|
||||
def select_backend(
|
||||
module,
|
||||
content,
|
||||
passphrase=None,
|
||||
return_private_key_data=False,
|
||||
check_consistency=False,
|
||||
):
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes,
|
||||
passphrase: str | None = None,
|
||||
return_private_key_data: bool = False,
|
||||
check_consistency: bool = False,
|
||||
) -> PrivateKeyInfoRetrieval:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
@@ -19,6 +20,18 @@ from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PublicKeyTypes,
|
||||
)
|
||||
|
||||
from ....plugin_utils.action_module import AnsibleActionModule
|
||||
from ....plugin_utils.filter_module import FilterModuleMock
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
try:
|
||||
@@ -32,23 +45,25 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _get_cryptography_public_key_info(key):
|
||||
key_public_data = dict()
|
||||
def _get_cryptography_public_key_info(
|
||||
key: PublicKeyTypes,
|
||||
) -> tuple[str, dict[str, t.Any]]:
|
||||
key_public_data: dict[str, t.Any] = {}
|
||||
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey):
|
||||
key_type = "RSA"
|
||||
public_numbers = key.public_numbers()
|
||||
rsa_public_numbers = key.public_numbers()
|
||||
key_public_data["size"] = key.key_size
|
||||
key_public_data["modulus"] = public_numbers.n
|
||||
key_public_data["exponent"] = public_numbers.e
|
||||
key_public_data["modulus"] = rsa_public_numbers.n
|
||||
key_public_data["exponent"] = rsa_public_numbers.e
|
||||
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey):
|
||||
key_type = "DSA"
|
||||
parameter_numbers = key.parameters().parameter_numbers()
|
||||
public_numbers = key.public_numbers()
|
||||
dsa_parameter_numbers = key.parameters().parameter_numbers()
|
||||
dsa_public_numbers = key.public_numbers()
|
||||
key_public_data["size"] = key.key_size
|
||||
key_public_data["p"] = parameter_numbers.p
|
||||
key_public_data["q"] = parameter_numbers.q
|
||||
key_public_data["g"] = parameter_numbers.g
|
||||
key_public_data["y"] = public_numbers.y
|
||||
key_public_data["p"] = dsa_parameter_numbers.p
|
||||
key_public_data["q"] = dsa_parameter_numbers.q
|
||||
key_public_data["g"] = dsa_parameter_numbers.g
|
||||
key_public_data["y"] = dsa_public_numbers.y
|
||||
elif isinstance(
|
||||
key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey
|
||||
):
|
||||
@@ -67,10 +82,10 @@ def _get_cryptography_public_key_info(key):
|
||||
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey
|
||||
):
|
||||
key_type = "ECC"
|
||||
public_numbers = key.public_numbers()
|
||||
ecc_public_numbers = key.public_numbers()
|
||||
key_public_data["curve"] = key.curve.name
|
||||
key_public_data["x"] = public_numbers.x
|
||||
key_public_data["y"] = public_numbers.y
|
||||
key_public_data["x"] = ecc_public_numbers.x
|
||||
key_public_data["y"] = ecc_public_numbers.y
|
||||
key_public_data["exponent_size"] = key.curve.key_size
|
||||
else:
|
||||
key_type = f"unknown ({type(key)})"
|
||||
@@ -78,29 +93,34 @@ def _get_cryptography_public_key_info(key):
|
||||
|
||||
|
||||
class PublicKeyParseError(OpenSSLObjectError):
|
||||
def __init__(self, msg, result):
|
||||
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
|
||||
super(PublicKeyParseError, self).__init__(msg)
|
||||
self.error_message = msg
|
||||
self.result = result
|
||||
|
||||
|
||||
class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module, content=None, key=None):
|
||||
def __init__(
|
||||
self,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
) -> None:
|
||||
# content must be a bytes string
|
||||
self.module = module
|
||||
self.content = content
|
||||
self.key = key
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key(self, binary):
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_key_info(self):
|
||||
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
|
||||
pass
|
||||
|
||||
def get_info(self, prefer_one_fingerprint=False):
|
||||
result = dict()
|
||||
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
if self.key is None:
|
||||
try:
|
||||
self.key = load_publickey(content=self.content)
|
||||
@@ -123,27 +143,45 @@ class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
|
||||
class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
|
||||
"""Validate the supplied public key, using the cryptography backend"""
|
||||
|
||||
def __init__(self, module, content=None, key=None):
|
||||
def __init__(
|
||||
self,
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
) -> None:
|
||||
super(PublicKeyInfoRetrievalCryptography, self).__init__(
|
||||
module, content=content, key=key
|
||||
)
|
||||
|
||||
def _get_public_key(self, binary):
|
||||
def _get_public_key(self, binary: bool) -> bytes:
|
||||
if self.key is None:
|
||||
raise AssertionError("key must be set")
|
||||
return self.key.public_bytes(
|
||||
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def _get_key_info(self):
|
||||
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
|
||||
if self.key is None:
|
||||
raise AssertionError("key must be set")
|
||||
return _get_cryptography_public_key_info(self.key)
|
||||
|
||||
|
||||
def get_publickey_info(module, content=None, key=None, prefer_one_fingerprint=False):
|
||||
def get_publickey_info(
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
prefer_one_fingerprint: bool = False,
|
||||
) -> dict[str, t.Any]:
|
||||
info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
|
||||
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
|
||||
|
||||
|
||||
def select_backend(module, content=None, key=None):
|
||||
def select_backend(
|
||||
module: GeneralAnsibleModule,
|
||||
content: bytes | None = None,
|
||||
key: PublicKeyTypes | None = None,
|
||||
) -> PublicKeyInfoRetrieval:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
@@ -8,3 +8,6 @@ from __future__ import annotations
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import ( # noqa: F401, pylint: disable=unused-import
|
||||
parse_openssh_version,
|
||||
)
|
||||
|
||||
|
||||
# TODO: delete!
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
|
||||
PEM_START = "-----BEGIN "
|
||||
PEM_END_START = "-----END "
|
||||
@@ -12,7 +14,7 @@ PKCS8_PRIVATEKEY_NAMES = ("PRIVATE KEY", "ENCRYPTED PRIVATE KEY")
|
||||
PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY"
|
||||
|
||||
|
||||
def identify_pem_format(content, encoding="utf-8"):
|
||||
def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
|
||||
"""Given the contents of a binary file, tests whether this could be a PEM file."""
|
||||
try:
|
||||
first_pem = extract_first_pem(content.decode(encoding))
|
||||
@@ -30,7 +32,9 @@ def identify_pem_format(content, encoding="utf-8"):
|
||||
return False
|
||||
|
||||
|
||||
def identify_private_key_format(content, encoding="utf-8"):
|
||||
def identify_private_key_format(
|
||||
content: bytes, encoding: str = "utf-8"
|
||||
) -> t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]:
|
||||
"""Given the contents of a private key file, identifies its format."""
|
||||
# See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85
|
||||
# (PEM_read_bio_PrivateKey)
|
||||
@@ -59,12 +63,12 @@ def identify_private_key_format(content, encoding="utf-8"):
|
||||
return "raw"
|
||||
|
||||
|
||||
def split_pem_list(text, keep_inbetween=False):
|
||||
def split_pem_list(text: str, keep_inbetween: bool = False) -> list[str]:
|
||||
"""
|
||||
Split concatenated PEM objects into a list of strings, where each is one PEM object.
|
||||
"""
|
||||
result = []
|
||||
current = [] if keep_inbetween else None
|
||||
current: list[str] | None = [] if keep_inbetween else None
|
||||
for line in text.splitlines(True):
|
||||
if line.strip():
|
||||
if not keep_inbetween and line.startswith("-----BEGIN "):
|
||||
@@ -77,7 +81,7 @@ def split_pem_list(text, keep_inbetween=False):
|
||||
return result
|
||||
|
||||
|
||||
def extract_first_pem(text):
|
||||
def extract_first_pem(text: str) -> str | None:
|
||||
"""
|
||||
Given one PEM or multiple concatenated PEM objects, return only the first one, or None if there is none.
|
||||
"""
|
||||
@@ -87,7 +91,7 @@ def extract_first_pem(text):
|
||||
return all_pems[0]
|
||||
|
||||
|
||||
def _extract_type(line, start=PEM_START):
|
||||
def _extract_type(line: str, start: str = PEM_START) -> str | None:
|
||||
if not line.startswith(start):
|
||||
return None
|
||||
if not line.endswith(PEM_END):
|
||||
@@ -95,7 +99,7 @@ def _extract_type(line, start=PEM_START):
|
||||
return line[len(start) : -len(PEM_END)]
|
||||
|
||||
|
||||
def extract_pem(content, strict=False):
|
||||
def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
|
||||
lines = content.splitlines()
|
||||
if len(lines) < 3:
|
||||
raise ValueError(f"PEM must have at least 3 lines, have only {len(lines)}")
|
||||
@@ -117,5 +121,4 @@ def extract_pem(content, strict=False):
|
||||
raise ValueError(
|
||||
f"Last line has length {len(lines[-2])}, should be in (0, 64]"
|
||||
)
|
||||
content = lines[1:-1]
|
||||
return header_type, "".join(content)
|
||||
return header_type, "".join(lines[1:-1])
|
||||
|
||||
@@ -8,8 +8,13 @@ import abc
|
||||
import errno
|
||||
import hashlib
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
|
||||
is_potential_certificate_issuer_private_key,
|
||||
is_potential_certificate_private_key,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import (
|
||||
identify_pem_format,
|
||||
)
|
||||
@@ -34,6 +39,17 @@ except ImportError:
|
||||
from .basic import OpenSSLBadPassphraseError, OpenSSLObjectError
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
PublicKeyTypes,
|
||||
)
|
||||
|
||||
from .cryptography_support import CertificatePrivateKeyTypes
|
||||
|
||||
|
||||
# This list of preferred fingerprints is used when prefer_one=True is supplied to the
|
||||
# fingerprinting methods.
|
||||
PREFERRED_FINGERPRINTS = (
|
||||
@@ -48,18 +64,12 @@ PREFERRED_FINGERPRINTS = (
|
||||
)
|
||||
|
||||
|
||||
def get_fingerprint_of_bytes(source, prefer_one=False):
|
||||
def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[str, str]:
|
||||
"""Generate the fingerprint of the given bytes."""
|
||||
|
||||
fingerprint = {}
|
||||
|
||||
try:
|
||||
algorithms = hashlib.algorithms
|
||||
except AttributeError:
|
||||
try:
|
||||
algorithms = hashlib.algorithms_guaranteed
|
||||
except AttributeError:
|
||||
return None
|
||||
algorithms: t.Iterable[str] = hashlib.algorithms_guaranteed
|
||||
|
||||
if prefer_one:
|
||||
# Sort algorithms to have the ones in PREFERRED_FINGERPRINTS at the beginning
|
||||
@@ -97,7 +107,9 @@ def get_fingerprint_of_bytes(source, prefer_one=False):
|
||||
return fingerprint
|
||||
|
||||
|
||||
def get_fingerprint_of_privatekey(privatekey, prefer_one=False):
|
||||
def get_fingerprint_of_privatekey(
|
||||
privatekey: PrivateKeyTypes, prefer_one: bool = False
|
||||
) -> dict[str, str]:
|
||||
"""Generate the fingerprint of the public key."""
|
||||
|
||||
publickey = privatekey.public_key().public_bytes(
|
||||
@@ -107,11 +119,16 @@ def get_fingerprint_of_privatekey(privatekey, prefer_one=False):
|
||||
return get_fingerprint_of_bytes(publickey, prefer_one=prefer_one)
|
||||
|
||||
|
||||
def get_fingerprint(path, passphrase=None, content=None, prefer_one=False):
|
||||
def get_fingerprint(
|
||||
path: os.PathLike | str | None = None,
|
||||
passphrase: str | bytes | None = None,
|
||||
content: bytes | None = None,
|
||||
prefer_one: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Generate the fingerprint of the public key."""
|
||||
|
||||
privatekey = load_privatekey(
|
||||
path,
|
||||
path=path,
|
||||
passphrase=passphrase,
|
||||
content=content,
|
||||
check_passphrase=False,
|
||||
@@ -121,11 +138,11 @@ def get_fingerprint(path, passphrase=None, content=None, prefer_one=False):
|
||||
|
||||
|
||||
def load_privatekey(
|
||||
path,
|
||||
passphrase=None,
|
||||
check_passphrase=True,
|
||||
content=None,
|
||||
):
|
||||
path: os.PathLike | str | None = None,
|
||||
passphrase: str | bytes | None = None,
|
||||
check_passphrase: bool = True,
|
||||
content: bytes | None = None,
|
||||
) -> PrivateKeyTypes:
|
||||
"""Load the specified OpenSSL private key.
|
||||
|
||||
The content can also be specified via content; in that case,
|
||||
@@ -134,6 +151,8 @@ def load_privatekey(
|
||||
|
||||
try:
|
||||
if content is None:
|
||||
if path is None:
|
||||
raise OpenSSLObjectError("Must provide either path or content")
|
||||
with open(path, "rb") as b_priv_key_fh:
|
||||
priv_key_detail = b_priv_key_fh.read()
|
||||
else:
|
||||
@@ -154,7 +173,55 @@ def load_privatekey(
|
||||
raise OpenSSLBadPassphraseError("Wrong passphrase provided for private key")
|
||||
|
||||
|
||||
def load_publickey(path=None, content=None):
|
||||
def load_certificate_privatekey(
|
||||
*,
|
||||
path: os.PathLike | str | None = None,
|
||||
content: bytes | None = None,
|
||||
passphrase: str | bytes | None = None,
|
||||
check_passphrase: bool = True,
|
||||
) -> CertificatePrivateKeyTypes:
|
||||
"""
|
||||
Load the specified OpenSSL private key that can be used as a private key for certificates.
|
||||
"""
|
||||
private_key = load_privatekey(
|
||||
path=path,
|
||||
passphrase=passphrase,
|
||||
check_passphrase=check_passphrase,
|
||||
content=content,
|
||||
)
|
||||
if not is_potential_certificate_private_key(private_key):
|
||||
raise OpenSSLObjectError(
|
||||
f"Key of type {type(private_key)} not supported for certificates"
|
||||
)
|
||||
return private_key
|
||||
|
||||
|
||||
def load_certificate_issuer_privatekey(
|
||||
*,
|
||||
path: os.PathLike | str | None = None,
|
||||
content: bytes | None = None,
|
||||
passphrase: str | bytes | None = None,
|
||||
check_passphrase: bool = True,
|
||||
) -> CertificateIssuerPrivateKeyTypes:
|
||||
"""
|
||||
Load the specified OpenSSL private key that can be used for issuing certificates.
|
||||
"""
|
||||
private_key = load_privatekey(
|
||||
path=path,
|
||||
passphrase=passphrase,
|
||||
check_passphrase=check_passphrase,
|
||||
content=content,
|
||||
)
|
||||
if not is_potential_certificate_issuer_private_key(private_key):
|
||||
raise OpenSSLObjectError(
|
||||
f"Key of type {type(private_key)} not supported for issuing certificates"
|
||||
)
|
||||
return private_key
|
||||
|
||||
|
||||
def load_publickey(
|
||||
path: os.PathLike | str | None = None, content: bytes | None = None
|
||||
) -> PublicKeyTypes:
|
||||
if content is None:
|
||||
if path is None:
|
||||
raise OpenSSLObjectError("Must provide either path or content")
|
||||
@@ -170,11 +237,17 @@ def load_publickey(path=None, content=None):
|
||||
raise OpenSSLObjectError(f"Error while deserializing key: {e}")
|
||||
|
||||
|
||||
def load_certificate(path, content=None, der_support_enabled=False):
|
||||
def load_certificate(
|
||||
path: os.PathLike | str | None = None,
|
||||
content: bytes | None = None,
|
||||
der_support_enabled: bool = False,
|
||||
) -> x509.Certificate:
|
||||
"""Load the specified certificate."""
|
||||
|
||||
try:
|
||||
if content is None:
|
||||
if path is None:
|
||||
raise OpenSSLObjectError("Must provide either path or content")
|
||||
with open(path, "rb") as cert_fh:
|
||||
cert_content = cert_fh.read()
|
||||
else:
|
||||
@@ -193,10 +266,14 @@ def load_certificate(path, content=None, der_support_enabled=False):
|
||||
raise OpenSSLObjectError(f"Cannot parse DER certificate: {exc}")
|
||||
|
||||
|
||||
def load_certificate_request(path, content=None):
|
||||
def load_certificate_request(
|
||||
path: os.PathLike | str | None = None, content: bytes | None = None
|
||||
) -> x509.CertificateSigningRequest:
|
||||
"""Load the specified certificate signing request."""
|
||||
try:
|
||||
if content is None:
|
||||
if path is None:
|
||||
raise OpenSSLObjectError("Must provide either path or content")
|
||||
with open(path, "rb") as csr_fh:
|
||||
csr_content = csr_fh.read()
|
||||
else:
|
||||
@@ -209,45 +286,44 @@ def load_certificate_request(path, content=None):
|
||||
raise OpenSSLObjectError(exc)
|
||||
|
||||
|
||||
def parse_name_field(input_dict, name_field_name=None):
|
||||
def parse_name_field(
|
||||
input_dict: dict[str, list[str | bytes] | str | bytes],
|
||||
name_field_name: str | None = None,
|
||||
) -> list[tuple[str, str | bytes]]:
|
||||
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
|
||||
error_str = "{key}" if name_field_name is None else "{key} in {name}"
|
||||
|
||||
def error_str(key: str) -> str:
|
||||
if name_field_name is None:
|
||||
return f"{key}"
|
||||
return f"{key} in {name_field_name}"
|
||||
|
||||
result = []
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, list):
|
||||
for entry in value:
|
||||
if not isinstance(entry, (str, bytes)):
|
||||
raise TypeError(
|
||||
f"Values {error_str} must be strings".format(
|
||||
key=key, name=name_field_name
|
||||
)
|
||||
)
|
||||
raise TypeError(f"Values {error_str(key)} must be strings")
|
||||
if not entry:
|
||||
raise ValueError(
|
||||
f"Values for {error_str} must not be empty strings".format(
|
||||
key=key, name=name_field_name
|
||||
)
|
||||
f"Values for {error_str(key)} must not be empty strings"
|
||||
)
|
||||
result.append((key, entry))
|
||||
elif isinstance(value, (str, bytes)):
|
||||
if not value:
|
||||
raise ValueError(
|
||||
f"Value for {error_str} must not be an empty string".format(
|
||||
key=key, name=name_field_name
|
||||
)
|
||||
f"Value for {error_str(key)} must not be an empty string"
|
||||
)
|
||||
result.append((key, value))
|
||||
else:
|
||||
raise TypeError(
|
||||
(
|
||||
f"Value for {error_str} must be either a string or a list of strings"
|
||||
).format(key=key, name=name_field_name)
|
||||
f"Value for {error_str(key)} must be either a string or a list of strings"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def parse_ordered_name_field(input_list, name_field_name):
|
||||
def parse_ordered_name_field(
|
||||
input_list: list[dict[str, list[str | bytes] | str | bytes]], name_field_name: str
|
||||
) -> list[tuple[str, str | bytes]]:
|
||||
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
|
||||
|
||||
result = []
|
||||
@@ -265,24 +341,39 @@ def parse_ordered_name_field(input_list, name_field_name):
|
||||
return result
|
||||
|
||||
|
||||
def select_message_digest(digest_string):
|
||||
digest = None
|
||||
@t.overload
|
||||
def select_message_digest(
|
||||
digest_string: t.Literal["sha256", "sha384", "sha512", "sha1", "md5"],
|
||||
) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def select_message_digest(
|
||||
digest_string: str,
|
||||
) -> (
|
||||
hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None
|
||||
): ...
|
||||
|
||||
|
||||
def select_message_digest(
|
||||
digest_string: str,
|
||||
) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None:
|
||||
if digest_string == "sha256":
|
||||
digest = hashes.SHA256()
|
||||
elif digest_string == "sha384":
|
||||
digest = hashes.SHA384()
|
||||
elif digest_string == "sha512":
|
||||
digest = hashes.SHA512()
|
||||
elif digest_string == "sha1":
|
||||
digest = hashes.SHA1()
|
||||
elif digest_string == "md5":
|
||||
digest = hashes.MD5()
|
||||
return digest
|
||||
return hashes.SHA256()
|
||||
if digest_string == "sha384":
|
||||
return hashes.SHA384()
|
||||
if digest_string == "sha512":
|
||||
return hashes.SHA512()
|
||||
if digest_string == "sha1":
|
||||
return hashes.SHA1()
|
||||
if digest_string == "md5":
|
||||
return hashes.MD5()
|
||||
return None
|
||||
|
||||
|
||||
class OpenSSLObject(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, path, state, force, check_mode):
|
||||
def __init__(self, path: str, state: str, force: bool, check_mode: bool) -> None:
|
||||
self.path = path
|
||||
self.state = state
|
||||
self.force = force
|
||||
@@ -290,13 +381,13 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
|
||||
self.changed = False
|
||||
self.check_mode = check_mode
|
||||
|
||||
def check(self, module, perms_required=True):
|
||||
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
|
||||
def _check_state():
|
||||
def _check_state() -> bool:
|
||||
return os.path.exists(self.path)
|
||||
|
||||
def _check_perms(module):
|
||||
def _check_perms(module: AnsibleModule) -> bool:
|
||||
file_args = module.load_file_common_arguments(module.params)
|
||||
if module.check_file_absent_if_check_mode(file_args["path"]):
|
||||
return False
|
||||
@@ -308,18 +399,14 @@ class OpenSSLObject(metaclass=abc.ABCMeta):
|
||||
return _check_state() and _check_perms(module)
|
||||
|
||||
@abc.abstractmethod
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate(self):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate the resource."""
|
||||
|
||||
pass
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
"""Remove the resource from the filesystem."""
|
||||
if self.check_mode:
|
||||
if os.path.exists(self.path):
|
||||
|
||||
@@ -11,6 +11,7 @@ Must be kept in sync with plugins/doc_fragments/cryptography_dep.py.
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import missing_required_lib
|
||||
from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
@@ -18,19 +19,29 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
)
|
||||
|
||||
|
||||
_CRYPTOGRAPHY_IMP_ERR = None
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
from ..plugin_utils.action_module import AnsibleActionModule
|
||||
from ..plugin_utils.filter_module import FilterModuleMock
|
||||
|
||||
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
|
||||
|
||||
|
||||
_CRYPTOGRAPHY_IMP_ERR: str | None = None
|
||||
_CRYPTOGRAPHY_FILE: str | None = None
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography import x509 # noqa: F401, pylint: disable=unused-import
|
||||
|
||||
_CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
|
||||
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
|
||||
_CRYPTOGRAPHY_FILE = cryptography.__file__
|
||||
except ImportError:
|
||||
_CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
|
||||
_CRYPTOGRAPHY_FOUND = False
|
||||
_CRYPTOGRAPHY_FILE = None
|
||||
CRYPTOGRAPHY_FOUND = False
|
||||
CRYPTOGRAPHY_VERSION = LooseVersion("0.0")
|
||||
else:
|
||||
_CRYPTOGRAPHY_FOUND = True
|
||||
CRYPTOGRAPHY_FOUND = True
|
||||
|
||||
|
||||
# Corresponds to the community.crypto.cryptography_dep.minimum doc fragment
|
||||
@@ -38,25 +49,27 @@ COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION = "3.3"
|
||||
|
||||
|
||||
def assert_required_cryptography_version(
|
||||
module,
|
||||
module: GeneralAnsibleModule,
|
||||
*,
|
||||
minimum_cryptography_version: str = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
|
||||
) -> None:
|
||||
if not _CRYPTOGRAPHY_FOUND:
|
||||
if not CRYPTOGRAPHY_FOUND:
|
||||
module.fail_json(
|
||||
msg=missing_required_lib(f"cryptography >= {minimum_cryptography_version}"),
|
||||
exception=_CRYPTOGRAPHY_IMP_ERR,
|
||||
)
|
||||
if _CRYPTOGRAPHY_VERSION < LooseVersion(minimum_cryptography_version):
|
||||
if CRYPTOGRAPHY_VERSION < LooseVersion(minimum_cryptography_version):
|
||||
module.fail_json(
|
||||
msg=(
|
||||
f"Cannot detect the required Python library cryptography (>= {minimum_cryptography_version})."
|
||||
f" Only found a too old version ({_CRYPTOGRAPHY_VERSION}) at {_CRYPTOGRAPHY_FILE}."
|
||||
f" Only found a too old version ({CRYPTOGRAPHY_VERSION}) at {_CRYPTOGRAPHY_FILE}."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION",
|
||||
"CRYPTOGRAPHY_FOUND",
|
||||
"CRYPTOGRAPHY_VERSION",
|
||||
"assert_required_cryptography_version",
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
import typing as t
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import urlencode
|
||||
|
||||
@@ -34,7 +35,7 @@ else:
|
||||
valid_file_format = re.compile(r".*(\.)(yml|yaml|json)$")
|
||||
|
||||
|
||||
def ecs_client_argument_spec():
|
||||
def ecs_client_argument_spec() -> dict[str, t.Any]:
|
||||
return dict(
|
||||
entrust_api_user=dict(type="str", required=True),
|
||||
entrust_api_key=dict(type="str", required=True, no_log=True),
|
||||
@@ -50,19 +51,17 @@ def ecs_client_argument_spec():
|
||||
class SessionConfigurationException(Exception):
|
||||
"""Raised if we cannot configure a session with the API"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RestOperationException(Exception):
|
||||
"""Encapsulate a REST API error"""
|
||||
|
||||
def __init__(self, error):
|
||||
def __init__(self, error: dict[str, t.Any]) -> None:
|
||||
self.status = to_native(error.get("status", None))
|
||||
self.errors = [to_native(err.get("message")) for err in error.get("errors", {})]
|
||||
self.message = " ".join(self.errors)
|
||||
|
||||
|
||||
def generate_docstring(operation_spec):
|
||||
def generate_docstring(operation_spec: dict[str, t.Any]) -> str:
|
||||
"""Generate a docstring for an operation defined in operation_spec (swagger)"""
|
||||
# Description of the operation
|
||||
docs = operation_spec.get("description", "No Description")
|
||||
|
||||
@@ -14,7 +14,9 @@ class GPGError(Exception):
|
||||
|
||||
class GPGRunner(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def run_command(self, command, check_rc=True, data=None):
|
||||
def run_command(
|
||||
self, command: list[str], check_rc: bool = True, data: bytes | None = None
|
||||
) -> tuple[int, str, str]:
|
||||
"""
|
||||
Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``.
|
||||
|
||||
@@ -29,7 +31,7 @@ class GPGRunner(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
|
||||
def get_fingerprint_from_stdout(stdout):
|
||||
def get_fingerprint_from_stdout(stdout: str) -> str:
|
||||
lines = stdout.splitlines(False)
|
||||
for line in lines:
|
||||
if line.startswith("fpr:"):
|
||||
@@ -42,7 +44,7 @@ def get_fingerprint_from_stdout(stdout):
|
||||
raise GPGError(f'Cannot extract fingerprint from stdout "{stdout}"')
|
||||
|
||||
|
||||
def get_fingerprint_from_file(gpg_runner, path):
|
||||
def get_fingerprint_from_file(gpg_runner: GPGRunner, path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
raise GPGError(f"{path} does not exist")
|
||||
stdout = gpg_runner.run_command(
|
||||
@@ -59,7 +61,7 @@ def get_fingerprint_from_file(gpg_runner, path):
|
||||
return get_fingerprint_from_stdout(stdout)
|
||||
|
||||
|
||||
def get_fingerprint_from_bytes(gpg_runner, content):
|
||||
def get_fingerprint_from_bytes(gpg_runner: GPGRunner, content: bytes) -> str:
|
||||
stdout = gpg_runner.run_command(
|
||||
[
|
||||
"--no-keyring",
|
||||
|
||||
@@ -7,9 +7,14 @@ from __future__ import annotations
|
||||
import errno
|
||||
import os
|
||||
import tempfile
|
||||
import typing as t
|
||||
|
||||
|
||||
def load_file(path, module=None):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def load_file(path: str | os.PathLike, module: AnsibleModule | None = None) -> bytes:
|
||||
"""
|
||||
Load the file as a bytes string.
|
||||
"""
|
||||
@@ -22,7 +27,11 @@ def load_file(path, module=None):
|
||||
module.fail_json(f"Error while loading {path} - {exc}")
|
||||
|
||||
|
||||
def load_file_if_exists(path, module=None, ignore_errors=False):
|
||||
def load_file_if_exists(
|
||||
path: str | os.PathLike,
|
||||
module: AnsibleModule | None = None,
|
||||
ignore_errors: bool = False,
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Load the file as a bytes string. If the file does not exist, ``None`` is returned.
|
||||
|
||||
@@ -49,7 +58,12 @@ def load_file_if_exists(path, module=None, ignore_errors=False):
|
||||
module.fail_json(f"Error while loading {path} - {exc}")
|
||||
|
||||
|
||||
def write_file(module, content, default_mode=None, path=None):
|
||||
def write_file(
|
||||
module: AnsibleModule,
|
||||
content: bytes,
|
||||
default_mode: str | int | None = None,
|
||||
path: str | os.PathLike | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Writes content into destination file as securely as possible.
|
||||
Uses file arguments from module.
|
||||
|
||||
@@ -8,14 +8,31 @@ import abc
|
||||
import os
|
||||
import stat
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import (
|
||||
parse_openssh_version,
|
||||
)
|
||||
|
||||
|
||||
def restore_on_failure(f):
|
||||
def backup_and_restore(module, path, *args, **kwargs):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
from ..certificate import OpensshCertificateTimeParameters
|
||||
|
||||
Param = t.ParamSpec("Param")
|
||||
|
||||
|
||||
def restore_on_failure(
|
||||
f: t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None],
|
||||
) -> t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None]:
|
||||
def backup_and_restore(
|
||||
module: AnsibleModule, path: str | os.PathLike, *args, **kwargs
|
||||
) -> None:
|
||||
backup_file = module.backup_local(path) if os.path.exists(path) else None
|
||||
|
||||
try:
|
||||
@@ -31,12 +48,31 @@ def restore_on_failure(f):
|
||||
|
||||
|
||||
@restore_on_failure
|
||||
def safe_atomic_move(module, path, destination):
|
||||
def safe_atomic_move(
|
||||
module: AnsibleModule, path: str | os.PathLike, destination: str | os.PathLike
|
||||
) -> None:
|
||||
module.atomic_move(os.path.abspath(path), os.path.abspath(destination))
|
||||
|
||||
|
||||
def _restore_all_on_failure(f):
|
||||
def backup_and_restore(self, sources_and_destinations, *args, **kwargs):
|
||||
def _restore_all_on_failure(
|
||||
f: t.Callable[
|
||||
t.Concatenate[
|
||||
OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param
|
||||
],
|
||||
None,
|
||||
],
|
||||
) -> t.Callable[
|
||||
t.Concatenate[
|
||||
OpensshModule, list[tuple[str | os.PathLike, str | os.PathLike]], Param
|
||||
],
|
||||
None,
|
||||
]:
|
||||
def backup_and_restore(
|
||||
self: OpensshModule,
|
||||
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
backups = [
|
||||
(d, self.module.backup_local(d))
|
||||
for s, d in sources_and_destinations
|
||||
@@ -59,13 +95,13 @@ def _restore_all_on_failure(f):
|
||||
|
||||
|
||||
class OpensshModule(metaclass=abc.ABCMeta):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
|
||||
self.changed = False
|
||||
self.check_mode = self.module.check_mode
|
||||
self.changed: bool = False
|
||||
self.check_mode: bool = self.module.check_mode
|
||||
|
||||
def execute(self):
|
||||
def execute(self) -> t.NoReturn:
|
||||
try:
|
||||
self._execute()
|
||||
except Exception as e:
|
||||
@@ -77,11 +113,11 @@ class OpensshModule(metaclass=abc.ABCMeta):
|
||||
self.module.exit_json(**self.result)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _execute(self):
|
||||
def _execute(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
def result(self) -> dict[str, t.Any]:
|
||||
result = self._result
|
||||
|
||||
result["changed"] = self.changed
|
||||
@@ -93,31 +129,31 @@ class OpensshModule(metaclass=abc.ABCMeta):
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _result(self):
|
||||
def _result(self) -> dict[str, t.Any]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def diff(self):
|
||||
def diff(self) -> dict[str, t.Any]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def skip_if_check_mode(f):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
def skip_if_check_mode(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
|
||||
def wrapper(self, *args, **kwargs) -> None:
|
||||
if not self.check_mode:
|
||||
f(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return wrapper # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def trigger_change(f):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
def trigger_change(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
|
||||
def wrapper(self, *args, **kwargs) -> None:
|
||||
f(self, *args, **kwargs)
|
||||
self.changed = True
|
||||
|
||||
return wrapper
|
||||
return wrapper # type: ignore
|
||||
|
||||
def _check_if_base_dir(self, path):
|
||||
def _check_if_base_dir(self, path: str | os.PathLike) -> None:
|
||||
base_dir = os.path.dirname(path) or "."
|
||||
if not os.path.isdir(base_dir):
|
||||
self.module.fail_json(
|
||||
@@ -125,16 +161,19 @@ class OpensshModule(metaclass=abc.ABCMeta):
|
||||
msg=f"The directory {base_dir} does not exist or the file is not a directory",
|
||||
)
|
||||
|
||||
def _get_ssh_version(self):
|
||||
def _get_ssh_version(self) -> str | None:
|
||||
ssh_bin = self.module.get_bin_path("ssh")
|
||||
if not ssh_bin:
|
||||
return ""
|
||||
return None
|
||||
return parse_openssh_version(
|
||||
self.module.run_command([ssh_bin, "-V", "-q"], check_rc=True)[2].strip()
|
||||
)
|
||||
|
||||
@_restore_all_on_failure
|
||||
def _safe_secure_move(self, sources_and_destinations):
|
||||
def _safe_secure_move(
|
||||
self,
|
||||
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
|
||||
) -> None:
|
||||
"""Moves a list of files from 'source' to 'destination' and restores 'destination' from backup upon failure.
|
||||
If 'destination' does not already exist, then 'source' permissions are preserved to prevent
|
||||
exposing protected data ('atomic_move' uses the 'destination' base directory mask for
|
||||
@@ -148,7 +187,7 @@ class OpensshModule(metaclass=abc.ABCMeta):
|
||||
else:
|
||||
self.module.preserved_copy(source, destination)
|
||||
|
||||
def _update_permissions(self, path):
|
||||
def _update_permissions(self, path: str | os.PathLike) -> None:
|
||||
file_args = self.module.load_file_common_arguments(self.module.params)
|
||||
file_args["path"] = path
|
||||
|
||||
@@ -161,25 +200,25 @@ class OpensshModule(metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class KeygenCommand:
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self._bin_path = module.get_bin_path("ssh-keygen", True)
|
||||
self._run_command = module.run_command
|
||||
|
||||
def generate_certificate(
|
||||
self,
|
||||
certificate_path,
|
||||
identifier,
|
||||
options,
|
||||
pkcs11_provider,
|
||||
principals,
|
||||
serial_number,
|
||||
signature_algorithm,
|
||||
signing_key_path,
|
||||
type,
|
||||
time_parameters,
|
||||
use_agent,
|
||||
certificate_path: str,
|
||||
identifier: str,
|
||||
options: list[str] | None,
|
||||
pkcs11_provider: str | None,
|
||||
principals: list[str] | None,
|
||||
serial_number: int | None,
|
||||
signature_algorithm: str | None,
|
||||
signing_key_path: str,
|
||||
type: t.Literal["host", "user"] | None,
|
||||
time_parameters: OpensshCertificateTimeParameters,
|
||||
use_agent: bool,
|
||||
**kwargs,
|
||||
):
|
||||
) -> tuple[int, str, str]:
|
||||
args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier]
|
||||
|
||||
if options:
|
||||
@@ -203,7 +242,9 @@ class KeygenCommand:
|
||||
|
||||
return self._run_command(args, **kwargs)
|
||||
|
||||
def generate_keypair(self, private_key_path, size, type, comment, **kwargs):
|
||||
def generate_keypair(
|
||||
self, private_key_path: str, size: int, type: str, comment: str | None, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
args = [
|
||||
self._bin_path,
|
||||
"-q",
|
||||
@@ -224,32 +265,40 @@ class KeygenCommand:
|
||||
|
||||
return self._run_command(args, data=data, **kwargs)
|
||||
|
||||
def get_certificate_info(self, certificate_path, **kwargs):
|
||||
def get_certificate_info(
|
||||
self, certificate_path: str, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-L", "-f", certificate_path], **kwargs
|
||||
)
|
||||
|
||||
def get_matching_public_key(self, private_key_path, **kwargs):
|
||||
def get_matching_public_key(
|
||||
self, private_key_path: str, **kwargs
|
||||
) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
|
||||
)
|
||||
|
||||
def get_private_key(self, private_key_path, **kwargs):
|
||||
def get_private_key(self, private_key_path: str, **kwargs) -> tuple[int, str, str]:
|
||||
return self._run_command(
|
||||
[self._bin_path, "-l", "-f", private_key_path], **kwargs
|
||||
)
|
||||
|
||||
def update_comment(
|
||||
self, private_key_path, comment, force_new_format=True, **kwargs
|
||||
):
|
||||
self,
|
||||
private_key_path: str,
|
||||
comment: str,
|
||||
force_new_format: bool = True,
|
||||
**kwargs,
|
||||
) -> tuple[int, str, str]:
|
||||
if os.path.exists(private_key_path) and not os.access(
|
||||
private_key_path, os.W_OK
|
||||
):
|
||||
try:
|
||||
os.chmod(private_key_path, stat.S_IWUSR + stat.S_IRUSR)
|
||||
except (IOError, OSError) as e:
|
||||
raise e(
|
||||
f"The private key at {private_key_path} is not writeable preventing a comment update"
|
||||
raise ValueError(
|
||||
f"The private key at {private_key_path} is not writeable preventing a comment update ({e})"
|
||||
)
|
||||
|
||||
command = [self._bin_path, "-q"]
|
||||
@@ -259,31 +308,36 @@ class KeygenCommand:
|
||||
return self._run_command(command, **kwargs)
|
||||
|
||||
|
||||
_PrivateKey = t.TypeVar("_PrivateKey", bound="PrivateKey")
|
||||
|
||||
|
||||
class PrivateKey:
|
||||
def __init__(self, size, key_type, fingerprint, format=""):
|
||||
def __init__(
|
||||
self, size: int, key_type: str, fingerprint: str, format: str = ""
|
||||
) -> None:
|
||||
self._size = size
|
||||
self._type = key_type
|
||||
self._fingerprint = fingerprint
|
||||
self._format = format
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
def size(self) -> int:
|
||||
return self._size
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
def type(self) -> str:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def fingerprint(self):
|
||||
def fingerprint(self) -> str:
|
||||
return self._fingerprint
|
||||
|
||||
@property
|
||||
def format(self):
|
||||
def format(self) -> str:
|
||||
return self._format
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, string):
|
||||
def from_string(cls: t.Type[_PrivateKey], string: str) -> _PrivateKey:
|
||||
properties = string.split()
|
||||
|
||||
return cls(
|
||||
@@ -292,7 +346,7 @@ class PrivateKey:
|
||||
fingerprint=properties[1],
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, t.Any]:
|
||||
return {
|
||||
"size": self._size,
|
||||
"type": self._type,
|
||||
@@ -301,13 +355,16 @@ class PrivateKey:
|
||||
}
|
||||
|
||||
|
||||
_PublicKey = t.TypeVar("_PublicKey", bound="PublicKey")
|
||||
|
||||
|
||||
class PublicKey:
|
||||
def __init__(self, type_string, data, comment):
|
||||
def __init__(self, type_string: str, data: str, comment: str | None) -> None:
|
||||
self._type_string = type_string
|
||||
self._data = data
|
||||
self._comment = comment
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
@@ -323,30 +380,30 @@ class PublicKey:
|
||||
]
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self._type_string} {self._data}"
|
||||
|
||||
@property
|
||||
def comment(self):
|
||||
def comment(self) -> str | None:
|
||||
return self._comment
|
||||
|
||||
@comment.setter
|
||||
def comment(self, value):
|
||||
def comment(self, value: str | None) -> None:
|
||||
self._comment = value
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> str:
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def type_string(self):
|
||||
def type_string(self) -> str:
|
||||
return self._type_string
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, string):
|
||||
def from_string(cls: t.Type[_PublicKey], string: str) -> _PublicKey:
|
||||
properties = string.strip("\n").split(" ", 2)
|
||||
|
||||
return cls(
|
||||
@@ -356,7 +413,7 @@ class PublicKey:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
def load(cls: t.Type[_PublicKey], path: str | os.PathLike) -> _PublicKey | None:
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
properties = f.read().strip(" \n").split(" ", 2)
|
||||
@@ -372,14 +429,16 @@ class PublicKey:
|
||||
comment="" if len(properties) <= 2 else properties[2],
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, t.Any]:
|
||||
return {
|
||||
"comment": self._comment,
|
||||
"public_key": self._data,
|
||||
}
|
||||
|
||||
|
||||
def parse_private_key_format(path):
|
||||
def parse_private_key_format(
|
||||
path: str | os.PathLike,
|
||||
) -> t.Literal["SSH", "PKCS8", "PKCS1", ""]:
|
||||
with open(path, "r") as file:
|
||||
header = file.readline().strip()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import missing_required_lib
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
@@ -39,31 +40,43 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
PrivateKeyTypes,
|
||||
)
|
||||
|
||||
|
||||
class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackend, self).__init__(module)
|
||||
|
||||
self.comment = self.module.params["comment"]
|
||||
self.private_key_path = self.module.params["path"]
|
||||
self.comment: str | None = self.module.params["comment"]
|
||||
self.private_key_path: str = self.module.params["path"]
|
||||
self.public_key_path = self.private_key_path + ".pub"
|
||||
self.regenerate = (
|
||||
self.regenerate: t.Literal[
|
||||
"never", "fail", "partial_idempotence", "full_idempotence", "always"
|
||||
] = (
|
||||
self.module.params["regenerate"]
|
||||
if not self.module.params["force"]
|
||||
else "always"
|
||||
)
|
||||
self.state = self.module.params["state"]
|
||||
self.type = self.module.params["type"]
|
||||
self.state: t.Literal["present", "absent"] = self.module.params["state"]
|
||||
self.type: t.Literal["rsa", "dsa", "rsa1", "ecdsa", "ed25519"] = (
|
||||
self.module.params["type"]
|
||||
)
|
||||
|
||||
self.size = self._get_size(self.module.params["size"])
|
||||
self.size: int = self._get_size(self.module.params["size"])
|
||||
self._validate_path()
|
||||
|
||||
self.original_private_key = None
|
||||
self.original_public_key = None
|
||||
self.private_key = None
|
||||
self.public_key = None
|
||||
self.original_private_key: PrivateKey | None = None
|
||||
self.original_public_key: PublicKey | None = None
|
||||
self.private_key: PrivateKey | None = None
|
||||
self.public_key: PublicKey | None = None
|
||||
|
||||
def _get_size(self, size):
|
||||
def _get_size(self, size: int | None) -> int:
|
||||
if self.type in ("rsa", "rsa1"):
|
||||
result = 4096 if size is None else size
|
||||
if result < 1024:
|
||||
@@ -96,7 +109,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
return result
|
||||
|
||||
def _validate_path(self):
|
||||
def _validate_path(self) -> None:
|
||||
self._check_if_base_dir(self.private_key_path)
|
||||
|
||||
if os.path.isdir(self.private_key_path):
|
||||
@@ -104,7 +117,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
msg=f"{self.private_key_path} is a directory. Please specify a path to a file."
|
||||
)
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self) -> None:
|
||||
self.original_private_key = self._load_private_key()
|
||||
self.original_public_key = self._load_public_key()
|
||||
|
||||
@@ -125,7 +138,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
if self._should_remove():
|
||||
self._remove()
|
||||
|
||||
def _load_private_key(self):
|
||||
def _load_private_key(self) -> PrivateKey | None:
|
||||
result = None
|
||||
if self._private_key_exists():
|
||||
try:
|
||||
@@ -135,14 +148,14 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
return result
|
||||
|
||||
def _private_key_exists(self):
|
||||
def _private_key_exists(self) -> bool:
|
||||
return os.path.exists(self.private_key_path)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_private_key(self):
|
||||
def _get_private_key(self) -> PrivateKey:
|
||||
pass
|
||||
|
||||
def _load_public_key(self):
|
||||
def _load_public_key(self) -> PublicKey | None:
|
||||
result = None
|
||||
if self._public_key_exists():
|
||||
try:
|
||||
@@ -151,10 +164,10 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
pass
|
||||
return result
|
||||
|
||||
def _public_key_exists(self):
|
||||
def _public_key_exists(self) -> bool:
|
||||
return os.path.exists(self.public_key_path)
|
||||
|
||||
def _validate_key_load(self):
|
||||
def _validate_key_load(self) -> None:
|
||||
if (
|
||||
self._private_key_exists()
|
||||
and self.regenerate in ("never", "fail", "partial_idempotence")
|
||||
@@ -167,10 +180,10 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _private_key_readable(self):
|
||||
def _private_key_readable(self) -> bool:
|
||||
pass
|
||||
|
||||
def _should_generate(self):
|
||||
def _should_generate(self) -> bool:
|
||||
if self.original_private_key is None:
|
||||
return True
|
||||
elif self.regenerate == "never":
|
||||
@@ -188,7 +201,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _private_key_valid(self):
|
||||
def _private_key_valid(self) -> bool:
|
||||
if self.original_private_key is None:
|
||||
return False
|
||||
|
||||
@@ -196,17 +209,17 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
[
|
||||
self.size == self.original_private_key.size,
|
||||
self.type == self.original_private_key.type,
|
||||
self._private_key_valid_backend(),
|
||||
self._private_key_valid_backend(self.original_private_key),
|
||||
]
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _private_key_valid_backend(self):
|
||||
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
|
||||
pass
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _generate(self):
|
||||
def _generate(self) -> None:
|
||||
temp_private_key, temp_public_key = self._generate_temp_keypair()
|
||||
|
||||
try:
|
||||
@@ -219,7 +232,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
except OSError as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
def _generate_temp_keypair(self):
|
||||
def _generate_temp_keypair(self) -> tuple[str, str]:
|
||||
temp_private_key = os.path.join(
|
||||
self.module.tmpdir, os.path.basename(self.private_key_path)
|
||||
)
|
||||
@@ -236,25 +249,26 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
return temp_private_key, temp_public_key
|
||||
|
||||
@abc.abstractmethod
|
||||
def _generate_keypair(self, private_key_path):
|
||||
def _generate_keypair(self, private_key_path: str) -> None:
|
||||
pass
|
||||
|
||||
def _public_key_valid(self):
|
||||
def _public_key_valid(self) -> bool:
|
||||
if self.original_public_key is None:
|
||||
return False
|
||||
|
||||
valid_public_key = self._get_public_key()
|
||||
valid_public_key.comment = self.comment
|
||||
if valid_public_key:
|
||||
valid_public_key.comment = self.comment
|
||||
|
||||
return self.original_public_key == valid_public_key
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_public_key(self):
|
||||
def _get_public_key(self) -> PublicKey | t.Literal[""]:
|
||||
pass
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _restore_public_key(self):
|
||||
def _restore_public_key(self) -> None:
|
||||
try:
|
||||
temp_public_key = self._create_temp_public_key(
|
||||
str(self._get_public_key()) + "\n"
|
||||
@@ -269,7 +283,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
if self.comment:
|
||||
self._update_comment()
|
||||
|
||||
def _create_temp_public_key(self, content):
|
||||
def _create_temp_public_key(self, content: str | bytes) -> str:
|
||||
temp_public_key = os.path.join(
|
||||
self.module.tmpdir, os.path.basename(self.public_key_path)
|
||||
)
|
||||
@@ -290,15 +304,15 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
return temp_public_key
|
||||
|
||||
@abc.abstractmethod
|
||||
def _update_comment(self):
|
||||
def _update_comment(self) -> None:
|
||||
pass
|
||||
|
||||
def _should_remove(self):
|
||||
def _should_remove(self) -> bool:
|
||||
return self._private_key_exists() or self._public_key_exists()
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _remove(self):
|
||||
def _remove(self) -> None:
|
||||
try:
|
||||
if self._private_key_exists():
|
||||
os.remove(self.private_key_path)
|
||||
@@ -308,7 +322,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
@property
|
||||
def _result(self):
|
||||
def _result(self) -> dict[str, t.Any]:
|
||||
private_key = self.private_key or self.original_private_key
|
||||
public_key = self.public_key or self.original_public_key
|
||||
|
||||
@@ -322,7 +336,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
}
|
||||
|
||||
@property
|
||||
def diff(self):
|
||||
def diff(self) -> dict[str, t.Any]:
|
||||
before = (
|
||||
self.original_private_key.to_dict() if self.original_private_key else {}
|
||||
)
|
||||
@@ -340,7 +354,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class KeypairBackendOpensshBin(KeypairBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendOpensshBin, self).__init__(module)
|
||||
|
||||
if self.module.params["private_key_format"] != "auto":
|
||||
@@ -350,12 +364,12 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
|
||||
self.ssh_keygen = KeygenCommand(self.module)
|
||||
|
||||
def _generate_keypair(self, private_key_path):
|
||||
def _generate_keypair(self, private_key_path: str) -> None:
|
||||
self.ssh_keygen.generate_keypair(
|
||||
private_key_path, self.size, self.type, self.comment, check_rc=True
|
||||
)
|
||||
|
||||
def _get_private_key(self):
|
||||
def _get_private_key(self) -> PrivateKey:
|
||||
rc, private_key_content, err = self.ssh_keygen.get_private_key(
|
||||
self.private_key_path, check_rc=False
|
||||
)
|
||||
@@ -363,13 +377,13 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
raise ValueError(err)
|
||||
return PrivateKey.from_string(private_key_content)
|
||||
|
||||
def _get_public_key(self):
|
||||
def _get_public_key(self) -> PublicKey | t.Literal[""]:
|
||||
public_key_content = self.ssh_keygen.get_matching_public_key(
|
||||
self.private_key_path, check_rc=True
|
||||
)[1]
|
||||
return PublicKey.from_string(public_key_content)
|
||||
|
||||
def _private_key_readable(self):
|
||||
def _private_key_readable(self) -> bool:
|
||||
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
|
||||
self.private_key_path, check_rc=False
|
||||
)
|
||||
@@ -383,7 +397,7 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
)
|
||||
)
|
||||
|
||||
def _update_comment(self):
|
||||
def _update_comment(self) -> None:
|
||||
try:
|
||||
ssh_version = self._get_ssh_version() or "7.8"
|
||||
force_new_format = (
|
||||
@@ -391,19 +405,19 @@ class KeypairBackendOpensshBin(KeypairBackend):
|
||||
)
|
||||
self.ssh_keygen.update_comment(
|
||||
self.private_key_path,
|
||||
self.comment,
|
||||
self.comment or "",
|
||||
force_new_format=force_new_format,
|
||||
check_rc=True,
|
||||
)
|
||||
except (IOError, OSError) as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
def _private_key_valid_backend(self):
|
||||
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class KeypairBackendCryptography(KeypairBackend):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(KeypairBackendCryptography, self).__init__(module)
|
||||
|
||||
if self.type == "rsa1":
|
||||
@@ -416,12 +430,15 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
if module.params["passphrase"]
|
||||
else None
|
||||
)
|
||||
self.private_key_format = self._get_key_format(
|
||||
module.params["private_key_format"]
|
||||
)
|
||||
key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"] = module.params[
|
||||
"private_key_format"
|
||||
]
|
||||
self.private_key_format = self._get_key_format(key_format)
|
||||
|
||||
def _get_key_format(self, key_format):
|
||||
result = "SSH"
|
||||
def _get_key_format(
|
||||
self, key_format: t.Literal["auto", "pkcs1", "pkcs8", "ssh"]
|
||||
) -> t.Literal["SSH", "PKCS1", "PKCS8"]:
|
||||
result: t.Literal["SSH", "PKCS1", "PKCS8"] = "SSH"
|
||||
|
||||
if key_format == "auto":
|
||||
# Default to OpenSSH 7.8 compatibility when OpenSSH is not installed
|
||||
@@ -435,11 +452,12 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
# but still defaulted to PKCS1 format with the exception of ed25519 keys
|
||||
result = "PKCS1"
|
||||
else:
|
||||
result = key_format.upper()
|
||||
result = key_format.upper() # type: ignore
|
||||
|
||||
return result
|
||||
|
||||
def _generate_keypair(self, private_key_path):
|
||||
def _generate_keypair(self, private_key_path: str) -> None:
|
||||
assert self.type != "rsa1"
|
||||
keypair = OpensshKeypair.generate(
|
||||
keytype=self.type,
|
||||
size=self.size,
|
||||
@@ -455,7 +473,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
public_key_path = private_key_path + ".pub"
|
||||
secure_write(public_key_path, 0o644, keypair.public_key)
|
||||
|
||||
def _get_private_key(self):
|
||||
def _get_private_key(self) -> PrivateKey:
|
||||
keypair = OpensshKeypair.load(
|
||||
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
|
||||
)
|
||||
@@ -467,7 +485,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
format=parse_private_key_format(self.private_key_path),
|
||||
)
|
||||
|
||||
def _get_public_key(self):
|
||||
def _get_public_key(self) -> PublicKey | t.Literal[""]:
|
||||
try:
|
||||
keypair = OpensshKeypair.load(
|
||||
path=self.private_key_path,
|
||||
@@ -480,7 +498,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
|
||||
return PublicKey.from_string(to_text(keypair.public_key))
|
||||
|
||||
def _private_key_readable(self):
|
||||
def _private_key_readable(self) -> bool:
|
||||
try:
|
||||
OpensshKeypair.load(
|
||||
path=self.private_key_path,
|
||||
@@ -504,7 +522,7 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
|
||||
return True
|
||||
|
||||
def _update_comment(self):
|
||||
def _update_comment(self) -> None:
|
||||
keypair = OpensshKeypair.load(
|
||||
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
|
||||
)
|
||||
@@ -519,16 +537,18 @@ class KeypairBackendCryptography(KeypairBackend):
|
||||
except (IOError, OSError) as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
def _private_key_valid_backend(self):
|
||||
def _private_key_valid_backend(self, original_private_key: PrivateKey) -> bool:
|
||||
# avoids breaking behavior and prevents
|
||||
# automatic conversions with OpenSSH upgrades
|
||||
if self.module.params["private_key_format"] == "auto":
|
||||
return True
|
||||
|
||||
return self.private_key_format == self.original_private_key.format
|
||||
return self.private_key_format == original_private_key.format
|
||||
|
||||
|
||||
def select_backend(module, backend):
|
||||
def select_backend(
|
||||
module: AnsibleModule, backend: t.Literal["auto", "opensshbin", "cryptography"]
|
||||
) -> KeypairBackend:
|
||||
can_use_cryptography = HAS_OPENSSH_SUPPORT and LooseVersion(
|
||||
CRYPTOGRAPHY_VERSION
|
||||
) >= LooseVersion(COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION)
|
||||
|
||||
@@ -8,6 +8,7 @@ import abc
|
||||
import binascii
|
||||
import datetime as _datetime
|
||||
import os
|
||||
import typing as t
|
||||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
@@ -26,6 +27,16 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .cryptography import KeyType
|
||||
|
||||
DateFormat = t.Literal["human_readable", "openssh", "timestamp"]
|
||||
DateFormatStr = t.Literal["human_readable", "openssh"]
|
||||
DateFormatInt = t.Literal["timestamp"]
|
||||
else:
|
||||
KeyType = None
|
||||
|
||||
|
||||
# Protocol References
|
||||
# -------------------
|
||||
# https://datatracker.ietf.org/doc/html/rfc4251
|
||||
@@ -44,7 +55,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
_USER_TYPE = 1
|
||||
_HOST_TYPE = 2
|
||||
|
||||
_SSH_TYPE_STRINGS = {
|
||||
_SSH_TYPE_STRINGS: dict[KeyType | str, bytes] = {
|
||||
"rsa": b"ssh-rsa",
|
||||
"dsa": b"ssh-dss",
|
||||
"ecdsa-nistp256": b"ecdsa-sha2-nistp256",
|
||||
@@ -94,16 +105,18 @@ _EXTENSIONS = (
|
||||
|
||||
|
||||
class OpensshCertificateTimeParameters:
|
||||
def __init__(self, valid_from, valid_to):
|
||||
def __init__(
|
||||
self, valid_from: str | bytes | int, valid_to: str | bytes | int
|
||||
) -> None:
|
||||
self._valid_from = self.to_datetime(valid_from)
|
||||
self._valid_to = self.to_datetime(valid_to)
|
||||
|
||||
if self._valid_from > self._valid_to:
|
||||
raise ValueError(
|
||||
f"Valid from: {valid_from} must not be greater than Valid to: {valid_to}"
|
||||
f"Valid from: {valid_from!r} must not be greater than Valid to: {valid_to!r}"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
else:
|
||||
@@ -112,55 +125,83 @@ class OpensshCertificateTimeParameters:
|
||||
and self._valid_to == other._valid_to
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self == other
|
||||
|
||||
@property
|
||||
def validity_string(self):
|
||||
def validity_string(self) -> str:
|
||||
if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER):
|
||||
return f"{self.valid_from(date_format='openssh')}:{self.valid_to(date_format='openssh')}"
|
||||
return ""
|
||||
|
||||
def valid_from(self, date_format):
|
||||
@t.overload
|
||||
def valid_from(self, date_format: DateFormatStr) -> str: ...
|
||||
|
||||
@t.overload
|
||||
def valid_from(self, date_format: DateFormatInt) -> int: ...
|
||||
|
||||
@t.overload
|
||||
def valid_from(self, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
def valid_from(self, date_format: DateFormat) -> str | int:
|
||||
return self.format_datetime(self._valid_from, date_format)
|
||||
|
||||
def valid_to(self, date_format):
|
||||
@t.overload
|
||||
def valid_to(self, date_format: DateFormatStr) -> str: ...
|
||||
|
||||
@t.overload
|
||||
def valid_to(self, date_format: DateFormatInt) -> int: ...
|
||||
|
||||
@t.overload
|
||||
def valid_to(self, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
def valid_to(self, date_format: DateFormat) -> str | int:
|
||||
return self.format_datetime(self._valid_to, date_format)
|
||||
|
||||
def within_range(self, valid_at):
|
||||
def within_range(self, valid_at: str | bytes | int | None) -> bool:
|
||||
if valid_at is not None:
|
||||
valid_at_datetime = self.to_datetime(valid_at)
|
||||
return self._valid_from <= valid_at_datetime <= self._valid_to
|
||||
return True
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt, date_format):
|
||||
def format_datetime(dt: datetime, date_format: DateFormatStr) -> str: ...
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormatInt) -> int: ...
|
||||
|
||||
@t.overload
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int: ...
|
||||
|
||||
@staticmethod
|
||||
def format_datetime(dt: datetime, date_format: DateFormat) -> str | int:
|
||||
if date_format in ("human_readable", "openssh"):
|
||||
if dt == _ALWAYS:
|
||||
result = "always"
|
||||
elif dt == _FOREVER:
|
||||
result = "forever"
|
||||
return "always"
|
||||
if dt == _FOREVER:
|
||||
return "forever"
|
||||
else:
|
||||
result = (
|
||||
return (
|
||||
dt.isoformat().replace("+00:00", "")
|
||||
if date_format == "human_readable"
|
||||
else dt.strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
elif date_format == "timestamp":
|
||||
if date_format == "timestamp":
|
||||
td = dt - _ALWAYS
|
||||
result = int(
|
||||
return int(
|
||||
(td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{date_format} is not a valid format")
|
||||
return result
|
||||
raise ValueError(f"{date_format} is not a valid format")
|
||||
|
||||
@staticmethod
|
||||
def to_datetime(time_string_or_timestamp):
|
||||
def to_datetime(time_string_or_timestamp: str | bytes | int) -> datetime:
|
||||
try:
|
||||
if isinstance(time_string_or_timestamp, (str, bytes)):
|
||||
result = OpensshCertificateTimeParameters._time_string_to_datetime(
|
||||
time_string_or_timestamp.strip()
|
||||
to_text(time_string_or_timestamp.strip())
|
||||
)
|
||||
elif isinstance(time_string_or_timestamp, int):
|
||||
result = OpensshCertificateTimeParameters._timestamp_to_datetime(
|
||||
@@ -175,43 +216,53 @@ class OpensshCertificateTimeParameters:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _timestamp_to_datetime(timestamp):
|
||||
def _timestamp_to_datetime(timestamp: int) -> datetime:
|
||||
if timestamp == 0x0:
|
||||
result = _ALWAYS
|
||||
elif timestamp == 0xFFFFFFFFFFFFFFFF:
|
||||
result = _FOREVER
|
||||
else:
|
||||
try:
|
||||
result = datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
|
||||
except OverflowError:
|
||||
raise ValueError
|
||||
return result
|
||||
return _ALWAYS
|
||||
if timestamp == 0xFFFFFFFFFFFFFFFF:
|
||||
return _FOREVER
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
|
||||
except OverflowError:
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _time_string_to_datetime(time_string):
|
||||
result = None
|
||||
def _time_string_to_datetime(time_string: str) -> datetime:
|
||||
if time_string == "always":
|
||||
result = _ALWAYS
|
||||
elif time_string == "forever":
|
||||
result = _FOREVER
|
||||
elif is_relative_time_string(time_string):
|
||||
return _ALWAYS
|
||||
if time_string == "forever":
|
||||
return _FOREVER
|
||||
if is_relative_time_string(time_string):
|
||||
result = convert_relative_to_datetime(time_string, with_timezone=True)
|
||||
else:
|
||||
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
|
||||
try:
|
||||
result = _add_or_remove_timezone(
|
||||
datetime.strptime(time_string, time_format),
|
||||
with_timezone=True,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
if result is None:
|
||||
raise ValueError
|
||||
return result
|
||||
result = None
|
||||
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
|
||||
try:
|
||||
result = _add_or_remove_timezone(
|
||||
datetime.strptime(time_string, time_format),
|
||||
with_timezone=True,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
if result is None:
|
||||
raise ValueError
|
||||
return result
|
||||
|
||||
|
||||
_OpensshCertificateOption = t.TypeVar(
|
||||
"_OpensshCertificateOption", bound="OpensshCertificateOption"
|
||||
)
|
||||
|
||||
|
||||
class OpensshCertificateOption:
|
||||
def __init__(self, option_type, name, data):
|
||||
def __init__(
|
||||
self,
|
||||
option_type: t.Literal["critical", "extension"],
|
||||
name: str | bytes,
|
||||
data: str | bytes,
|
||||
):
|
||||
if option_type not in ("critical", "extension"):
|
||||
raise ValueError("type must be either 'critical' or 'extension'")
|
||||
|
||||
@@ -225,7 +276,7 @@ class OpensshCertificateOption:
|
||||
self._name = name.lower()
|
||||
self._data = data
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
@@ -237,32 +288,34 @@ class OpensshCertificateOption:
|
||||
]
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._option_type, self._name, self._data))
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self._data:
|
||||
return f"{self._name}={self._data}"
|
||||
return self._name
|
||||
return f"{self._name!r}={self._data!r}"
|
||||
return f"{self._name!r}"
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> str | bytes:
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str | bytes:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
def type(self) -> t.Literal["critical", "extension"]:
|
||||
return self._option_type
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, option_string):
|
||||
if not isinstance(option_string, (str, bytes)):
|
||||
def from_string(
|
||||
cls: t.Type[_OpensshCertificateOption], option_string: str
|
||||
) -> _OpensshCertificateOption:
|
||||
if not isinstance(option_string, str):
|
||||
raise ValueError(
|
||||
f"option_string must be a string not {type(option_string)}"
|
||||
)
|
||||
@@ -280,7 +333,8 @@ class OpensshCertificateOption:
|
||||
name, data = option_string.strip(), ""
|
||||
|
||||
return cls(
|
||||
option_type=option_type or get_option_type(name.lower()),
|
||||
# We have str, but we're expecting a specific literal:
|
||||
option_type=option_type or get_option_type(name.lower()), # type: ignore
|
||||
name=name,
|
||||
data=data,
|
||||
)
|
||||
@@ -291,21 +345,21 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nonce=None,
|
||||
serial=None,
|
||||
cert_type=None,
|
||||
key_id=None,
|
||||
principals=None,
|
||||
valid_after=None,
|
||||
valid_before=None,
|
||||
critical_options=None,
|
||||
extensions=None,
|
||||
reserved=None,
|
||||
signing_key=None,
|
||||
nonce: bytes | None = None,
|
||||
serial: int | None = None,
|
||||
cert_type: int | None = None,
|
||||
key_id: bytes | None = None,
|
||||
principals: list[bytes] | None = None,
|
||||
valid_after: int | None = None,
|
||||
valid_before: int | None = None,
|
||||
critical_options: list[tuple[bytes, bytes]] | None = None,
|
||||
extensions: list[tuple[bytes, bytes]] | None = None,
|
||||
reserved: bytes | None = None,
|
||||
signing_key: bytes | None = None,
|
||||
):
|
||||
self.nonce = nonce
|
||||
self.serial = serial
|
||||
self._cert_type = cert_type
|
||||
self._cert_type: int | None = cert_type
|
||||
self.key_id = key_id
|
||||
self.principals = principals
|
||||
self.valid_after = valid_after
|
||||
@@ -315,10 +369,10 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
self.reserved = reserved
|
||||
self.signing_key = signing_key
|
||||
|
||||
self.type_string = None
|
||||
self.type_string: bytes | None = None
|
||||
|
||||
@property
|
||||
def cert_type(self):
|
||||
def cert_type(self) -> t.Literal["user", "host", ""]:
|
||||
if self._cert_type == _USER_TYPE:
|
||||
return "user"
|
||||
elif self._cert_type == _HOST_TYPE:
|
||||
@@ -327,7 +381,7 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
return ""
|
||||
|
||||
@cert_type.setter
|
||||
def cert_type(self, cert_type):
|
||||
def cert_type(self, cert_type: t.Literal["user", "host"] | int) -> None:
|
||||
if cert_type == "user" or cert_type == _USER_TYPE:
|
||||
self._cert_type = _USER_TYPE
|
||||
elif cert_type == "host" or cert_type == _HOST_TYPE:
|
||||
@@ -335,28 +389,30 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
|
||||
else:
|
||||
raise ValueError(f"{cert_type} is not a valid certificate type")
|
||||
|
||||
def signing_key_fingerprint(self):
|
||||
def signing_key_fingerprint(self) -> bytes:
|
||||
if self.signing_key is None:
|
||||
raise ValueError("signing_key not present")
|
||||
return fingerprint(self.signing_key)
|
||||
|
||||
@abc.abstractmethod
|
||||
def public_key_fingerprint(self):
|
||||
def public_key_fingerprint(self) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_public_numbers(self, parser):
|
||||
def parse_public_numbers(self, parser: OpensshParser) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class OpensshRSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, e=None, n=None, **kwargs):
|
||||
def __init__(self, e: int | None = None, n: int | None = None, **kwargs) -> None:
|
||||
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
|
||||
self.e = e
|
||||
self.n = n
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
def public_key_fingerprint(self):
|
||||
if any([self.e is None, self.n is None]):
|
||||
def public_key_fingerprint(self) -> bytes:
|
||||
if self.e is None or self.n is None:
|
||||
return b""
|
||||
|
||||
writer = _OpensshWriter()
|
||||
@@ -366,13 +422,20 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
return fingerprint(writer.bytes())
|
||||
|
||||
def parse_public_numbers(self, parser):
|
||||
def parse_public_numbers(self, parser: OpensshParser) -> None:
|
||||
self.e = parser.mpint()
|
||||
self.n = parser.mpint()
|
||||
|
||||
|
||||
class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, p=None, q=None, g=None, y=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
p: int | None = None,
|
||||
q: int | None = None,
|
||||
g: int | None = None,
|
||||
y: int | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super(OpensshDSACertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01
|
||||
self.p = p
|
||||
@@ -381,8 +444,8 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
self.y = y
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
def public_key_fingerprint(self):
|
||||
if any([self.p is None, self.q is None, self.g is None, self.y is None]):
|
||||
def public_key_fingerprint(self) -> bytes:
|
||||
if self.p is None or self.q is None or self.g is None or self.y is None:
|
||||
return b""
|
||||
|
||||
writer = _OpensshWriter()
|
||||
@@ -394,7 +457,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
return fingerprint(writer.bytes())
|
||||
|
||||
def parse_public_numbers(self, parser):
|
||||
def parse_public_numbers(self, parser: OpensshParser) -> None:
|
||||
self.p = parser.mpint()
|
||||
self.q = parser.mpint()
|
||||
self.g = parser.mpint()
|
||||
@@ -402,7 +465,9 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
|
||||
class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, curve=None, public_key=None, **kwargs):
|
||||
def __init__(
|
||||
self, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
|
||||
):
|
||||
super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
|
||||
self._curve = None
|
||||
if curve is not None:
|
||||
@@ -411,11 +476,11 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
self.public_key = public_key
|
||||
|
||||
@property
|
||||
def curve(self):
|
||||
def curve(self) -> bytes | None:
|
||||
return self._curve
|
||||
|
||||
@curve.setter
|
||||
def curve(self, curve):
|
||||
def curve(self, curve: bytes) -> None:
|
||||
if curve in _ECDSA_CURVE_IDENTIFIERS.values():
|
||||
self._curve = curve
|
||||
self.type_string = (
|
||||
@@ -428,8 +493,8 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
)
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
|
||||
def public_key_fingerprint(self):
|
||||
if any([self.curve is None, self.public_key is None]):
|
||||
def public_key_fingerprint(self) -> bytes:
|
||||
if self.curve is None or self.public_key is None:
|
||||
return b""
|
||||
|
||||
writer = _OpensshWriter()
|
||||
@@ -439,18 +504,18 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
return fingerprint(writer.bytes())
|
||||
|
||||
def parse_public_numbers(self, parser):
|
||||
def parse_public_numbers(self, parser: OpensshParser) -> None:
|
||||
self.curve = parser.string()
|
||||
self.public_key = parser.string()
|
||||
|
||||
|
||||
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
|
||||
def __init__(self, pk=None, **kwargs):
|
||||
def __init__(self, pk: bytes | None = None, **kwargs) -> None:
|
||||
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
|
||||
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
|
||||
self.pk = pk
|
||||
|
||||
def public_key_fingerprint(self):
|
||||
def public_key_fingerprint(self) -> bytes:
|
||||
if self.pk is None:
|
||||
return b""
|
||||
|
||||
@@ -460,21 +525,26 @@ class OpensshED25519CertificateInfo(OpensshCertificateInfo):
|
||||
|
||||
return fingerprint(writer.bytes())
|
||||
|
||||
def parse_public_numbers(self, parser):
|
||||
def parse_public_numbers(self, parser: OpensshParser) -> None:
|
||||
self.pk = parser.string()
|
||||
|
||||
|
||||
_OpensshCertificate = t.TypeVar("_OpensshCertificate", bound="OpensshCertificate")
|
||||
|
||||
|
||||
# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
|
||||
class OpensshCertificate:
|
||||
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
|
||||
|
||||
def __init__(self, cert_info, signature):
|
||||
def __init__(self, cert_info: OpensshCertificateInfo, signature: bytes):
|
||||
|
||||
self._cert_info = cert_info
|
||||
self.signature = signature
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
def load(
|
||||
cls: t.Type[_OpensshCertificate], path: str | os.PathLike
|
||||
) -> _OpensshCertificate:
|
||||
if not os.path.exists(path):
|
||||
raise ValueError(f"{path} is not a valid path.")
|
||||
|
||||
@@ -492,11 +562,11 @@ class OpensshCertificate:
|
||||
|
||||
for key_type, string in _SSH_TYPE_STRINGS.items():
|
||||
if format_identifier == string + _CERT_SUFFIX_V01:
|
||||
pub_key_type = key_type
|
||||
pub_key_type = t.cast(KeyType, key_type)
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid certificate format identifier: {format_identifier}"
|
||||
f"Invalid certificate format identifier: {format_identifier!r}"
|
||||
)
|
||||
|
||||
parser = OpensshParser(cert)
|
||||
@@ -521,75 +591,97 @@ class OpensshCertificate:
|
||||
)
|
||||
|
||||
@property
|
||||
def type_string(self):
|
||||
def type_string(self) -> str:
|
||||
return to_text(self._cert_info.type_string)
|
||||
|
||||
@property
|
||||
def nonce(self):
|
||||
def nonce(self) -> bytes:
|
||||
if self._cert_info.nonce is None:
|
||||
raise ValueError
|
||||
return self._cert_info.nonce
|
||||
|
||||
@property
|
||||
def public_key(self):
|
||||
def public_key(self) -> str:
|
||||
return to_text(self._cert_info.public_key_fingerprint())
|
||||
|
||||
@property
|
||||
def serial(self):
|
||||
def serial(self) -> int:
|
||||
if self._cert_info.serial is None:
|
||||
raise ValueError
|
||||
return self._cert_info.serial
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self._cert_info.cert_type
|
||||
def type(self) -> t.Literal["user", "host"]:
|
||||
result = self._cert_info.cert_type
|
||||
if result == "":
|
||||
raise ValueError
|
||||
return result
|
||||
|
||||
@property
|
||||
def key_id(self):
|
||||
def key_id(self) -> str:
|
||||
return to_text(self._cert_info.key_id)
|
||||
|
||||
@property
|
||||
def principals(self):
|
||||
def principals(self) -> list[str]:
|
||||
if self._cert_info.principals is None:
|
||||
raise ValueError
|
||||
return [to_text(p) for p in self._cert_info.principals]
|
||||
|
||||
@property
|
||||
def valid_after(self):
|
||||
def valid_after(self) -> int:
|
||||
if self._cert_info.valid_after is None:
|
||||
raise ValueError
|
||||
return self._cert_info.valid_after
|
||||
|
||||
@property
|
||||
def valid_before(self):
|
||||
def valid_before(self) -> int:
|
||||
if self._cert_info.valid_before is None:
|
||||
raise ValueError
|
||||
return self._cert_info.valid_before
|
||||
|
||||
@property
|
||||
def critical_options(self):
|
||||
def critical_options(self) -> list[OpensshCertificateOption]:
|
||||
if self._cert_info.critical_options is None:
|
||||
raise ValueError
|
||||
return [
|
||||
OpensshCertificateOption("critical", to_text(n), to_text(d))
|
||||
for n, d in self._cert_info.critical_options
|
||||
]
|
||||
|
||||
@property
|
||||
def extensions(self):
|
||||
def extensions(self) -> list[OpensshCertificateOption]:
|
||||
if self._cert_info.extensions is None:
|
||||
raise ValueError
|
||||
return [
|
||||
OpensshCertificateOption("extension", to_text(n), to_text(d))
|
||||
for n, d in self._cert_info.extensions
|
||||
]
|
||||
|
||||
@property
|
||||
def reserved(self):
|
||||
def reserved(self) -> bytes:
|
||||
if self._cert_info.reserved is None:
|
||||
raise ValueError
|
||||
return self._cert_info.reserved
|
||||
|
||||
@property
|
||||
def signing_key(self):
|
||||
def signing_key(self) -> str:
|
||||
return to_text(self._cert_info.signing_key_fingerprint())
|
||||
|
||||
@property
|
||||
def signature_type(self):
|
||||
def signature_type(self) -> str:
|
||||
signature_data = OpensshParser.signature_data(self.signature)
|
||||
return to_text(signature_data["signature_type"])
|
||||
|
||||
@staticmethod
|
||||
def _parse_cert_info(pub_key_type, parser):
|
||||
def _parse_cert_info(
|
||||
pub_key_type: KeyType, parser: OpensshParser
|
||||
) -> OpensshCertificateInfo:
|
||||
cert_info = get_cert_info_object(pub_key_type)
|
||||
cert_info.nonce = parser.string()
|
||||
cert_info.parse_public_numbers(parser)
|
||||
cert_info.serial = parser.uint64()
|
||||
cert_info.cert_type = parser.uint32()
|
||||
# mypy doesn't understand that the setter accepts other types than the getter:
|
||||
cert_info.cert_type = parser.uint32() # type: ignore
|
||||
cert_info.key_id = parser.string()
|
||||
cert_info.principals = parser.string_list()
|
||||
cert_info.valid_after = parser.uint64()
|
||||
@@ -601,7 +693,7 @@ class OpensshCertificate:
|
||||
|
||||
return cert_info
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, t.Any]:
|
||||
time_parameters = OpensshCertificateTimeParameters(
|
||||
valid_from=self.valid_after, valid_to=self.valid_before
|
||||
)
|
||||
@@ -624,7 +716,7 @@ class OpensshCertificate:
|
||||
}
|
||||
|
||||
|
||||
def apply_directives(directives):
|
||||
def apply_directives(directives: t.Iterable[str]) -> list[OpensshCertificateOption]:
|
||||
if any(d not in _DIRECTIVES for d in directives):
|
||||
raise ValueError(f"directives must be one of {', '.join(_DIRECTIVES)}")
|
||||
|
||||
@@ -650,50 +742,47 @@ def apply_directives(directives):
|
||||
)
|
||||
|
||||
|
||||
def default_options():
|
||||
def default_options() -> list[OpensshCertificateOption]:
|
||||
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
|
||||
|
||||
|
||||
def fingerprint(public_key):
|
||||
def fingerprint(public_key: bytes) -> bytes:
|
||||
"""Generates a SHA256 hash and formats output to resemble ``ssh-keygen``"""
|
||||
h = sha256()
|
||||
h.update(public_key)
|
||||
return b"SHA256:" + b64encode(h.digest()).rstrip(b"=")
|
||||
|
||||
|
||||
def get_cert_info_object(key_type):
|
||||
def get_cert_info_object(key_type: KeyType) -> OpensshCertificateInfo:
|
||||
if key_type == "rsa":
|
||||
cert_info = OpensshRSACertificateInfo()
|
||||
elif key_type == "dsa":
|
||||
cert_info = OpensshDSACertificateInfo()
|
||||
elif key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
|
||||
cert_info = OpensshECDSACertificateInfo()
|
||||
elif key_type == "ed25519":
|
||||
cert_info = OpensshED25519CertificateInfo()
|
||||
else:
|
||||
raise ValueError(f"{key_type} is not a valid key type")
|
||||
|
||||
return cert_info
|
||||
return OpensshRSACertificateInfo()
|
||||
if key_type == "dsa":
|
||||
return OpensshDSACertificateInfo()
|
||||
if key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
|
||||
return OpensshECDSACertificateInfo()
|
||||
if key_type == "ed25519":
|
||||
return OpensshED25519CertificateInfo()
|
||||
raise ValueError(f"{key_type} is not a valid key type")
|
||||
|
||||
|
||||
def get_option_type(name):
|
||||
def get_option_type(name: str) -> t.Literal["critical", "extension"]:
|
||||
if name in _CRITICAL_OPTIONS:
|
||||
result = "critical"
|
||||
elif name in _EXTENSIONS:
|
||||
result = "extension"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{name} is not a valid option. "
|
||||
"Custom options must start with 'critical:' or 'extension:' to indicate type"
|
||||
)
|
||||
return result
|
||||
return "critical"
|
||||
if name in _EXTENSIONS:
|
||||
return "extension"
|
||||
raise ValueError(
|
||||
f"{name} is not a valid option. "
|
||||
"Custom options must start with 'critical:' or 'extension:' to indicate type"
|
||||
)
|
||||
|
||||
|
||||
def is_relative_time_string(time_string):
|
||||
def is_relative_time_string(time_string: str) -> bool:
|
||||
return time_string.startswith("+") or time_string.startswith("-")
|
||||
|
||||
|
||||
def parse_option_list(option_list):
|
||||
def parse_option_list(
|
||||
option_list: t.Iterable[str],
|
||||
) -> tuple[list[OpensshCertificateOption], list[OpensshCertificateOption]]:
|
||||
critical_options = []
|
||||
directives = []
|
||||
extensions = []
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
from base64 import b64decode, b64encode
|
||||
from getpass import getuser
|
||||
from socket import gethostname
|
||||
@@ -64,6 +65,27 @@ except ImportError:
|
||||
CRYPTOGRAPHY_VERSION = "0.0"
|
||||
_ALGORITHM_PARAMETERS = {}
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
KeyFormat = t.Literal["SSH", "PKCS8", "PKCS1"]
|
||||
KeySerializationFormat = t.Literal["PEM", "DER", "SSH"]
|
||||
KeyType = t.Literal["rsa", "dsa", "ed25519", "ecdsa"]
|
||||
|
||||
PrivateKeyTypes = t.Union[
|
||||
rsa.RSAPrivateKey,
|
||||
dsa.DSAPrivateKey,
|
||||
ec.EllipticCurvePrivateKey,
|
||||
Ed25519PrivateKey,
|
||||
]
|
||||
PublicKeyTypes = t.Union[
|
||||
rsa.RSAPublicKey, dsa.DSAPublicKey, ec.EllipticCurvePublicKey, Ed25519PublicKey
|
||||
]
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PublicKeyTypes as AllPublicKeyTypes,
|
||||
)
|
||||
|
||||
|
||||
_TEXT_ENCODING = "UTF-8"
|
||||
|
||||
|
||||
@@ -111,11 +133,19 @@ class InvalidSignatureError(OpenSSHError):
|
||||
pass
|
||||
|
||||
|
||||
_AsymmetricKeypair = t.TypeVar("_AsymmetricKeypair", bound="AsymmetricKeypair")
|
||||
|
||||
|
||||
class AsymmetricKeypair:
|
||||
"""Container for newly generated asymmetric key pairs or those loaded from existing files"""
|
||||
|
||||
@classmethod
|
||||
def generate(cls, keytype="rsa", size=None, passphrase=None):
|
||||
def generate(
|
||||
cls: t.Type[_AsymmetricKeypair],
|
||||
keytype: KeyType = "rsa",
|
||||
size: int | None = None,
|
||||
passphrase: bytes | None = None,
|
||||
) -> _AsymmetricKeypair:
|
||||
"""Returns an Asymmetric_Keypair object generated with the supplied parameters
|
||||
or defaults to an unencrypted RSA-2048 key
|
||||
|
||||
@@ -124,19 +154,21 @@ class AsymmetricKeypair:
|
||||
:passphrase: Secret of type Bytes used to encrypt the private key being generated
|
||||
"""
|
||||
|
||||
if keytype not in _ALGORITHM_PARAMETERS.keys():
|
||||
if keytype not in _ALGORITHM_PARAMETERS:
|
||||
raise InvalidKeyTypeError(
|
||||
f"{keytype} is not a valid keytype. Valid keytypes are {', '.join(_ALGORITHM_PARAMETERS)}"
|
||||
)
|
||||
|
||||
if not size:
|
||||
size = _ALGORITHM_PARAMETERS[keytype]["default_size"]
|
||||
size = _ALGORITHM_PARAMETERS[keytype]["default_size"] # type: ignore
|
||||
else:
|
||||
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]:
|
||||
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]: # type: ignore
|
||||
raise InvalidKeySizeError(
|
||||
f"{size} is not a valid key size for {keytype} keys"
|
||||
)
|
||||
size = t.cast(int, size)
|
||||
|
||||
privatekey: PrivateKeyTypes
|
||||
if passphrase:
|
||||
encryption_algorithm = get_encryption_algorithm(passphrase)
|
||||
else:
|
||||
@@ -157,7 +189,7 @@ class AsymmetricKeypair:
|
||||
privatekey = Ed25519PrivateKey.generate()
|
||||
elif keytype == "ecdsa":
|
||||
privatekey = ec.generate_private_key(
|
||||
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size],
|
||||
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size], # type: ignore
|
||||
)
|
||||
|
||||
publickey = privatekey.public_key()
|
||||
@@ -172,13 +204,13 @@ class AsymmetricKeypair:
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
path,
|
||||
passphrase=None,
|
||||
private_key_format="PEM",
|
||||
public_key_format="PEM",
|
||||
no_public_key=False,
|
||||
):
|
||||
cls: t.Type[_AsymmetricKeypair],
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None = None,
|
||||
private_key_format: KeySerializationFormat = "PEM",
|
||||
public_key_format: KeySerializationFormat = "PEM",
|
||||
no_public_key: bool = False,
|
||||
) -> _AsymmetricKeypair:
|
||||
"""Returns an Asymmetric_Keypair object loaded from the supplied file path
|
||||
|
||||
:path: A path to an existing private key to be loaded
|
||||
@@ -197,14 +229,17 @@ class AsymmetricKeypair:
|
||||
if no_public_key:
|
||||
publickey = privatekey.public_key()
|
||||
else:
|
||||
publickey = load_publickey(path + ".pub", public_key_format)
|
||||
# TODO: BUG: load_publickey() can return unsupported key types
|
||||
# (Also we should check whether the public key fits the private key...)
|
||||
publickey = load_publickey(path + ".pub", public_key_format) # type: ignore
|
||||
|
||||
# Ed25519 keys are always of size 256 and do not have a key_size attribute
|
||||
if isinstance(privatekey, Ed25519PrivateKey):
|
||||
size = _ALGORITHM_PARAMETERS["ed25519"]["default_size"]
|
||||
size: int = _ALGORITHM_PARAMETERS["ed25519"]["default_size"] # type: ignore
|
||||
else:
|
||||
size = privatekey.key_size
|
||||
|
||||
keytype: KeyType
|
||||
if isinstance(privatekey, rsa.RSAPrivateKey):
|
||||
keytype = "rsa"
|
||||
elif isinstance(privatekey, dsa.DSAPrivateKey):
|
||||
@@ -224,7 +259,14 @@ class AsymmetricKeypair:
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
||||
|
||||
def __init__(self, keytype, size, privatekey, publickey, encryption_algorithm):
|
||||
def __init__(
|
||||
self,
|
||||
keytype: KeyType,
|
||||
size: int,
|
||||
privatekey: PrivateKeyTypes,
|
||||
publickey: PublicKeyTypes,
|
||||
encryption_algorithm: serialization.KeySerializationEncryption,
|
||||
) -> None:
|
||||
"""
|
||||
:keytype: One of rsa, dsa, ecdsa, ed25519
|
||||
:size: The key length for the private key of this key pair
|
||||
@@ -246,7 +288,7 @@ class AsymmetricKeypair:
|
||||
"The private key and public key of this keypair do not match"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, AsymmetricKeypair):
|
||||
return NotImplemented
|
||||
|
||||
@@ -256,55 +298,53 @@ class AsymmetricKeypair:
|
||||
self.encryption_algorithm, other.encryption_algorithm
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self == other
|
||||
|
||||
@property
|
||||
def private_key(self):
|
||||
def private_key(self) -> PrivateKeyTypes:
|
||||
"""Returns the private key of this key pair"""
|
||||
|
||||
return self.__privatekey
|
||||
|
||||
@property
|
||||
def public_key(self):
|
||||
def public_key(self) -> PublicKeyTypes:
|
||||
"""Returns the public key of this key pair"""
|
||||
|
||||
return self.__publickey
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
def size(self) -> int:
|
||||
"""Returns the size of the private key of this key pair"""
|
||||
|
||||
return self.__size
|
||||
|
||||
@property
|
||||
def key_type(self):
|
||||
def key_type(self) -> KeyType:
|
||||
"""Returns the key type of this key pair"""
|
||||
|
||||
return self.__keytype
|
||||
|
||||
@property
|
||||
def encryption_algorithm(self):
|
||||
def encryption_algorithm(self) -> serialization.KeySerializationEncryption:
|
||||
"""Returns the key encryption algorithm of this key pair"""
|
||||
|
||||
return self.__encryption_algorithm
|
||||
|
||||
def sign(self, data):
|
||||
def sign(self, data: bytes) -> bytes:
|
||||
"""Returns signature of data signed with the private key of this key pair
|
||||
|
||||
:data: byteslike data to sign
|
||||
"""
|
||||
|
||||
try:
|
||||
signature = self.__privatekey.sign(
|
||||
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"]
|
||||
return self.__privatekey.sign(
|
||||
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"] # type: ignore
|
||||
)
|
||||
except TypeError as e:
|
||||
raise InvalidDataError(e)
|
||||
|
||||
return signature
|
||||
|
||||
def verify(self, signature, data):
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
"""Verifies that the signature associated with the provided data was signed
|
||||
by the private key of this key pair.
|
||||
|
||||
@@ -312,15 +352,15 @@ class AsymmetricKeypair:
|
||||
:data: byteslike data signed by the provided signature
|
||||
"""
|
||||
try:
|
||||
return self.__publickey.verify(
|
||||
self.__publickey.verify(
|
||||
signature,
|
||||
data,
|
||||
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"],
|
||||
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"], # type: ignore
|
||||
)
|
||||
except InvalidSignature:
|
||||
raise InvalidSignatureError
|
||||
|
||||
def update_passphrase(self, passphrase=None):
|
||||
def update_passphrase(self, passphrase: bytes | None = None) -> None:
|
||||
"""Updates the encryption algorithm of this key pair
|
||||
|
||||
:passphrase: Byte secret used to encrypt this key pair
|
||||
@@ -332,11 +372,20 @@ class AsymmetricKeypair:
|
||||
self.__encryption_algorithm = serialization.NoEncryption()
|
||||
|
||||
|
||||
_OpensshKeypair = t.TypeVar("_OpensshKeypair", bound="OpensshKeypair")
|
||||
|
||||
|
||||
class OpensshKeypair:
|
||||
"""Container for OpenSSH encoded asymmetric key pairs"""
|
||||
|
||||
@classmethod
|
||||
def generate(cls, keytype="rsa", size=None, passphrase=None, comment=None):
|
||||
def generate(
|
||||
cls: t.Type[_OpensshKeypair],
|
||||
keytype: KeyType = "rsa",
|
||||
size: int | None = None,
|
||||
passphrase: bytes | None = None,
|
||||
comment: str | None = None,
|
||||
) -> _OpensshKeypair:
|
||||
"""Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key
|
||||
|
||||
:keytype: One of rsa, dsa, ecdsa, ed25519
|
||||
@@ -362,7 +411,12 @@ class OpensshKeypair:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path, passphrase=None, no_public_key=False):
|
||||
def load(
|
||||
cls: t.Type[_OpensshKeypair],
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None = None,
|
||||
no_public_key: bool = False,
|
||||
) -> _OpensshKeypair:
|
||||
"""Returns an Openssh_Keypair object loaded from the supplied file path
|
||||
|
||||
:path: A path to an existing private key to be loaded
|
||||
@@ -373,7 +427,7 @@ class OpensshKeypair:
|
||||
if no_public_key:
|
||||
comment = ""
|
||||
else:
|
||||
comment = extract_comment(path + ".pub")
|
||||
comment = extract_comment(str(path) + ".pub")
|
||||
|
||||
asym_keypair = AsymmetricKeypair.load(
|
||||
path, passphrase, "SSH", "SSH", no_public_key
|
||||
@@ -391,7 +445,9 @@ class OpensshKeypair:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def encode_openssh_privatekey(asym_keypair, key_format):
|
||||
def encode_openssh_privatekey(
|
||||
asym_keypair: AsymmetricKeypair, key_format: KeyFormat
|
||||
) -> bytes:
|
||||
"""Returns an OpenSSH encoded private key for a given keypair
|
||||
|
||||
:asym_keypair: Asymmetric_Keypair from the private key is extracted
|
||||
@@ -422,7 +478,9 @@ class OpensshKeypair:
|
||||
return encoded_privatekey
|
||||
|
||||
@staticmethod
|
||||
def encode_openssh_publickey(asym_keypair, comment):
|
||||
def encode_openssh_publickey(
|
||||
asym_keypair: AsymmetricKeypair, comment: str
|
||||
) -> bytes:
|
||||
"""Returns an OpenSSH encoded public key for a given keypair
|
||||
|
||||
:asym_keypair: Asymmetric_Keypair from the public key is extracted
|
||||
@@ -436,14 +494,19 @@ class OpensshKeypair:
|
||||
validate_comment(comment)
|
||||
|
||||
encoded_publickey += (
|
||||
f" {comment}".encode(encoding=_TEXT_ENCODING) if comment else b""
|
||||
(b" " + comment.encode(encoding=_TEXT_ENCODING)) if comment else b""
|
||||
)
|
||||
|
||||
return encoded_publickey
|
||||
|
||||
def __init__(
|
||||
self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment
|
||||
):
|
||||
self,
|
||||
asym_keypair: AsymmetricKeypair,
|
||||
openssh_privatekey: bytes,
|
||||
openssh_publickey: bytes,
|
||||
fingerprint: str,
|
||||
comment: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
|
||||
:openssh_privatekey: An OpenSSH encoded private key
|
||||
@@ -458,7 +521,7 @@ class OpensshKeypair:
|
||||
self.__fingerprint = fingerprint
|
||||
self.__comment = comment
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, OpensshKeypair):
|
||||
return NotImplemented
|
||||
|
||||
@@ -468,49 +531,49 @@ class OpensshKeypair:
|
||||
)
|
||||
|
||||
@property
|
||||
def asymmetric_keypair(self):
|
||||
def asymmetric_keypair(self) -> AsymmetricKeypair:
|
||||
"""Returns the underlying asymmetric key pair of this OpenSSH encoded key pair"""
|
||||
|
||||
return self.__asym_keypair
|
||||
|
||||
@property
|
||||
def private_key(self):
|
||||
def private_key(self) -> bytes:
|
||||
"""Returns the OpenSSH formatted private key of this key pair"""
|
||||
|
||||
return self.__openssh_privatekey
|
||||
|
||||
@property
|
||||
def public_key(self):
|
||||
def public_key(self) -> bytes:
|
||||
"""Returns the OpenSSH formatted public key of this key pair"""
|
||||
|
||||
return self.__openssh_publickey
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
def size(self) -> int:
|
||||
"""Returns the size of the private key of this key pair"""
|
||||
|
||||
return self.__asym_keypair.size
|
||||
|
||||
@property
|
||||
def key_type(self):
|
||||
def key_type(self) -> KeyType:
|
||||
"""Returns the key type of this key pair"""
|
||||
|
||||
return self.__asym_keypair.key_type
|
||||
|
||||
@property
|
||||
def fingerprint(self):
|
||||
def fingerprint(self) -> str:
|
||||
"""Returns the fingerprint (SHA256 Hash) of the public key of this key pair"""
|
||||
|
||||
return self.__fingerprint
|
||||
|
||||
@property
|
||||
def comment(self):
|
||||
def comment(self) -> str | None:
|
||||
"""Returns the comment applied to the OpenSSH formatted public key of this key pair"""
|
||||
|
||||
return self.__comment
|
||||
|
||||
@comment.setter
|
||||
def comment(self, comment):
|
||||
def comment(self, comment: str) -> bytes:
|
||||
"""Updates the comment applied to the OpenSSH formatted public key of this key pair
|
||||
|
||||
:comment: Text to update the OpenSSH public key comment
|
||||
@@ -529,7 +592,7 @@ class OpensshKeypair:
|
||||
)
|
||||
return self.__openssh_publickey
|
||||
|
||||
def update_passphrase(self, passphrase):
|
||||
def update_passphrase(self, passphrase: bytes | None) -> None:
|
||||
"""Updates the passphrase used to encrypt the private key of this keypair
|
||||
|
||||
:passphrase: Text secret used for encryption
|
||||
@@ -541,18 +604,17 @@ class OpensshKeypair:
|
||||
)
|
||||
|
||||
|
||||
def load_privatekey(path, passphrase, key_format):
|
||||
def load_privatekey(
|
||||
path: str | os.PathLike,
|
||||
passphrase: bytes | None,
|
||||
key_format: KeySerializationFormat,
|
||||
) -> PrivateKeyTypes:
|
||||
privatekey_loaders = {
|
||||
"PEM": serialization.load_pem_private_key,
|
||||
"DER": serialization.load_der_private_key,
|
||||
"SSH": serialization.load_ssh_private_key,
|
||||
}
|
||||
|
||||
# OpenSSH formatted private keys are not available in Cryptography <3.0
|
||||
if hasattr(serialization, "load_ssh_private_key"):
|
||||
privatekey_loaders["SSH"] = serialization.load_ssh_private_key
|
||||
else:
|
||||
privatekey_loaders["SSH"] = serialization.load_pem_private_key
|
||||
|
||||
try:
|
||||
privatekey_loader = privatekey_loaders[key_format]
|
||||
except KeyError:
|
||||
@@ -567,16 +629,16 @@ def load_privatekey(path, passphrase, key_format):
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
privatekey = privatekey_loader(
|
||||
privatekey = privatekey_loader( # type: ignore
|
||||
data=content,
|
||||
password=passphrase,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
except ValueError as exc:
|
||||
# Revert to PEM if key could not be loaded in SSH format
|
||||
if key_format == "SSH":
|
||||
try:
|
||||
privatekey = privatekey_loaders["PEM"](
|
||||
privatekey = privatekey_loaders["PEM"]( # type: ignore
|
||||
data=content,
|
||||
password=passphrase,
|
||||
)
|
||||
@@ -587,7 +649,7 @@ def load_privatekey(path, passphrase, key_format):
|
||||
except UnsupportedAlgorithm as e:
|
||||
raise InvalidAlgorithmError(e)
|
||||
else:
|
||||
raise InvalidPrivateKeyFileError(e)
|
||||
raise InvalidPrivateKeyFileError(exc)
|
||||
except TypeError as e:
|
||||
raise InvalidPassphraseError(e)
|
||||
except UnsupportedAlgorithm as e:
|
||||
@@ -596,7 +658,9 @@ def load_privatekey(path, passphrase, key_format):
|
||||
return privatekey
|
||||
|
||||
|
||||
def load_publickey(path, key_format):
|
||||
def load_publickey(
|
||||
path: str | os.PathLike, key_format: KeySerializationFormat
|
||||
) -> AllPublicKeyTypes:
|
||||
publickey_loaders = {
|
||||
"PEM": serialization.load_pem_public_key,
|
||||
"DER": serialization.load_der_public_key,
|
||||
@@ -628,20 +692,27 @@ def load_publickey(path, key_format):
|
||||
return publickey
|
||||
|
||||
|
||||
def compare_publickeys(pk1, pk2):
|
||||
def compare_publickeys(pk1: PublicKeyTypes, pk2: PublicKeyTypes) -> bool:
|
||||
a = isinstance(pk1, Ed25519PublicKey)
|
||||
b = isinstance(pk2, Ed25519PublicKey)
|
||||
if a or b:
|
||||
if not a or not b:
|
||||
return False
|
||||
a = pk1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
||||
b = pk2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
||||
return a == b
|
||||
a_bytes = pk1.public_bytes(
|
||||
serialization.Encoding.Raw, serialization.PublicFormat.Raw
|
||||
)
|
||||
b_bytes = pk2.public_bytes(
|
||||
serialization.Encoding.Raw, serialization.PublicFormat.Raw
|
||||
)
|
||||
return a_bytes == b_bytes
|
||||
else:
|
||||
return pk1.public_numbers() == pk2.public_numbers()
|
||||
return pk1.public_numbers() == pk2.public_numbers() # type: ignore
|
||||
|
||||
|
||||
def compare_encryption_algorithms(ea1, ea2):
|
||||
def compare_encryption_algorithms(
|
||||
ea1: serialization.KeySerializationEncryption,
|
||||
ea2: serialization.KeySerializationEncryption,
|
||||
) -> bool:
|
||||
if isinstance(ea1, serialization.NoEncryption) and isinstance(
|
||||
ea2, serialization.NoEncryption
|
||||
):
|
||||
@@ -654,19 +725,21 @@ def compare_encryption_algorithms(ea1, ea2):
|
||||
return False
|
||||
|
||||
|
||||
def get_encryption_algorithm(passphrase):
|
||||
def get_encryption_algorithm(
|
||||
passphrase: bytes,
|
||||
) -> serialization.KeySerializationEncryption:
|
||||
try:
|
||||
return serialization.BestAvailableEncryption(passphrase)
|
||||
except ValueError as e:
|
||||
raise InvalidPassphraseError(e)
|
||||
|
||||
|
||||
def validate_comment(comment):
|
||||
def validate_comment(comment: str) -> None:
|
||||
if not hasattr(comment, "encode"):
|
||||
raise InvalidCommentError(f"{comment} cannot be encoded to text")
|
||||
|
||||
|
||||
def extract_comment(path):
|
||||
def extract_comment(path: str | os.PathLike) -> str:
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise InvalidPublicKeyFileError(f"No file was found at {path}")
|
||||
@@ -684,7 +757,7 @@ def extract_comment(path):
|
||||
return comment
|
||||
|
||||
|
||||
def calculate_fingerprint(openssh_publickey):
|
||||
def calculate_fingerprint(openssh_publickey: bytes) -> str:
|
||||
digest = hashes.Hash(hashes.SHA256())
|
||||
decoded_pubkey = b64decode(openssh_publickey.split(b" ")[1])
|
||||
digest.update(decoded_pubkey)
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from struct import Struct
|
||||
|
||||
@@ -38,17 +39,20 @@ _UINT64 = Struct(b"!Q")
|
||||
_UINT64_MAX = 0xFFFFFFFFFFFFFFFF
|
||||
|
||||
|
||||
def any_in(sequence, *elements):
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
def any_in(sequence: t.Iterable[_T], *elements: _T) -> bool:
|
||||
return any(e in sequence for e in elements)
|
||||
|
||||
|
||||
def file_mode(path):
|
||||
def file_mode(path: str | os.PathLike) -> int:
|
||||
if not os.path.exists(path):
|
||||
return 0o000
|
||||
return os.stat(path).st_mode & 0o777
|
||||
|
||||
|
||||
def parse_openssh_version(version_string):
|
||||
def parse_openssh_version(version_string: str) -> str | None:
|
||||
"""Parse the version output of ssh -V and return version numbers that can be compared"""
|
||||
|
||||
parsed_result = re.match(
|
||||
@@ -63,7 +67,7 @@ def parse_openssh_version(version_string):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def secure_open(path, mode):
|
||||
def secure_open(path: str | os.PathLike, mode: int) -> t.Iterator[int]:
|
||||
fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode)
|
||||
try:
|
||||
yield fd
|
||||
@@ -71,7 +75,7 @@ def secure_open(path, mode):
|
||||
os.close(fd)
|
||||
|
||||
|
||||
def secure_write(path, mode, content):
|
||||
def secure_write(path: str | os.PathLike, mode: int, content: bytes) -> None:
|
||||
with secure_open(path, mode) as fd:
|
||||
os.write(fd, content)
|
||||
|
||||
@@ -84,35 +88,35 @@ class OpensshParser:
|
||||
UINT32_OFFSET = 4
|
||||
UINT64_OFFSET = 8
|
||||
|
||||
def __init__(self, data):
|
||||
def __init__(self, data: bytes | bytearray) -> None:
|
||||
if not isinstance(data, (bytes, bytearray)):
|
||||
raise TypeError(f"Data must be bytes-like not {type(data)}")
|
||||
|
||||
self._data = memoryview(data)
|
||||
self._pos = 0
|
||||
|
||||
def boolean(self):
|
||||
def boolean(self) -> bool:
|
||||
next_pos = self._check_position(self.BOOLEAN_OFFSET)
|
||||
|
||||
value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0]
|
||||
self._pos = next_pos
|
||||
return value
|
||||
|
||||
def uint32(self):
|
||||
def uint32(self) -> int:
|
||||
next_pos = self._check_position(self.UINT32_OFFSET)
|
||||
|
||||
value = _UINT32.unpack(self._data[self._pos : next_pos])[0]
|
||||
self._pos = next_pos
|
||||
return value
|
||||
|
||||
def uint64(self):
|
||||
def uint64(self) -> int:
|
||||
next_pos = self._check_position(self.UINT64_OFFSET)
|
||||
|
||||
value = _UINT64.unpack(self._data[self._pos : next_pos])[0]
|
||||
self._pos = next_pos
|
||||
return value
|
||||
|
||||
def string(self):
|
||||
def string(self) -> bytes:
|
||||
length = self.uint32()
|
||||
|
||||
next_pos = self._check_position(length)
|
||||
@@ -122,15 +126,15 @@ class OpensshParser:
|
||||
# Cast to bytes is required as a memoryview slice is itself a memoryview
|
||||
return bytes(value)
|
||||
|
||||
def mpint(self):
|
||||
def mpint(self) -> int:
|
||||
return self._big_int(self.string(), "big", signed=True)
|
||||
|
||||
def name_list(self):
|
||||
def name_list(self) -> list[str]:
|
||||
raw_string = self.string()
|
||||
return raw_string.decode("ASCII").split(",")
|
||||
|
||||
# Convenience function, but not an official data type from SSH
|
||||
def string_list(self):
|
||||
def string_list(self) -> list[bytes]:
|
||||
result = []
|
||||
raw_string = self.string()
|
||||
|
||||
@@ -142,7 +146,7 @@ class OpensshParser:
|
||||
return result
|
||||
|
||||
# Convenience function, but not an official data type from SSH
|
||||
def option_list(self):
|
||||
def option_list(self) -> list[tuple[bytes, bytes]]:
|
||||
result = []
|
||||
raw_string = self.string()
|
||||
|
||||
@@ -159,15 +163,15 @@ class OpensshParser:
|
||||
|
||||
return result
|
||||
|
||||
def seek(self, offset):
|
||||
def seek(self, offset: int) -> int:
|
||||
self._pos = self._check_position(offset)
|
||||
|
||||
return self._pos
|
||||
|
||||
def remaining_bytes(self):
|
||||
def remaining_bytes(self) -> int:
|
||||
return len(self._data) - self._pos
|
||||
|
||||
def _check_position(self, offset):
|
||||
def _check_position(self, offset: int) -> int:
|
||||
if self._pos + offset > len(self._data):
|
||||
raise ValueError(f"Insufficient data remaining at position: {self._pos}")
|
||||
elif self._pos + offset < 0:
|
||||
@@ -176,8 +180,8 @@ class OpensshParser:
|
||||
return self._pos + offset
|
||||
|
||||
@classmethod
|
||||
def signature_data(cls, signature_string):
|
||||
signature_data = {}
|
||||
def signature_data(cls, signature_string: bytes) -> dict[str, bytes | int]:
|
||||
signature_data: dict[str, bytes | int] = {}
|
||||
|
||||
parser = cls(signature_string)
|
||||
signature_type = parser.string()
|
||||
@@ -205,14 +209,19 @@ class OpensshParser:
|
||||
signature_data["R"] = cls._big_int(signature_blob[:32], "little")
|
||||
signature_data["S"] = cls._big_int(signature_blob[32:], "little")
|
||||
else:
|
||||
raise ValueError(f"{signature_type} is not a valid signature type")
|
||||
raise ValueError(f"{signature_type!r} is not a valid signature type")
|
||||
|
||||
signature_data["signature_type"] = signature_type
|
||||
|
||||
return signature_data
|
||||
|
||||
@classmethod
|
||||
def _big_int(cls, raw_string, byte_order, signed=False):
|
||||
def _big_int(
|
||||
cls,
|
||||
raw_string: bytes,
|
||||
byte_order: t.Literal["big", "little"],
|
||||
signed: bool = False,
|
||||
) -> int:
|
||||
if byte_order not in ("big", "little"):
|
||||
raise ValueError(
|
||||
f"Byte_order must be one of (big, little) not {byte_order}"
|
||||
@@ -230,18 +239,16 @@ class _OpensshWriter:
|
||||
in validating parsed material.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer=None):
|
||||
def __init__(self, buffer: bytearray | None = None):
|
||||
if buffer is not None:
|
||||
if not isinstance(buffer, (bytes, bytearray)):
|
||||
raise TypeError(
|
||||
f"Buffer must be a bytes-like object not {type(buffer)}"
|
||||
)
|
||||
if not isinstance(buffer, bytearray):
|
||||
raise TypeError(f"Buffer must be a bytearray, not {type(buffer)}")
|
||||
else:
|
||||
buffer = bytearray()
|
||||
|
||||
self._buff = buffer
|
||||
self._buff: bytearray = buffer
|
||||
|
||||
def boolean(self, value):
|
||||
def boolean(self, value: bool) -> t.Self:
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError(f"Value must be of type bool not {type(value)}")
|
||||
|
||||
@@ -249,7 +256,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def uint32(self, value):
|
||||
def uint32(self, value: int) -> t.Self:
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f"Value must be of type int not {type(value)}")
|
||||
if value < 0 or value > _UINT32_MAX:
|
||||
@@ -261,7 +268,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def uint64(self, value):
|
||||
def uint64(self, value: int) -> t.Self:
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f"Value must be of type int not {type(value)}")
|
||||
if value < 0 or value > _UINT64_MAX:
|
||||
@@ -273,7 +280,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def string(self, value):
|
||||
def string(self, value: bytes | bytearray) -> t.Self:
|
||||
if not isinstance(value, (bytes, bytearray)):
|
||||
raise TypeError(f"Value must be bytes-like not {type(value)}")
|
||||
self.uint32(len(value))
|
||||
@@ -281,7 +288,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def mpint(self, value):
|
||||
def mpint(self, value: int) -> t.Self:
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f"Value must be of type int not {type(value)}")
|
||||
|
||||
@@ -289,7 +296,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def name_list(self, value):
|
||||
def name_list(self, value: list[str]) -> t.Self:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Value must be a list of byte strings not {type(value)}")
|
||||
|
||||
@@ -300,7 +307,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def string_list(self, value):
|
||||
def string_list(self, value: list[bytes]) -> t.Self:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Value must be a list of byte string not {type(value)}")
|
||||
|
||||
@@ -312,7 +319,7 @@ class _OpensshWriter:
|
||||
|
||||
return self
|
||||
|
||||
def option_list(self, value):
|
||||
def option_list(self, value: list[tuple[bytes, bytes]]) -> t.Self:
|
||||
if not isinstance(value, list) or (value and not isinstance(value[0], tuple)):
|
||||
raise TypeError("Value must be a list of tuples")
|
||||
|
||||
@@ -327,7 +334,7 @@ class _OpensshWriter:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _int_to_mpint(num):
|
||||
def _int_to_mpint(num: int) -> bytes:
|
||||
byte_length = (num.bit_length() + 7) // 8
|
||||
try:
|
||||
return num.to_bytes(byte_length, "big", signed=True)
|
||||
@@ -335,5 +342,5 @@ class _OpensshWriter:
|
||||
except OverflowError:
|
||||
return num.to_bytes(byte_length + 1, "big", signed=True)
|
||||
|
||||
def bytes(self):
|
||||
def bytes(self) -> bytes:
|
||||
return bytes(self._buff)
|
||||
|
||||
@@ -10,7 +10,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.math impor
|
||||
)
|
||||
|
||||
|
||||
def th(number):
|
||||
def th(number: int) -> str:
|
||||
abs_number = abs(number)
|
||||
mod_10 = abs_number % 10
|
||||
mod_100 = abs_number % 100
|
||||
@@ -24,13 +24,13 @@ def th(number):
|
||||
return "th"
|
||||
|
||||
|
||||
def parse_serial(value):
|
||||
def parse_serial(value: str | bytes) -> int:
|
||||
"""
|
||||
Given a colon-separated string of hexadecimal byte values, converts it to an integer.
|
||||
"""
|
||||
value = to_native(value)
|
||||
value_str = to_native(value)
|
||||
result = 0
|
||||
for i, part in enumerate(value.split(":")):
|
||||
for i, part in enumerate(value_str.split(":")):
|
||||
try:
|
||||
part_value = int(part, 16)
|
||||
if part_value < 0 or part_value > 255:
|
||||
@@ -43,11 +43,11 @@ def parse_serial(value):
|
||||
return result
|
||||
|
||||
|
||||
def to_serial(value):
|
||||
def to_serial(value: int) -> str:
|
||||
"""
|
||||
Given an integer, converts its absolute value to a colon-separated string of hexadecimal byte values.
|
||||
"""
|
||||
value = convert_int_to_hex(value).upper()
|
||||
if len(value) % 2 != 0:
|
||||
value = "0" + value
|
||||
return ":".join(value[i : i + 2] for i in range(0, len(value), 2))
|
||||
value_str = convert_int_to_hex(value).upper()
|
||||
if len(value_str) % 2 != 0:
|
||||
value_str = f"0{value_str}"
|
||||
return ":".join(value_str[i : i + 2] for i in range(0, len(value_str), 2))
|
||||
|
||||
@@ -13,37 +13,16 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.basic impo
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
UTC = datetime.timezone.utc
|
||||
except AttributeError:
|
||||
_DURATION_ZERO = datetime.timedelta(0)
|
||||
|
||||
class _UTCClass(datetime.tzinfo):
|
||||
def utcoffset(self, dt):
|
||||
return _DURATION_ZERO
|
||||
|
||||
def dst(self, dt):
|
||||
return _DURATION_ZERO
|
||||
|
||||
def tzname(self, dt):
|
||||
return "UTC"
|
||||
|
||||
def fromutc(self, dt):
|
||||
return dt
|
||||
|
||||
def __repr__(self):
|
||||
return "UTC"
|
||||
|
||||
UTC = _UTCClass()
|
||||
UTC = datetime.timezone.utc
|
||||
|
||||
|
||||
def get_now_datetime(with_timezone):
|
||||
def get_now_datetime(with_timezone: bool) -> datetime.datetime:
|
||||
if with_timezone:
|
||||
return datetime.datetime.now(tz=UTC)
|
||||
return datetime.datetime.utcnow()
|
||||
|
||||
|
||||
def ensure_utc_timezone(timestamp):
|
||||
def ensure_utc_timezone(timestamp: datetime.datetime) -> datetime.datetime:
|
||||
if timestamp.tzinfo is UTC:
|
||||
return timestamp
|
||||
if timestamp.tzinfo is None:
|
||||
@@ -52,7 +31,7 @@ def ensure_utc_timezone(timestamp):
|
||||
return timestamp.astimezone(UTC)
|
||||
|
||||
|
||||
def remove_timezone(timestamp):
|
||||
def remove_timezone(timestamp: datetime.datetime) -> datetime.datetime:
|
||||
# Convert to native datetime object
|
||||
if timestamp.tzinfo is None:
|
||||
return timestamp
|
||||
@@ -61,26 +40,34 @@ def remove_timezone(timestamp):
|
||||
return timestamp.replace(tzinfo=None)
|
||||
|
||||
|
||||
def add_or_remove_timezone(timestamp, with_timezone):
|
||||
def add_or_remove_timezone(
|
||||
timestamp: datetime.datetime, with_timezone: bool
|
||||
) -> datetime.datetime:
|
||||
return (
|
||||
ensure_utc_timezone(timestamp) if with_timezone else remove_timezone(timestamp)
|
||||
)
|
||||
|
||||
|
||||
def get_epoch_seconds(timestamp):
|
||||
def get_epoch_seconds(timestamp: datetime.datetime) -> float:
|
||||
if timestamp.tzinfo is None:
|
||||
# timestamp.timestamp() is offset by the local timezone if timestamp has no timezone
|
||||
timestamp = ensure_utc_timezone(timestamp)
|
||||
return timestamp.timestamp()
|
||||
|
||||
|
||||
def from_epoch_seconds(timestamp, with_timezone):
|
||||
def from_epoch_seconds(
|
||||
timestamp: int | float, with_timezone: bool
|
||||
) -> datetime.datetime:
|
||||
if with_timezone:
|
||||
return datetime.datetime.fromtimestamp(timestamp, UTC)
|
||||
return datetime.datetime.utcfromtimestamp(timestamp)
|
||||
|
||||
|
||||
def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=None):
|
||||
def convert_relative_to_datetime(
|
||||
relative_time_string: str,
|
||||
with_timezone: bool = False,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> datetime.datetime | None:
|
||||
"""Get a datetime.datetime or None from a string in the time format described in sshd_config(5)"""
|
||||
|
||||
parsed_result = re.match(
|
||||
@@ -115,7 +102,12 @@ def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=
|
||||
return now - offset
|
||||
|
||||
|
||||
def get_relative_time_option(input_string, input_name, with_timezone=False, now=None):
|
||||
def get_relative_time_option(
|
||||
input_string: str,
|
||||
input_name: str,
|
||||
with_timezone: bool = False,
|
||||
now: datetime.datetime | None = None,
|
||||
) -> datetime.datetime:
|
||||
"""
|
||||
Return an absolute timespec if a relative timespec or an ASN1 formatted
|
||||
string is provided.
|
||||
@@ -129,9 +121,12 @@ def get_relative_time_option(input_string, input_name, with_timezone=False, now=
|
||||
)
|
||||
# Relative time
|
||||
if result.startswith("+") or result.startswith("-"):
|
||||
return convert_relative_to_datetime(
|
||||
result, with_timezone=with_timezone, now=now
|
||||
)
|
||||
res = convert_relative_to_datetime(result, with_timezone=with_timezone, now=now)
|
||||
if res is None:
|
||||
raise OpenSSLObjectError(
|
||||
f'The timespec "{input_string}" for {input_name} is invalid'
|
||||
)
|
||||
return res
|
||||
# Absolute time
|
||||
for date_fmt, length in [
|
||||
(
|
||||
|
||||
@@ -165,6 +165,7 @@ account_uri:
|
||||
"""
|
||||
|
||||
import base64
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
@@ -180,7 +181,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec()
|
||||
argument_spec.update_argspec(
|
||||
terms_agreed=dict(type="bool", default=False),
|
||||
@@ -204,24 +205,24 @@ def main():
|
||||
),
|
||||
)
|
||||
argument_spec.update(
|
||||
mutually_exclusive=(["new_account_key_src", "new_account_key_content"],),
|
||||
required_if=(
|
||||
mutually_exclusive=[("new_account_key_src", "new_account_key_content")],
|
||||
required_if=[
|
||||
# Make sure that for state == changed_key, one of
|
||||
# new_account_key_src and new_account_key_content are specified
|
||||
[
|
||||
(
|
||||
"state",
|
||||
"changed_key",
|
||||
["new_account_key_src", "new_account_key_content"],
|
||||
True,
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
module = argument_spec.create_ansible_module(supports_check_mode=True)
|
||||
backend = create_backend(module, True)
|
||||
|
||||
if module.params["external_account_binding"]:
|
||||
# Make sure padding is there
|
||||
key = module.params["external_account_binding"]["key"]
|
||||
key: str = module.params["external_account_binding"]["key"]
|
||||
if len(key) % 4 != 0:
|
||||
key = key + ("=" * (4 - (len(key) % 4)))
|
||||
# Make sure key is Base64 encoded
|
||||
@@ -237,24 +238,25 @@ def main():
|
||||
client = ACMEClient(module, backend)
|
||||
account = ACMEAccount(client)
|
||||
changed = False
|
||||
state = module.params.get("state")
|
||||
diff_before = {}
|
||||
diff_after = {}
|
||||
state: t.Literal["present", "absent", "changed_key"] = module.params["state"]
|
||||
diff_before: dict[str, t.Any] = {}
|
||||
diff_after: dict[str, t.Any] = {}
|
||||
if state == "absent":
|
||||
created, account_data = account.setup_account(allow_creation=False)
|
||||
if account_data:
|
||||
diff_before = dict(account_data)
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
if client.account_key_data:
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
if created:
|
||||
raise AssertionError("Unwanted account creation")
|
||||
if account_data is not None:
|
||||
# Account is not yet deactivated
|
||||
if not module.check_mode:
|
||||
# Deactivate it
|
||||
payload = {"status": "deactivated"}
|
||||
deactivate_payload = {"status": "deactivated"}
|
||||
result, info = client.send_signed_request(
|
||||
client.account_uri,
|
||||
payload,
|
||||
t.cast(str, client.account_uri),
|
||||
deactivate_payload,
|
||||
error_msg="Failed to deactivate account",
|
||||
expected_status_codes=[200],
|
||||
)
|
||||
@@ -278,13 +280,15 @@ def main():
|
||||
diff_before = {}
|
||||
else:
|
||||
diff_before = dict(account_data)
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
if client.account_key_data:
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
updated = False
|
||||
if not created:
|
||||
updated, account_data = account.update_account(account_data, contact)
|
||||
changed = created or updated
|
||||
diff_after = dict(account_data)
|
||||
diff_after["public_account_key"] = client.account_key_data["jwk"]
|
||||
if client.account_key_data:
|
||||
diff_after["public_account_key"] = client.account_key_data["jwk"]
|
||||
elif state == "changed_key":
|
||||
# Parse new account key
|
||||
try:
|
||||
@@ -306,7 +310,8 @@ def main():
|
||||
msg="Account does not exist or is deactivated."
|
||||
)
|
||||
diff_before = dict(account_data)
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
if client.account_key_data:
|
||||
diff_before["public_account_key"] = client.account_key_data["jwk"]
|
||||
# Now we can start the account key rollover
|
||||
if not module.check_mode:
|
||||
# Compose inner signed message
|
||||
@@ -317,12 +322,12 @@ def main():
|
||||
"jwk": new_key_data["jwk"],
|
||||
"url": url,
|
||||
}
|
||||
payload = {
|
||||
change_key_payload = {
|
||||
"account": client.account_uri,
|
||||
"newKey": new_key_data["jwk"], # specified in draft 12 and older
|
||||
"oldKey": client.account_jwk, # specified in draft 13 and newer
|
||||
}
|
||||
data = client.sign_request(protected, payload, new_key_data)
|
||||
data = client.sign_request(protected, change_key_payload, new_key_data)
|
||||
# Send request and verify result
|
||||
result, info = client.send_signed_request(
|
||||
url,
|
||||
@@ -332,8 +337,9 @@ def main():
|
||||
)
|
||||
if module._diff:
|
||||
client.account_key_data = new_key_data
|
||||
client.account_jws_header["alg"] = new_key_data["alg"]
|
||||
diff_after = account.get_account_data()
|
||||
if client.account_jws_header:
|
||||
client.account_jws_header["alg"] = new_key_data["alg"]
|
||||
diff_after = account.get_account_data() or {}
|
||||
elif module._diff:
|
||||
# Kind of fake diff_after
|
||||
diff_after = dict(diff_before)
|
||||
|
||||
@@ -204,6 +204,8 @@ order_uris:
|
||||
version_added: 1.5.0
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
)
|
||||
@@ -220,48 +222,55 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
def get_orders_list(module, client, orders_url):
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
def get_orders_list(
|
||||
module: AnsibleModule, client: ACMEClient, orders_url: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Retrieves orders list (handles pagination).
|
||||
Retrieves order URL list (handles pagination).
|
||||
"""
|
||||
orders = []
|
||||
while orders_url:
|
||||
orders: list[str] = []
|
||||
next_orders_url: str | None = orders_url
|
||||
while next_orders_url:
|
||||
# Get part of orders list
|
||||
res, info = client.get_request(
|
||||
orders_url, parse_json_result=True, fail_on_error=True
|
||||
next_orders_url, parse_json_result=True, fail_on_error=True
|
||||
)
|
||||
if not res.get("orders"):
|
||||
if orders:
|
||||
module.warn(
|
||||
f"When retrieving orders list part {orders_url}, got empty result list"
|
||||
f"When retrieving orders list part {next_orders_url}, got empty result list"
|
||||
)
|
||||
break
|
||||
# Add order URLs to result list
|
||||
orders.extend(res["orders"])
|
||||
# Extract URL of next part of results list
|
||||
new_orders_url = []
|
||||
new_orders_url: list[str | None] = []
|
||||
|
||||
def f(link, relation):
|
||||
def f(link: str, relation: str) -> None:
|
||||
if relation == "next":
|
||||
new_orders_url.append(link)
|
||||
|
||||
process_links(info, f)
|
||||
new_orders_url.append(None)
|
||||
previous_orders_url, orders_url = orders_url, new_orders_url.pop(0)
|
||||
if orders_url == previous_orders_url:
|
||||
previous_orders_url, next_orders_url = next_orders_url, new_orders_url.pop(0)
|
||||
if next_orders_url == previous_orders_url:
|
||||
# Prevent infinite loop
|
||||
orders_url = None
|
||||
next_orders_url = None
|
||||
return orders
|
||||
|
||||
|
||||
def get_order(client, order_url):
|
||||
def get_order(client: ACMEClient, order_url: str) -> dict[str, t.Any]:
|
||||
"""
|
||||
Retrieve order data.
|
||||
"""
|
||||
return client.get_request(order_url, parse_json_result=True, fail_on_error=True)[0]
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec()
|
||||
argument_spec.update_argspec(
|
||||
retrieve_orders=dict(
|
||||
@@ -282,16 +291,19 @@ def main():
|
||||
)
|
||||
if created:
|
||||
raise AssertionError("Unwanted account creation")
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"changed": False,
|
||||
"exists": client.account_uri is not None,
|
||||
"account_uri": client.account_uri,
|
||||
"exists": False,
|
||||
"account_uri": None,
|
||||
}
|
||||
if client.account_uri is not None:
|
||||
if client.account_uri is not None and account_data:
|
||||
result["account_uri"] = client.account_uri
|
||||
result["exists"] = True
|
||||
# Make sure promised data is there
|
||||
if "contact" not in account_data:
|
||||
account_data["contact"] = []
|
||||
account_data["public_account_key"] = client.account_key_data["jwk"]
|
||||
if client.account_key_data:
|
||||
account_data["public_account_key"] = client.account_key_data["jwk"]
|
||||
result["account"] = account_data
|
||||
# Retrieve orders list
|
||||
if (
|
||||
|
||||
@@ -94,6 +94,8 @@ renewal_info:
|
||||
sample: '2024-04-29T01:17:10.236921+00:00'
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
ACMEClient,
|
||||
create_backend,
|
||||
@@ -104,15 +106,15 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_account=False)
|
||||
argument_spec.update_argspec(
|
||||
certificate_path=dict(type="path"),
|
||||
certificate_content=dict(type="str"),
|
||||
)
|
||||
argument_spec.update(
|
||||
required_one_of=(["certificate_path", "certificate_content"],),
|
||||
mutually_exclusive=(["certificate_path", "certificate_content"],),
|
||||
required_one_of=[("certificate_path", "certificate_content")],
|
||||
mutually_exclusive=[("certificate_path", "certificate_content")],
|
||||
)
|
||||
module = argument_spec.create_ansible_module(supports_check_mode=True)
|
||||
backend = create_backend(module, True)
|
||||
|
||||
@@ -562,6 +562,7 @@ all_chains:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
@@ -592,6 +593,17 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.backends import (
|
||||
CertificateInformation,
|
||||
CryptoBackend,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
|
||||
Authorization,
|
||||
)
|
||||
|
||||
|
||||
NO_CHALLENGE = "no challenge"
|
||||
|
||||
|
||||
@@ -602,7 +614,7 @@ class ACMECertificateClient:
|
||||
certificates.
|
||||
"""
|
||||
|
||||
def __init__(self, module, backend):
|
||||
def __init__(self, module: AnsibleModule, backend: CryptoBackend):
|
||||
self.module = module
|
||||
self.version = module.params["acme_version"]
|
||||
self.challenge = module.params["challenge"]
|
||||
@@ -618,9 +630,9 @@ class ACMECertificateClient:
|
||||
self.account = ACMEAccount(self.client)
|
||||
self.directory = self.client.directory
|
||||
self.data = module.params["data"]
|
||||
self.authorizations = None
|
||||
self.authorizations: dict[str, Authorization] | None = None
|
||||
self.cert_days = -1
|
||||
self.order = None
|
||||
self.order: Order | None = None
|
||||
self.order_uri = self.data.get("order_uri") if self.data else None
|
||||
self.all_chains = None
|
||||
self.select_chain_matcher = []
|
||||
@@ -662,7 +674,6 @@ class ACMECertificateClient:
|
||||
contact.append("mailto:" + module.params["account_email"])
|
||||
created, account_data = self.account.setup_account(
|
||||
contact,
|
||||
agreement=module.params.get("agreement"),
|
||||
terms_agreed=module.params.get("terms_agreed"),
|
||||
allow_creation=modify_account,
|
||||
)
|
||||
@@ -681,7 +692,7 @@ class ACMECertificateClient:
|
||||
csr_filename=self.csr, csr_content=self.csr_content
|
||||
)
|
||||
|
||||
def is_first_step(self):
|
||||
def is_first_step(self) -> bool:
|
||||
"""
|
||||
Return True if this is the first execution of this module, i.e. if a
|
||||
sufficient data object from a first run has not been provided.
|
||||
@@ -692,7 +703,7 @@ class ACMECertificateClient:
|
||||
# stored in self.order_uri by the constructor).
|
||||
return self.order_uri is None
|
||||
|
||||
def _get_cert_info_or_none(self):
|
||||
def _get_cert_info_or_none(self) -> CertificateInformation | None:
|
||||
if self.module.params.get("dest"):
|
||||
filename = self.module.params["dest"]
|
||||
else:
|
||||
@@ -701,7 +712,7 @@ class ACMECertificateClient:
|
||||
return None
|
||||
return self.client.backend.get_cert_information(cert_filename=filename)
|
||||
|
||||
def start_challenges(self):
|
||||
def start_challenges(self) -> None:
|
||||
"""
|
||||
Create new authorizations for all identifiers of the CSR,
|
||||
respectively start a new order for ACME v2.
|
||||
@@ -733,13 +744,16 @@ class ACMECertificateClient:
|
||||
self.authorizations.update(self.order.authorizations)
|
||||
self.changed = True
|
||||
|
||||
def get_challenges_data(self, first_step):
|
||||
def get_challenges_data(
|
||||
self, first_step: bool
|
||||
) -> tuple[dict[str, t.Any], dict[str, list[str]]]:
|
||||
"""
|
||||
Get challenge details for the chosen challenge type.
|
||||
Return a tuple of generic challenge details, and specialized DNS challenge details.
|
||||
"""
|
||||
# Get general challenge data
|
||||
data = {}
|
||||
assert self.authorizations is not None
|
||||
data: dict[str, t.Any] = {}
|
||||
data_dns: dict[str, list[str]] = {}
|
||||
for type_identifier, authz in self.authorizations.items():
|
||||
identifier_type, identifier = split_identifier(type_identifier)
|
||||
# Skip valid authentications: their challenges are already valid
|
||||
@@ -747,7 +761,9 @@ class ACMECertificateClient:
|
||||
if authz.status == "valid":
|
||||
continue
|
||||
# We drop the type from the key to preserve backwards compatibility
|
||||
data[authz.identifier] = authz.get_challenge_data(self.client)
|
||||
challenges = authz.get_challenge_data(self.client)
|
||||
assert authz.identifier is not None
|
||||
data[authz.identifier] = challenges
|
||||
if (
|
||||
first_step
|
||||
and self.challenge is not None
|
||||
@@ -756,10 +772,7 @@ class ACMECertificateClient:
|
||||
raise ModuleFailException(
|
||||
f"Found no challenge of type '{self.challenge}' for identifier {type_identifier}!"
|
||||
)
|
||||
# Get DNS challenge data
|
||||
data_dns = {}
|
||||
if self.challenge == "dns-01":
|
||||
for identifier, challenges in data.items():
|
||||
if self.challenge == "dns-01":
|
||||
if self.challenge in challenges:
|
||||
values = data_dns.get(challenges[self.challenge]["record"])
|
||||
if values is None:
|
||||
@@ -768,7 +781,7 @@ class ACMECertificateClient:
|
||||
values.append(challenges[self.challenge]["resource_value"])
|
||||
return data, data_dns
|
||||
|
||||
def finish_challenges(self):
|
||||
def finish_challenges(self) -> None:
|
||||
"""
|
||||
Verify challenges for all identifiers of the CSR.
|
||||
"""
|
||||
@@ -777,6 +790,7 @@ class ACMECertificateClient:
|
||||
# Step 1: obtain challenge information
|
||||
# For ACME v2, we obtain the order object by fetching the
|
||||
# order URI, and extract the information from there.
|
||||
assert self.order_uri is not None
|
||||
self.order = Order.from_url(self.client, self.order_uri)
|
||||
self.order.load_authorizations(self.client)
|
||||
self.authorizations.update(self.order.authorizations)
|
||||
@@ -799,7 +813,9 @@ class ACMECertificateClient:
|
||||
# Step 3: wait for authzs to validate
|
||||
wait_for_validation(authzs_to_wait_for, self.client)
|
||||
|
||||
def download_alternate_chains(self, cert):
|
||||
def download_alternate_chains(
|
||||
self, cert: CertificateChain
|
||||
) -> list[CertificateChain]:
|
||||
alternate_chains = []
|
||||
for alternate in cert.alternates:
|
||||
try:
|
||||
@@ -812,7 +828,9 @@ class ACMECertificateClient:
|
||||
alternate_chains.append(alt_cert)
|
||||
return alternate_chains
|
||||
|
||||
def find_matching_chain(self, chains):
|
||||
def find_matching_chain(
|
||||
self, chains: t.Iterable[CertificateChain]
|
||||
) -> CertificateChain | None:
|
||||
for criterium_idx, matcher in enumerate(self.select_chain_matcher):
|
||||
for chain in chains:
|
||||
if matcher.match(chain):
|
||||
@@ -822,12 +840,13 @@ class ACMECertificateClient:
|
||||
return chain
|
||||
return None
|
||||
|
||||
def get_certificate(self):
|
||||
def get_certificate(self) -> None:
|
||||
"""
|
||||
Request a new certificate and write it to the destination file.
|
||||
First verifies whether all authorizations are valid; if not, aborts
|
||||
with an error.
|
||||
"""
|
||||
assert self.authorizations is not None
|
||||
for identifier_type, identifier in self.identifiers:
|
||||
authz = self.authorizations.get(
|
||||
normalize_combined_identifier(
|
||||
@@ -844,7 +863,9 @@ class ACMECertificateClient:
|
||||
module=self.module,
|
||||
)
|
||||
|
||||
assert self.order is not None
|
||||
self.order.finalize(self.client, pem_to_der(self.csr, self.csr_content))
|
||||
assert self.order.certificate_uri is not None
|
||||
cert = CertificateChain.download(self.client, self.order.certificate_uri)
|
||||
if self.module.params["retrieve_all_alternates"] or self.select_chain_matcher:
|
||||
# Retrieve alternate chains
|
||||
@@ -887,12 +908,13 @@ class ACMECertificateClient:
|
||||
):
|
||||
self.changed = True
|
||||
|
||||
def deactivate_authzs(self):
|
||||
def deactivate_authzs(self) -> None:
|
||||
"""
|
||||
Deactivates all valid authz's. Does not raise exceptions.
|
||||
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
|
||||
https://tools.ietf.org/html/rfc8555#section-7.5.2
|
||||
"""
|
||||
assert self.authorizations is not None
|
||||
for authz in self.authorizations.values():
|
||||
try:
|
||||
authz.deactivate(self.client)
|
||||
@@ -905,7 +927,7 @@ class ACMECertificateClient:
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_certificate=True)
|
||||
argument_spec.argument_spec["csr"]["aliases"] = ["src"]
|
||||
argument_spec.update_argspec(
|
||||
@@ -981,7 +1003,7 @@ def main():
|
||||
else:
|
||||
client = ACMECertificateClient(module, backend)
|
||||
client.cert_days = cert_days
|
||||
other = dict()
|
||||
other: dict[str, t.Any] = {}
|
||||
is_first_step = client.is_first_step()
|
||||
if is_first_step:
|
||||
# First run: start challenges / start new order
|
||||
@@ -998,6 +1020,7 @@ def main():
|
||||
client.deactivate_authzs()
|
||||
data, data_dns = client.get_challenges_data(first_step=is_first_step)
|
||||
auths = dict()
|
||||
assert client.authorizations is not None
|
||||
for k, v in client.authorizations.items():
|
||||
# Remove "type:" from key
|
||||
auths[v.identifier] = v.to_json()
|
||||
|
||||
@@ -51,6 +51,8 @@ EXAMPLES = r"""
|
||||
|
||||
RETURN = """#"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
)
|
||||
@@ -65,7 +67,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.orders import Order
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec()
|
||||
argument_spec.update_argspec(
|
||||
order_uri=dict(type="str", required=True),
|
||||
|
||||
@@ -371,6 +371,8 @@ account_uri:
|
||||
type: str
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
create_backend,
|
||||
create_default_argspec,
|
||||
@@ -383,7 +385,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_certificate=True)
|
||||
argument_spec.update_argspec(
|
||||
deactivate_authzs=dict(type="bool", default=True),
|
||||
|
||||
@@ -317,6 +317,8 @@ selected_chain:
|
||||
returned: always
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
create_backend,
|
||||
create_default_argspec,
|
||||
@@ -329,7 +331,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.certificates import (
|
||||
CertificateChain,
|
||||
)
|
||||
|
||||
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_certificate=True)
|
||||
argument_spec.update_argspec(
|
||||
order_uri=dict(type="str", required=True),
|
||||
@@ -375,6 +383,7 @@ def main():
|
||||
or module.params["retrieve_all_alternates"]
|
||||
)
|
||||
changed = False
|
||||
alternate_chains: list[CertificateChain] | None
|
||||
if order.status == "valid":
|
||||
# Step 2 and 3: download certificate(s) and chain(s)
|
||||
cert, alternate_chains = client.download_certificate(
|
||||
|
||||
@@ -357,6 +357,8 @@ authorizations_by_status:
|
||||
returned: always
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
create_backend,
|
||||
create_default_argspec,
|
||||
@@ -369,7 +371,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_certificate=False)
|
||||
argument_spec.update_argspec(
|
||||
order_uri=dict(type="str", required=True),
|
||||
@@ -381,8 +383,8 @@ def main():
|
||||
try:
|
||||
client = ACMECertificateClient(module, backend)
|
||||
order = client.load_order()
|
||||
authorizations_by_identifier = dict()
|
||||
authorizations_by_status = {
|
||||
authorizations_by_identifier: dict[str, dict[str, t.Any]] = {}
|
||||
authorizations_by_status: dict[str, list[str]] = {
|
||||
"pending": [],
|
||||
"invalid": [],
|
||||
"valid": [],
|
||||
@@ -392,7 +394,8 @@ def main():
|
||||
}
|
||||
for identifier, authz in order.authorizations.items():
|
||||
authorizations_by_identifier[identifier] = authz.to_json()
|
||||
authorizations_by_status[authz.status].append(identifier)
|
||||
if authz.status is not None:
|
||||
authorizations_by_status[authz.status].append(identifier)
|
||||
module.exit_json(
|
||||
changed=False,
|
||||
account_uri=client.client.account_uri,
|
||||
|
||||
@@ -229,6 +229,8 @@ validating_challenges:
|
||||
returned: always
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
create_backend,
|
||||
create_default_argspec,
|
||||
@@ -241,7 +243,13 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.challenges import (
|
||||
Authorization,
|
||||
)
|
||||
|
||||
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_certificate=False)
|
||||
argument_spec.update_argspec(
|
||||
order_uri=dict(type="str", required=True),
|
||||
@@ -271,10 +279,12 @@ def main():
|
||||
|
||||
missing_challenge_authzs = [k for k, v in challenges.items() if v is None]
|
||||
if missing_challenge_authzs:
|
||||
missing_challenge_authzs = ", ".join(sorted(missing_challenge_authzs))
|
||||
missing_challenge_authzs_str = ", ".join(
|
||||
sorted(missing_challenge_authzs)
|
||||
)
|
||||
raise ModuleFailException(
|
||||
"The challenge parameter must be supplied if there are pending authorizations."
|
||||
f" The following authorizations are pending: {missing_challenge_authzs}"
|
||||
f" The following authorizations are pending: {missing_challenge_authzs_str}"
|
||||
)
|
||||
|
||||
bad_challenge_authzs = [
|
||||
@@ -293,11 +303,13 @@ def main():
|
||||
f"The following authorizations do not support the selected challenges: {authz_challenges_pairs}"
|
||||
)
|
||||
|
||||
def is_pending(authz: Authorization) -> bool:
|
||||
challenge_name = challenges[authz.combined_identifier]
|
||||
challenge_obj = authz.find_challenge(challenge_name)
|
||||
return challenge_obj is not None and challenge_obj.status == "pending"
|
||||
|
||||
really_pending_authzs = [
|
||||
authz
|
||||
for authz in pending_authzs
|
||||
if authz.find_challenge(challenges[authz.combined_identifier]).status
|
||||
== "pending"
|
||||
authz for authz in pending_authzs if is_pending(authz)
|
||||
]
|
||||
|
||||
# Step 4: validate pending authorizations
|
||||
@@ -320,7 +332,7 @@ def main():
|
||||
identifier_type=authz.identifier_type,
|
||||
authz_url=authz.url,
|
||||
challenge_type=challenge_type,
|
||||
challenge_url=challenge.url,
|
||||
challenge_url=challenge.url if challenge else None,
|
||||
)
|
||||
for authz, challenge_type, challenge in authzs_with_challenges_to_wait_for
|
||||
],
|
||||
|
||||
@@ -160,6 +160,7 @@ cert_id:
|
||||
|
||||
import os
|
||||
import random
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
ACMEClient,
|
||||
@@ -175,7 +176,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(with_account=False)
|
||||
argument_spec.update_argspec(
|
||||
certificate_path=dict(type="path"),
|
||||
@@ -190,7 +191,7 @@ def main():
|
||||
treat_parsing_error_as_non_existing=dict(type="bool", default=False),
|
||||
)
|
||||
argument_spec.update(
|
||||
mutually_exclusive=(["certificate_path", "certificate_content"],),
|
||||
mutually_exclusive=[("certificate_path", "certificate_content")],
|
||||
)
|
||||
module = argument_spec.create_ansible_module(supports_check_mode=True)
|
||||
backend = create_backend(module, True)
|
||||
@@ -203,7 +204,7 @@ def main():
|
||||
supports_ari=False,
|
||||
)
|
||||
|
||||
def complete(should_renew, **kwargs):
|
||||
def complete(should_renew: bool, **kwargs) -> t.NoReturn:
|
||||
result["should_renew"] = should_renew
|
||||
result.update(kwargs)
|
||||
module.exit_json(**result)
|
||||
|
||||
@@ -110,6 +110,8 @@ EXAMPLES = r"""
|
||||
|
||||
RETURN = """#"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.account import (
|
||||
ACMEAccount,
|
||||
)
|
||||
@@ -129,7 +131,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(require_account_key=False)
|
||||
argument_spec.update_argspec(
|
||||
private_key_src=dict(type="path"),
|
||||
@@ -139,22 +141,22 @@ def main():
|
||||
revoke_reason=dict(type="int"),
|
||||
)
|
||||
argument_spec.update(
|
||||
required_one_of=(
|
||||
[
|
||||
required_one_of=[
|
||||
(
|
||||
"account_key_src",
|
||||
"account_key_content",
|
||||
"private_key_src",
|
||||
"private_key_content",
|
||||
],
|
||||
),
|
||||
mutually_exclusive=(
|
||||
[
|
||||
),
|
||||
],
|
||||
mutually_exclusive=[
|
||||
(
|
||||
"account_key_src",
|
||||
"account_key_content",
|
||||
"private_key_src",
|
||||
"private_key_content",
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
module = argument_spec.create_ansible_module()
|
||||
backend = create_backend(module, False)
|
||||
@@ -164,9 +166,9 @@ def main():
|
||||
account = ACMEAccount(client)
|
||||
# Load certificate
|
||||
certificate = pem_to_der(module.params.get("certificate"))
|
||||
certificate = nopad_b64(certificate)
|
||||
certificate_b64 = nopad_b64(certificate)
|
||||
# Construct payload
|
||||
payload = {"certificate": certificate}
|
||||
payload = {"certificate": certificate_b64}
|
||||
if module.params.get("revoke_reason") is not None:
|
||||
payload["reason"] = module.params.get("revoke_reason")
|
||||
endpoint = client.directory["revokeCert"]
|
||||
|
||||
@@ -149,6 +149,7 @@ regular_certificate:
|
||||
import base64
|
||||
import datetime
|
||||
import ipaddress
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
@@ -173,10 +174,13 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
try:
|
||||
import cryptography
|
||||
import cryptography.hazmat.backends
|
||||
import cryptography.hazmat.primitives.asymmetric.dh
|
||||
import cryptography.hazmat.primitives.asymmetric.ec
|
||||
import cryptography.hazmat.primitives.asymmetric.padding
|
||||
import cryptography.hazmat.primitives.asymmetric.rsa
|
||||
import cryptography.hazmat.primitives.asymmetric.utils
|
||||
import cryptography.hazmat.primitives.asymmetric.x448
|
||||
import cryptography.hazmat.primitives.asymmetric.x25519
|
||||
import cryptography.hazmat.primitives.hashes
|
||||
import cryptography.hazmat.primitives.serialization
|
||||
import cryptography.x509
|
||||
@@ -186,7 +190,7 @@ except ImportError:
|
||||
|
||||
|
||||
# Convert byte string to ASN1 encoded octet string
|
||||
def encode_octet_string(octet_string):
|
||||
def encode_octet_string(octet_string: bytes) -> bytes:
|
||||
if len(octet_string) >= 128:
|
||||
raise ModuleFailException(
|
||||
"Cannot handle octet strings with more than 128 bytes"
|
||||
@@ -194,7 +198,7 @@ def encode_octet_string(octet_string):
|
||||
return bytes([0x4, len(octet_string)]) + octet_string
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
challenge=dict(type="str", required=True, choices=["tls-alpn-01"]),
|
||||
@@ -213,16 +217,16 @@ def main():
|
||||
|
||||
try:
|
||||
# Get parameters
|
||||
challenge = module.params["challenge"]
|
||||
challenge_data = module.params["challenge_data"]
|
||||
challenge: t.Literal["tls-alpn-01"] = module.params["challenge"]
|
||||
challenge_data: dict[str, t.Any] = module.params["challenge_data"]
|
||||
|
||||
# Get hold of private key
|
||||
private_key_content = module.params.get("private_key_content")
|
||||
private_key_passphrase = module.params.get("private_key_passphrase")
|
||||
if private_key_content is None:
|
||||
private_key_content_str: str | None = module.params["private_key_content"]
|
||||
private_key_passphrase: str | None = module.params["private_key_passphrase"]
|
||||
if private_key_content_str is None:
|
||||
private_key_content = read_file(module.params["private_key_src"])
|
||||
else:
|
||||
private_key_content = to_bytes(private_key_content)
|
||||
private_key_content = to_bytes(private_key_content_str)
|
||||
try:
|
||||
private_key = (
|
||||
cryptography.hazmat.primitives.serialization.load_pem_private_key(
|
||||
@@ -236,6 +240,17 @@ def main():
|
||||
)
|
||||
except Exception as e:
|
||||
raise ModuleFailException(f"Error while loading private key: {e}")
|
||||
if isinstance(
|
||||
private_key,
|
||||
(
|
||||
cryptography.hazmat.primitives.asymmetric.dh.DHPrivateKey,
|
||||
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
|
||||
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
|
||||
),
|
||||
):
|
||||
raise ModuleFailException(
|
||||
f"Cannot use private key type {type(private_key)}"
|
||||
)
|
||||
|
||||
# Some common attributes
|
||||
domain = to_text(challenge_data["resource"])
|
||||
@@ -246,6 +261,7 @@ def main():
|
||||
now = get_now_datetime(with_timezone=CRYPTOGRAPHY_TIMEZONE)
|
||||
not_valid_before = now
|
||||
not_valid_after = now + datetime.timedelta(days=10)
|
||||
san: cryptography.x509.GeneralName
|
||||
if identifier_type == "dns":
|
||||
san = cryptography.x509.DNSName(identifier)
|
||||
elif identifier_type == "ip":
|
||||
|
||||
@@ -223,6 +223,8 @@ output_json:
|
||||
- '...'
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.acme.acme import (
|
||||
ACMEClient,
|
||||
@@ -235,7 +237,7 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = create_default_argspec(require_account_key=False)
|
||||
argument_spec.update_argspec(
|
||||
url=dict(type="str"),
|
||||
@@ -246,17 +248,17 @@ def main():
|
||||
fail_on_acme_error=dict(type="bool", default=True),
|
||||
)
|
||||
argument_spec.update(
|
||||
required_if=(
|
||||
["method", "get", ["url"]],
|
||||
["method", "post", ["url", "content"]],
|
||||
["method", "get", ["account_key_src", "account_key_content"], True],
|
||||
["method", "post", ["account_key_src", "account_key_content"], True],
|
||||
),
|
||||
required_if=[
|
||||
("method", "get", ["url"]),
|
||||
("method", "post", ["url", "content"]),
|
||||
("method", "get", ["account_key_src", "account_key_content"], True),
|
||||
("method", "post", ["account_key_src", "account_key_content"], True),
|
||||
],
|
||||
)
|
||||
module = argument_spec.create_ansible_module()
|
||||
backend = create_backend(module, False)
|
||||
|
||||
result = dict()
|
||||
result: dict[str, t.Any] = {}
|
||||
changed = False
|
||||
try:
|
||||
# Get hold of ACMEClient and ACMEAccount objects (includes directory)
|
||||
|
||||
@@ -121,6 +121,7 @@ complete_chain:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
@@ -153,14 +154,18 @@ class Certificate:
|
||||
Stores PEM with parsed certificate.
|
||||
"""
|
||||
|
||||
def __init__(self, pem, cert):
|
||||
def __init__(self, pem: str, cert: cryptography.x509.Certificate) -> None:
|
||||
if not (pem.endswith("\n") or pem.endswith("\r")):
|
||||
pem = pem + "\n"
|
||||
self.pem = pem
|
||||
self.cert = cert
|
||||
|
||||
|
||||
def is_parent(module, cert, potential_parent):
|
||||
def is_parent(
|
||||
module: AnsibleModule,
|
||||
cert: Certificate,
|
||||
potential_parent: Certificate,
|
||||
) -> bool:
|
||||
"""
|
||||
Tests whether the given certificate has been issued by the potential parent certificate.
|
||||
"""
|
||||
@@ -173,6 +178,10 @@ def is_parent(module, cert, potential_parent):
|
||||
if isinstance(
|
||||
public_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey
|
||||
):
|
||||
if cert.cert.signature_hash_algorithm is None:
|
||||
raise AssertionError(
|
||||
"signature_hash_algorithm should be present for RSA certificates"
|
||||
)
|
||||
public_key.verify(
|
||||
cert.cert.signature,
|
||||
cert.cert.tbs_certificate_bytes,
|
||||
@@ -183,6 +192,10 @@ def is_parent(module, cert, potential_parent):
|
||||
public_key,
|
||||
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey,
|
||||
):
|
||||
if cert.cert.signature_hash_algorithm is None:
|
||||
raise AssertionError(
|
||||
"signature_hash_algorithm should be present for EC certificates"
|
||||
)
|
||||
public_key.verify(
|
||||
cert.cert.signature,
|
||||
cert.cert.tbs_certificate_bytes,
|
||||
@@ -213,11 +226,16 @@ def is_parent(module, cert, potential_parent):
|
||||
module.fail_json(msg=f"Unknown error on signature validation: {e}")
|
||||
|
||||
|
||||
def parse_PEM_list(module, text, source, fail_on_error=True):
|
||||
def parse_PEM_list(
|
||||
module: AnsibleModule,
|
||||
text: str,
|
||||
source: str | os.PathLike,
|
||||
fail_on_error: bool = True,
|
||||
) -> list[Certificate]:
|
||||
"""
|
||||
Parse concatenated PEM certificates. Return list of ``Certificate`` objects.
|
||||
"""
|
||||
result = []
|
||||
result: list[Certificate] = []
|
||||
for cert_pem in split_pem_list(text):
|
||||
# Try to load PEM certificate
|
||||
try:
|
||||
@@ -232,7 +250,9 @@ def parse_PEM_list(module, text, source, fail_on_error=True):
|
||||
return result
|
||||
|
||||
|
||||
def load_PEM_list(module, path, fail_on_error=True):
|
||||
def load_PEM_list(
|
||||
module: AnsibleModule, path: str | os.PathLike, fail_on_error: bool = True
|
||||
) -> list[Certificate]:
|
||||
"""
|
||||
Load concatenated PEM certificates from file. Return list of ``Certificate`` objects.
|
||||
"""
|
||||
@@ -258,13 +278,15 @@ class CertificateSet:
|
||||
Stores a set of certificates. Allows to search for parent (issuer of a certificate).
|
||||
"""
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.module = module
|
||||
self.certificates = set()
|
||||
self.certificates_by_issuer = dict()
|
||||
self.certificate_by_cert = dict()
|
||||
self.certificates: set[Certificate] = set()
|
||||
self.certificates_by_issuer: dict[cryptography.x509.Name, list[Certificate]] = (
|
||||
{}
|
||||
)
|
||||
self.certificate_by_cert: dict[cryptography.x509.Certificate, Certificate] = {}
|
||||
|
||||
def _load_file(self, path):
|
||||
def _load_file(self, path: str | os.PathLike) -> None:
|
||||
certs = load_PEM_list(self.module, path, fail_on_error=False)
|
||||
for cert in certs:
|
||||
self.certificates.add(cert)
|
||||
@@ -273,7 +295,7 @@ class CertificateSet:
|
||||
self.certificates_by_issuer[cert.cert.subject].append(cert)
|
||||
self.certificate_by_cert[cert.cert] = cert
|
||||
|
||||
def load(self, path):
|
||||
def load(self, path: str | os.PathLike) -> None:
|
||||
"""
|
||||
Load lists of PEM certificates from a file or a directory.
|
||||
"""
|
||||
@@ -285,7 +307,7 @@ class CertificateSet:
|
||||
else:
|
||||
self._load_file(b_path)
|
||||
|
||||
def find_parent(self, cert):
|
||||
def find_parent(self, cert: Certificate) -> Certificate | None:
|
||||
"""
|
||||
Search for the parent (issuer) of a certificate. Return ``None`` if none was found.
|
||||
"""
|
||||
@@ -296,14 +318,18 @@ class CertificateSet:
|
||||
return None
|
||||
|
||||
|
||||
def format_cert(cert):
|
||||
def format_cert(cert: Certificate) -> str:
|
||||
"""
|
||||
Return human readable representation of certificate for error messages.
|
||||
"""
|
||||
return str(cert.cert)
|
||||
|
||||
|
||||
def check_cycle(module, occured_certificates, next):
|
||||
def check_cycle(
|
||||
module: AnsibleModule,
|
||||
occured_certificates: set[cryptography.x509.Certificate],
|
||||
next: Certificate,
|
||||
) -> None:
|
||||
"""
|
||||
Make sure that next is not in occured_certificates so far, and add it.
|
||||
"""
|
||||
@@ -313,7 +339,7 @@ def check_cycle(module, occured_certificates, next):
|
||||
occured_certificates.add(next_cert)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
input_chain=dict(type="str", required=True),
|
||||
@@ -354,10 +380,10 @@ def main():
|
||||
roots.load(path)
|
||||
|
||||
# Try to complete chain
|
||||
current = chain[-1]
|
||||
current: Certificate | None = chain[-1]
|
||||
completed = []
|
||||
occured_certificates = set([cert.cert for cert in chain])
|
||||
if current.cert in roots.certificate_by_cert:
|
||||
if current and current.cert in roots.certificate_by_cert:
|
||||
# Do not try to complete the chain when it is already ending with a root certificate
|
||||
current = None
|
||||
while current:
|
||||
|
||||
@@ -152,10 +152,13 @@ openssl:
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
CRYPTOGRAPHY_VERSION: str | None
|
||||
CRYPTOGRAPHY_IMP_ERR: str | None
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography.exceptions import UnsupportedAlgorithm
|
||||
@@ -165,10 +168,10 @@ try:
|
||||
# only got added in 0.2, so let's guard the import
|
||||
from cryptography.exceptions import InternalError as CryptographyInternalError
|
||||
except ImportError:
|
||||
CryptographyInternalError = Exception
|
||||
CryptographyInternalError = Exception # type: ignore
|
||||
except ImportError:
|
||||
UnsupportedAlgorithm = Exception
|
||||
CryptographyInternalError = Exception
|
||||
UnsupportedAlgorithm = Exception # type: ignore
|
||||
CryptographyInternalError = Exception # type: ignore
|
||||
HAS_CRYPTOGRAPHY = False
|
||||
CRYPTOGRAPHY_VERSION = None
|
||||
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
|
||||
@@ -201,8 +204,8 @@ CURVES = (
|
||||
)
|
||||
|
||||
|
||||
def add_crypto_information(module):
|
||||
result = {}
|
||||
def add_crypto_information(module: AnsibleModule) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
result["python_cryptography_installed"] = HAS_CRYPTOGRAPHY
|
||||
if not HAS_CRYPTOGRAPHY:
|
||||
result["python_cryptography_import_error"] = CRYPTOGRAPHY_IMP_ERR
|
||||
@@ -397,9 +400,9 @@ def add_crypto_information(module):
|
||||
return result
|
||||
|
||||
|
||||
def add_openssl_information(module):
|
||||
def add_openssl_information(module: AnsibleModule) -> dict[str, t.Any]:
|
||||
openssl_binary = module.get_bin_path("openssl")
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"openssl_present": openssl_binary is not None,
|
||||
}
|
||||
if openssl_binary is None:
|
||||
@@ -426,9 +429,9 @@ INFO_FUNCTIONS = (
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(argument_spec={}, supports_check_mode=True)
|
||||
result = {}
|
||||
result: dict[str, t.Any] = {}
|
||||
for fn in INFO_FUNCTIONS:
|
||||
result.update(fn(module))
|
||||
module.exit_json(**result)
|
||||
|
||||
@@ -550,6 +550,7 @@ import datetime
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes
|
||||
@@ -572,7 +573,7 @@ from ansible_collections.community.crypto.plugins.module_utils.io import write_f
|
||||
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
|
||||
|
||||
|
||||
def validate_cert_expiry(cert_expiry):
|
||||
def validate_cert_expiry(cert_expiry: str) -> bool:
|
||||
search_string_partial = re.compile(
|
||||
r"^([0-9]+)-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01])\Z"
|
||||
)
|
||||
@@ -587,7 +588,7 @@ def validate_cert_expiry(cert_expiry):
|
||||
return False
|
||||
|
||||
|
||||
def calculate_cert_days(expires_after):
|
||||
def calculate_cert_days(expires_after: str | None) -> int:
|
||||
cert_days = 0
|
||||
if expires_after:
|
||||
expires_after_datetime = datetime.datetime.strptime(
|
||||
@@ -600,7 +601,9 @@ def calculate_cert_days(expires_after):
|
||||
# Populate the value of body[dict_param_name] with the JSON equivalent of
|
||||
# module parameter of param_name if that parameter is present, otherwise leave field
|
||||
# out of resulting dict
|
||||
def convert_module_param_to_json_bool(module, dict_param_name, param_name):
|
||||
def convert_module_param_to_json_bool(
|
||||
module: AnsibleModule, dict_param_name: str, param_name: str
|
||||
) -> dict[str, str]:
|
||||
body = {}
|
||||
if module.params[param_name] is not None:
|
||||
if module.params[param_name]:
|
||||
@@ -886,7 +889,7 @@ class EcsCertificate:
|
||||
return result
|
||||
|
||||
|
||||
def custom_fields_spec():
|
||||
def custom_fields_spec() -> dict[str, dict[str, str]]:
|
||||
return dict(
|
||||
text1=dict(type="str"),
|
||||
text2=dict(type="str"),
|
||||
@@ -926,7 +929,7 @@ def custom_fields_spec():
|
||||
)
|
||||
|
||||
|
||||
def ecs_certificate_argument_spec():
|
||||
def ecs_certificate_argument_spec() -> dict[str, dict[str, t.Any]]:
|
||||
return dict(
|
||||
backup=dict(type="bool", default=False),
|
||||
force=dict(type="bool", default=False),
|
||||
@@ -979,7 +982,7 @@ def ecs_certificate_argument_spec():
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
ecs_argument_spec = ecs_client_argument_spec()
|
||||
ecs_argument_spec.update(ecs_certificate_argument_spec())
|
||||
module = AnsibleModule(
|
||||
|
||||
@@ -218,6 +218,7 @@ ev_days_remaining:
|
||||
|
||||
import datetime
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.ecs.api import (
|
||||
@@ -228,7 +229,7 @@ from ansible_collections.community.crypto.plugins.module_utils.ecs.api import (
|
||||
)
|
||||
|
||||
|
||||
def calculate_days_remaining(expiry_date):
|
||||
def calculate_days_remaining(expiry_date: str | None) -> int | None:
|
||||
days_remaining = None
|
||||
if expiry_date:
|
||||
expiry_datetime = datetime.datetime.strptime(expiry_date, "%Y-%m-%dT%H:%M:%SZ")
|
||||
@@ -403,8 +404,8 @@ class EcsDomain:
|
||||
msg=f"Failed to request domain validation from Entrust (ECS) {e.message}"
|
||||
)
|
||||
|
||||
def dump(self):
|
||||
result = {
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {
|
||||
"changed": self.changed,
|
||||
"client_id": self.client_id,
|
||||
"domain_status": self.domain_status,
|
||||
@@ -436,7 +437,7 @@ class EcsDomain:
|
||||
return result
|
||||
|
||||
|
||||
def ecs_domain_argument_spec():
|
||||
def ecs_domain_argument_spec() -> dict[str, dict[str, t.Any]]:
|
||||
return dict(
|
||||
client_id=dict(type="int", default=1),
|
||||
domain_name=dict(type="str", required=True),
|
||||
@@ -447,7 +448,7 @@ def ecs_domain_argument_spec():
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
ecs_argument_spec = ecs_client_argument_spec()
|
||||
ecs_argument_spec.update(ecs_domain_argument_spec())
|
||||
module = AnsibleModule(
|
||||
|
||||
@@ -268,6 +268,7 @@ import atexit
|
||||
import base64
|
||||
import ssl
|
||||
import sys
|
||||
import typing as t
|
||||
from os.path import isfile
|
||||
from socket import create_connection, setdefaulttimeout, socket
|
||||
from ssl import (
|
||||
@@ -305,7 +306,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def send_starttls_packet(sock, server_type):
|
||||
def send_starttls_packet(sock: socket, server_type: t.Literal["mysql"]) -> None:
|
||||
if server_type == "mysql":
|
||||
ssl_request_packet = (
|
||||
b"\x20\x00\x00\x01\x85\xae\x7f\x00"
|
||||
@@ -321,7 +322,7 @@ def send_starttls_packet(sock, server_type):
|
||||
sock.send(ssl_request_packet)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
ca_cert=dict(type="path"),
|
||||
@@ -342,18 +343,18 @@ def main():
|
||||
),
|
||||
)
|
||||
|
||||
ca_cert = module.params.get("ca_cert")
|
||||
host = module.params.get("host")
|
||||
port = module.params.get("port")
|
||||
proxy_host = module.params.get("proxy_host")
|
||||
proxy_port = module.params.get("proxy_port")
|
||||
timeout = module.params.get("timeout")
|
||||
server_name = module.params.get("server_name")
|
||||
start_tls_server_type = module.params.get("starttls")
|
||||
ciphers = module.params.get("ciphers")
|
||||
asn1_base64 = module.params["asn1_base64"]
|
||||
tls_ctx_options = module.params["tls_ctx_options"]
|
||||
get_certificate_chain = module.params["get_certificate_chain"]
|
||||
ca_cert: str | None = module.params.get("ca_cert")
|
||||
host: str = module.params.get("host")
|
||||
port: int = module.params.get("port")
|
||||
proxy_host: str | None = module.params.get("proxy_host")
|
||||
proxy_port: int | None = module.params.get("proxy_port")
|
||||
timeout: int = module.params.get("timeout")
|
||||
server_name: str | None = module.params.get("server_name")
|
||||
start_tls_server_type: t.Literal["mysql"] | None = module.params.get("starttls")
|
||||
ciphers: list[str] | None = module.params.get("ciphers")
|
||||
asn1_base64: bool = module.params["asn1_base64"]
|
||||
tls_ctx_options: list[str | bytes | int] | None = module.params["tls_ctx_options"]
|
||||
get_certificate_chain: bool = module.params["get_certificate_chain"]
|
||||
|
||||
if get_certificate_chain and sys.version_info < (3, 10):
|
||||
module.fail_json(
|
||||
@@ -365,9 +366,9 @@ def main():
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
|
||||
result = dict(
|
||||
changed=False,
|
||||
)
|
||||
result: dict[str, t.Any] = {
|
||||
"changed": False,
|
||||
}
|
||||
|
||||
if timeout:
|
||||
setdefaulttimeout(timeout)
|
||||
@@ -409,7 +410,7 @@ def main():
|
||||
|
||||
if tls_ctx_options is not None:
|
||||
# Clear default ctx options
|
||||
ctx.options = 0
|
||||
ctx.options = 0 # type: ignore
|
||||
|
||||
# For each item in the tls_ctx_options list
|
||||
for tls_ctx_option in tls_ctx_options:
|
||||
@@ -450,8 +451,10 @@ def main():
|
||||
)
|
||||
|
||||
tls_sock = ctx.wrap_socket(sock, server_hostname=server_name or host)
|
||||
cert = tls_sock.getpeercert(True)
|
||||
cert = DER_cert_to_PEM_cert(cert)
|
||||
cert_der = tls_sock.getpeercert(True)
|
||||
if cert_der is None:
|
||||
raise Exception("Unexpected error: no peer certificate has been returned")
|
||||
cert: str = DER_cert_to_PEM_cert(cert_der)
|
||||
|
||||
if get_certificate_chain:
|
||||
if sys.version_info < (3, 13):
|
||||
@@ -474,7 +477,7 @@ def main():
|
||||
# Python 3.13 do not return lists of byte strings, but lists of _ssl.Certificate objects. This is going to
|
||||
# be fixed by https://github.com/python/cpython/pull/118669. For now we convert the certificates ourselves
|
||||
# if they are not byte strings to work around this.
|
||||
def _convert_chain(chain):
|
||||
def _convert_chain(chain: list[bytes]) -> list[bytes]:
|
||||
return [
|
||||
(
|
||||
c
|
||||
@@ -514,13 +517,13 @@ def main():
|
||||
result["extensions"] = []
|
||||
for dotted_number, entry in cryptography_get_extensions_from_cert(x509).items():
|
||||
oid = cryptography.x509.oid.ObjectIdentifier(dotted_number)
|
||||
ext = {
|
||||
ext: dict[str, t.Any] = {
|
||||
"critical": entry["critical"],
|
||||
"asn1_data": entry["value"],
|
||||
"name": cryptography_oid_to_name(oid, short=True),
|
||||
}
|
||||
if not asn1_base64:
|
||||
ext["asn1_data"] = base64.b64decode(ext["asn1_data"])
|
||||
ext["asn1_data"] = base64.b64decode(entry["value"]) # type: ignore
|
||||
result["extensions"].append(ext)
|
||||
|
||||
result["issuer"] = {}
|
||||
|
||||
@@ -420,16 +420,13 @@ name:
|
||||
import os
|
||||
import re
|
||||
import stat
|
||||
import typing as t
|
||||
from base64 import b64decode
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
|
||||
|
||||
RETURN_CODE = 0
|
||||
STDOUT = 1
|
||||
STDERR = 2
|
||||
|
||||
# used to get <luks-name> out of lsblk output in format 'crypt <luks-name>'
|
||||
# regex takes care of any possible blank characters
|
||||
LUKS_NAME_REGEX = re.compile(r"^crypt\s+([^\s]*)\s*$")
|
||||
@@ -456,7 +453,7 @@ LUKS2_HEADER_OFFSETS = [
|
||||
LUKS2_HEADER2 = b"SKUL\xba\xbe"
|
||||
|
||||
|
||||
def wipe_luks_headers(device):
|
||||
def wipe_luks_headers(device: str) -> None:
|
||||
wipe_offsets = []
|
||||
with open(device, "rb") as f:
|
||||
# f.seek(0)
|
||||
@@ -478,12 +475,12 @@ def wipe_luks_headers(device):
|
||||
|
||||
class Handler:
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self._module = module
|
||||
self._lsblk_bin = self._module.get_bin_path("lsblk", True)
|
||||
self._passphrase_encoding = module.params["passphrase_encoding"]
|
||||
|
||||
def get_passphrase_from_module_params(self, parameter_name):
|
||||
def get_passphrase_from_module_params(self, parameter_name: str) -> bytes | None:
|
||||
passphrase = self._module.params[parameter_name]
|
||||
if passphrase is None:
|
||||
return None
|
||||
@@ -496,88 +493,91 @@ class Handler:
|
||||
f"Error while base64-decoding '{parameter_name}': {exc}"
|
||||
)
|
||||
|
||||
def _run_command(self, command, data=None):
|
||||
def _run_command(
|
||||
self, command: list[str], data: bytes | None = None
|
||||
) -> tuple[int, str, str]:
|
||||
return self._module.run_command(command, data=data, binary_data=True)
|
||||
|
||||
def get_device_by_uuid(self, uuid):
|
||||
def get_device_by_uuid(self, uuid: str | None) -> str | None:
|
||||
"""Returns the device that holds UUID passed by user"""
|
||||
self._blkid_bin = self._module.get_bin_path("blkid", True)
|
||||
uuid = self._module.params["uuid"]
|
||||
if uuid is None:
|
||||
return None
|
||||
result = self._run_command([self._blkid_bin, "--uuid", uuid])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, stdout, dummy = self._run_command([self._blkid_bin, "--uuid", uuid])
|
||||
if rc != 0:
|
||||
return None
|
||||
return result[STDOUT].strip()
|
||||
return stdout.strip()
|
||||
|
||||
def get_device_by_label(self, label):
|
||||
def get_device_by_label(self, label: str) -> str | None:
|
||||
"""Returns the device that holds label passed by user"""
|
||||
self._blkid_bin = self._module.get_bin_path("blkid", True)
|
||||
label = self._module.params["label"]
|
||||
if label is None:
|
||||
return None
|
||||
result = self._run_command([self._blkid_bin, "--label", label])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, stdout, dummy = self._run_command([self._blkid_bin, "--label", label])
|
||||
if rc != 0:
|
||||
return None
|
||||
return result[STDOUT].strip()
|
||||
return stdout.strip()
|
||||
|
||||
def generate_luks_name(self, device):
|
||||
def generate_luks_name(self, device: str) -> str:
|
||||
"""Generate name for luks based on device UUID ('luks-<UUID>').
|
||||
Raises ValueError when obtaining of UUID fails.
|
||||
"""
|
||||
result = self._run_command([self._lsblk_bin, "-n", device, "-o", "UUID"])
|
||||
rc, stdout, stderr = self._run_command(
|
||||
[self._lsblk_bin, "-n", device, "-o", "UUID"]
|
||||
)
|
||||
|
||||
if result[RETURN_CODE] != 0:
|
||||
raise ValueError(
|
||||
f"Error while generating LUKS name for {device}: {result[STDERR]}"
|
||||
)
|
||||
dev_uuid = result[STDOUT].strip()
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while generating LUKS name for {device}: {stderr}")
|
||||
dev_uuid = stdout.strip()
|
||||
return f"luks-{dev_uuid}"
|
||||
|
||||
|
||||
class CryptHandler(Handler):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CryptHandler, self).__init__(module)
|
||||
self._cryptsetup_bin = self._module.get_bin_path("cryptsetup", True)
|
||||
|
||||
def get_container_name_by_device(self, device):
|
||||
def get_container_name_by_device(self, device: str) -> str | None:
|
||||
"""obtain LUKS container name based on the device where it is located
|
||||
return None if not found
|
||||
raise ValueError if lsblk command fails
|
||||
"""
|
||||
result = self._run_command([self._lsblk_bin, device, "-nlo", "type,name"])
|
||||
if result[RETURN_CODE] != 0:
|
||||
raise ValueError(
|
||||
f"Error while obtaining LUKS name for {device}: {result[STDERR]}"
|
||||
)
|
||||
rc, stdout, stderr = self._run_command(
|
||||
[self._lsblk_bin, device, "-nlo", "type,name"]
|
||||
)
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while obtaining LUKS name for {device}: {stderr}")
|
||||
|
||||
for line in result[STDOUT].splitlines(False):
|
||||
for line in stdout.splitlines(False):
|
||||
m = LUKS_NAME_REGEX.match(line)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return None
|
||||
|
||||
def get_container_device_by_name(self, name):
|
||||
def get_container_device_by_name(self, name: str) -> str | None:
|
||||
"""obtain device name based on the LUKS container name
|
||||
return None if not found
|
||||
raise ValueError if lsblk command fails
|
||||
"""
|
||||
# apparently each device can have only one LUKS container on it
|
||||
result = self._run_command([self._cryptsetup_bin, "status", name])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, stdout, dummy = self._run_command([self._cryptsetup_bin, "status", name])
|
||||
if rc != 0:
|
||||
return None
|
||||
|
||||
m = LUKS_DEVICE_REGEX.search(result[STDOUT])
|
||||
m = LUKS_DEVICE_REGEX.search(stdout)
|
||||
if not m:
|
||||
return None
|
||||
device = m.group(1)
|
||||
return device
|
||||
|
||||
def is_luks(self, device):
|
||||
def is_luks(self, device: str) -> bool:
|
||||
"""check if the LUKS container does exist"""
|
||||
result = self._run_command([self._cryptsetup_bin, "isLuks", device])
|
||||
return result[RETURN_CODE] == 0
|
||||
rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "isLuks", device])
|
||||
return rc == 0
|
||||
|
||||
def get_luks_type(self, device):
|
||||
def get_luks_type(self, device: str) -> t.Literal["luks1", "luks2"] | None:
|
||||
"""get the luks type of a device"""
|
||||
if self.is_luks(device):
|
||||
with open(device, "rb") as f:
|
||||
@@ -589,16 +589,18 @@ class CryptHandler(Handler):
|
||||
return "luks1"
|
||||
return None
|
||||
|
||||
def is_luks_slot_set(self, device, keyslot):
|
||||
def is_luks_slot_set(self, device: str, keyslot: int) -> bool:
|
||||
"""check if a keyslot is set"""
|
||||
result = self._run_command([self._cryptsetup_bin, "luksDump", device])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, stdout, dummy = self._run_command(
|
||||
[self._cryptsetup_bin, "luksDump", device]
|
||||
)
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while dumping LUKS header from {device}")
|
||||
result_luks1 = f"Key Slot {keyslot}: ENABLED" in result[STDOUT]
|
||||
result_luks2 = f" {keyslot}: luks2" in result[STDOUT]
|
||||
result_luks1 = f"Key Slot {keyslot}: ENABLED" in stdout
|
||||
result_luks2 = f" {keyslot}: luks2" in stdout
|
||||
return result_luks1 or result_luks2
|
||||
|
||||
def _add_pbkdf_options(self, options, pbkdf):
|
||||
def _add_pbkdf_options(self, options: list[str], pbkdf: dict[str, t.Any]) -> None:
|
||||
if pbkdf["iteration_time"] is not None:
|
||||
options.extend(["--iter-time", str(int(pbkdf["iteration_time"] * 1000))])
|
||||
if pbkdf["iteration_count"] is not None:
|
||||
@@ -612,16 +614,16 @@ class CryptHandler(Handler):
|
||||
|
||||
def run_luks_create(
|
||||
self,
|
||||
device,
|
||||
keyfile,
|
||||
passphrase,
|
||||
keyslot,
|
||||
keysize,
|
||||
cipher,
|
||||
hash_,
|
||||
sector_size,
|
||||
pbkdf,
|
||||
):
|
||||
device: str,
|
||||
keyfile: str | None,
|
||||
passphrase: bytes | None,
|
||||
keyslot: int | None,
|
||||
keysize: int | None,
|
||||
cipher: str | None,
|
||||
hash_: str | None,
|
||||
sector_size: str | None,
|
||||
pbkdf: dict[str, t.Any] | None,
|
||||
) -> None:
|
||||
# create a new luks container; use batch mode to auto confirm
|
||||
luks_type = self._module.params["type"]
|
||||
label = self._module.params["label"]
|
||||
@@ -653,23 +655,23 @@ class CryptHandler(Handler):
|
||||
else:
|
||||
args.append("-")
|
||||
|
||||
result = self._run_command(args, data=passphrase)
|
||||
if result[RETURN_CODE] != 0:
|
||||
raise ValueError(f"Error while creating LUKS on {device}: {result[STDERR]}")
|
||||
rc, dummy, stderr = self._run_command(args, data=passphrase)
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while creating LUKS on {device}: {stderr}")
|
||||
|
||||
def run_luks_open(
|
||||
self,
|
||||
device,
|
||||
keyfile,
|
||||
passphrase,
|
||||
perf_same_cpu_crypt,
|
||||
perf_submit_from_crypt_cpus,
|
||||
perf_no_read_workqueue,
|
||||
perf_no_write_workqueue,
|
||||
persistent,
|
||||
allow_discards,
|
||||
name,
|
||||
):
|
||||
device: str,
|
||||
keyfile: str | None,
|
||||
passphrase: bytes | None,
|
||||
perf_same_cpu_crypt: bool,
|
||||
perf_submit_from_crypt_cpus: bool,
|
||||
perf_no_read_workqueue: bool,
|
||||
perf_no_write_workqueue: bool,
|
||||
persistent: bool,
|
||||
allow_discards: bool,
|
||||
name: str,
|
||||
) -> None:
|
||||
args = [self._cryptsetup_bin]
|
||||
if keyfile:
|
||||
args.extend(["--key-file", keyfile])
|
||||
@@ -689,27 +691,27 @@ class CryptHandler(Handler):
|
||||
args.extend(["--allow-discards"])
|
||||
args.extend(["open", "--type", "luks", device, name])
|
||||
|
||||
result = self._run_command(args, data=passphrase)
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, dummy, stderr = self._run_command(args, data=passphrase)
|
||||
if rc != 0:
|
||||
raise ValueError(
|
||||
f"Error while opening LUKS container on {device}: {result[STDERR]}"
|
||||
f"Error while opening LUKS container on {device}: {stderr}"
|
||||
)
|
||||
|
||||
def run_luks_close(self, name):
|
||||
result = self._run_command([self._cryptsetup_bin, "close", name])
|
||||
if result[RETURN_CODE] != 0:
|
||||
def run_luks_close(self, name: str) -> None:
|
||||
rc, dummy, dummy2 = self._run_command([self._cryptsetup_bin, "close", name])
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while closing LUKS container {name}")
|
||||
|
||||
def run_luks_remove(self, device):
|
||||
def run_luks_remove(self, device: str) -> None:
|
||||
wipefs_bin = self._module.get_bin_path("wipefs", True)
|
||||
|
||||
name = self.get_container_name_by_device(device)
|
||||
if name is not None:
|
||||
self.run_luks_close(name)
|
||||
result = self._run_command([wipefs_bin, "--all", device])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, dummy, stderr = self._run_command([wipefs_bin, "--all", device])
|
||||
if rc != 0:
|
||||
raise ValueError(
|
||||
f"Error while wiping LUKS container signatures for {device}: {result[STDERR]}"
|
||||
f"Error while wiping LUKS container signatures for {device}: {stderr}"
|
||||
)
|
||||
|
||||
# For LUKS2, sometimes both `cryptsetup erase` and `wipefs` do **not**
|
||||
@@ -724,14 +726,14 @@ class CryptHandler(Handler):
|
||||
|
||||
def run_luks_add_key(
|
||||
self,
|
||||
device,
|
||||
keyfile,
|
||||
passphrase,
|
||||
new_keyfile,
|
||||
new_passphrase,
|
||||
new_keyslot,
|
||||
pbkdf,
|
||||
):
|
||||
device: str,
|
||||
keyfile: str | None,
|
||||
passphrase: bytes | None,
|
||||
new_keyfile: str | None,
|
||||
new_passphrase: bytes | None,
|
||||
new_keyslot: int | None,
|
||||
pbkdf: dict[str, t.Any] | None,
|
||||
) -> None:
|
||||
"""Add new key from a keyfile or passphrase to given 'device';
|
||||
authentication done using 'keyfile' or 'passphrase'.
|
||||
Raises ValueError when command fails.
|
||||
@@ -746,36 +748,47 @@ class CryptHandler(Handler):
|
||||
|
||||
if keyfile:
|
||||
args.extend(["--key-file", keyfile])
|
||||
else:
|
||||
elif passphrase is not None:
|
||||
args.extend(["--key-file", "-", "--keyfile-size", str(len(passphrase))])
|
||||
data.append(passphrase)
|
||||
else:
|
||||
raise ValueError("Need passphrase or keyfile")
|
||||
|
||||
if new_keyfile:
|
||||
args.append(new_keyfile)
|
||||
else:
|
||||
elif new_passphrase is not None:
|
||||
args.append("-")
|
||||
data.append(new_passphrase)
|
||||
else:
|
||||
raise ValueError("Need new passphrase or new keyfile")
|
||||
|
||||
result = self._run_command(args, data=b"".join(data) or None)
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, dummy, stderr = self._run_command(args, data=b"".join(data) or None)
|
||||
if rc != 0:
|
||||
raise ValueError(
|
||||
f"Error while adding new LUKS keyslot to {device}: {result[STDERR]}"
|
||||
f"Error while adding new LUKS keyslot to {device}: {stderr}"
|
||||
)
|
||||
|
||||
def run_luks_remove_key(
|
||||
self, device, keyfile, passphrase, keyslot, force_remove_last_key=False
|
||||
):
|
||||
self,
|
||||
device: str,
|
||||
keyfile: str | None,
|
||||
passphrase: bytes | None,
|
||||
keyslot: int | None,
|
||||
force_remove_last_key: bool = False,
|
||||
) -> None:
|
||||
"""Remove key from given device
|
||||
Raises ValueError when command fails
|
||||
"""
|
||||
if not force_remove_last_key:
|
||||
result = self._run_command([self._cryptsetup_bin, "luksDump", device])
|
||||
if result[RETURN_CODE] != 0:
|
||||
rc, stdout, dummy = self._run_command(
|
||||
[self._cryptsetup_bin, "luksDump", device]
|
||||
)
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while dumping LUKS header from {device}")
|
||||
keyslot_count = 0
|
||||
keyslot_area = False
|
||||
keyslot_re = re.compile(r"^Key Slot [0-9]+: ENABLED")
|
||||
for line in result[STDOUT].splitlines():
|
||||
for line in stdout.splitlines():
|
||||
if line.startswith("Keyslots:"):
|
||||
keyslot_area = True
|
||||
elif line.startswith(" "):
|
||||
@@ -808,13 +821,17 @@ class CryptHandler(Handler):
|
||||
# Since we supply -q no passphrase is needed
|
||||
args = [self._cryptsetup_bin, "luksKillSlot", device, "-q", str(keyslot)]
|
||||
passphrase = None
|
||||
result = self._run_command(args, data=passphrase)
|
||||
if result[RETURN_CODE] != 0:
|
||||
raise ValueError(
|
||||
f"Error while removing LUKS key from {device}: {result[STDERR]}"
|
||||
)
|
||||
rc, dummy, stderr = self._run_command(args, data=passphrase)
|
||||
if rc != 0:
|
||||
raise ValueError(f"Error while removing LUKS key from {device}: {stderr}")
|
||||
|
||||
def luks_test_key(self, device, keyfile, passphrase, keyslot=None):
|
||||
def luks_test_key(
|
||||
self,
|
||||
device: str,
|
||||
keyfile: str | None,
|
||||
passphrase: bytes | None,
|
||||
keyslot: int | None = None,
|
||||
) -> bool:
|
||||
"""Check whether the keyfile or passphrase works.
|
||||
Raises ValueError when command fails.
|
||||
"""
|
||||
@@ -830,42 +847,37 @@ class CryptHandler(Handler):
|
||||
if keyslot is not None:
|
||||
args.extend(["--key-slot", str(keyslot)])
|
||||
|
||||
result = self._run_command(args, data=data)
|
||||
if result[RETURN_CODE] == 0:
|
||||
rc, stdout, stderr = self._run_command(args, data=data)
|
||||
if rc == 0:
|
||||
return True
|
||||
for output in (STDOUT, STDERR):
|
||||
if "No key available with this passphrase" in result[output]:
|
||||
for output in (stdout, stderr):
|
||||
if "No key available with this passphrase" in output:
|
||||
return False
|
||||
if "No usable keyslot is available." in result[output]:
|
||||
if "No usable keyslot is available." in output:
|
||||
return False
|
||||
|
||||
# This check is necessary due to cryptsetup in version 2.0.3 not printing 'No usable keyslot is available'
|
||||
# when using the --key-slot parameter in combination with --test-passphrase
|
||||
if (
|
||||
result[RETURN_CODE] == 1
|
||||
and keyslot is not None
|
||||
and result[STDOUT] == ""
|
||||
and result[STDERR] == ""
|
||||
):
|
||||
if rc == 1 and keyslot is not None and stdout == "" and stderr == "":
|
||||
return False
|
||||
|
||||
raise ValueError(
|
||||
f"Error while testing whether keyslot exists on {device}: {result[STDERR]}"
|
||||
f"Error while testing whether keyslot exists on {device}: {stderr}"
|
||||
)
|
||||
|
||||
|
||||
class ConditionsHandler(Handler):
|
||||
|
||||
def __init__(self, module, crypthandler):
|
||||
def __init__(self, module: AnsibleModule, crypthandler: CryptHandler) -> None:
|
||||
super(ConditionsHandler, self).__init__(module)
|
||||
self._crypthandler = crypthandler
|
||||
self.device = self.get_device_name()
|
||||
|
||||
def get_device_name(self):
|
||||
device = self._module.params.get("device")
|
||||
label = self._module.params.get("label")
|
||||
uuid = self._module.params.get("uuid")
|
||||
name = self._module.params.get("name")
|
||||
def get_device_name(self) -> str | None:
|
||||
device: str | None = self._module.params.get("device")
|
||||
label: str | None = self._module.params.get("label")
|
||||
uuid: str | None = self._module.params.get("uuid")
|
||||
name: str | None = self._module.params.get("name")
|
||||
|
||||
if device is None and label is not None:
|
||||
device = self.get_device_by_label(label)
|
||||
@@ -876,7 +888,7 @@ class ConditionsHandler(Handler):
|
||||
|
||||
return device
|
||||
|
||||
def luks_create(self):
|
||||
def luks_create(self) -> bool:
|
||||
return (
|
||||
self.device is not None
|
||||
and (
|
||||
@@ -887,7 +899,7 @@ class ConditionsHandler(Handler):
|
||||
and not self._crypthandler.is_luks(self.device)
|
||||
)
|
||||
|
||||
def opened_luks_name(self):
|
||||
def opened_luks_name(self, device: str) -> str | None:
|
||||
"""If luks is already opened, return its name.
|
||||
If 'name' parameter is specified and differs
|
||||
from obtained value, fail.
|
||||
@@ -897,7 +909,7 @@ class ConditionsHandler(Handler):
|
||||
return None
|
||||
|
||||
# try to obtain luks name - it may be already opened
|
||||
name = self._crypthandler.get_container_name_by_device(self.device)
|
||||
name = self._crypthandler.get_container_name_by_device(device)
|
||||
|
||||
if name is None:
|
||||
# container is not open
|
||||
@@ -917,7 +929,7 @@ class ConditionsHandler(Handler):
|
||||
# container is opened and the names match
|
||||
return name
|
||||
|
||||
def luks_open(self):
|
||||
def luks_open(self) -> bool:
|
||||
if (
|
||||
(
|
||||
self._module.params["keyfile"] is None
|
||||
@@ -929,13 +941,13 @@ class ConditionsHandler(Handler):
|
||||
# conditions for open not fulfilled
|
||||
return False
|
||||
|
||||
name = self.opened_luks_name()
|
||||
name = self.opened_luks_name(self.device)
|
||||
|
||||
if name is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
def luks_close(self):
|
||||
def luks_close(self) -> bool:
|
||||
if (
|
||||
self._module.params["name"] is None and self.device is None
|
||||
) or self._module.params["state"] != "closed":
|
||||
@@ -948,15 +960,17 @@ class ConditionsHandler(Handler):
|
||||
luks_is_open = name is not None
|
||||
|
||||
if self._module.params["name"] is not None:
|
||||
self.device = self._crypthandler.get_container_device_by_name(
|
||||
device = self._crypthandler.get_container_device_by_name(
|
||||
self._module.params["name"]
|
||||
)
|
||||
# successfully getting device based on name means that luks is open
|
||||
luks_is_open = self.device is not None
|
||||
luks_is_open = device is not None
|
||||
if device is not None:
|
||||
self.device = device
|
||||
|
||||
return luks_is_open
|
||||
|
||||
def luks_add_key(self):
|
||||
def luks_add_key(self) -> bool:
|
||||
if (
|
||||
self.device is None
|
||||
or (
|
||||
@@ -995,7 +1009,7 @@ class ConditionsHandler(Handler):
|
||||
|
||||
return not key_present
|
||||
|
||||
def luks_remove_key(self):
|
||||
def luks_remove_key(self) -> bool:
|
||||
if self.device is None or (
|
||||
self._module.params["remove_keyfile"] is None
|
||||
and self._module.params["remove_passphrase"] is None
|
||||
@@ -1037,14 +1051,16 @@ class ConditionsHandler(Handler):
|
||||
self.get_passphrase_from_module_params("remove_passphrase"),
|
||||
)
|
||||
|
||||
def luks_remove(self):
|
||||
def luks_remove(self) -> bool:
|
||||
return (
|
||||
self.device is not None
|
||||
and self._module.params["state"] == "absent"
|
||||
and self._crypthandler.is_luks(self.device)
|
||||
)
|
||||
|
||||
def validate_keyslot(self, param, luks_type):
|
||||
def validate_keyslot(
|
||||
self, param: str, luks_type: t.Literal["luks1", "luks2"] | None
|
||||
) -> None:
|
||||
if self._module.params[param] is not None:
|
||||
if luks_type is None and param == "keyslot":
|
||||
if 8 <= self._module.params[param] <= 31:
|
||||
@@ -1066,7 +1082,7 @@ class ConditionsHandler(Handler):
|
||||
)
|
||||
|
||||
|
||||
def run_module():
|
||||
def run_module() -> t.NoReturn:
|
||||
# available arguments/parameters that a user can pass
|
||||
module_args = dict(
|
||||
state=dict(
|
||||
@@ -1122,7 +1138,7 @@ def run_module():
|
||||
]
|
||||
|
||||
# seed the result dict in the object
|
||||
result = dict(changed=False, name=None)
|
||||
result: dict[str, t.Any] = {"changed": False, "name": None}
|
||||
|
||||
module = AnsibleModule(
|
||||
argument_spec=module_args,
|
||||
@@ -1142,19 +1158,26 @@ def run_module():
|
||||
except Exception as e:
|
||||
module.fail_json(msg=str(e))
|
||||
|
||||
crypt = CryptHandler(module)
|
||||
conditions = ConditionsHandler(module, crypt)
|
||||
|
||||
# conditions not allowed to run
|
||||
if module.params["label"] is not None and module.params["type"] == "luks1":
|
||||
module.fail_json(msg="You cannot combine type luks1 with the label option.")
|
||||
|
||||
crypt = CryptHandler(module)
|
||||
try:
|
||||
conditions = ConditionsHandler(module, crypt)
|
||||
except ValueError as exc:
|
||||
module.fail_json(msg=str(exc))
|
||||
|
||||
if (
|
||||
module.params["keyslot"] is not None
|
||||
or module.params["new_keyslot"] is not None
|
||||
or module.params["remove_keyslot"] is not None
|
||||
):
|
||||
luks_type = crypt.get_luks_type(conditions.get_device_name())
|
||||
luks_type = (
|
||||
crypt.get_luks_type(conditions.device)
|
||||
if conditions.device is not None
|
||||
else None
|
||||
)
|
||||
if luks_type is None and module.params["type"] is not None:
|
||||
luks_type = module.params["type"]
|
||||
for param in ["keyslot", "new_keyslot", "remove_keyslot"]:
|
||||
@@ -1175,6 +1198,7 @@ def run_module():
|
||||
|
||||
# luks create
|
||||
if conditions.luks_create():
|
||||
assert conditions.device # ensured in conditions.luks_create()
|
||||
if not module.check_mode:
|
||||
try:
|
||||
crypt.run_luks_create(
|
||||
@@ -1196,11 +1220,13 @@ def run_module():
|
||||
|
||||
# luks open
|
||||
|
||||
name = conditions.opened_luks_name()
|
||||
if name is not None:
|
||||
result["name"] = name
|
||||
if conditions.device is not None:
|
||||
name = conditions.opened_luks_name(conditions.device)
|
||||
if name is not None:
|
||||
result["name"] = name
|
||||
|
||||
if conditions.luks_open():
|
||||
assert conditions.device # ensured in conditions.luks_open()
|
||||
name = module.params["name"]
|
||||
if name is None:
|
||||
try:
|
||||
@@ -1237,6 +1263,8 @@ def run_module():
|
||||
module.fail_json(msg=f"luks_device error: {e}")
|
||||
else:
|
||||
name = module.params["name"]
|
||||
if name is None:
|
||||
module.fail_json(msg="Cannot determine name to close device")
|
||||
if not module.check_mode:
|
||||
try:
|
||||
crypt.run_luks_close(name)
|
||||
@@ -1249,6 +1277,7 @@ def run_module():
|
||||
|
||||
# luks add key
|
||||
if conditions.luks_add_key():
|
||||
assert conditions.device # ensured in conditions.luks_add_key()
|
||||
if not module.check_mode:
|
||||
try:
|
||||
crypt.run_luks_add_key(
|
||||
@@ -1268,6 +1297,7 @@ def run_module():
|
||||
|
||||
# luks remove key
|
||||
if conditions.luks_remove_key():
|
||||
assert conditions.device # ensured in conditions.luks_remove_key()
|
||||
if not module.check_mode:
|
||||
try:
|
||||
last_key = module.params["force_remove_last_key"]
|
||||
@@ -1286,6 +1316,7 @@ def run_module():
|
||||
|
||||
# luks remove
|
||||
if conditions.luks_remove():
|
||||
assert conditions.device # ensured in conditions.luks_remove()
|
||||
if not module.check_mode:
|
||||
try:
|
||||
crypt.run_luks_remove(conditions.device)
|
||||
@@ -1299,7 +1330,7 @@ def run_module():
|
||||
module.exit_json(**result)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
run_module()
|
||||
|
||||
|
||||
|
||||
@@ -284,6 +284,7 @@ info:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.common import (
|
||||
@@ -302,47 +303,58 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
|
||||
|
||||
|
||||
class Certificate(OpensshModule):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(Certificate, self).__init__(module)
|
||||
self.ssh_keygen = KeygenCommand(self.module)
|
||||
|
||||
self.identifier = self.module.params["identifier"] or ""
|
||||
self.options = self.module.params["options"] or []
|
||||
self.path = self.module.params["path"]
|
||||
self.pkcs11_provider = self.module.params["pkcs11_provider"]
|
||||
self.principals = self.module.params["principals"] or []
|
||||
self.public_key = self.module.params["public_key"]
|
||||
self.regenerate = (
|
||||
self.identifier: str = self.module.params["identifier"] or ""
|
||||
self.options: list[str] = self.module.params["options"] or []
|
||||
self.path: str = self.module.params["path"]
|
||||
self.pkcs11_provider: str | None = self.module.params["pkcs11_provider"]
|
||||
self.principals: list[str] = self.module.params["principals"] or []
|
||||
self.public_key: str | None = self.module.params["public_key"]
|
||||
self.regenerate: t.Literal[
|
||||
"never",
|
||||
"fail",
|
||||
"partial_idempotence",
|
||||
"full_idempotence",
|
||||
"always",
|
||||
] = (
|
||||
self.module.params["regenerate"]
|
||||
if not self.module.params["force"]
|
||||
else "always"
|
||||
)
|
||||
self.serial_number = self.module.params["serial_number"]
|
||||
self.signature_algorithm = self.module.params["signature_algorithm"]
|
||||
self.signing_key = self.module.params["signing_key"]
|
||||
self.state = self.module.params["state"]
|
||||
self.type = self.module.params["type"]
|
||||
self.use_agent = self.module.params["use_agent"]
|
||||
self.valid_at = self.module.params["valid_at"]
|
||||
self.ignore_timestamps = self.module.params["ignore_timestamps"]
|
||||
self.serial_number: int | None = self.module.params["serial_number"]
|
||||
self.signature_algorithm: (
|
||||
t.Literal["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"] | None
|
||||
) = self.module.params["signature_algorithm"]
|
||||
self.signing_key: str | None = self.module.params["signing_key"]
|
||||
self.state: t.Literal["absent", "present"] = self.module.params["state"]
|
||||
self.type: t.Literal["host", "user"] | None = self.module.params["type"]
|
||||
self.use_agent: bool = self.module.params["use_agent"]
|
||||
self.valid_at: str | None = self.module.params["valid_at"]
|
||||
self.ignore_timestamps: bool = self.module.params["ignore_timestamps"]
|
||||
|
||||
self._check_if_base_dir(self.path)
|
||||
|
||||
if self.state == "present":
|
||||
self._validate_parameters()
|
||||
|
||||
self.data = None
|
||||
self.original_data = None
|
||||
self.data: OpensshCertificate | None = None
|
||||
self.original_data: OpensshCertificate | None = None
|
||||
if self._exists():
|
||||
self._load_certificate()
|
||||
|
||||
self.time_parameters = None
|
||||
self.time_parameters: OpensshCertificateTimeParameters | None = None
|
||||
if self.state == "present":
|
||||
self._set_time_parameters()
|
||||
|
||||
def _validate_parameters(self):
|
||||
def _validate_parameters(self) -> None:
|
||||
for path in (self.public_key, self.signing_key):
|
||||
self._check_if_base_dir(path)
|
||||
if (
|
||||
path is not None
|
||||
): # should never be None, but the type checker doesn't know
|
||||
self._check_if_base_dir(path)
|
||||
|
||||
if self.options and self.type == "host":
|
||||
self.module.fail_json(
|
||||
@@ -352,7 +364,7 @@ class Certificate(OpensshModule):
|
||||
if self.use_agent:
|
||||
self._use_agent_available()
|
||||
|
||||
def _use_agent_available(self):
|
||||
def _use_agent_available(self) -> None:
|
||||
ssh_version = self._get_ssh_version()
|
||||
if not ssh_version:
|
||||
self.module.fail_json(msg="Failed to determine ssh version")
|
||||
@@ -362,10 +374,10 @@ class Certificate(OpensshModule):
|
||||
+ f" Your version is: {ssh_version}"
|
||||
)
|
||||
|
||||
def _exists(self):
|
||||
def _exists(self) -> bool:
|
||||
return os.path.exists(self.path)
|
||||
|
||||
def _load_certificate(self):
|
||||
def _load_certificate(self) -> None:
|
||||
try:
|
||||
self.original_data = OpensshCertificate.load(self.path)
|
||||
except (TypeError, ValueError) as e:
|
||||
@@ -373,7 +385,7 @@ class Certificate(OpensshModule):
|
||||
self.module.fail_json(msg=f"Unable to read existing certificate: {e}")
|
||||
self.module.warn(f"Unable to read existing certificate: {e}")
|
||||
|
||||
def _set_time_parameters(self):
|
||||
def _set_time_parameters(self) -> None:
|
||||
try:
|
||||
self.time_parameters = OpensshCertificateTimeParameters(
|
||||
valid_from=self.module.params["valid_from"],
|
||||
@@ -382,7 +394,7 @@ class Certificate(OpensshModule):
|
||||
except ValueError as e:
|
||||
self.module.fail_json(msg=str(e))
|
||||
|
||||
def _execute(self):
|
||||
def _execute(self) -> None:
|
||||
if self.state == "present":
|
||||
if self._should_generate():
|
||||
self._generate()
|
||||
@@ -391,7 +403,7 @@ class Certificate(OpensshModule):
|
||||
if self._exists():
|
||||
self._remove()
|
||||
|
||||
def _should_generate(self):
|
||||
def _should_generate(self) -> bool:
|
||||
if self.regenerate == "never":
|
||||
return self.original_data is None
|
||||
elif self.regenerate == "fail":
|
||||
@@ -408,7 +420,13 @@ class Certificate(OpensshModule):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _is_fully_valid(self):
|
||||
def _is_fully_valid(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
if self.public_key is None:
|
||||
raise AssertionError("Contract violation public_key not provided")
|
||||
if self.signing_key is None:
|
||||
raise AssertionError("Contract violation signing_key not provided")
|
||||
return self._is_partially_valid() and all(
|
||||
[
|
||||
self._compare_options() if self.original_data.type == "user" else True,
|
||||
@@ -420,7 +438,9 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _is_partially_valid(self):
|
||||
def _is_partially_valid(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
return all(
|
||||
[
|
||||
set(self.original_data.principals) == set(self.principals),
|
||||
@@ -439,7 +459,9 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _compare_time_parameters(self):
|
||||
def _compare_time_parameters(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
try:
|
||||
original_time_parameters = OpensshCertificateTimeParameters(
|
||||
valid_from=self.original_data.valid_after,
|
||||
@@ -458,7 +480,9 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _compare_options(self):
|
||||
def _compare_options(self) -> bool:
|
||||
if self.original_data is None:
|
||||
raise AssertionError("Contract violation original_data not provided")
|
||||
try:
|
||||
critical_options, extensions = parse_option_list(self.options)
|
||||
except ValueError as e:
|
||||
@@ -471,13 +495,13 @@ class Certificate(OpensshModule):
|
||||
]
|
||||
)
|
||||
|
||||
def _get_key_fingerprint(self, path):
|
||||
def _get_key_fingerprint(self, path: str) -> str:
|
||||
private_key_content = self.ssh_keygen.get_private_key(path, check_rc=True)[1]
|
||||
return PrivateKey.from_string(private_key_content).fingerprint
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _generate(self):
|
||||
def _generate(self) -> None:
|
||||
try:
|
||||
temp_certificate = self._generate_temp_certificate()
|
||||
self._safe_secure_move([(temp_certificate, self.path)])
|
||||
@@ -491,7 +515,14 @@ class Certificate(OpensshModule):
|
||||
except (TypeError, ValueError) as e:
|
||||
self.module.fail_json(msg=f"Unable to read new certificate: {e}")
|
||||
|
||||
def _generate_temp_certificate(self):
|
||||
def _generate_temp_certificate(self) -> str:
|
||||
if self.public_key is None:
|
||||
raise AssertionError("Contract violation public_key not provided")
|
||||
if self.signing_key is None:
|
||||
raise AssertionError("Contract violation signing_key not provided")
|
||||
if self.time_parameters is None:
|
||||
raise AssertionError("Contract violation time_parameters not provided")
|
||||
|
||||
key_copy = os.path.join(self.module.tmpdir, os.path.basename(self.public_key))
|
||||
|
||||
try:
|
||||
@@ -523,14 +554,14 @@ class Certificate(OpensshModule):
|
||||
|
||||
@OpensshModule.trigger_change
|
||||
@OpensshModule.skip_if_check_mode
|
||||
def _remove(self):
|
||||
def _remove(self) -> None:
|
||||
try:
|
||||
os.remove(self.path)
|
||||
except OSError as e:
|
||||
self.module.fail_json(msg=f"Unable to remove existing certificate: {e}")
|
||||
|
||||
@property
|
||||
def _result(self):
|
||||
def _result(self) -> dict[str, t.Any]:
|
||||
if self.state != "present":
|
||||
return {}
|
||||
|
||||
@@ -546,14 +577,14 @@ class Certificate(OpensshModule):
|
||||
}
|
||||
|
||||
@property
|
||||
def diff(self):
|
||||
def diff(self) -> dict[str, t.Any]:
|
||||
return {
|
||||
"before": get_cert_dict(self.original_data),
|
||||
"after": get_cert_dict(self.data),
|
||||
}
|
||||
|
||||
|
||||
def format_cert_info(cert_info):
|
||||
def format_cert_info(cert_info: str) -> list[str]:
|
||||
result = []
|
||||
string = ""
|
||||
|
||||
@@ -579,7 +610,7 @@ def format_cert_info(cert_info):
|
||||
return result
|
||||
|
||||
|
||||
def get_cert_dict(data):
|
||||
def get_cert_dict(data: OpensshCertificate | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return {}
|
||||
|
||||
@@ -590,7 +621,7 @@ def get_cert_dict(data):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
force=dict(type="bool", default=False),
|
||||
|
||||
@@ -198,13 +198,15 @@ comment:
|
||||
sample: test@comment
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.openssh.backends.keypair_backend import (
|
||||
select_backend,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
|
||||
@@ -239,6 +239,7 @@ csr:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
@@ -256,9 +257,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import (
|
||||
CertificateSigningRequestBackend,
|
||||
)
|
||||
|
||||
|
||||
class CertificateSigningRequestModule(OpenSSLObject):
|
||||
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend
|
||||
) -> None:
|
||||
super(CertificateSigningRequestModule, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -269,11 +279,11 @@ class CertificateSigningRequestModule(OpenSSLObject):
|
||||
self.return_content = module.params["return_content"]
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup_file: str | None = None
|
||||
|
||||
self.module_backend.set_existing(load_file_if_exists(self.path, module))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate the certificate signing request."""
|
||||
if self.force or self.module_backend.needs_regeneration():
|
||||
if not self.check_mode:
|
||||
@@ -292,13 +302,13 @@ class CertificateSigningRequestModule(OpenSSLObject):
|
||||
file_args, self.changed
|
||||
)
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
self.module_backend.set_existing(None)
|
||||
if self.backup and not self.check_mode:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
super(CertificateSigningRequestModule, self).remove(module)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = self.module_backend.dump(include_csr=self.return_content)
|
||||
result.update(
|
||||
@@ -312,7 +322,7 @@ class CertificateSigningRequestModule(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = get_csr_argument_spec()
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
|
||||
@@ -308,6 +308,7 @@ authority_cert_serial_number:
|
||||
sample: 12345
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -318,7 +319,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path"),
|
||||
@@ -335,11 +336,15 @@ def main():
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
if module.params["content"] is not None:
|
||||
data = module.params["content"].encode("utf-8")
|
||||
content: str | None = module.params["content"]
|
||||
path: str | None = module.params["path"]
|
||||
if content is not None:
|
||||
data = content.encode("utf-8")
|
||||
else:
|
||||
if path is None:
|
||||
module.fail_json(msg="One of content and path must be provided")
|
||||
try:
|
||||
with open(module.params["path"], "rb") as f:
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
except (IOError, OSError) as e:
|
||||
module.fail_json(msg=f"Error while reading CSR file from disk: {e}")
|
||||
|
||||
@@ -127,6 +127,8 @@ csr:
|
||||
type: str
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
)
|
||||
@@ -136,8 +138,17 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr import (
|
||||
CertificateSigningRequestBackend,
|
||||
)
|
||||
|
||||
|
||||
class CertificateSigningRequestModule:
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, module_backend: CertificateSigningRequestBackend
|
||||
) -> None:
|
||||
self.check_mode = module.check_mode
|
||||
self.module = module
|
||||
self.module_backend = module_backend
|
||||
@@ -145,13 +156,13 @@ class CertificateSigningRequestModule:
|
||||
if module.params["content"] is not None:
|
||||
self.module_backend.set_existing(module.params["content"].encode("utf-8"))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate the certificate signing request."""
|
||||
if self.module_backend.needs_regeneration():
|
||||
self.module_backend.generate_csr()
|
||||
self.changed = True
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = self.module_backend.dump(include_csr=True)
|
||||
result.update(
|
||||
@@ -162,7 +173,7 @@ class CertificateSigningRequestModule:
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = get_csr_argument_spec()
|
||||
argument_spec.argument_spec.update(
|
||||
dict(
|
||||
|
||||
@@ -132,6 +132,7 @@ import abc
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_native
|
||||
@@ -173,23 +174,23 @@ class DHParameterError(Exception):
|
||||
|
||||
class DHParameterBase:
|
||||
|
||||
def __init__(self, module):
|
||||
self.state = module.params["state"]
|
||||
self.path = module.params["path"]
|
||||
self.size = module.params["size"]
|
||||
self.force = module.params["force"]
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
self.state: t.Literal["absent", "present"] = module.params["state"]
|
||||
self.path: str = module.params["path"]
|
||||
self.size: int = module.params["size"]
|
||||
self.force: bool = module.params["force"]
|
||||
self.changed = False
|
||||
self.return_content = module.params["return_content"]
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def _do_generate(self, module):
|
||||
def _do_generate(self, module: AnsibleModule) -> None:
|
||||
"""Actually generate the DH params."""
|
||||
pass
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate DH params."""
|
||||
changed = False
|
||||
|
||||
@@ -206,7 +207,7 @@ class DHParameterBase:
|
||||
|
||||
self.changed = changed
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
try:
|
||||
@@ -215,28 +216,27 @@ class DHParameterBase:
|
||||
except OSError as exc:
|
||||
module.fail_json(msg=str(exc))
|
||||
|
||||
def check(self, module):
|
||||
def check(self, module: AnsibleModule) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
if self.force:
|
||||
return False
|
||||
return self._check_params_valid(module) and self._check_fs_attributes(module)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_params_valid(self, module):
|
||||
def _check_params_valid(self, module: AnsibleModule) -> bool:
|
||||
"""Check if the params are in the correct state"""
|
||||
pass
|
||||
|
||||
def _check_fs_attributes(self, module):
|
||||
def _check_fs_attributes(self, module: AnsibleModule) -> bool:
|
||||
"""Checks (and changes if not in check mode!) fs attributes"""
|
||||
file_args = module.load_file_common_arguments(module.params)
|
||||
if module.check_file_absent_if_check_mode(file_args["path"]):
|
||||
return False
|
||||
return not module.set_fs_attributes_if_different(file_args, False)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"size": self.size,
|
||||
"filename": self.path,
|
||||
"changed": self.changed,
|
||||
@@ -252,25 +252,24 @@ class DHParameterBase:
|
||||
|
||||
class DHParameterAbsent(DHParameterBase):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(DHParameterAbsent, self).__init__(module)
|
||||
|
||||
def _do_generate(self, module):
|
||||
def _do_generate(self, module: AnsibleModule) -> None:
|
||||
"""Actually generate the DH params."""
|
||||
pass
|
||||
|
||||
def _check_params_valid(self, module):
|
||||
def _check_params_valid(self, module: AnsibleModule) -> bool:
|
||||
"""Check if the params are in the correct state"""
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
class DHParameterOpenSSL(DHParameterBase):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(DHParameterOpenSSL, self).__init__(module)
|
||||
self.openssl_bin = module.get_bin_path("openssl", True)
|
||||
|
||||
def _do_generate(self, module):
|
||||
def _do_generate(self, module: AnsibleModule) -> None:
|
||||
"""Actually generate the DH params."""
|
||||
# create a tempfile
|
||||
fd, tmpsrc = tempfile.mkstemp()
|
||||
@@ -288,7 +287,7 @@ class DHParameterOpenSSL(DHParameterBase):
|
||||
except Exception as e:
|
||||
module.fail_json(msg=f"Failed to write to file {self.path}: {str(e)}")
|
||||
|
||||
def _check_params_valid(self, module):
|
||||
def _check_params_valid(self, module: AnsibleModule) -> bool:
|
||||
"""Check if the params are in the correct state"""
|
||||
command = [
|
||||
self.openssl_bin,
|
||||
@@ -321,10 +320,10 @@ class DHParameterOpenSSL(DHParameterBase):
|
||||
|
||||
class DHParameterCryptography(DHParameterBase):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(DHParameterCryptography, self).__init__(module)
|
||||
|
||||
def _do_generate(self, module):
|
||||
def _do_generate(self, module: AnsibleModule) -> None:
|
||||
"""Actually generate the DH params."""
|
||||
# Generate parameters
|
||||
params = cryptography.hazmat.primitives.asymmetric.dh.generate_parameters(
|
||||
@@ -341,7 +340,7 @@ class DHParameterCryptography(DHParameterBase):
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
write_file(module, result)
|
||||
|
||||
def _check_params_valid(self, module):
|
||||
def _check_params_valid(self, module: AnsibleModule) -> bool:
|
||||
"""Check if the params are in the correct state"""
|
||||
# Load parameters
|
||||
try:
|
||||
@@ -357,7 +356,7 @@ class DHParameterCryptography(DHParameterBase):
|
||||
return bits == self.size
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
"""Main function"""
|
||||
|
||||
module = AnsibleModule(
|
||||
@@ -383,6 +382,7 @@ def main():
|
||||
msg=f"The directory '{base_dir}' does not exist or the file is not a directory",
|
||||
)
|
||||
|
||||
dhparam: DHParameterOpenSSL | DHParameterCryptography | DHParameterAbsent
|
||||
if module.params["state"] == "present":
|
||||
backend = module.params["select_crypto_backend"]
|
||||
if backend == "auto":
|
||||
|
||||
@@ -280,6 +280,7 @@ import itertools
|
||||
import os
|
||||
import stat
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
@@ -296,7 +297,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
|
||||
OpenSSLObject,
|
||||
load_certificate,
|
||||
load_privatekey,
|
||||
load_certificate_issuer_privatekey,
|
||||
)
|
||||
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
|
||||
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
|
||||
@@ -320,6 +321,7 @@ except ImportError:
|
||||
|
||||
CRYPTOGRAPHY_COMPATIBILITY2022_ERR = None
|
||||
try:
|
||||
import cryptography.x509
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.serialization.pkcs12 import PBES
|
||||
|
||||
@@ -333,8 +335,22 @@ except Exception:
|
||||
else:
|
||||
CRYPTOGRAPHY_HAS_COMPATIBILITY2022 = True
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..module_utils.crypto.cryptography_support import (
|
||||
CertificateIssuerPrivateKeyTypes,
|
||||
)
|
||||
|
||||
def load_certificate_set(filename):
|
||||
PKCS12 = tuple[
|
||||
t.Union[CertificateIssuerPrivateKeyTypes, None],
|
||||
t.Union[cryptography.x509.Certificate, None],
|
||||
list[cryptography.x509.Certificate],
|
||||
t.Union[bytes, None],
|
||||
]
|
||||
|
||||
|
||||
def load_certificate_set(
|
||||
filename: str | os.PathLike,
|
||||
) -> list[cryptography.x509.Certificate]:
|
||||
"""
|
||||
Load list of concatenated PEM files, and return a list of parsed certificates.
|
||||
"""
|
||||
@@ -351,70 +367,80 @@ class PkcsError(OpenSSLObjectError):
|
||||
|
||||
|
||||
class Pkcs(OpenSSLObject):
|
||||
def __init__(self, module, iter_size_default=2048):
|
||||
path: str
|
||||
|
||||
def __init__(self, module: AnsibleModule, iter_size_default: int = 2048) -> None:
|
||||
super(Pkcs, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
module.params["force"],
|
||||
module.check_mode,
|
||||
)
|
||||
self.action = module.params["action"]
|
||||
self.other_certificates = module.params["other_certificates"]
|
||||
self.other_certificates_parse_all = module.params[
|
||||
self.action: t.Literal["export", "parse"] = module.params["action"]
|
||||
self.other_certificates: list[cryptography.x509.Certificate] = []
|
||||
self.other_certificates_str: list[str] | None = module.params[
|
||||
"other_certificates"
|
||||
]
|
||||
self.other_certificates_parse_all: bool = module.params[
|
||||
"other_certificates_parse_all"
|
||||
]
|
||||
self.other_certificates_content = module.params["other_certificates_content"]
|
||||
self.certificate_path = module.params["certificate_path"]
|
||||
self.certificate_content = module.params["certificate_content"]
|
||||
self.friendly_name = module.params["friendly_name"]
|
||||
self.iter_size = module.params["iter_size"] or iter_size_default
|
||||
self.maciter_size = module.params["maciter_size"] or 1
|
||||
self.encryption_level = module.params["encryption_level"]
|
||||
self.other_certificates_content: list[str] | None = module.params[
|
||||
"other_certificates_content"
|
||||
]
|
||||
self.certificate_path: str | None = module.params["certificate_path"]
|
||||
certificate_content: str | None = module.params["certificate_content"]
|
||||
self.friendly_name: str | None = module.params["friendly_name"]
|
||||
self.iter_size: int = module.params["iter_size"] or iter_size_default
|
||||
self.maciter_size: int = module.params["maciter_size"] or 1
|
||||
self.encryption_level: t.Literal["auto", "compatibility2022"] = module.params[
|
||||
"encryption_level"
|
||||
]
|
||||
self.passphrase = module.params["passphrase"]
|
||||
self.pkcs12 = None
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
self.pkcs12_bytes = None
|
||||
self.return_content = module.params["return_content"]
|
||||
self.src = module.params["src"]
|
||||
self.pkcs12: PKCS12 | None = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
self.pkcs12_bytes: bytes | None = None
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
self.src: str | None = module.params["src"]
|
||||
|
||||
if module.params["mode"] is None:
|
||||
module.params["mode"] = "0400"
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
self.certificate_content: bytes | None = None
|
||||
if self.certificate_path is not None:
|
||||
try:
|
||||
with open(self.certificate_path, "rb") as fh:
|
||||
self.certificate_content = fh.read()
|
||||
except (IOError, OSError) as exc:
|
||||
raise PkcsError(exc)
|
||||
elif self.certificate_content is not None:
|
||||
self.certificate_content = to_bytes(self.certificate_content)
|
||||
elif certificate_content is not None:
|
||||
self.certificate_content = to_bytes(certificate_content)
|
||||
|
||||
self.privatekey_content: bytes | None = None
|
||||
if self.privatekey_path is not None:
|
||||
try:
|
||||
with open(self.privatekey_path, "rb") as fh:
|
||||
self.privatekey_content = fh.read()
|
||||
except (IOError, OSError) as exc:
|
||||
raise PkcsError(exc)
|
||||
elif self.privatekey_content is not None:
|
||||
self.privatekey_content = to_bytes(self.privatekey_content)
|
||||
elif privatekey_content is not None:
|
||||
self.privatekey_content = to_bytes(privatekey_content)
|
||||
|
||||
if self.other_certificates:
|
||||
if self.other_certificates_str:
|
||||
if self.other_certificates_parse_all:
|
||||
filenames = list(self.other_certificates)
|
||||
self.other_certificates = []
|
||||
for other_cert_bundle in filenames:
|
||||
for other_cert_bundle in self.other_certificates_str:
|
||||
self.other_certificates.extend(
|
||||
load_certificate_set(other_cert_bundle)
|
||||
)
|
||||
else:
|
||||
self.other_certificates = [
|
||||
load_certificate(other_cert)
|
||||
for other_cert in self.other_certificates
|
||||
for other_cert in self.other_certificates_str
|
||||
]
|
||||
elif self.other_certificates_content:
|
||||
certs = self.other_certificates_content
|
||||
@@ -430,40 +456,42 @@ class Pkcs(OpenSSLObject):
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_bytes(self, module):
|
||||
def generate_bytes(self, module: AnsibleModule) -> bytes:
|
||||
"""Generate PKCS#12 file archive."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
|
||||
bytes | None,
|
||||
bytes | None,
|
||||
list[bytes],
|
||||
bytes | None,
|
||||
]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_bytes(self, pkcs12_content):
|
||||
def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _dump_privatekey(self, pkcs12):
|
||||
def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _dump_certificate(self, pkcs12):
|
||||
def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _dump_other_certificates(self, pkcs12):
|
||||
def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_friendly_name(self, pkcs12):
|
||||
pass
|
||||
|
||||
def check(self, module, perms_required=True):
|
||||
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
|
||||
state_and_perms = super(Pkcs, self).check(module, perms_required)
|
||||
|
||||
def _check_pkey_passphrase():
|
||||
def _check_pkey_passphrase() -> bool:
|
||||
if self.privatekey_passphrase:
|
||||
try:
|
||||
load_privatekey(
|
||||
None,
|
||||
load_certificate_issuer_privatekey(
|
||||
content=self.privatekey_content,
|
||||
passphrase=self.privatekey_passphrase,
|
||||
)
|
||||
@@ -476,6 +504,7 @@ class Pkcs(OpenSSLObject):
|
||||
|
||||
if os.path.exists(self.path) and module.params["action"] == "export":
|
||||
self.generate_bytes(module) # ignore result
|
||||
assert self.pkcs12 is not None
|
||||
self.src = self.path
|
||||
try:
|
||||
(
|
||||
@@ -524,7 +553,7 @@ class Pkcs(OpenSSLObject):
|
||||
return False
|
||||
elif (
|
||||
module.params["action"] == "parse"
|
||||
and os.path.exists(self.src)
|
||||
and os.path.exists(self.src or "")
|
||||
and os.path.exists(self.path)
|
||||
):
|
||||
try:
|
||||
@@ -548,10 +577,10 @@ class Pkcs(OpenSSLObject):
|
||||
|
||||
return _check_pkey_passphrase()
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"filename": self.path,
|
||||
}
|
||||
if self.privatekey_path:
|
||||
@@ -567,13 +596,20 @@ class Pkcs(OpenSSLObject):
|
||||
|
||||
return result
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
super(Pkcs, self).remove(module)
|
||||
|
||||
def parse(self):
|
||||
def parse(self) -> tuple[
|
||||
bytes | None,
|
||||
bytes | None,
|
||||
list[bytes],
|
||||
bytes | None,
|
||||
]:
|
||||
"""Read PKCS#12 file."""
|
||||
if self.src is None:
|
||||
raise AssertionError("Contract violation: src is None")
|
||||
|
||||
try:
|
||||
with open(self.src, "rb") as pkcs12_fh:
|
||||
@@ -582,10 +618,13 @@ class Pkcs(OpenSSLObject):
|
||||
except IOError as exc:
|
||||
raise PkcsError(exc)
|
||||
|
||||
def generate(self):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
# Empty method because OpenSSLObject wants this
|
||||
pass
|
||||
|
||||
def write(self, module, content, mode=None):
|
||||
def write(
|
||||
self, module: AnsibleModule, content: bytes, mode: int | str | None = None
|
||||
) -> None:
|
||||
"""Write the PKCS#12 file."""
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
@@ -595,7 +634,7 @@ class Pkcs(OpenSSLObject):
|
||||
|
||||
|
||||
class PkcsCryptography(Pkcs):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(PkcsCryptography, self).__init__(module, iter_size_default=50000)
|
||||
if (
|
||||
self.encryption_level == "compatibility2022"
|
||||
@@ -607,13 +646,12 @@ class PkcsCryptography(Pkcs):
|
||||
exception=CRYPTOGRAPHY_COMPATIBILITY2022_ERR,
|
||||
)
|
||||
|
||||
def generate_bytes(self, module):
|
||||
def generate_bytes(self, module: AnsibleModule) -> bytes:
|
||||
"""Generate PKCS#12 file archive."""
|
||||
pkey = None
|
||||
if self.privatekey_content:
|
||||
try:
|
||||
pkey = load_privatekey(
|
||||
None,
|
||||
pkey = load_certificate_issuer_privatekey(
|
||||
content=self.privatekey_content,
|
||||
passphrase=self.privatekey_passphrase,
|
||||
)
|
||||
@@ -631,6 +669,7 @@ class PkcsCryptography(Pkcs):
|
||||
# Store fake object which can be used to retrieve the components back
|
||||
self.pkcs12 = (pkey, cert, self.other_certificates, friendly_name)
|
||||
|
||||
encryption: serialization.KeySerializationEncryption
|
||||
if not self.passphrase:
|
||||
encryption = serialization.NoEncryption()
|
||||
elif self.encryption_level == "compatibility2022":
|
||||
@@ -654,7 +693,12 @@ class PkcsCryptography(Pkcs):
|
||||
encryption,
|
||||
)
|
||||
|
||||
def parse_bytes(self, pkcs12_content):
|
||||
def parse_bytes(self, pkcs12_content: bytes) -> tuple[
|
||||
bytes | None,
|
||||
bytes | None,
|
||||
list[bytes],
|
||||
bytes | None,
|
||||
]:
|
||||
try:
|
||||
private_key, certificate, additional_certificates, friendly_name = (
|
||||
parse_pkcs12(pkcs12_content, self.passphrase)
|
||||
@@ -683,11 +727,7 @@ class PkcsCryptography(Pkcs):
|
||||
except ValueError as exc:
|
||||
raise PkcsError(exc)
|
||||
|
||||
# The following methods will get self.pkcs12 passed, which is computed as:
|
||||
#
|
||||
# self.pkcs12 = (pkey, cert, self.other_certificates, self.friendly_name)
|
||||
|
||||
def _dump_privatekey(self, pkcs12):
|
||||
def _dump_privatekey(self, pkcs12: PKCS12) -> bytes | None:
|
||||
return (
|
||||
pkcs12[0].private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
@@ -698,27 +738,27 @@ class PkcsCryptography(Pkcs):
|
||||
else None
|
||||
)
|
||||
|
||||
def _dump_certificate(self, pkcs12):
|
||||
def _dump_certificate(self, pkcs12: PKCS12) -> bytes | None:
|
||||
return pkcs12[1].public_bytes(serialization.Encoding.PEM) if pkcs12[1] else None
|
||||
|
||||
def _dump_other_certificates(self, pkcs12):
|
||||
def _dump_other_certificates(self, pkcs12: PKCS12) -> list[bytes]:
|
||||
return [
|
||||
other_cert.public_bytes(serialization.Encoding.PEM)
|
||||
for other_cert in pkcs12[2]
|
||||
]
|
||||
|
||||
def _get_friendly_name(self, pkcs12):
|
||||
def _get_friendly_name(self, pkcs12: PKCS12) -> bytes | None:
|
||||
return pkcs12[3]
|
||||
|
||||
|
||||
def select_backend(module):
|
||||
def select_backend(module: AnsibleModule) -> Pkcs:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
return PkcsCryptography(module)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = dict(
|
||||
action=dict(type="str", default="export", choices=["export", "parse"]),
|
||||
other_certificates=dict(
|
||||
|
||||
@@ -155,6 +155,7 @@ privatekey:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
@@ -172,9 +173,18 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey import (
|
||||
PrivateKeyBackend,
|
||||
)
|
||||
|
||||
|
||||
class PrivateKeyModule(OpenSSLObject):
|
||||
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, module_backend: PrivateKeyBackend
|
||||
) -> None:
|
||||
super(PrivateKeyModule, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -182,19 +192,19 @@ class PrivateKeyModule(OpenSSLObject):
|
||||
module.check_mode,
|
||||
)
|
||||
self.module_backend = module_backend
|
||||
self.return_content = module.params["return_content"]
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
if self.force:
|
||||
module_backend.regenerate = "always"
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: str | None = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
if module.params["mode"] is None:
|
||||
module.params["mode"] = "0600"
|
||||
|
||||
module_backend.set_existing(load_file_if_exists(self.path, module))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate a keypair."""
|
||||
|
||||
if self.module_backend.needs_regeneration():
|
||||
@@ -228,13 +238,13 @@ class PrivateKeyModule(OpenSSLObject):
|
||||
file_args, self.changed
|
||||
)
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
self.module_backend.set_existing(None)
|
||||
if self.backup and not self.check_mode:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
super(PrivateKeyModule, self).remove(module)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
result = self.module_backend.dump(include_key=self.return_content)
|
||||
@@ -246,7 +256,7 @@ class PrivateKeyModule(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
|
||||
argument_spec = get_privatekey_argument_spec()
|
||||
argument_spec.argument_spec.update(
|
||||
|
||||
@@ -60,6 +60,7 @@ backup_file:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
@@ -77,8 +78,17 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.privatekey_convert import (
|
||||
PrivateKeyConvertBackend,
|
||||
)
|
||||
|
||||
|
||||
class PrivateKeyConvertModule(OpenSSLObject):
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(
|
||||
self, module: AnsibleModule, module_backend: PrivateKeyConvertBackend
|
||||
) -> None:
|
||||
super(PrivateKeyConvertModule, self).__init__(
|
||||
module.params["dest_path"],
|
||||
"present",
|
||||
@@ -87,8 +97,8 @@ class PrivateKeyConvertModule(OpenSSLObject):
|
||||
)
|
||||
self.module_backend = module_backend
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
module.params["path"] = module.params["dest_path"]
|
||||
if module.params["mode"] is None:
|
||||
@@ -96,12 +106,14 @@ class PrivateKeyConvertModule(OpenSSLObject):
|
||||
|
||||
module_backend.set_existing_destination(load_file_if_exists(self.path, module))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Do conversion."""
|
||||
|
||||
if self.module_backend.needs_conversion():
|
||||
# Convert
|
||||
privatekey_data = self.module_backend.get_private_key_data()
|
||||
if privatekey_data is None:
|
||||
raise AssertionError("Contract violation: privatekey_data is None")
|
||||
if not self.check_mode:
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
@@ -116,7 +128,7 @@ class PrivateKeyConvertModule(OpenSSLObject):
|
||||
file_args, self.changed
|
||||
)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
result = self.module_backend.dump()
|
||||
@@ -127,7 +139,7 @@ class PrivateKeyConvertModule(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
|
||||
argument_spec = get_privatekey_argument_spec()
|
||||
argument_spec.argument_spec.update(
|
||||
|
||||
@@ -200,6 +200,7 @@ private_data:
|
||||
type: dict
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -212,7 +213,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path"),
|
||||
@@ -243,7 +244,7 @@ def main():
|
||||
data = f.read()
|
||||
except (IOError, OSError) as e:
|
||||
module.fail_json(
|
||||
msg=f"Error while reading private key file from disk: {e}", **result
|
||||
msg=f"Error while reading private key file from disk: {e}", **result # type: ignore
|
||||
)
|
||||
|
||||
result["can_load_key"] = True
|
||||
@@ -261,10 +262,10 @@ def main():
|
||||
module.exit_json(**result)
|
||||
except PrivateKeyParseError as exc:
|
||||
result.update(exc.result)
|
||||
module.fail_json(msg=exc.error_message, **result)
|
||||
module.fail_json(msg=exc.error_message, **result) # type: ignore
|
||||
except PrivateKeyConsistencyError as exc:
|
||||
result.update(exc.result)
|
||||
module.fail_json(msg=exc.error_message, **result)
|
||||
module.fail_json(msg=exc.error_message, **result) # type: ignore
|
||||
except OpenSSLObjectError as exc:
|
||||
module.fail_json(msg=str(exc))
|
||||
|
||||
|
||||
@@ -186,6 +186,7 @@ publickey:
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -218,6 +219,12 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from cryptography.hazmat.primitives.asymmetric.types import (
|
||||
PrivateKeyTypes,
|
||||
PublicKeyTypes,
|
||||
)
|
||||
|
||||
|
||||
class PublicKeyError(OpenSSLObjectError):
|
||||
pass
|
||||
@@ -225,7 +232,7 @@ class PublicKeyError(OpenSSLObjectError):
|
||||
|
||||
class PublicKey(OpenSSLObject):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(PublicKey, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -233,27 +240,29 @@ class PublicKey(OpenSSLObject):
|
||||
module.check_mode,
|
||||
)
|
||||
self.module = module
|
||||
self.format = module.params["format"]
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
if self.privatekey_content is not None:
|
||||
self.privatekey_content = self.privatekey_content.encode("utf-8")
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.privatekey = None
|
||||
self.publickey_bytes = None
|
||||
self.return_content = module.params["return_content"]
|
||||
self.fingerprint = {}
|
||||
self.format: t.Literal["OpenSSH", "PEM"] = module.params["format"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
if privatekey_content is not None:
|
||||
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
|
||||
else:
|
||||
self.privatekey_content = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
self.privatekey: PrivateKeyTypes | None = None
|
||||
self.publickey_bytes: bytes | None = None
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
self.fingerprint: dict[str, str] = {}
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
self.diff_before = self._get_info(None)
|
||||
self.diff_after = self._get_info(None)
|
||||
|
||||
def _get_info(self, data):
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return dict()
|
||||
result = dict(can_parse_key=False)
|
||||
return {}
|
||||
result = {"can_parse_key": False}
|
||||
try:
|
||||
result.update(
|
||||
get_publickey_info(
|
||||
@@ -267,7 +276,7 @@ class PublicKey(OpenSSLObject):
|
||||
pass
|
||||
return result
|
||||
|
||||
def _create_publickey(self, module):
|
||||
def _create_publickey(self, module: AnsibleModule) -> bytes:
|
||||
self.privatekey = load_privatekey(
|
||||
path=self.privatekey_path,
|
||||
content=self.privatekey_content,
|
||||
@@ -284,10 +293,12 @@ class PublicKey(OpenSSLObject):
|
||||
crypto_serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Generate the public key."""
|
||||
|
||||
if self.privatekey_content is None and not os.path.exists(self.privatekey_path):
|
||||
if self.privatekey_path is not None and not os.path.exists(
|
||||
self.privatekey_path
|
||||
):
|
||||
raise PublicKeyError(
|
||||
f"The private key {self.privatekey_path} does not exist"
|
||||
)
|
||||
@@ -320,17 +331,18 @@ class PublicKey(OpenSSLObject):
|
||||
elif module.set_fs_attributes_if_different(file_args, False):
|
||||
self.changed = True
|
||||
|
||||
def check(self, module, perms_required=True):
|
||||
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
|
||||
state_and_perms = super(PublicKey, self).check(module, perms_required)
|
||||
|
||||
def _check_privatekey():
|
||||
if self.privatekey_content is None and not os.path.exists(
|
||||
def _check_privatekey() -> bool:
|
||||
if self.privatekey_path is not None and not os.path.exists(
|
||||
self.privatekey_path
|
||||
):
|
||||
return False
|
||||
|
||||
current_publickey: PublicKeyTypes
|
||||
try:
|
||||
with open(self.path, "rb") as public_key_fh:
|
||||
publickey_content = public_key_fh.read()
|
||||
@@ -369,15 +381,15 @@ class PublicKey(OpenSSLObject):
|
||||
|
||||
return _check_privatekey()
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
super(PublicKey, self).remove(module)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
|
||||
result = {
|
||||
result: dict[str, t.Any] = {
|
||||
"privatekey": self.privatekey_path,
|
||||
"filename": self.path,
|
||||
"format": self.format,
|
||||
@@ -403,7 +415,7 @@ class PublicKey(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
|
||||
@@ -152,6 +152,7 @@ public_data:
|
||||
returned: When RV(type=DSA) or RV(type=ECC)
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -163,7 +164,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path"),
|
||||
@@ -191,7 +192,7 @@ def main():
|
||||
data = f.read()
|
||||
except (IOError, OSError) as e:
|
||||
module.fail_json(
|
||||
msg=f"Error while reading public key file from disk: {e}", **result
|
||||
msg=f"Error while reading public key file from disk: {e}", **result # type: ignore
|
||||
)
|
||||
|
||||
module_backend = select_backend(module, data)
|
||||
@@ -201,7 +202,7 @@ def main():
|
||||
module.exit_json(**result)
|
||||
except PublicKeyParseError as exc:
|
||||
result.update(exc.result)
|
||||
module.fail_json(msg=exc.error_message, **result)
|
||||
module.fail_json(msg=exc.error_message, **result) # type: ignore
|
||||
except OpenSSLObjectError as exc:
|
||||
module.fail_json(msg=str(exc))
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ signature:
|
||||
|
||||
import base64
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
|
||||
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
|
||||
@@ -132,7 +133,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im
|
||||
|
||||
class SignatureBase(OpenSSLObject):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SignatureBase, self).__init__(
|
||||
path=module.params["path"],
|
||||
state="present",
|
||||
@@ -140,32 +141,35 @@ class SignatureBase(OpenSSLObject):
|
||||
check_mode=module.check_mode,
|
||||
)
|
||||
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
if self.privatekey_content is not None:
|
||||
self.privatekey_content = self.privatekey_content.encode("utf-8")
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.module = module
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
if privatekey_content is not None:
|
||||
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
|
||||
else:
|
||||
self.privatekey_content = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
|
||||
def generate(self):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
# Empty method because OpenSSLObject wants this
|
||||
pass
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
# Empty method because OpenSSLObject wants this
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
# Implementation with using cryptography
|
||||
class SignatureCryptography(SignatureBase):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SignatureCryptography, self).__init__(module)
|
||||
|
||||
def run(self):
|
||||
def run(self) -> dict[str, t.Any]:
|
||||
_padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
|
||||
_hash = cryptography.hazmat.primitives.hashes.SHA256()
|
||||
|
||||
result = dict()
|
||||
result: dict[str, t.Any] = {}
|
||||
|
||||
try:
|
||||
with open(self.path, "rb") as f:
|
||||
@@ -223,7 +227,7 @@ class SignatureCryptography(SignatureBase):
|
||||
raise OpenSSLObjectError(e)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
privatekey_path=dict(type="path"),
|
||||
|
||||
@@ -88,6 +88,7 @@ valid:
|
||||
|
||||
import base64
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.cryptography_dep import (
|
||||
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
|
||||
@@ -121,7 +122,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.support im
|
||||
|
||||
class SignatureInfoBase(OpenSSLObject):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SignatureInfoBase, self).__init__(
|
||||
path=module.params["path"],
|
||||
state="present",
|
||||
@@ -129,32 +130,35 @@ class SignatureInfoBase(OpenSSLObject):
|
||||
check_mode=module.check_mode,
|
||||
)
|
||||
|
||||
self.signature = module.params["signature"]
|
||||
self.certificate_path = module.params["certificate_path"]
|
||||
self.certificate_content = module.params["certificate_content"]
|
||||
if self.certificate_content is not None:
|
||||
self.certificate_content = self.certificate_content.encode("utf-8")
|
||||
self.module = module
|
||||
self.signature: str = module.params["signature"]
|
||||
self.certificate_path: str | None = module.params["certificate_path"]
|
||||
certificate_content: str | None = module.params["certificate_content"]
|
||||
if certificate_content is not None:
|
||||
self.certificate_content: bytes | None = certificate_content.encode("utf-8")
|
||||
else:
|
||||
self.certificate_content = None
|
||||
|
||||
def generate(self):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
# Empty method because OpenSSLObject wants this
|
||||
pass
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
# Empty method because OpenSSLObject wants this
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
# Implementation with using cryptography
|
||||
class SignatureInfoCryptography(SignatureInfoBase):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(SignatureInfoCryptography, self).__init__(module)
|
||||
|
||||
def run(self):
|
||||
def run(self) -> dict[str, t.Any]:
|
||||
_padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
|
||||
_hash = cryptography.hazmat.primitives.hashes.SHA256()
|
||||
|
||||
result = dict()
|
||||
result: dict[str, t.Any] = {}
|
||||
|
||||
try:
|
||||
with open(self.path, "rb") as f:
|
||||
@@ -228,7 +232,7 @@ class SignatureInfoCryptography(SignatureInfoBase):
|
||||
raise OpenSSLObjectError(e)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
certificate_path=dict(type="path"),
|
||||
|
||||
@@ -224,6 +224,7 @@ certificate:
|
||||
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
@@ -257,8 +258,15 @@ from ansible_collections.community.crypto.plugins.module_utils.io import (
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
|
||||
CertificateBackend,
|
||||
)
|
||||
|
||||
|
||||
class CertificateAbsent(OpenSSLObject):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CertificateAbsent, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -266,19 +274,19 @@ class CertificateAbsent(OpenSSLObject):
|
||||
module.check_mode,
|
||||
)
|
||||
self.module = module
|
||||
self.return_content = module.params["return_content"]
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
pass
|
||||
|
||||
def remove(self, module):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
if self.backup:
|
||||
self.backup_file = module.backup_local(self.path)
|
||||
super(CertificateAbsent, self).remove(module)
|
||||
|
||||
def dump(self, check_mode=False):
|
||||
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
|
||||
result = {
|
||||
"changed": self.changed,
|
||||
"filename": self.path,
|
||||
@@ -296,7 +304,7 @@ class CertificateAbsent(OpenSSLObject):
|
||||
class GenericCertificate(OpenSSLObject):
|
||||
"""Retrieve a certificate using the given module backend."""
|
||||
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(self, module: AnsibleModule, module_backend: CertificateBackend):
|
||||
super(GenericCertificate, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -311,7 +319,7 @@ class GenericCertificate(OpenSSLObject):
|
||||
self.module_backend = module_backend
|
||||
self.module_backend.set_existing(load_file_if_exists(self.path, module))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
if self.module_backend.needs_regeneration():
|
||||
if not self.check_mode:
|
||||
self.module_backend.generate_certificate()
|
||||
@@ -329,14 +337,14 @@ class GenericCertificate(OpenSSLObject):
|
||||
file_args, self.changed
|
||||
)
|
||||
|
||||
def check(self, module, perms_required=True):
|
||||
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
return (
|
||||
super(GenericCertificate, self).check(module, perms_required)
|
||||
and not self.module_backend.needs_regeneration()
|
||||
)
|
||||
|
||||
def dump(self, check_mode=False):
|
||||
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
|
||||
result = self.module_backend.dump(include_certificate=self.return_content)
|
||||
result.update(
|
||||
{
|
||||
@@ -349,7 +357,7 @@ class GenericCertificate(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = get_certificate_argument_spec()
|
||||
add_acme_provider_to_argument_spec(argument_spec)
|
||||
add_entrust_provider_to_argument_spec(argument_spec)
|
||||
@@ -363,13 +371,14 @@ def main():
|
||||
return_content=dict(type="bool", default=False),
|
||||
)
|
||||
)
|
||||
argument_spec.required_if.append(["state", "present", ["provider"]])
|
||||
argument_spec.required_if.append(("state", "present", ["provider"]))
|
||||
module = argument_spec.create_ansible_module(
|
||||
add_file_common_args=True,
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
try:
|
||||
certificate: GenericCertificate | CertificateAbsent
|
||||
if module.params["state"] == "absent":
|
||||
certificate = CertificateAbsent(module)
|
||||
|
||||
@@ -389,7 +398,13 @@ def main():
|
||||
)
|
||||
|
||||
provider = module.params["provider"]
|
||||
provider_map = {
|
||||
provider_map: dict[
|
||||
str,
|
||||
type[AcmeCertificateProvider]
|
||||
| type[EntrustCertificateProvider]
|
||||
| type[OwnCACertificateProvider]
|
||||
| type[SelfSignedCertificateProvider],
|
||||
] = {
|
||||
"acme": AcmeCertificateProvider,
|
||||
"entrust": EntrustCertificateProvider,
|
||||
"ownca": OwnCACertificateProvider,
|
||||
|
||||
@@ -106,6 +106,7 @@ backup_file:
|
||||
|
||||
import base64
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
@@ -142,8 +143,12 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def parse_certificate(input, strict=False):
|
||||
input_format = "pem" if identify_pem_format(input) else "der"
|
||||
def parse_certificate(
|
||||
input: bytes, strict: bool = False
|
||||
) -> tuple[bytes, t.Literal["pem", "der"], str | None]:
|
||||
input_format: t.Literal["pem", "der"] = (
|
||||
"pem" if identify_pem_format(input) else "der"
|
||||
)
|
||||
if input_format == "pem":
|
||||
pems = split_pem_list(to_text(input))
|
||||
if len(pems) > 1 and strict:
|
||||
@@ -162,7 +167,7 @@ def parse_certificate(input, strict=False):
|
||||
|
||||
|
||||
class X509CertificateConvertModule(OpenSSLObject):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(X509CertificateConvertModule, self).__init__(
|
||||
module.params["dest_path"],
|
||||
"present",
|
||||
@@ -170,9 +175,9 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
module.check_mode,
|
||||
)
|
||||
|
||||
self.src_path = module.params["src_path"]
|
||||
self.src_content = module.params["src_content"]
|
||||
self.src_content_base64 = module.params["src_content_base64"]
|
||||
self.src_path: str | None = module.params["src_path"]
|
||||
self.src_content: str | None = module.params["src_content"]
|
||||
self.src_content_base64: bool = module.params["src_content_base64"]
|
||||
if self.src_content is not None:
|
||||
self.input = to_bytes(self.src_content)
|
||||
if self.src_content_base64:
|
||||
@@ -181,6 +186,8 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
except Exception as exc:
|
||||
module.fail_json(msg=f"Cannot Base64 decode src_content: {exc}")
|
||||
else:
|
||||
if self.src_path is None:
|
||||
module.fail_json(msg="One of src_path and src_content must be provided")
|
||||
try:
|
||||
with open(self.src_path, "rb") as f:
|
||||
self.input = f.read()
|
||||
@@ -189,8 +196,8 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
msg=f"Failure while reading file {self.src_path}: {exc}"
|
||||
)
|
||||
|
||||
self.format = module.params["format"]
|
||||
self.strict = module.params["strict"]
|
||||
self.format: t.Literal["pem", "der"] = module.params["format"]
|
||||
self.strict: bool = module.params["strict"]
|
||||
self.wanted_pem_type = "CERTIFICATE"
|
||||
|
||||
try:
|
||||
@@ -203,8 +210,8 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
if module.params["verify_cert_parsable"]:
|
||||
self.verify_cert_parsable(module)
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
module.params["path"] = self.path
|
||||
|
||||
@@ -221,7 +228,7 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def verify_cert_parsable(self, module):
|
||||
def verify_cert_parsable(self, module: AnsibleModule) -> None:
|
||||
assert_required_cryptography_version(
|
||||
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
|
||||
)
|
||||
@@ -230,7 +237,7 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
except Exception as exc:
|
||||
module.fail_json(msg=f"Error while parsing certificate: {exc}")
|
||||
|
||||
def needs_conversion(self):
|
||||
def needs_conversion(self) -> bool:
|
||||
if self.dest_content is None or self.dest_content_format is None:
|
||||
return True
|
||||
if self.dest_content_format != self.format:
|
||||
@@ -241,7 +248,7 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_dest_certificate(self):
|
||||
def get_dest_certificate(self) -> bytes:
|
||||
if self.format == "der":
|
||||
return self.input
|
||||
data = to_bytes(base64.b64encode(self.input))
|
||||
@@ -250,7 +257,7 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
lines.append(to_bytes(f"{PEM_END_START}{self.wanted_pem_type}{PEM_END}\n"))
|
||||
return b"\n".join(lines)
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
"""Do conversion."""
|
||||
if self.needs_conversion():
|
||||
# Convert
|
||||
@@ -269,18 +276,18 @@ class X509CertificateConvertModule(OpenSSLObject):
|
||||
file_args, self.changed
|
||||
)
|
||||
|
||||
def dump(self):
|
||||
def dump(self) -> dict[str, t.Any]:
|
||||
"""Serialize the object into a dictionary."""
|
||||
result = dict(
|
||||
changed=self.changed,
|
||||
)
|
||||
result: dict[str, t.Any] = {
|
||||
"changed": self.changed,
|
||||
}
|
||||
if self.backup_file:
|
||||
result["backup_file"] = self.backup_file
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = dict(
|
||||
src_path=dict(type="path"),
|
||||
src_content=dict(type="str"),
|
||||
|
||||
@@ -390,8 +390,10 @@ issuer_uri:
|
||||
version_added: 2.9.0
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_text
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
)
|
||||
@@ -406,7 +408,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path"),
|
||||
@@ -424,18 +426,22 @@ def main():
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
if module.params["content"] is not None:
|
||||
data = module.params["content"].encode("utf-8")
|
||||
content: str | None = module.params["content"]
|
||||
path: str | None = module.params["path"]
|
||||
if content is not None:
|
||||
data = content.encode("utf-8")
|
||||
else:
|
||||
if path is None:
|
||||
module.fail_json(msg="One of path and content must be provided")
|
||||
try:
|
||||
with open(module.params["path"], "rb") as f:
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
except (IOError, OSError) as e:
|
||||
module.fail_json(msg=f"Error while reading certificate file from disk: {e}")
|
||||
|
||||
module_backend = select_backend(module, data)
|
||||
|
||||
valid_at = module.params["valid_at"]
|
||||
valid_at: dict[str, t.Any] = module.params["valid_at"]
|
||||
if valid_at:
|
||||
for k, v in valid_at.items():
|
||||
if not isinstance(v, (str, bytes)):
|
||||
@@ -443,13 +449,11 @@ def main():
|
||||
msg=f"The value for valid_at.{k} must be of type string (got {type(v)})"
|
||||
)
|
||||
valid_at[k] = get_relative_time_option(
|
||||
v, f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE
|
||||
to_text(v), f"valid_at.{k}", with_timezone=CRYPTOGRAPHY_TIMEZONE
|
||||
)
|
||||
|
||||
try:
|
||||
result = module_backend.get_info(
|
||||
der_support_enabled=module.params["content"] is None
|
||||
)
|
||||
result = module_backend.get_info(der_support_enabled=content is None)
|
||||
|
||||
not_before = module_backend.get_not_before()
|
||||
not_after = module_backend.get_not_after()
|
||||
|
||||
@@ -118,6 +118,8 @@ certificate:
|
||||
type: str
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
OpenSSLObjectError,
|
||||
)
|
||||
@@ -139,23 +141,31 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.certificate import (
|
||||
CertificateBackend,
|
||||
)
|
||||
|
||||
|
||||
class GenericCertificate:
|
||||
"""Retrieve a certificate using the given module backend."""
|
||||
|
||||
def __init__(self, module, module_backend):
|
||||
def __init__(self, module: AnsibleModule, module_backend: CertificateBackend):
|
||||
self.check_mode = module.check_mode
|
||||
self.module = module
|
||||
self.module_backend = module_backend
|
||||
self.changed = False
|
||||
if module.params["content"] is not None:
|
||||
self.module_backend.set_existing(module.params["content"].encode("utf-8"))
|
||||
content: str | None = module.params["content"]
|
||||
if content is not None:
|
||||
self.module_backend.set_existing(content.encode("utf-8"))
|
||||
|
||||
def generate(self, module):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
if self.module_backend.needs_regeneration():
|
||||
self.module_backend.generate_certificate()
|
||||
self.changed = True
|
||||
|
||||
def dump(self, check_mode=False):
|
||||
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
|
||||
result = self.module_backend.dump(include_certificate=True)
|
||||
result.update(
|
||||
{
|
||||
@@ -165,7 +175,7 @@ class GenericCertificate:
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
argument_spec = get_certificate_argument_spec()
|
||||
argument_spec.argument_spec["provider"]["required"] = True
|
||||
add_entrust_provider_to_argument_spec(argument_spec)
|
||||
@@ -182,7 +192,12 @@ def main():
|
||||
|
||||
try:
|
||||
provider = module.params["provider"]
|
||||
provider_map = {
|
||||
provider_map: dict[
|
||||
str,
|
||||
type[EntrustCertificateProvider]
|
||||
| type[OwnCACertificateProvider]
|
||||
| type[SelfSignedCertificateProvider],
|
||||
] = {
|
||||
"entrust": EntrustCertificateProvider,
|
||||
"ownca": OwnCACertificateProvider,
|
||||
"selfsigned": SelfSignedCertificateProvider,
|
||||
|
||||
@@ -425,6 +425,7 @@ crl:
|
||||
|
||||
import base64
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_text
|
||||
@@ -463,7 +464,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
|
||||
OpenSSLObject,
|
||||
load_certificate,
|
||||
load_privatekey,
|
||||
load_certificate_issuer_privatekey,
|
||||
parse_name_field,
|
||||
parse_ordered_name_field,
|
||||
select_message_digest,
|
||||
@@ -495,6 +496,9 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
|
||||
class CRLError(OpenSSLObjectError):
|
||||
pass
|
||||
@@ -502,7 +506,7 @@ class CRLError(OpenSSLObjectError):
|
||||
|
||||
class CRL(OpenSSLObject):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module: AnsibleModule) -> None:
|
||||
super(CRL, self).__init__(
|
||||
module.params["path"],
|
||||
module.params["state"],
|
||||
@@ -510,53 +514,69 @@ class CRL(OpenSSLObject):
|
||||
module.check_mode,
|
||||
)
|
||||
|
||||
self.format = module.params["format"]
|
||||
self.format: t.Literal["pem", "der"] = module.params["format"]
|
||||
|
||||
self.update = module.params["crl_mode"] == "update"
|
||||
self.ignore_timestamps = module.params["ignore_timestamps"]
|
||||
self.return_content = module.params["return_content"]
|
||||
self.name_encoding = module.params["name_encoding"]
|
||||
self.serial_numbers_format = module.params["serial_numbers"]
|
||||
self.crl_content = None
|
||||
self.update: bool = module.params["crl_mode"] == "update"
|
||||
self.ignore_timestamps: bool = module.params["ignore_timestamps"]
|
||||
self.return_content: bool = module.params["return_content"]
|
||||
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params[
|
||||
"name_encoding"
|
||||
]
|
||||
self.serial_numbers_format: t.Literal["integer", "hex-octets"] = module.params[
|
||||
"serial_numbers"
|
||||
]
|
||||
self.crl_content: bytes | None = None
|
||||
|
||||
self.privatekey_path = module.params["privatekey_path"]
|
||||
self.privatekey_content = module.params["privatekey_content"]
|
||||
if self.privatekey_content is not None:
|
||||
self.privatekey_content = self.privatekey_content.encode("utf-8")
|
||||
self.privatekey_passphrase = module.params["privatekey_passphrase"]
|
||||
self.privatekey_path: str | None = module.params["privatekey_path"]
|
||||
privatekey_content: str | None = module.params["privatekey_content"]
|
||||
if privatekey_content is not None:
|
||||
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
|
||||
else:
|
||||
self.privatekey_content = None
|
||||
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
|
||||
|
||||
try:
|
||||
if module.params["issuer_ordered"]:
|
||||
issuer_ordered: list[dict[str, t.Any]] | None = module.params[
|
||||
"issuer_ordered"
|
||||
]
|
||||
issuer: dict[str, list[str | bytes] | bytes | str] | None = module.params[
|
||||
"issuer"
|
||||
]
|
||||
if issuer_ordered:
|
||||
self.issuer_ordered = True
|
||||
self.issuer = parse_ordered_name_field(
|
||||
module.params["issuer_ordered"], "issuer_ordered"
|
||||
)
|
||||
self.issuer = parse_ordered_name_field(issuer_ordered, "issuer_ordered")
|
||||
else:
|
||||
self.issuer_ordered = False
|
||||
self.issuer = parse_name_field(module.params["issuer"], "issuer")
|
||||
self.issuer = (
|
||||
parse_name_field(issuer, "issuer") if issuer is not None else []
|
||||
)
|
||||
except (TypeError, ValueError) as exc:
|
||||
module.fail_json(msg=str(exc))
|
||||
|
||||
self.last_update = get_relative_time_option(
|
||||
self.last_update: datetime.datetime = get_relative_time_option(
|
||||
module.params["last_update"],
|
||||
"last_update",
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
self.next_update = get_relative_time_option(
|
||||
self.next_update: datetime.datetime | None = get_relative_time_option(
|
||||
module.params["next_update"],
|
||||
"next_update",
|
||||
with_timezone=CRYPTOGRAPHY_TIMEZONE,
|
||||
)
|
||||
|
||||
self.digest = select_message_digest(module.params["digest"])
|
||||
if self.digest is None:
|
||||
digest = select_message_digest(module.params["digest"])
|
||||
if digest is None:
|
||||
raise CRLError(f'The digest "{module.params["digest"]}" is not supported')
|
||||
self.digest = digest
|
||||
|
||||
self.module = module
|
||||
|
||||
self.revoked_certificates = []
|
||||
for i, rc in enumerate(module.params["revoked_certificates"]):
|
||||
result = {
|
||||
revoked_certificates: list[dict[str, t.Any]] = module.params[
|
||||
"revoked_certificates"
|
||||
]
|
||||
for i, rc in enumerate(revoked_certificates):
|
||||
result: dict[str, t.Any] = {
|
||||
"serial_number": None,
|
||||
"revocation_date": None,
|
||||
"issuer": None,
|
||||
@@ -567,21 +587,25 @@ class CRL(OpenSSLObject):
|
||||
"invalidity_date_critical": False,
|
||||
}
|
||||
path_prefix = f"revoked_certificates[{i}]."
|
||||
if rc["path"] is not None or rc["content"] is not None:
|
||||
path: str | None = rc["path"]
|
||||
content_str: str | None = rc["content"]
|
||||
if path is not None or content_str is not None:
|
||||
# Load certificate from file or content
|
||||
try:
|
||||
if rc["content"] is not None:
|
||||
rc["content"] = rc["content"].encode("utf-8")
|
||||
cert = load_certificate(rc["path"], content=rc["content"])
|
||||
content: bytes | None = None
|
||||
if content_str is not None:
|
||||
content = content_str.encode("utf-8")
|
||||
rc["content"] = content
|
||||
cert = load_certificate(path, content=content)
|
||||
result["serial_number"] = cert.serial_number
|
||||
except OpenSSLObjectError as e:
|
||||
if rc["content"] is not None:
|
||||
if content_str is not None:
|
||||
module.fail_json(
|
||||
msg=f"Cannot parse certificate from {path_prefix}content: {e}"
|
||||
)
|
||||
else:
|
||||
module.fail_json(
|
||||
msg=f'Cannot read certificate "{rc["path"]}" from {path_prefix}path: {e}'
|
||||
msg=f'Cannot read certificate "{path}" from {path_prefix}path: {e}'
|
||||
)
|
||||
else:
|
||||
# Specify serial_number (and potentially issuer) directly
|
||||
@@ -611,11 +635,11 @@ class CRL(OpenSSLObject):
|
||||
result["invalidity_date_critical"] = rc["invalidity_date_critical"]
|
||||
self.revoked_certificates.append(result)
|
||||
|
||||
self.backup = module.params["backup"]
|
||||
self.backup_file = None
|
||||
self.backup: bool = module.params["backup"]
|
||||
self.backup_file: str | None = None
|
||||
|
||||
try:
|
||||
self.privatekey = load_privatekey(
|
||||
self.privatekey = load_certificate_issuer_privatekey(
|
||||
path=self.privatekey_path,
|
||||
content=self.privatekey_content,
|
||||
passphrase=self.privatekey_passphrase,
|
||||
@@ -643,7 +667,7 @@ class CRL(OpenSSLObject):
|
||||
|
||||
self.diff_after = self.diff_before = self._get_info(data)
|
||||
|
||||
def _parse_serial_number(self, value, index):
|
||||
def _parse_serial_number(self, value: t.Any, index: int) -> int:
|
||||
if self.serial_numbers_format == "integer":
|
||||
try:
|
||||
return check_type_int(value)
|
||||
@@ -662,22 +686,42 @@ class CRL(OpenSSLObject):
|
||||
f"Unexpected value {self.serial_numbers_format} of serial_numbers"
|
||||
)
|
||||
|
||||
def _get_info(self, data):
|
||||
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
|
||||
if data is None:
|
||||
return dict()
|
||||
return {}
|
||||
try:
|
||||
result = get_crl_info(self.module, data)
|
||||
result["can_parse_crl"] = True
|
||||
return result
|
||||
except Exception:
|
||||
return dict(can_parse_crl=False)
|
||||
return {"can_parse_crl": False}
|
||||
|
||||
def remove(self):
|
||||
def remove(self, module: AnsibleModule) -> None:
|
||||
if self.backup:
|
||||
self.backup_file = self.module.backup_local(self.path)
|
||||
super(CRL, self).remove(self.module)
|
||||
|
||||
def _compress_entry(self, entry):
|
||||
def _compress_entry(self, entry: dict[str, t.Any]) -> (
|
||||
tuple[
|
||||
int | None,
|
||||
tuple[str, ...] | None,
|
||||
bool,
|
||||
int | None,
|
||||
bool,
|
||||
datetime.datetime | None,
|
||||
bool,
|
||||
]
|
||||
| tuple[
|
||||
int | None,
|
||||
datetime.datetime | None,
|
||||
tuple[str, ...] | None,
|
||||
bool,
|
||||
int | None,
|
||||
bool,
|
||||
datetime.datetime | None,
|
||||
bool,
|
||||
]
|
||||
):
|
||||
issuer = None
|
||||
if entry["issuer"] is not None:
|
||||
# Normalize to IDNA. If this is used-provided, it was already converted to
|
||||
@@ -713,7 +757,12 @@ class CRL(OpenSSLObject):
|
||||
entry["invalidity_date_critical"],
|
||||
)
|
||||
|
||||
def check(self, module, perms_required=True, ignore_conversion=True):
|
||||
def check(
|
||||
self,
|
||||
module: AnsibleModule,
|
||||
perms_required: bool = True,
|
||||
ignore_conversion: bool = True,
|
||||
) -> bool:
|
||||
"""Ensure the resource is in its desired state."""
|
||||
|
||||
state_and_perms = super(CRL, self).check(self.module, perms_required)
|
||||
@@ -743,10 +792,13 @@ class CRL(OpenSSLObject):
|
||||
]
|
||||
is_issuer = [(sub.oid, sub.value) for sub in self.crl.issuer]
|
||||
if not self.issuer_ordered:
|
||||
want_issuer = set(want_issuer)
|
||||
is_issuer = set(is_issuer)
|
||||
if want_issuer != is_issuer:
|
||||
return False
|
||||
want_issuer_set = set(want_issuer)
|
||||
is_issuer_set = set(is_issuer)
|
||||
if want_issuer_set != is_issuer_set:
|
||||
return False
|
||||
else:
|
||||
if want_issuer != is_issuer:
|
||||
return False
|
||||
|
||||
old_entries = [
|
||||
self._compress_entry(cryptography_decode_revoked_certificate(cert))
|
||||
@@ -769,7 +821,7 @@ class CRL(OpenSSLObject):
|
||||
|
||||
return True
|
||||
|
||||
def _generate_crl(self):
|
||||
def _generate_crl(self) -> bytes:
|
||||
crl = CertificateRevocationListBuilder()
|
||||
|
||||
try:
|
||||
@@ -787,7 +839,8 @@ class CRL(OpenSSLObject):
|
||||
raise CRLError(e)
|
||||
|
||||
crl = set_last_update(crl, self.last_update)
|
||||
crl = set_next_update(crl, self.next_update)
|
||||
if self.next_update is not None:
|
||||
crl = set_next_update(crl, self.next_update)
|
||||
|
||||
if self.update and self.crl:
|
||||
new_entries = set(
|
||||
@@ -799,22 +852,26 @@ class CRL(OpenSSLObject):
|
||||
)
|
||||
if decoded_entry not in new_entries:
|
||||
crl = crl.add_revoked_certificate(entry)
|
||||
for entry in self.revoked_certificates:
|
||||
for revoked_entry in self.revoked_certificates:
|
||||
revoked_cert = RevokedCertificateBuilder()
|
||||
revoked_cert = revoked_cert.serial_number(entry["serial_number"])
|
||||
revoked_cert = set_revocation_date(revoked_cert, entry["revocation_date"])
|
||||
if entry["issuer"] is not None:
|
||||
revoked_cert = revoked_cert.serial_number(revoked_entry["serial_number"])
|
||||
revoked_cert = set_revocation_date(
|
||||
revoked_cert, revoked_entry["revocation_date"]
|
||||
)
|
||||
if revoked_entry["issuer"] is not None:
|
||||
revoked_cert = revoked_cert.add_extension(
|
||||
x509.CertificateIssuer(entry["issuer"]), entry["issuer_critical"]
|
||||
x509.CertificateIssuer(revoked_entry["issuer"]),
|
||||
revoked_entry["issuer_critical"],
|
||||
)
|
||||
if entry["reason"] is not None:
|
||||
if revoked_entry["reason"] is not None:
|
||||
revoked_cert = revoked_cert.add_extension(
|
||||
x509.CRLReason(entry["reason"]), entry["reason_critical"]
|
||||
x509.CRLReason(revoked_entry["reason"]),
|
||||
revoked_entry["reason_critical"],
|
||||
)
|
||||
if entry["invalidity_date"] is not None:
|
||||
if revoked_entry["invalidity_date"] is not None:
|
||||
revoked_cert = revoked_cert.add_extension(
|
||||
x509.InvalidityDate(entry["invalidity_date"]),
|
||||
entry["invalidity_date_critical"],
|
||||
x509.InvalidityDate(revoked_entry["invalidity_date"]),
|
||||
revoked_entry["invalidity_date_critical"],
|
||||
)
|
||||
crl = crl.add_revoked_certificate(revoked_cert.build())
|
||||
|
||||
@@ -827,7 +884,7 @@ class CRL(OpenSSLObject):
|
||||
else:
|
||||
return self.crl.public_bytes(Encoding.DER)
|
||||
|
||||
def generate(self):
|
||||
def generate(self, module: AnsibleModule) -> None:
|
||||
result = None
|
||||
if (
|
||||
not self.check(self.module, perms_required=False, ignore_conversion=True)
|
||||
@@ -861,7 +918,7 @@ class CRL(OpenSSLObject):
|
||||
elif self.module.set_fs_attributes_if_different(file_args, False):
|
||||
self.changed = True
|
||||
|
||||
def dump(self, check_mode=False):
|
||||
def dump(self, check_mode: bool = False) -> dict[str, t.Any]:
|
||||
result = {
|
||||
"changed": self.changed,
|
||||
"filename": self.path,
|
||||
@@ -879,37 +936,53 @@ class CRL(OpenSSLObject):
|
||||
|
||||
if check_mode:
|
||||
result["last_update"] = self.last_update.strftime(TIMESTAMP_FORMAT)
|
||||
result["next_update"] = self.next_update.strftime(TIMESTAMP_FORMAT)
|
||||
result["next_update"] = (
|
||||
self.next_update.strftime(TIMESTAMP_FORMAT)
|
||||
if self.next_update is not None
|
||||
else None
|
||||
)
|
||||
# result['digest'] = cryptography_oid_to_name(self.crl.signature_algorithm_oid)
|
||||
result["digest"] = self.module.params["digest"]
|
||||
result["issuer_ordered"] = self.issuer
|
||||
result["issuer"] = {}
|
||||
issuer: dict[str, str | bytes] = {}
|
||||
result["issuer"] = issuer
|
||||
for k, v in self.issuer:
|
||||
result["issuer"][k] = v
|
||||
result["revoked_certificates"] = []
|
||||
issuer[k] = v
|
||||
revoked_certificates: list[dict[str, t.Any]] = []
|
||||
result["revoked_certificates"] = revoked_certificates
|
||||
for entry in self.revoked_certificates:
|
||||
result["revoked_certificates"].append(
|
||||
revoked_certificates.append(
|
||||
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
|
||||
)
|
||||
elif self.crl:
|
||||
result["last_update"] = get_last_update(self.crl).strftime(TIMESTAMP_FORMAT)
|
||||
result["next_update"] = get_next_update(self.crl).strftime(TIMESTAMP_FORMAT)
|
||||
next_update = get_next_update(self.crl)
|
||||
result["next_update"] = (
|
||||
next_update.strftime(TIMESTAMP_FORMAT)
|
||||
if next_update is not None
|
||||
else None
|
||||
)
|
||||
result["digest"] = cryptography_oid_to_name(
|
||||
cryptography_get_signature_algorithm_oid_from_crl(self.crl)
|
||||
)
|
||||
issuer = []
|
||||
issuer_list: list[list[str]] = []
|
||||
for attribute in self.crl.issuer:
|
||||
issuer.append(
|
||||
[cryptography_oid_to_name(attribute.oid), attribute.value]
|
||||
issuer_list.append(
|
||||
[
|
||||
cryptography_oid_to_name(attribute.oid),
|
||||
to_text(attribute.value),
|
||||
]
|
||||
)
|
||||
result["issuer_ordered"] = issuer
|
||||
result["issuer"] = {}
|
||||
for k, v in issuer:
|
||||
result["issuer"][k] = v
|
||||
result["revoked_certificates"] = []
|
||||
result["issuer_ordered"] = issuer_list
|
||||
issuer = {}
|
||||
result["issuer"] = issuer
|
||||
for k, v in issuer_list:
|
||||
issuer[k] = v
|
||||
revoked_certificates = []
|
||||
result["revoked_certificates"] = revoked_certificates
|
||||
for cert in self.crl:
|
||||
entry = cryptography_decode_revoked_certificate(cert)
|
||||
result["revoked_certificates"].append(
|
||||
revoked_certificates.append(
|
||||
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
|
||||
)
|
||||
|
||||
@@ -923,7 +996,7 @@ class CRL(OpenSSLObject):
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
state=dict(type="str", default="present", choices=["present", "absent"]),
|
||||
@@ -1015,14 +1088,14 @@ def main():
|
||||
)
|
||||
module.exit_json(**result)
|
||||
|
||||
crl.generate()
|
||||
crl.generate(module)
|
||||
else:
|
||||
if module.check_mode:
|
||||
result = crl.dump(check_mode=True)
|
||||
result["changed"] = os.path.exists(module.params["path"])
|
||||
module.exit_json(**result)
|
||||
|
||||
crl.remove()
|
||||
crl.remove(module)
|
||||
|
||||
result = crl.dump()
|
||||
module.exit_json(**result)
|
||||
|
||||
@@ -100,7 +100,9 @@ last_update:
|
||||
type: str
|
||||
sample: '20190413202428Z'
|
||||
next_update:
|
||||
description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
|
||||
description:
|
||||
- The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME.
|
||||
- Will be C(none) if no such timestamp is present.
|
||||
returned: success
|
||||
type: str
|
||||
sample: '20190413202428Z'
|
||||
@@ -172,6 +174,7 @@ revoked_certificates:
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
|
||||
@@ -185,7 +188,7 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> t.NoReturn:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path"),
|
||||
@@ -200,25 +203,30 @@ def main():
|
||||
supports_check_mode=True,
|
||||
)
|
||||
|
||||
if module.params["content"] is None:
|
||||
content: str | None = module.params["content"]
|
||||
path: str | None = module.params["path"]
|
||||
if content is None:
|
||||
if path is None:
|
||||
module.fail_json(msg="One of content and path must be provided")
|
||||
try:
|
||||
with open(module.params["path"], "rb") as f:
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
except (IOError, OSError) as e:
|
||||
module.fail_json(msg=f"Error while reading CRL file from disk: {e}")
|
||||
else:
|
||||
data = module.params["content"].encode("utf-8")
|
||||
data = content.encode("utf-8")
|
||||
if not identify_pem_format(data):
|
||||
try:
|
||||
data = base64.b64decode(module.params["content"])
|
||||
data = base64.b64decode(content)
|
||||
except (binascii.Error, TypeError) as e:
|
||||
module.fail_json(msg=f"Error while Base64 decoding content: {e}")
|
||||
|
||||
list_revoked_certificates: bool = module.params["list_revoked_certificates"]
|
||||
try:
|
||||
result = get_crl_info(
|
||||
module,
|
||||
data,
|
||||
list_revoked_certificates=module.params["list_revoked_certificates"],
|
||||
list_revoked_certificates=list_revoked_certificates,
|
||||
)
|
||||
module.exit_json(**result)
|
||||
except OpenSSLObjectError as e:
|
||||
|
||||
@@ -15,20 +15,24 @@ from __future__ import annotations
|
||||
import abc
|
||||
import copy
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleError
|
||||
from ansible.module_utils.basic import SEQUENCETYPE, remove_values
|
||||
from ansible.module_utils.common._collections_compat import Mapping
|
||||
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
|
||||
from ansible.module_utils.common.validation import (
|
||||
safe_eval,
|
||||
)
|
||||
from ansible.module_utils.errors import UnsupportedError
|
||||
from ansible.plugins.action import ActionBase
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.crypto.plugins.module_utils.argspec import (
|
||||
ArgumentSpec,
|
||||
)
|
||||
|
||||
|
||||
class _ModuleExitException(Exception):
|
||||
def __init__(self, result):
|
||||
def __init__(self, result: dict[str, t.Any]) -> None:
|
||||
super(_ModuleExitException, self).__init__()
|
||||
self.result = result
|
||||
|
||||
@@ -36,20 +40,21 @@ class _ModuleExitException(Exception):
|
||||
class AnsibleActionModule:
|
||||
def __init__(
|
||||
self,
|
||||
action_plugin,
|
||||
argument_spec,
|
||||
bypass_checks=False,
|
||||
mutually_exclusive=None,
|
||||
required_together=None,
|
||||
required_one_of=None,
|
||||
supports_check_mode=False,
|
||||
required_if=None,
|
||||
required_by=None,
|
||||
):
|
||||
action_plugin: ActionModuleBase,
|
||||
argument_spec: dict[str, t.Any],
|
||||
*,
|
||||
bypass_checks: bool = False,
|
||||
supports_check_mode: bool = False,
|
||||
mutually_exclusive: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_together: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_one_of: list[list[str] | tuple[str, ...]] | None = None,
|
||||
required_if: list[tuple[str, t.Any, list[str] | tuple[str, ...]]] | None = None,
|
||||
required_by: dict[str, tuple[str, ...] | list[str]] | None = None,
|
||||
) -> None:
|
||||
# Internal data
|
||||
self.__action_plugin = action_plugin
|
||||
self.__warnings = []
|
||||
self.__deprecations = []
|
||||
self.__warnings: list[str] = []
|
||||
self.__deprecations: list[dict[str, str | None]] = []
|
||||
|
||||
# AnsibleModule data
|
||||
self._name = self.__action_plugin._task.action
|
||||
@@ -67,10 +72,6 @@ class AnsibleActionModule:
|
||||
self._diff = self.__action_plugin._play_context.diff
|
||||
self._verbosity = self.__action_plugin._display.verbosity
|
||||
|
||||
self.aliases = {}
|
||||
self._legal_inputs = []
|
||||
self._options_context = list()
|
||||
|
||||
self.params = copy.deepcopy(self.__action_plugin._task.args)
|
||||
self.no_log_values = set()
|
||||
self._validator = ArgumentSpecValidator(
|
||||
@@ -122,38 +123,41 @@ class AnsibleActionModule:
|
||||
|
||||
self.fail_json(msg=msg)
|
||||
|
||||
def safe_eval(self, value, locals=None, include_exceptions=False):
|
||||
return safe_eval(value, locals, include_exceptions)
|
||||
|
||||
def warn(self, warning):
|
||||
def warn(self, warning: str) -> None:
|
||||
# Copied from ansible.module_utils.common.warnings:
|
||||
if isinstance(warning, (str, bytes)):
|
||||
if isinstance(warning, str):
|
||||
self.__warnings.append(warning)
|
||||
else:
|
||||
raise TypeError(f"warn requires a string not a {type(warning)}")
|
||||
|
||||
def deprecate(self, msg, version=None, date=None, collection_name=None):
|
||||
def deprecate(
|
||||
self,
|
||||
msg: str,
|
||||
version: str | None = None,
|
||||
date: str | None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
if version is not None and date is not None:
|
||||
raise AssertionError(
|
||||
"implementation error -- version and date must not both be set"
|
||||
)
|
||||
|
||||
# Copied from ansible.module_utils.common.warnings:
|
||||
if isinstance(msg, (str, bytes)):
|
||||
# For compatibility, we accept that neither version nor date is set,
|
||||
# and treat that the same as if version would haven been set
|
||||
if date is not None:
|
||||
self.__deprecations.append(
|
||||
{"msg": msg, "date": date, "collection_name": collection_name}
|
||||
)
|
||||
else:
|
||||
self.__deprecations.append(
|
||||
{"msg": msg, "version": version, "collection_name": collection_name}
|
||||
)
|
||||
else:
|
||||
if not isinstance(msg, str):
|
||||
raise TypeError(f"deprecate requires a string not a {type(msg)}")
|
||||
|
||||
def _return_formatted(self, kwargs):
|
||||
# For compatibility, we accept that neither version nor date is set,
|
||||
# and treat that the same as if version would haven been set
|
||||
if date is not None:
|
||||
self.__deprecations.append(
|
||||
{"msg": msg, "date": date, "collection_name": collection_name}
|
||||
)
|
||||
else:
|
||||
self.__deprecations.append(
|
||||
{"msg": msg, "version": version, "collection_name": collection_name}
|
||||
)
|
||||
|
||||
def _return_formatted(self, kwargs: dict[str, t.Any]) -> t.NoReturn:
|
||||
if "invocation" not in kwargs:
|
||||
kwargs["invocation"] = {"module_args": self.params}
|
||||
|
||||
@@ -194,13 +198,13 @@ class AnsibleActionModule:
|
||||
kwargs = remove_values(kwargs, self.no_log_values)
|
||||
raise _ModuleExitException(kwargs)
|
||||
|
||||
def exit_json(self, **kwargs):
|
||||
def exit_json(self, **kwargs) -> t.NoReturn:
|
||||
result = dict(kwargs)
|
||||
if "failed" not in result:
|
||||
result["failed"] = False
|
||||
self._return_formatted(result)
|
||||
|
||||
def fail_json(self, msg, **kwargs):
|
||||
def fail_json(self, msg: str, **kwargs) -> t.NoReturn:
|
||||
result = dict(kwargs)
|
||||
result["failed"] = True
|
||||
result["msg"] = msg
|
||||
@@ -209,16 +213,15 @@ class AnsibleActionModule:
|
||||
|
||||
class ActionModuleBase(ActionBase, metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def setup_module(self):
|
||||
def setup_module(self) -> tuple[ArgumentSpec, dict[str, t.Any]]:
|
||||
"""Return pair (ArgumentSpec, kwargs)."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def run_module(self, module):
|
||||
def run_module(self, module: AnsibleActionModule) -> None:
|
||||
"""Run module code"""
|
||||
module.fail_json(msg="Not implemented.")
|
||||
|
||||
def run(self, tmp=None, task_vars=None):
|
||||
def run(self, tmp=None, task_vars=None) -> dict[str, t.Any]:
|
||||
if task_vars is None:
|
||||
task_vars = dict()
|
||||
|
||||
|
||||
@@ -6,14 +6,23 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.utils.display import Display
|
||||
|
||||
|
||||
_display = Display()
|
||||
|
||||
|
||||
class FilterModuleMock:
|
||||
def __init__(self, params):
|
||||
def __init__(self, params: dict[str, t.Any]) -> None:
|
||||
self.check_mode = True
|
||||
self.params = params
|
||||
self._diff = False
|
||||
|
||||
def fail_json(self, msg, **kwargs):
|
||||
def fail_json(self, msg: str, **kwargs) -> t.NoReturn:
|
||||
raise AnsibleFilterError(msg)
|
||||
|
||||
def warn(self, warning: str) -> None:
|
||||
_display.warning(warning)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from subprocess import PIPE, Popen
|
||||
|
||||
from ansible.module_utils.common.process import get_bin_path
|
||||
@@ -15,7 +16,7 @@ from ansible_collections.community.crypto.plugins.module_utils.gnupg.cli import
|
||||
|
||||
|
||||
class PluginGPGRunner(GPGRunner):
|
||||
def __init__(self, executable=None, cwd=None):
|
||||
def __init__(self, executable: str | None = None, cwd: str | None = None) -> None:
|
||||
if executable is None:
|
||||
try:
|
||||
executable = get_bin_path("gpg")
|
||||
@@ -24,7 +25,9 @@ class PluginGPGRunner(GPGRunner):
|
||||
self.executable = executable
|
||||
self.cwd = cwd
|
||||
|
||||
def run_command(self, command, check_rc=True, data=None):
|
||||
def run_command(
|
||||
self, command: list[str], check_rc: bool = True, data: bytes | None = None
|
||||
) -> tuple[int, str, str]:
|
||||
"""
|
||||
Run ``[gpg] + command`` and return ``(rc, stdout, stderr)``.
|
||||
|
||||
@@ -41,12 +44,10 @@ class PluginGPGRunner(GPGRunner):
|
||||
command, shell=False, cwd=self.cwd, stdin=PIPE, stdout=PIPE, stderr=PIPE
|
||||
)
|
||||
stdout, stderr = p.communicate(input=data)
|
||||
stdout = to_native(stdout, errors="surrogate_or_replace")
|
||||
stderr = to_native(stderr, errors="surrogate_or_replace")
|
||||
stdout_n = to_native(stdout, errors="surrogate_or_replace")
|
||||
stderr_n = to_native(stderr, errors="surrogate_or_replace")
|
||||
if check_rc and p.returncode != 0:
|
||||
stdout_n = (to_native(stdout, errors="surrogate_or_replace"),)
|
||||
stderr_n = (to_native(stderr, errors="surrogate_or_replace"),)
|
||||
raise GPGError(
|
||||
f'Running {" ".join(command)} yielded return code {p.returncode} with stdout: "{stdout_n}" and stderr: "{stderr_n}")'
|
||||
)
|
||||
return p.returncode, stdout, stderr
|
||||
return t.cast(int, p.returncode), stdout_n, stderr_n
|
||||
|
||||
Reference in New Issue
Block a user