From a3a5284f97dc76be142b302c38aad27c64a20a03 Mon Sep 17 00:00:00 2001 From: Felix Fontein Date: Sat, 17 May 2025 17:43:50 +0200 Subject: [PATCH] Add basic typing for Entrust code. (#894) --- .../module_backends/certificate_entrust.py | 10 +- plugins/module_utils/_ecs/api.py | 69 ++++++++---- plugins/modules/ecs_certificate.py | 106 +++++++++--------- plugins/modules/ecs_domain.py | 39 ++++--- tests/sanity/ignore-2.17.txt | 2 + tests/sanity/ignore-2.18.txt | 2 + tests/sanity/ignore-2.19.txt | 1 + 7 files changed, 134 insertions(+), 95 deletions(-) diff --git a/plugins/module_utils/_crypto/module_backends/certificate_entrust.py b/plugins/module_utils/_crypto/module_backends/certificate_entrust.py index d57f5246..11285a85 100644 --- a/plugins/module_utils/_crypto/module_backends/certificate_entrust.py +++ b/plugins/module_utils/_crypto/module_backends/certificate_entrust.py @@ -140,7 +140,7 @@ class EntrustCertificateBackend(CertificateBackend): } try: - result = self.ecs_client.NewCertRequest( # pylint: disable=no-member + result = self.ecs_client.NewCertRequest( # type: ignore[attr-defined] # pylint: disable=no-member Body=body ) self.trackingId = result.get("trackingId") @@ -206,10 +206,10 @@ class EntrustCertificateBackend(CertificateBackend): # If a trackingId is not already defined (from the result of a generate) # use the serial number to identify the tracking Id if self.trackingId is None and serial_number is not None: - cert_results = ( - self.ecs_client.GetCertificates( # pylint: disable=no-member - serialNumber=serial_number - ).get("certificates", {}) + cert_results = self.ecs_client.GetCertificates( # type: ignore[attr-defined] # pylint: disable=no-member + serialNumber=serial_number + ).get( + "certificates", {} ) # Finding 0 or more than 1 result is a very unlikely use case, it simply means we cannot perform additional checks diff --git a/plugins/module_utils/_ecs/api.py b/plugins/module_utils/_ecs/api.py index 754c66d9..5b789767 100644 --- a/plugins/module_utils/_ecs/api.py +++ b/plugins/module_utils/_ecs/api.py @@ -26,6 +26,10 @@ from ansible.module_utils.common.text.converters import to_native, to_text from ansible.module_utils.urls import Request +if t.TYPE_CHECKING: + _P = t.ParamSpec("_P") + + YAML_IMP_ERR = None try: import yaml @@ -81,13 +85,21 @@ def generate_docstring(operation_spec: dict[str, t.Any]) -> str: return docs -def bind(instance, method, operation_spec): - def binding_scope_fn(*args, **kwargs): +_T = t.TypeVar("_T") +_R = t.TypeVar("_R") + + +def bind( + instance: _T, + method: t.Callable[t.Concatenate[_T, _P], _R], + operation_spec: dict[str, str], +) -> t.Callable[_P, _R]: + def binding_scope_fn(*args, **kwargs) -> _R: return method(instance, *args, **kwargs) # Make sure we do not confuse users; add the proper name and documentation to the function. # Users can use !help() to get help on the function from interactive python or pdb - operation_name = operation_spec.get("operationId").split("Using")[0] + operation_name = operation_spec["operationId"].split("Using")[0] binding_scope_fn.__name__ = str(operation_name) binding_scope_fn.__doc__ = generate_docstring(operation_spec) @@ -95,7 +107,13 @@ def bind(instance, method, operation_spec): class RestOperation: - def __init__(self, session, uri, method, parameters=None): + def __init__( + self, + session: "ECSSession", + uri: str, + method: str, + parameters: dict | None = None, + ) -> None: self.session = session self.method = method if parameters is None: @@ -106,10 +124,11 @@ class RestOperation: f"https://{session._spec.get('host')}{session._spec.get('basePath')}{uri}" ) - def restmethod(self, *args, **kwargs): + def restmethod(self, *args, **kwargs) -> t.Any: """Do the hard work of making the request here""" # gather named path parameters and do substitution on the URL + body_parameters: dict[str, t.Any] | None if self.parameters: path_parameters = {} body_parameters = {} @@ -175,9 +194,9 @@ class RestOperation: class Resource: """Implement basic CRUD operations against a path.""" - def __init__(self, session): + def __init__(self, session: "ECSSession") -> None: self.session = session - self.parameters = {} + self.parameters: dict[str, t.Any] = {} for url in session._spec.get("paths").keys(): methods = session._spec.get("paths").get(url) @@ -220,18 +239,18 @@ class Resource: # Session to encapsulate the connection parameters of the module_utils Request object, the api spec, etc class ECSSession: - def __init__(self, name, **kwargs): + def __init__(self, name: str, **kwargs) -> None: """ Initialize our session """ self._set_config(name, **kwargs) - def client(self): + def client(self) -> Resource: resource = Resource(self) return resource - def _set_config(self, name, **kwargs): + def _set_config(self, name: str, **kwargs) -> None: headers = { "Content-Type": "application/json", "Connection": "keep-alive", @@ -247,8 +266,8 @@ class ECSSession: raise SessionConfigurationException("No Configuration Found.") # set up auth if passed - entrust_api_user = self.get_config("entrust_api_user") - entrust_api_key = self.get_config("entrust_api_key") + entrust_api_user: str | None = self.get_config("entrust_api_user") + entrust_api_key: str | None = self.get_config("entrust_api_key") if entrust_api_user and entrust_api_key: self.request.url_username = entrust_api_user self.request.url_password = entrust_api_key @@ -256,8 +275,8 @@ class ECSSession: raise SessionConfigurationException("User and key must be provided.") # set up client certificate if passed (support all-in one or cert + key) - entrust_api_cert = self.get_config("entrust_api_cert") - entrust_api_cert_key = self.get_config("entrust_api_cert_key") + entrust_api_cert: str | None = self.get_config("entrust_api_cert") + entrust_api_cert_key: str | None = self.get_config("entrust_api_cert_key") if entrust_api_cert: self.request.client_cert = entrust_api_cert if entrust_api_cert_key: @@ -271,6 +290,10 @@ class ECSSession: entrust_api_specification_path = self.get_config( "entrust_api_specification_path" ) + if not isinstance(entrust_api_specification_path, str): + raise SessionConfigurationException( + "entrust_api_specification_path must be a string." + ) if not entrust_api_specification_path.startswith("http") and not os.path.isfile( entrust_api_specification_path @@ -311,10 +334,10 @@ class ECSSession: ): self._spec = yaml.safe_load(f) - def get_config(self, item): + def get_config(self, item: str) -> t.Any | None: return self._config.get(item, None) - def _read_config_vars(self, name, **kwargs): + def _read_config_vars(self, name: str, **kwargs) -> dict[str, t.Any]: """Read configuration from variables passed to the module.""" config = {} @@ -353,17 +376,17 @@ class ECSSession: def ECSClient( - entrust_api_user=None, - entrust_api_key=None, - entrust_api_cert=None, - entrust_api_cert_key=None, - entrust_api_specification_path=None, -): + entrust_api_user: str | None = None, + entrust_api_key: str | None = None, + entrust_api_cert: str | None = None, + entrust_api_cert_key: str | None = None, + entrust_api_specification_path: str | None = None, +) -> Resource: """Create an ECS client""" if not YAML_FOUND: raise SessionConfigurationException( - missing_required_lib("PyYAML"), exception=YAML_IMP_ERR + missing_required_lib("PyYAML") # TODO: pass `exception=YAML_IMP_ERR` ) if entrust_api_specification_path is None: diff --git a/plugins/modules/ecs_certificate.py b/plugins/modules/ecs_certificate.py index 2028530d..e02bced7 100644 --- a/plugins/modules/ecs_certificate.py +++ b/plugins/modules/ecs_certificate.py @@ -590,7 +590,7 @@ def validate_cert_expiry(cert_expiry: str) -> bool: def calculate_cert_days(expires_after: str | None) -> int: cert_days = 0 - if expires_after: + if expires_after is not None: expires_after_datetime = datetime.datetime.strptime( expires_after, "%Y-%m-%dT%H:%M:%SZ" ) @@ -618,32 +618,33 @@ class EcsCertificate: Entrust Certificate Services certificate class. """ - def __init__(self, module): - self.path = module.params["path"] - self.full_chain_path = module.params["full_chain_path"] - self.force = module.params["force"] - self.backup = module.params["backup"] - self.request_type = module.params["request_type"] - self.csr = module.params["csr"] + def __init__(self, module: AnsibleModule) -> None: + self.path: str = module.params["path"] + self.full_chain_path: str | None = module.params["full_chain_path"] + self.force: bool = module.params["force"] + self.backup: bool = module.params["backup"] + self.request_type: t.Literal["new", "renew", "reissue", "validate_only"] = ( + module.params["request_type"] + ) + self.csr: str | None = module.params["csr"] # All return values self.changed = False - self.filename = None - self.tracking_id = None - self.cert_status = None - self.serial_number = None - self.cert_days = None - self.cert_details = None - self.backup_file = None - self.backup_full_chain_file = None + self.filename: str | None = None + self.tracking_id: int | None = None + self.cert_status: str | None = None + self.serial_number: int | None = None + self.cert_days: int | None = None + self.cert_details: dict[str, t.Any] | None = None + self.backup_file: str | None = None + self.backup_full_chain_file: str | None = None self.cert = None - self.ecs_client = None if self.path and os.path.exists(self.path): try: - self.cert = load_certificate(self.path) + self.cert = load_certificate(path=self.path) except Exception: - self.cert = None + pass # Instantiate the ECS client and then try a no-op connection to verify credentials are valid try: self.ecs_client = ECSClient( @@ -658,14 +659,14 @@ class EcsCertificate: except SessionConfigurationException as e: module.fail_json(msg=f"Failed to initialize Entrust Provider: {e}") try: - self.ecs_client.GetAppVersion() # pylint: disable=no-member + self.ecs_client.GetAppVersion() # type: ignore[attr-defined] # pylint: disable=no-member except RestOperationException as e: module.fail_json( msg=f"Please verify credential information. Received exception when testing ECS connection: {e.message}" ) # Conversion of the fields that go into the 'tracking' parameter of the request object - def convert_tracking_params(self, module): + def convert_tracking_params(self, module: AnsibleModule) -> dict[str, t.Any]: body = {} tracking = {} if module.params["requester_name"]: @@ -689,7 +690,7 @@ class EcsCertificate: body["tracking"] = tracking return body - def convert_cert_subject_params(self, module): + def convert_cert_subject_params(self, module: AnsibleModule) -> dict[str, t.Any]: body = {} if module.params["subject_alt_name"]: body["subjectAltName"] = module.params["subject_alt_name"] @@ -699,7 +700,7 @@ class EcsCertificate: body["ou"] = module.params["ou"] return body - def convert_general_params(self, module): + def convert_general_params(self, module: AnsibleModule) -> dict[str, t.Any]: body = {} if module.params["eku"]: body["eku"] = module.params["eku"] @@ -714,7 +715,7 @@ class EcsCertificate: ) return body - def convert_expiry_params(self, module): + def convert_expiry_params(self, module: AnsibleModule) -> dict[str, t.Any]: body = {} if module.params["cert_lifetime"]: body["certLifetime"] = module.params["cert_lifetime"] @@ -727,27 +728,29 @@ class EcsCertificate: body["certExpiryDate"] = expiry.strftime("%Y-%m-%dT%H:%M:%S.00Z") return body - def set_tracking_id_by_serial_number(self, module): + def set_tracking_id_by_serial_number(self, module: AnsibleModule) -> None: + assert self.cert is not None try: # Use serial_number to identify if certificate is an Entrust Certificate # with an associated tracking ID serial_number = f"{self.cert.serial_number:X}" - cert_results = self.ecs_client.GetCertificates( # pylint: disable=no-member + cert_results = self.ecs_client.GetCertificates( # type: ignore[attr-defined] # pylint: disable=no-member serialNumber=serial_number - ).get("certificates", {}) + ).get( + "certificates", {} + ) if len(cert_results) == 1: self.tracking_id = cert_results[0].get("trackingId") except RestOperationException: # If we fail to find a cert by serial number, that's fine, we just do not set self.tracking_id pass - def set_cert_details(self, module): + def set_cert_details(self, module: AnsibleModule) -> None: try: - self.cert_details = ( - self.ecs_client.GetCertificate( # pylint: disable=no-member - trackingId=self.tracking_id - ) + self.cert_details = self.ecs_client.GetCertificate( # type: ignore[attr-defined] # pylint: disable=no-member + trackingId=self.tracking_id ) + assert self.cert_details is not None self.cert_status = self.cert_details.get("status") self.serial_number = self.cert_details.get("serialNumber") self.cert_days = calculate_cert_days(self.cert_details.get("expiresAfter")) @@ -756,7 +759,7 @@ class EcsCertificate: msg=f'Failed to get details of certificate with tracking_id="{self.tracking_id}", Error: {e.message}' ) - def check(self, module): + def check(self, module: AnsibleModule) -> bool: if self.cert: # We will only treat a certificate as valid if it is found as a managed entrust cert. # We will only set updated tracking ID based on certificate in "path" if it is managed by entrust. @@ -781,6 +784,7 @@ class EcsCertificate: return False self.set_cert_details(module) + assert self.cert_details is not None if ( self.cert_status == "EXPIRED" @@ -793,7 +797,7 @@ class EcsCertificate: return True - def request_cert(self, module): + def request_cert(self, module: AnsibleModule) -> None: if not self.check(module) or self.force: body = {} @@ -830,27 +834,24 @@ class EcsCertificate: try: if self.request_type == "validate_only": body["validateOnly"] = "true" - result = ( - self.ecs_client.NewCertRequest( # pylint: disable=no-member - Body=body - ) + result = self.ecs_client.NewCertRequest( # type: ignore[attr-defined] # pylint: disable=no-member + Body=body ) if self.request_type == "new": - result = ( - self.ecs_client.NewCertRequest( # pylint: disable=no-member - Body=body - ) + result = self.ecs_client.NewCertRequest( # type: ignore[attr-defined] # pylint: disable=no-member + Body=body ) elif self.request_type == "renew": - result = self.ecs_client.RenewCertRequest( # pylint: disable=no-member + result = self.ecs_client.RenewCertRequest( # type: ignore[attr-defined] # pylint: disable=no-member trackingId=self.tracking_id, Body=body ) elif self.request_type == "reissue": - result = self.ecs_client.ReissueCertRequest( # pylint: disable=no-member + result = self.ecs_client.ReissueCertRequest( # type: ignore[attr-defined] # pylint: disable=no-member trackingId=self.tracking_id, Body=body ) self.tracking_id = result.get("trackingId") self.set_cert_details(module) + assert self.cert_details is not None except RestOperationException as e: module.fail_json( msg=f"Failed to request new certificate from Entrust (ECS) {e.message}" @@ -863,14 +864,14 @@ class EcsCertificate: module=module, content=to_bytes(self.cert_details.get("endEntityCert")), ) - if self.full_chain_path and self.cert_details.get("chainCerts"): + chain_certs = self.cert_details.get("chainCerts") + if self.full_chain_path and chain_certs: + assert isinstance(chain_certs, list) if self.backup: self.backup_full_chain_file = module.backup_local( self.full_chain_path ) - chain_string = ( - "\n".join(self.cert_details.get("chainCerts")) + "\n" - ) + chain_string = "\n".join(chain_certs) + "\n" write_file( module=module, content=to_bytes(chain_string), @@ -880,12 +881,15 @@ class EcsCertificate: # If there is no certificate present in path but a tracking ID was specified, save it to disk elif not os.path.exists(self.path) and self.tracking_id: if not module.check_mode: + assert self.cert_details is not None write_file( module=module, content=to_bytes(self.cert_details.get("endEntityCert")), ) - if self.full_chain_path and self.cert_details.get("chainCerts"): - chain_string = "\n".join(self.cert_details.get("chainCerts")) + "\n" + chain_certs = self.cert_details.get("chainCerts") + if self.full_chain_path and chain_certs: + assert isinstance(chain_certs, list) + chain_string = "\n".join(chain_certs) + "\n" write_file( module=module, content=to_bytes(chain_string), @@ -893,7 +897,7 @@ class EcsCertificate: ) self.changed = True - def dump(self): + def dump(self) -> dict[str, t.Any]: result = { "changed": self.changed, "filename": self.path, diff --git a/plugins/modules/ecs_domain.py b/plugins/modules/ecs_domain.py index 7f8fdeac..0a8aa175 100644 --- a/plugins/modules/ecs_domain.py +++ b/plugins/modules/ecs_domain.py @@ -229,9 +229,17 @@ from ansible_collections.community.crypto.plugins.module_utils._ecs.api import ( ) +@t.overload +def calculate_days_remaining(expiry_date: str) -> int: ... + + +@t.overload +def calculate_days_remaining(expiry_date: str | None) -> int | None: ... + + def calculate_days_remaining(expiry_date: str | None) -> int | None: days_remaining = None - if expiry_date: + if expiry_date is not None: expiry_datetime = datetime.datetime.strptime(expiry_date, "%Y-%m-%dT%H:%M:%SZ") days_remaining = (expiry_datetime - datetime.datetime.now()).days return days_remaining @@ -242,7 +250,7 @@ class EcsDomain: Entrust Certificate Services domain class. """ - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.changed = False self.domain_status = None self.verification_method = None @@ -252,16 +260,15 @@ class EcsDomain: self.dns_contents = None self.dns_resource_type = None self.emails = None - self.ov_eligible = None - self.ov_days_remaining = None - self.ev_eligble = None - self.ev_days_remaining = None + self.ov_eligible: bool | None = None + self.ov_days_remaining: int | None = None + self.ev_eligible: bool | None = None + self.ev_days_remaining: int | None = None # Note that verification_method is the 'current' verification # method of the domain, we'll use module.params when requesting a new # one, in case the verification method has changed. self.verification_method = None - self.ecs_client = None # Instantiate the ECS client and then try a no-op connection to verify credentials are valid try: self.ecs_client = ECSClient( @@ -276,13 +283,13 @@ class EcsDomain: except SessionConfigurationException as e: module.fail_json(msg=f"Failed to initialize Entrust Provider: {e}") try: - self.ecs_client.GetAppVersion() # pylint: disable=no-member + self.ecs_client.GetAppVersion() # type: ignore[attr-defined] # pylint: disable=no-member except RestOperationException as e: module.fail_json( msg=f"Please verify credential information. Received exception when testing ECS connection: {e.message}" ) - def set_domain_details(self, domain_details): + def set_domain_details(self, domain_details: dict[str, t.Any]) -> None: if domain_details.get("verificationMethod"): self.verification_method = domain_details["verificationMethod"].lower() self.domain_status = domain_details["verificationStatus"] @@ -308,9 +315,9 @@ class EcsDomain: elif self.verification_method == "email" and domain_details.get("emailMethod"): self.emails = domain_details["emailMethod"] - def check(self, module): + def check(self, module: AnsibleModule) -> bool: try: - domain_details = self.ecs_client.GetDomain( # pylint: disable=no-member + domain_details = self.ecs_client.GetDomain( # type: ignore[attr-defined] # pylint: disable=no-member clientId=module.params["client_id"], domain=module.params["domain_name"] ) self.set_domain_details(domain_details) @@ -337,7 +344,7 @@ class EcsDomain: except RestOperationException: return False - def request_domain(self, module): + def request_domain(self, module: AnsibleModule) -> None: if not self.check(module): body = {} @@ -355,18 +362,18 @@ class EcsDomain: body["domainName"] = module.params["domain_name"] try: if not self.domain_status: - self.ecs_client.AddDomain( # pylint: disable=no-member + self.ecs_client.AddDomain( # type: ignore[attr-defined] # pylint: disable=no-member clientId=module.params["client_id"], Body=body ) else: - self.ecs_client.ReverifyDomain( # pylint: disable=no-member + self.ecs_client.ReverifyDomain( # type: ignore[attr-defined] # pylint: disable=no-member clientId=module.params["client_id"], domain=module.params["domain_name"], Body=body, ) time.sleep(5) - result = self.ecs_client.GetDomain( # pylint: disable=no-member + result = self.ecs_client.GetDomain( # type: ignore[attr-defined] # pylint: disable=no-member clientId=module.params["client_id"], domain=module.params["domain_name"], ) @@ -393,7 +400,7 @@ class EcsDomain: ): break time.sleep(10) - result = self.ecs_client.GetDomain( # pylint: disable=no-member + result = self.ecs_client.GetDomain( # type: ignore[attr-defined] # pylint: disable=no-member clientId=module.params["client_id"], domain=module.params["domain_name"], ) diff --git a/tests/sanity/ignore-2.17.txt b/tests/sanity/ignore-2.17.txt index 92f3a165..7027e87d 100644 --- a/tests/sanity/ignore-2.17.txt +++ b/tests/sanity/ignore-2.17.txt @@ -21,6 +21,8 @@ plugins/modules/acme_certificate_deactivate_authz.py pylint:unpacking-non-sequen plugins/modules/acme_certificate_order_finalize.py pylint:unpacking-non-sequence plugins/modules/acme_certificate_revoke.py pylint:unpacking-non-sequence plugins/modules/acme_inspect.py pylint:unpacking-non-sequence +plugins/modules/ecs_certificate.py no-assert +plugins/modules/ecs_domain.py pep8:E704 plugins/modules/luks_device.py no-assert plugins/modules/openssl_pkcs12.py no-assert tests/ee/roles/smoke/library/smoke_ipaddress.py shebang diff --git a/tests/sanity/ignore-2.18.txt b/tests/sanity/ignore-2.18.txt index 45ae0ec8..47306ca1 100644 --- a/tests/sanity/ignore-2.18.txt +++ b/tests/sanity/ignore-2.18.txt @@ -13,6 +13,8 @@ plugins/module_utils/_crypto/support.py pep8:E704 plugins/module_utils/_openssh/backends/keypair_backend.py no-assert plugins/module_utils/_openssh/certificate.py pep8:E704 plugins/modules/acme_certificate.py no-assert +plugins/modules/ecs_certificate.py no-assert +plugins/modules/ecs_domain.py pep8:E704 plugins/modules/luks_device.py no-assert plugins/modules/openssl_pkcs12.py no-assert tests/ee/roles/smoke/library/smoke_ipaddress.py shebang diff --git a/tests/sanity/ignore-2.19.txt b/tests/sanity/ignore-2.19.txt index 3f35bdce..28295a47 100644 --- a/tests/sanity/ignore-2.19.txt +++ b/tests/sanity/ignore-2.19.txt @@ -6,6 +6,7 @@ plugins/module_utils/_crypto/module_backends/csr.py no-assert plugins/module_utils/_crypto/module_backends/privatekey_convert.py no-assert plugins/module_utils/_openssh/backends/keypair_backend.py no-assert plugins/modules/acme_certificate.py no-assert +plugins/modules/ecs_certificate.py no-assert plugins/modules/luks_device.py no-assert plugins/modules/openssl_pkcs12.py no-assert tests/ee/roles/smoke/library/smoke_ipaddress.py shebang