Improve typing (#911)

* Make type checking more strict.

* mypy: warn about unreachable code.

* Enable warn_redundant_casts.

* Enable strict_bytes.

* Look at some warn_return_any warnings.
This commit is contained in:
Felix Fontein
2025-05-31 10:25:55 +02:00
committed by GitHub
parent 6d273bc5b7
commit 82522fc07f
20 changed files with 88 additions and 65 deletions

View File

@@ -176,6 +176,7 @@ class ACMEAccount:
# check whether that failed with a malformed request error
if (
info["status"] >= 400
and isinstance(result, Mapping)
and result.get("type") == "urn:ietf:params:acme:error:malformed"
):
# retry as a regular POST (with no changed data) for pre-draft-15 ACME servers
@@ -183,7 +184,7 @@ class ACMEAccount:
result, info = self.client.send_signed_request(
self.client.account_uri, data, fail_on_error=False
)
if not isinstance(result, Mapping):
if not isinstance(result, dict):
raise ACMEProtocolException(
module=self.client.module,
msg="Invalid account data retrieved from ACME server",
@@ -328,16 +329,17 @@ class ACMEAccount:
account_data = dict(account_data)
account_data.update(update_request)
else:
account_data, info = self.client.send_signed_request(
raw_account_data, info = self.client.send_signed_request(
self.client.account_uri, update_request
)
if not isinstance(account_data, Mapping):
if not isinstance(raw_account_data, Mapping):
raise ACMEProtocolException(
module=self.client.module,
msg="Invalid account updating reply from ACME server",
info=info,
content_json=account_data,
)
account_data = raw_account_data
return True, account_data

View File

@@ -343,7 +343,7 @@ class ACMEClient:
fail_on_error: bool = True,
error_msg: str | None = None,
expected_status_codes: t.Iterable[int] | None = None,
) -> tuple[dict[str, t.Any], dict[str, t.Any]]: ...
) -> tuple[dict[str, t.Any] | bytes, dict[str, t.Any]]: ...
@t.overload
def send_signed_request(

View File

@@ -102,14 +102,9 @@ class CryptographyChainMatcher(ChainMatcher):
try:
return binascii.unhexlify(key_identifier.replace(":", ""))
except Exception:
if criterium_idx is None:
module.warn(
f"Criterium has invalid {name} value. Ignoring criterium."
)
else:
module.warn(
f"Criterium {criterium_idx} in select_chain has invalid {name} value. Ignoring criterium."
)
module.warn(
f"Criterium {criterium_idx} in select_chain has invalid {name} value. Ignoring criterium."
)
return None
def __init__(self, *, criterium: Criterium, module: AnsibleModule) -> None:

View File

@@ -236,6 +236,12 @@ class Authorization:
error_msg="Failed to request challenges",
expected_status_codes=[200, 201],
)
if not isinstance(result, dict):
raise ACMEProtocolException(
module=client.module,
msg="Unexpected authorization creation result",
content_json=result,
)
return cls.from_json(client=client, data=result, url=info["location"])
@property
@@ -358,7 +364,11 @@ class Authorization:
result, info = client.send_signed_request(
self.url, authz_deactivate, fail_on_error=False
)
if 200 <= info["status"] < 300 and result.get("status") == "deactivated":
if (
200 <= info["status"] < 300
and isinstance(result, dict)
and result.get("status") == "deactivated"
):
self.status = "deactivated"
return True
return False
@@ -377,6 +387,12 @@ class Authorization:
result, _info = client.send_signed_request(
url, authz_deactivate, fail_on_error=True
)
if not isinstance(result, dict):
raise ACMEProtocolException(
module=client.module,
msg="Unexpected challenge deactivation result",
content_json=result,
)
authz._setup(client=client, data=result)
return authz

View File

@@ -72,7 +72,7 @@ class ACMEProtocolException(ModuleFailException):
info: dict[str, t.Any] | None = None,
response=None,
content: bytes | None = None,
content_json: dict[str, t.Any] | None = None,
content_json: dict[str, t.Any] | bytes | None = None,
extras: dict[str, t.Any] | None = None,
):
# Try to get hold of content, if response is given and content is not provided
@@ -88,15 +88,18 @@ class ACMEProtocolException(ModuleFailException):
content = info.pop("body", None)
# Make sure that content_json is None or a dictionary
content_json_json: dict[str, t.Any] | None = None
if content_json is not None and not isinstance(content_json, dict):
if content is None and isinstance(content_json, bytes):
content = content_json
content_json = None
elif content_json is not None:
content_json_json = content_json.copy()
# Try to get hold of JSON decoded content, when content is given and JSON not provided
if content_json is None and content is not None and module is not None:
if content_json_json is None and content is not None and module is not None:
try:
content_json = module.from_json(to_text(content))
content_json_json = module.from_json(to_text(content))
except Exception:
pass
@@ -117,19 +120,22 @@ class ACMEProtocolException(ModuleFailException):
if (
code is not None
and code >= 400
and content_json is not None
and "type" in content_json
and content_json_json is not None
and "type" in content_json_json
):
error_type = content_json["type"]
if "status" in content_json and content_json["status"] != code:
code_msg = f"status {content_json['status']} (HTTP status: {format_http_status(code)})"
error_type = content_json_json["type"]
if (
"status" in content_json_json
and content_json_json["status"] != code
):
code_msg = f"status {content_json_json['status']} (HTTP status: {format_http_status(code)})"
else:
code_msg = f"status {format_http_status(code)}"
if code == -1 and info.get("msg"):
code_msg = f"error: {info['msg']}"
subproblems = content_json.pop("subproblems", None)
add_msg = f" {format_error_problem(content_json)}."
extras["problem"] = content_json
subproblems = content_json_json.pop("subproblems", None)
add_msg = f" {format_error_problem(content_json_json)}."
extras["problem"] = content_json_json
extras["subproblems"] = subproblems or []
if subproblems is not None:
add_msg = f"{add_msg} Subproblems:"
@@ -142,13 +148,13 @@ class ACMEProtocolException(ModuleFailException):
code_msg = f"HTTP status {format_http_status(code)}"
if code == -1 and info.get("msg"):
code_msg = f"error: {info['msg']}"
if content_json is not None:
add_msg = f" The JSON error result: {content_json}"
if content_json_json is not None:
add_msg = f" The JSON error result: {content_json_json}"
elif content is not None:
add_msg = f" The raw error result: {to_text(content)}"
msg = f"{msg} for {url} with {code_msg}"
elif content_json is not None:
add_msg = f" The JSON result: {content_json}"
elif content_json_json is not None:
add_msg = f" The JSON result: {content_json_json}"
elif content is not None:
add_msg = f" The raw result: {to_text(content)}"

