Add basic typing for Entrust code. (#894)

This commit is contained in:
Felix Fontein
2025-05-17 17:43:50 +02:00
committed by GitHub
parent 990b40df3e
commit a3a5284f97
7 changed files with 134 additions and 95 deletions

View File

@@ -590,7 +590,7 @@ def validate_cert_expiry(cert_expiry: str) -> bool:
def calculate_cert_days(expires_after: str | None) -> int:
cert_days = 0
if expires_after:
if expires_after is not None:
expires_after_datetime = datetime.datetime.strptime(
expires_after, "%Y-%m-%dT%H:%M:%SZ"
)
@@ -618,32 +618,33 @@ class EcsCertificate:
Entrust Certificate Services certificate class.
"""
def __init__(self, module):
self.path = module.params["path"]
self.full_chain_path = module.params["full_chain_path"]
self.force = module.params["force"]
self.backup = module.params["backup"]
self.request_type = module.params["request_type"]
self.csr = module.params["csr"]
def __init__(self, module: AnsibleModule) -> None:
self.path: str = module.params["path"]
self.full_chain_path: str | None = module.params["full_chain_path"]
self.force: bool = module.params["force"]
self.backup: bool = module.params["backup"]
self.request_type: t.Literal["new", "renew", "reissue", "validate_only"] = (
module.params["request_type"]
)
self.csr: str | None = module.params["csr"]
# All return values
self.changed = False
self.filename = None
self.tracking_id = None
self.cert_status = None
self.serial_number = None
self.cert_days = None
self.cert_details = None
self.backup_file = None
self.backup_full_chain_file = None
self.filename: str | None = None
self.tracking_id: int | None = None
self.cert_status: str | None = None
self.serial_number: int | None = None
self.cert_days: int | None = None
self.cert_details: dict[str, t.Any] | None = None
self.backup_file: str | None = None
self.backup_full_chain_file: str | None = None
self.cert = None
self.ecs_client = None
if self.path and os.path.exists(self.path):
try:
self.cert = load_certificate(self.path)
self.cert = load_certificate(path=self.path)
except Exception:
self.cert = None
pass
# Instantiate the ECS client and then try a no-op connection to verify credentials are valid
try:
self.ecs_client = ECSClient(
@@ -658,14 +659,14 @@ class EcsCertificate:
except SessionConfigurationException as e:
module.fail_json(msg=f"Failed to initialize Entrust Provider: {e}")
try:
self.ecs_client.GetAppVersion() # pylint: disable=no-member
self.ecs_client.GetAppVersion() # type: ignore[attr-defined] # pylint: disable=no-member
except RestOperationException as e:
module.fail_json(
msg=f"Please verify credential information. Received exception when testing ECS connection: {e.message}"
)
# Conversion of the fields that go into the 'tracking' parameter of the request object
def convert_tracking_params(self, module):
def convert_tracking_params(self, module: AnsibleModule) -> dict[str, t.Any]:
body = {}
tracking = {}
if module.params["requester_name"]:
@@ -689,7 +690,7 @@ class EcsCertificate:
body["tracking"] = tracking
return body
def convert_cert_subject_params(self, module):
def convert_cert_subject_params(self, module: AnsibleModule) -> dict[str, t.Any]:
body = {}
if module.params["subject_alt_name"]:
body["subjectAltName"] = module.params["subject_alt_name"]
@@ -699,7 +700,7 @@ class EcsCertificate:
body["ou"] = module.params["ou"]
return body
def convert_general_params(self, module):
def convert_general_params(self, module: AnsibleModule) -> dict[str, t.Any]:
body = {}
if module.params["eku"]:
body["eku"] = module.params["eku"]
@@ -714,7 +715,7 @@ class EcsCertificate:
)
return body
def convert_expiry_params(self, module):
def convert_expiry_params(self, module: AnsibleModule) -> dict[str, t.Any]:
body = {}
if module.params["cert_lifetime"]:
body["certLifetime"] = module.params["cert_lifetime"]
@@ -727,27 +728,29 @@ class EcsCertificate:
body["certExpiryDate"] = expiry.strftime("%Y-%m-%dT%H:%M:%S.00Z")
return body
def set_tracking_id_by_serial_number(self, module):
def set_tracking_id_by_serial_number(self, module: AnsibleModule) -> None:
assert self.cert is not None
try:
# Use serial_number to identify if certificate is an Entrust Certificate
# with an associated tracking ID
serial_number = f"{self.cert.serial_number:X}"
cert_results = self.ecs_client.GetCertificates( # pylint: disable=no-member
cert_results = self.ecs_client.GetCertificates( # type: ignore[attr-defined] # pylint: disable=no-member
serialNumber=serial_number
).get("certificates", {})
).get(
"certificates", {}
)
if len(cert_results) == 1:
self.tracking_id = cert_results[0].get("trackingId")
except RestOperationException:
# If we fail to find a cert by serial number, that's fine, we just do not set self.tracking_id
pass
def set_cert_details(self, module):
def set_cert_details(self, module: AnsibleModule) -> None:
try:
self.cert_details = (
self.ecs_client.GetCertificate( # pylint: disable=no-member
trackingId=self.tracking_id
)
self.cert_details = self.ecs_client.GetCertificate( # type: ignore[attr-defined] # pylint: disable=no-member
trackingId=self.tracking_id
)
assert self.cert_details is not None
self.cert_status = self.cert_details.get("status")
self.serial_number = self.cert_details.get("serialNumber")
self.cert_days = calculate_cert_days(self.cert_details.get("expiresAfter"))
@@ -756,7 +759,7 @@ class EcsCertificate:
msg=f'Failed to get details of certificate with tracking_id="{self.tracking_id}", Error: {e.message}'
)
def check(self, module):
def check(self, module: AnsibleModule) -> bool:
if self.cert:
# We will only treat a certificate as valid if it is found as a managed entrust cert.
# We will only set updated tracking ID based on certificate in "path" if it is managed by entrust.
@@ -781,6 +784,7 @@ class EcsCertificate:
return False
self.set_cert_details(module)
assert self.cert_details is not None
if (
self.cert_status == "EXPIRED"
@@ -793,7 +797,7 @@ class EcsCertificate:
return True
def request_cert(self, module):
def request_cert(self, module: AnsibleModule) -> None:
if not self.check(module) or self.force:
body = {}
@@ -830,27 +834,24 @@ class EcsCertificate:
try:
if self.request_type == "validate_only":
body["validateOnly"] = "true"
result = (
self.ecs_client.NewCertRequest( # pylint: disable=no-member
Body=body
)
result = self.ecs_client.NewCertRequest( # type: ignore[attr-defined] # pylint: disable=no-member
Body=body
)
if self.request_type == "new":
result = (
self.ecs_client.NewCertRequest( # pylint: disable=no-member
Body=body
)
result = self.ecs_client.NewCertRequest( # type: ignore[attr-defined] # pylint: disable=no-member
Body=body
)
elif self.request_type == "renew":
result = self.ecs_client.RenewCertRequest( # pylint: disable=no-member
result = self.ecs_client.RenewCertRequest( # type: ignore[attr-defined] # pylint: disable=no-member
trackingId=self.tracking_id, Body=body
)
elif self.request_type == "reissue":
result = self.ecs_client.ReissueCertRequest( # pylint: disable=no-member
result = self.ecs_client.ReissueCertRequest( # type: ignore[attr-defined] # pylint: disable=no-member
trackingId=self.tracking_id, Body=body
)
self.tracking_id = result.get("trackingId")
self.set_cert_details(module)
assert self.cert_details is not None
except RestOperationException as e:
module.fail_json(
msg=f"Failed to request new certificate from Entrust (ECS) {e.message}"
@@ -863,14 +864,14 @@ class EcsCertificate:
module=module,
content=to_bytes(self.cert_details.get("endEntityCert")),
)
if self.full_chain_path and self.cert_details.get("chainCerts"):
chain_certs = self.cert_details.get("chainCerts")
if self.full_chain_path and chain_certs:
assert isinstance(chain_certs, list)
if self.backup:
self.backup_full_chain_file = module.backup_local(
self.full_chain_path
)
chain_string = (
"\n".join(self.cert_details.get("chainCerts")) + "\n"
)
chain_string = "\n".join(chain_certs) + "\n"
write_file(
module=module,
content=to_bytes(chain_string),
@@ -880,12 +881,15 @@ class EcsCertificate:
# If there is no certificate present in path but a tracking ID was specified, save it to disk
elif not os.path.exists(self.path) and self.tracking_id:
if not module.check_mode:
assert self.cert_details is not None
write_file(
module=module,
content=to_bytes(self.cert_details.get("endEntityCert")),
)
if self.full_chain_path and self.cert_details.get("chainCerts"):
chain_string = "\n".join(self.cert_details.get("chainCerts")) + "\n"
chain_certs = self.cert_details.get("chainCerts")
if self.full_chain_path and chain_certs:
assert isinstance(chain_certs, list)
chain_string = "\n".join(chain_certs) + "\n"
write_file(
module=module,
content=to_bytes(chain_string),
@@ -893,7 +897,7 @@ class EcsCertificate:
)
self.changed = True
def dump(self):
def dump(self) -> dict[str, t.Any]:
result = {
"changed": self.changed,
"filename": self.path,

View File

@@ -229,9 +229,17 @@ from ansible_collections.community.crypto.plugins.module_utils._ecs.api import (
)
@t.overload
def calculate_days_remaining(expiry_date: str) -> int: ...
@t.overload
def calculate_days_remaining(expiry_date: str | None) -> int | None: ...
def calculate_days_remaining(expiry_date: str | None) -> int | None:
days_remaining = None
if expiry_date:
if expiry_date is not None:
expiry_datetime = datetime.datetime.strptime(expiry_date, "%Y-%m-%dT%H:%M:%SZ")
days_remaining = (expiry_datetime - datetime.datetime.now()).days
return days_remaining
@@ -242,7 +250,7 @@ class EcsDomain:
Entrust Certificate Services domain class.
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.changed = False
self.domain_status = None
self.verification_method = None
@@ -252,16 +260,15 @@ class EcsDomain:
self.dns_contents = None
self.dns_resource_type = None
self.emails = None
self.ov_eligible = None
self.ov_days_remaining = None
self.ev_eligble = None
self.ev_days_remaining = None
self.ov_eligible: bool | None = None
self.ov_days_remaining: int | None = None
self.ev_eligible: bool | None = None
self.ev_days_remaining: int | None = None
# Note that verification_method is the 'current' verification
# method of the domain, we'll use module.params when requesting a new
# one, in case the verification method has changed.
self.verification_method = None
self.ecs_client = None
# Instantiate the ECS client and then try a no-op connection to verify credentials are valid
try:
self.ecs_client = ECSClient(
@@ -276,13 +283,13 @@ class EcsDomain:
except SessionConfigurationException as e:
module.fail_json(msg=f"Failed to initialize Entrust Provider: {e}")
try:
self.ecs_client.GetAppVersion() # pylint: disable=no-member
self.ecs_client.GetAppVersion() # type: ignore[attr-defined] # pylint: disable=no-member
except RestOperationException as e:
module.fail_json(
msg=f"Please verify credential information. Received exception when testing ECS connection: {e.message}"
)
def set_domain_details(self, domain_details):
def set_domain_details(self, domain_details: dict[str, t.Any]) -> None:
if domain_details.get("verificationMethod"):
self.verification_method = domain_details["verificationMethod"].lower()
self.domain_status = domain_details["verificationStatus"]
@@ -308,9 +315,9 @@ class EcsDomain:
elif self.verification_method == "email" and domain_details.get("emailMethod"):
self.emails = domain_details["emailMethod"]
def check(self, module):
def check(self, module: AnsibleModule) -> bool:
try:
domain_details = self.ecs_client.GetDomain( # pylint: disable=no-member
domain_details = self.ecs_client.GetDomain( # type: ignore[attr-defined] # pylint: disable=no-member
clientId=module.params["client_id"], domain=module.params["domain_name"]
)
self.set_domain_details(domain_details)
@@ -337,7 +344,7 @@ class EcsDomain:
except RestOperationException:
return False
def request_domain(self, module):
def request_domain(self, module: AnsibleModule) -> None:
if not self.check(module):
body = {}
@@ -355,18 +362,18 @@ class EcsDomain:
body["domainName"] = module.params["domain_name"]
try:
if not self.domain_status:
self.ecs_client.AddDomain( # pylint: disable=no-member
self.ecs_client.AddDomain( # type: ignore[attr-defined] # pylint: disable=no-member
clientId=module.params["client_id"], Body=body
)
else:
self.ecs_client.ReverifyDomain( # pylint: disable=no-member
self.ecs_client.ReverifyDomain( # type: ignore[attr-defined] # pylint: disable=no-member
clientId=module.params["client_id"],
domain=module.params["domain_name"],
Body=body,
)
time.sleep(5)
result = self.ecs_client.GetDomain( # pylint: disable=no-member
result = self.ecs_client.GetDomain( # type: ignore[attr-defined] # pylint: disable=no-member
clientId=module.params["client_id"],
domain=module.params["domain_name"],
)
@@ -393,7 +400,7 @@ class EcsDomain:
):
break
time.sleep(10)
result = self.ecs_client.GetDomain( # pylint: disable=no-member
result = self.ecs_client.GetDomain( # type: ignore[attr-defined] # pylint: disable=no-member
clientId=module.params["client_id"],
domain=module.params["domain_name"],
)