Code refactoring (#889)

* Add __all__ to all module and plugin utils.

* Convert quite a few positional args to keyword args.

* Avoid Python 3.8+ syntax.
This commit is contained in:
Felix Fontein
2025-05-16 06:55:57 +02:00
committed by GitHub
parent a5a4e022ba
commit 44bcc8cebc
101 changed files with 1510 additions and 748 deletions

View File

@@ -60,7 +60,7 @@ SIGNATURE_TEST_DATA = b"1234"
def _get_cryptography_private_key_info(
key: PrivateKeyTypes, need_private_key_data: bool = False
key: PrivateKeyTypes, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
key_type, key_public_data = _get_cryptography_public_key_info(key.public_key())
key_private_data: dict[str, t.Any] = {}
@@ -84,7 +84,7 @@ def _get_cryptography_private_key_info(
def _check_dsa_consistency(
key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
*, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
# Get parameters
p: int | None = key_public_data.get("p")
@@ -112,10 +112,10 @@ def _check_dsa_consistency(
if (p - 1) % q != 0:
return False
# Check that g**q mod p == 1
if binary_exp_mod(g, q, p) != 1:
if binary_exp_mod(g, q, m=p) != 1:
return False
# Check whether g**x mod p == y
if binary_exp_mod(g, x, p) != y:
if binary_exp_mod(g, x, m=p) != y:
return False
# Check (quickly) whether p or q are not primes
if quick_is_not_prime(q) or quick_is_not_prime(p):
@@ -125,6 +125,7 @@ def _check_dsa_consistency(
def _is_cryptography_key_consistent(
key: PrivateKeyTypes,
*,
key_public_data: dict[str, t.Any],
key_private_data: dict[str, t.Any],
warn_func: t.Callable[[str], None] | None = None,
@@ -135,7 +136,9 @@ def _is_cryptography_key_consistent(
if backend is not None:
return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
result = _check_dsa_consistency(key_public_data, key_private_data)
result = _check_dsa_consistency(
key_public_data=key_public_data, key_private_data=key_private_data
)
if result is not None:
return result
signature = key.sign(
@@ -191,14 +194,14 @@ def _is_cryptography_key_consistent(
class PrivateKeyConsistencyError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PrivateKeyConsistencyError, self).__init__(msg)
self.error_message = msg
self.result = result
class PrivateKeyParseError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
def __init__(self, msg: str, *, result: dict[str, t.Any]) -> None:
super(PrivateKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
@@ -207,6 +210,7 @@ class PrivateKeyParseError(OpenSSLObjectError):
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self,
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -220,22 +224,22 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
self.check_consistency = check_consistency
@abc.abstractmethod
def _get_public_key(self, binary: bool) -> bytes:
def _get_public_key(self, *, binary: bool) -> bytes:
pass
@abc.abstractmethod
def _get_key_info(
self, need_private_key_data: bool = False
self, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
pass
@abc.abstractmethod
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
def get_info(self, *, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {
"can_parse_key": False,
"key_is_consistent": None,
@@ -253,7 +257,7 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
)
result["can_parse_key"] = True
except OpenSSLObjectError as exc:
raise PrivateKeyParseError(str(exc), result)
raise PrivateKeyParseError(str(exc), result=result)
result["public_key"] = to_native(self._get_public_key(binary=False))
pk = self._get_public_key(binary=True)
@@ -273,7 +277,7 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
if self.check_consistency:
result["key_is_consistent"] = self._is_key_consistent(
key_public_data, key_private_data
key_public_data=key_public_data, key_private_data=key_private_data
)
if result["key_is_consistent"] is False:
# Only fail when it is False, to avoid to fail on None (which means "we do not know")
@@ -281,40 +285,46 @@ class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
"Private key is not consistent! (See "
"https://blog.hboeck.de/archives/888-How-I-tricked-Symantec-with-a-Fake-Private-Key.html)"
)
raise PrivateKeyConsistencyError(msg, result)
raise PrivateKeyConsistencyError(msg, result=result)
return result
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
"""Validate the supplied private key, using the cryptography backend"""
def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None:
def __init__(
self, *, module: GeneralAnsibleModule, content: bytes, **kwargs
) -> None:
super(PrivateKeyInfoRetrievalCryptography, self).__init__(
module, content, **kwargs
module=module, content=content, **kwargs
)
def _get_public_key(self, binary: bool) -> bytes:
def _get_public_key(self, *, binary: bool) -> bytes:
return self.key.public_key().public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(
self, need_private_key_data: bool = False
self, *, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return _get_cryptography_private_key_info(
self.key, need_private_key_data=need_private_key_data
)
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
self, *, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
return _is_cryptography_key_consistent(
self.key, key_public_data, key_private_data, warn_func=self.module.warn
self.key,
key_public_data=key_public_data,
key_private_data=key_private_data,
warn_func=self.module.warn,
)
def get_privatekey_info(
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -322,8 +332,8 @@ def get_privatekey_info(
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PrivateKeyInfoRetrievalCryptography(
module,
content,
module=module,
content=content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
)
@@ -331,6 +341,7 @@ def get_privatekey_info(
def select_backend(
*,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
@@ -341,9 +352,18 @@ def select_backend(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyInfoRetrievalCryptography(
module,
content,
module=module,
content=content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
check_consistency=check_consistency,
)
__all__ = (
"PrivateKeyConsistencyError",
"PrivateKeyParseError",
"PrivateKeyInfoRetrieval",
"get_privatekey_info",
"select_backend",
)