View File

@@ -39,11 +39,11 @@ class Order:
self.data: dict[str, t.Any] | None = None
self.status = None
self.status: str | None = None
self.identifiers: list[tuple[str, str]] = []
self.replaces_cert_id = None
self.finalize_uri = None
self.certificate_uri = None
self.replaces_cert_id: str | None = None
self.finalize_uri: str | None = None
self.certificate_uri: str | None = None
self.authorization_uris: list[str] = []
self.authorizations: dict[str, Authorization] = {}
@@ -106,6 +106,12 @@ class Order:
error_msg="Failed to start new order",
expected_status_codes=[201],
)
if not isinstance(result, dict):
raise ACMEProtocolException(
module=client.module,
msg="Unexpected new order response",
content_json=result,
)
return cls.from_json(client=client, data=result, url=info["location"])
@classmethod

View File

@@ -57,7 +57,8 @@ def obj2txt(openssl_lib, openssl_ffi, obj) -> str:
buf_len = res + 1
buf = openssl_ffi.new("char[]", buf_len)
res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1)
return openssl_ffi.buffer(buf, res)[:].decode()
bytes_str: bytes = openssl_ffi.buffer(buf, res)[:]
return bytes_str.decode()
__all__ = ("obj2txt",)

View File

@@ -287,8 +287,6 @@ def cryptography_oid_to_name(
def _get_hex(bytesstr: bytes) -> str:
if bytesstr is None:
return bytesstr
data = binascii.hexlify(bytesstr)
return to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2)))
@@ -863,7 +861,7 @@ def parse_pkcs12(
if _load_pkcs12 is not None:
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=passphrase_bytes)
if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"):
if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"): # type: ignore[unreachable]
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=passphrase_bytes)
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase=passphrase_bytes)

View File

@@ -281,7 +281,7 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
return 1
if self.cert.version == x509.Version.v3:
return 3
return "unknown"
return "unknown" # type: ignore[unreachable]
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
try:

View File

