Ensure that *everything* is typed in community.crypto (#917)

* Ensure that *everything* is typed in community.crypto.

* Fix comment.

* Ignore type definitions/imports and AssertionErrors for code coverage.
This commit is contained in:
Felix Fontein
2025-06-09 10:10:19 +02:00
committed by GitHub
parent ec063d8515
commit d83a923325
73 changed files with 494 additions and 317 deletions

View File

@@ -19,23 +19,26 @@ from ansible_collections.community.crypto.plugins.module_utils._openssh.utils im
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils._openssh.certificate import (
from ansible.module_utils.basic import AnsibleModule # pragma: no cover
from ansible_collections.community.crypto.plugins.module_utils._openssh.certificate import ( # pragma: no cover
OpensshCertificateTimeParameters,
)
from cryptography.hazmat.primitives.asymmetric.types import (
from cryptography.hazmat.primitives.asymmetric.types import ( # pragma: no cover
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)
Param = t.ParamSpec("Param")
Param = t.ParamSpec("Param") # pragma: no cover
def restore_on_failure(
f: t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None],
) -> t.Callable[t.Concatenate[AnsibleModule, str | os.PathLike, Param], None]:
def backup_and_restore(
module: AnsibleModule, path: str | os.PathLike, *args, **kwargs
module: AnsibleModule,
path: str | os.PathLike,
*args: Param.args,
**kwargs: Param.kwargs,
) -> None:
backup_file = module.backup_local(path) if os.path.exists(path) else None
@@ -74,8 +77,8 @@ def _restore_all_on_failure(
def backup_and_restore(
self: OpensshModule,
sources_and_destinations: list[tuple[str | os.PathLike, str | os.PathLike]],
*args,
**kwargs,
*args: Param.args,
**kwargs: Param.kwargs,
) -> None:
backups = [
(d, self.module.backup_local(d))
@@ -97,6 +100,9 @@ def _restore_all_on_failure(
return backup_and_restore
_OpensshModule = t.TypeVar("_OpensshModule", bound="OpensshModule")
class OpensshModule(metaclass=abc.ABCMeta):
def __init__(self, *, module: AnsibleModule) -> None:
self.module = module
@@ -141,16 +147,24 @@ class OpensshModule(metaclass=abc.ABCMeta):
pass
@staticmethod
def skip_if_check_mode(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
def wrapper(self, *args, **kwargs) -> None:
def skip_if_check_mode(
f: t.Callable[t.Concatenate[_OpensshModule, Param], None],
) -> t.Callable[t.Concatenate[_OpensshModule, Param], None]:
def wrapper(
self: _OpensshModule, *args: Param.args, **kwargs: Param.kwargs
) -> None:
if not self.check_mode:
f(self, *args, **kwargs)
return wrapper # type: ignore
@staticmethod
def trigger_change(f: t.Callable[Param, None]) -> t.Callable[Param, None]:
def wrapper(self, *args, **kwargs) -> None:
def trigger_change(
f: t.Callable[t.Concatenate[_OpensshModule, Param], None],
) -> t.Callable[t.Concatenate[_OpensshModule, Param], None]:
def wrapper(
self: _OpensshModule, *args: Param.args, **kwargs: Param.kwargs
) -> None:
f(self, *args, **kwargs)
self.changed = True
@@ -202,6 +216,13 @@ class OpensshModule(metaclass=abc.ABCMeta):
self.changed = True
if t.TYPE_CHECKING:
class _RunCommandKwarg(t.TypedDict):
check_rc: t.NotRequired[bool]
environ_update: t.NotRequired[dict[str, str] | None]
class KeygenCommand:
def __init__(self, module: AnsibleModule) -> None:
self._bin_path = module.get_bin_path("ssh-keygen", True)
@@ -221,7 +242,7 @@ class KeygenCommand:
cert_type: t.Literal["host", "user"] | None,
time_parameters: OpensshCertificateTimeParameters,
use_agent: bool,
**kwargs,
**kwargs: t.Unpack[_RunCommandKwarg],
) -> tuple[int, str, str]:
args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier]
@@ -253,7 +274,7 @@ class KeygenCommand:
size: int,
key_type: str,
comment: str | None,
**kwargs,
**kwargs: t.Unpack[_RunCommandKwarg],
) -> tuple[int, str, str]:
args = [
self._bin_path,
@@ -276,21 +297,21 @@ class KeygenCommand:
return self._run_command(args, data=data, **kwargs)
def get_certificate_info(
self, *, certificate_path: str, **kwargs
self, *, certificate_path: str, **kwargs: t.Unpack[_RunCommandKwarg]
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-L", "-f", certificate_path], **kwargs
)
def get_matching_public_key(
self, *, private_key_path: str, **kwargs
self, *, private_key_path: str, **kwargs: t.Unpack[_RunCommandKwarg]
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
)
def get_private_key(
self, *, private_key_path: str, **kwargs
self, *, private_key_path: str, **kwargs: t.Unpack[_RunCommandKwarg]
) -> tuple[int, str, str]:
return self._run_command(
[self._bin_path, "-l", "-f", private_key_path], **kwargs
@@ -302,7 +323,7 @@ class KeygenCommand:
private_key_path: str,
comment: str,
force_new_format: bool = True,
**kwargs,
**kwargs: t.Unpack[_RunCommandKwarg],
) -> tuple[int, str, str]:
if os.path.exists(private_key_path) and not os.access(
private_key_path, os.W_OK

View File

@@ -44,8 +44,8 @@ from ansible_collections.community.crypto.plugins.module_utils._version import (
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
from ansible.module_utils.basic import AnsibleModule # pragma: no cover
from cryptography.hazmat.primitives.asymmetric.types import ( # pragma: no cover
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)

View File

@@ -31,13 +31,13 @@ from ansible_collections.community.crypto.plugins.module_utils._time import (
if t.TYPE_CHECKING:
from ansible_collections.community.crypto.plugins.module_utils._openssh.cryptography import (
from ansible_collections.community.crypto.plugins.module_utils._openssh.cryptography import ( # pragma: no cover
KeyType,
)
DateFormat = t.Literal["human_readable", "openssh", "timestamp"]
DateFormatStr = t.Literal["human_readable", "openssh"]
DateFormatInt = t.Literal["timestamp"]
DateFormat = t.Literal["human_readable", "openssh", "timestamp"] # pragma: no cover
DateFormatStr = t.Literal["human_readable", "openssh"] # pragma: no cover
DateFormatInt = t.Literal["timestamp"] # pragma: no cover
else:
KeyType = None # pylint: disable=invalid-name
@@ -338,6 +338,22 @@ class OpensshCertificateOption:
)
if t.TYPE_CHECKING:
class _OpensshCertificateInfoKwarg(t.TypedDict):
nonce: t.NotRequired[bytes | None]
serial: t.NotRequired[int | None]
cert_type: t.NotRequired[int | None]
key_id: t.NotRequired[bytes | None]
principals: t.NotRequired[list[bytes] | None]
valid_after: t.NotRequired[int | None]
valid_before: t.NotRequired[int | None]
critical_options: t.NotRequired[list[tuple[bytes, bytes]] | None]
extensions: t.NotRequired[list[tuple[bytes, bytes]] | None]
reserved: t.NotRequired[bytes | None]
signing_key: t.NotRequired[bytes | None]
class OpensshCertificateInfo(metaclass=abc.ABCMeta):
"""Encapsulates all certificate information which is signed by a CA key"""
@@ -402,7 +418,13 @@ class OpensshCertificateInfo(metaclass=abc.ABCMeta):
class OpensshRSACertificateInfo(OpensshCertificateInfo):
def __init__(self, *, e: int | None = None, n: int | None = None, **kwargs) -> None:
def __init__(
self,
*,
e: int | None = None,
n: int | None = None,
**kwargs: t.Unpack[_OpensshCertificateInfoKwarg],
) -> None:
super().__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
self.e = e
@@ -433,7 +455,7 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
q: int | None = None,
g: int | None = None,
y: int | None = None,
**kwargs,
**kwargs: t.Unpack[_OpensshCertificateInfoKwarg],
) -> None:
super().__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01
@@ -465,7 +487,11 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
class OpensshECDSACertificateInfo(OpensshCertificateInfo):
def __init__(
self, *, curve: bytes | None = None, public_key: bytes | None = None, **kwargs
self,
*,
curve: bytes | None = None,
public_key: bytes | None = None,
**kwargs: t.Unpack[_OpensshCertificateInfoKwarg],
):
super().__init__(**kwargs)
self._curve: bytes | None = None
@@ -509,7 +535,12 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
def __init__(self, *, pk: bytes | None = None, **kwargs) -> None:
def __init__(
self,
*,
pk: bytes | None = None,
**kwargs: t.Unpack[_OpensshCertificateInfoKwarg],
) -> None:
super().__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
self.pk = pk

View File

@@ -75,22 +75,22 @@ from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptogra
if t.TYPE_CHECKING:
KeyFormat = t.Literal["SSH", "PKCS8", "PKCS1"]
KeySerializationFormat = t.Literal["PEM", "DER", "SSH"]
KeyType = t.Literal["rsa", "dsa", "ed25519", "ecdsa"]
KeyFormat = t.Literal["SSH", "PKCS8", "PKCS1"] # pragma: no cover
KeySerializationFormat = t.Literal["PEM", "DER", "SSH"] # pragma: no cover
KeyType = t.Literal["rsa", "dsa", "ed25519", "ecdsa"] # pragma: no cover
PrivateKeyTypes = t.Union[
rsa.RSAPrivateKey,
dsa.DSAPrivateKey,
ec.EllipticCurvePrivateKey,
Ed25519PrivateKey,
]
] # pragma: no cover
PublicKeyTypes = t.Union[
rsa.RSAPublicKey, dsa.DSAPublicKey, ec.EllipticCurvePublicKey, Ed25519PublicKey
]
] # pragma: no cover
from cryptography.hazmat.primitives.asymmetric.types import (
PublicKeyTypes as AllPublicKeyTypes,
PublicKeyTypes as AllPublicKeyTypes, # pragma: no cover
)