@@ -133,7 +133,7 @@ class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
self.authority_cert_issuer: list[str] | None = module.params[
"authority_cert_issuer"
]
self.authority_cert_serial_number: int = module.params[
self.authority_cert_serial_number: int | None = module.params[
"authority_cert_serial_number"
]
self.crl_distribution_points: (
@@ -361,10 +361,6 @@ def parse_crl_distribution_points(
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
def __init__(self, *, module: AnsibleModule) -> None:
super().__init__(module=module)
if self.version != 1:
module.warn(
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
)
crl_distribution_points: list[dict[str, t.Any]] | None = module.params[
"crl_distribution_points"

View File

@@ -287,7 +287,9 @@ class _Curve:
def _get_ec_class(
self, *, module: GeneralAnsibleModule
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore
ecclass: (
type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve] | None
) = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype)
if ecclass is None:
module.fail_json(
msg=f"Your cryptography version does not support {self.ectype}"

View File

@@ -285,7 +285,7 @@ class KeypairBackend(OpensshModule, metaclass=abc.ABCMeta):
self._update_comment()
def _create_temp_public_key(self, content: str | bytes) -> str:
temp_public_key = os.path.join(
temp_public_key: str = os.path.join(
self.module.tmpdir, os.path.basename(self.public_key_path)
)

View File

@@ -101,21 +101,21 @@ class OpensshParser:
def boolean(self) -> bool:
next_pos = self._check_position(self.BOOLEAN_OFFSET)
value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0]
value: bool = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def uint32(self) -> int:
next_pos = self._check_position(self.UINT32_OFFSET)
value = _UINT32.unpack(self._data[self._pos : next_pos])[0]
value: int = _UINT32.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def uint64(self) -> int:
next_pos = self._check_position(self.UINT64_OFFSET)
value = _UINT64.unpack(self._data[self._pos : next_pos])[0]
value: int = _UINT64.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value

View File

@@ -634,7 +634,7 @@ class ACMECertificateClient:
self.cert_days = -1
self.order: Order | None = None
self.order_uri = self.data.get("order_uri") if self.data else None
self.all_chains = None
self.all_chains: list[dict[str, t.Any]] | None = None
self.select_chain_matcher = []
self.include_renewal_cert_id = module.params["include_renewal_cert_id"]
self.profile = module.params["profile"]

View File

@@ -217,12 +217,16 @@ def main() -> t.NoReturn:
if info["status"] != 200:
already_revoked = False
# Standardized error from draft 14 on (https://tools.ietf.org/html/rfc8555#section-7.6)
if result.get("type") == "urn:ietf:params:acme:error:alreadyRevoked":
if (
isinstance(result, dict)
and result.get("type") == "urn:ietf:params:acme:error:alreadyRevoked"
):
already_revoked = True
else:
# Hack for Boulder errors
if (
result.get("type") == "urn:ietf:params:acme:error:malformed"
isinstance(result, dict)
and result.get("type") == "urn:ietf:params:acme:error:malformed"
and result.get("detail") == "Certificate already revoked"
):
# Fallback: boulder returns this in case the certificate was already revoked.

View File

@@ -440,7 +440,7 @@ def main() -> t.NoReturn:
module.fail_json(
msg=f"tls_ctx_options must be a string or integer, got {tls_ctx_option!r}"
)
tls_ctx_option_int = (
tls_ctx_option_int = ( # type: ignore[unreachable]
0 # make pylint happy; this code is actually unreachable
)
@@ -558,7 +558,7 @@ def main() -> t.NoReturn:
elif x509.version == cryptography.x509.Version.v3:
result["version"] = 3 - 1
else:
result["version"] = "unknown"
result["version"] = "unknown" # type: ignore[unreachable]
if verified_chain is not None:
result["verified_chain"] = verified_chain

View File

@@ -510,9 +510,6 @@ class Handler:
def get_device_by_label(self, label: str) -> str | None:
"""Returns the device that holds label passed by user"""
blkid_bin = self._module.get_bin_path("blkid", True)
label = self._module.params["label"]
if label is None:
return None
rc, stdout, dummy = self._run_command([blkid_bin, "--label", label])
if rc != 0:
return None

View File

@@ -529,13 +529,8 @@ class Pkcs(OpenSSLObject):
elif bool(pkcs12_certificate) != bool(self.certificate_content):
return False
if (pkcs12_other_certificates is not None) and (
self.other_certificates is not None
):
expected_other_certs = self._dump_other_certificates(self.pkcs12)
if set(pkcs12_other_certificates) != set(expected_other_certs):
return False
elif bool(pkcs12_other_certificates) != bool(self.other_certificates):
expected_other_certs = self._dump_other_certificates(self.pkcs12)
if set(pkcs12_other_certificates) != set(expected_other_certs):
return False
if pkcs12_privatekey:

View File

@@ -293,7 +293,7 @@ class GenericCertificate(OpenSSLObject):
self.module = module
self.return_content = module.params["return_content"]
self.backup = module.params["backup"]
self.backup_file = None
self.backup_file: str | None = None
self.module_backend = module_backend
self.module_backend.set_existing(