Reformat everything with black.

I had to undo the u string prefix removals to not drop Python 2 compatibility.
That's why black isn't enabled in antsibull-nox.toml yet.
This commit is contained in:
Felix Fontein
2025-04-28 09:51:33 +02:00
parent 04a0d38e3b
commit aec1826c34
118 changed files with 11780 additions and 7565 deletions

View File

@@ -19,10 +19,10 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
class ACMEAccount(object):
'''
"""
ACME account object. Allows to create new accounts, check for existence of accounts,
retrieve account data.
'''
"""
def __init__(self, client):
# Set to true to enable logging of all signed requests
@@ -30,9 +30,15 @@ class ACMEAccount(object):
self.client = client
def _new_reg(self, contact=None, agreement=None, terms_agreed=False, allow_creation=True,
external_account_binding=None):
'''
def _new_reg(
self,
contact=None,
agreement=None,
terms_agreed=False,
allow_creation=True,
external_account_binding=None,
):
"""
Registers a new ACME account. Returns a pair ``(created, data)``.
Here, ``created`` is ``True`` if the account was created and
``False`` if it already existed (e.g. it was not newly created),
@@ -44,23 +50,25 @@ class ACMEAccount(object):
(https://tools.ietf.org/html/rfc8555#section-7.3.4).
https://tools.ietf.org/html/rfc8555#section-7.3
'''
"""
contact = contact or []
if self.client.version == 1:
new_reg = {
'resource': 'new-reg',
'contact': contact
}
new_reg = {"resource": "new-reg", "contact": contact}
if agreement:
new_reg['agreement'] = agreement
new_reg["agreement"] = agreement
else:
new_reg['agreement'] = self.client.directory['meta']['terms-of-service']
new_reg["agreement"] = self.client.directory["meta"]["terms-of-service"]
if external_account_binding is not None:
raise ModuleFailException('External account binding is not supported for ACME v1')
url = self.client.directory['new-reg']
raise ModuleFailException(
"External account binding is not supported for ACME v1"
)
url = self.client.directory["new-reg"]
else:
if (external_account_binding is not None or self.client.directory['meta'].get('externalAccountRequired')) and allow_creation:
if (
external_account_binding is not None
or self.client.directory["meta"].get("externalAccountRequired")
) and allow_creation:
# Some ACME servers such as ZeroSSL do not like it when you try to register an existing account
# and provide external_account_binding credentials. Thus we first send a request with allow_creation=False
# to see whether the account already exists.
@@ -73,44 +81,53 @@ class ACMEAccount(object):
return created, data
# An account does not yet exist. Try to create one next.
new_reg = {
'contact': contact
}
new_reg = {"contact": contact}
if not allow_creation:
# https://tools.ietf.org/html/rfc8555#section-7.3.1
new_reg['onlyReturnExisting'] = True
new_reg["onlyReturnExisting"] = True
if terms_agreed:
new_reg['termsOfServiceAgreed'] = True
url = self.client.directory['newAccount']
new_reg["termsOfServiceAgreed"] = True
url = self.client.directory["newAccount"]
if external_account_binding is not None:
new_reg['externalAccountBinding'] = self.client.sign_request(
new_reg["externalAccountBinding"] = self.client.sign_request(
{
'alg': external_account_binding['alg'],
'kid': external_account_binding['kid'],
'url': url,
"alg": external_account_binding["alg"],
"kid": external_account_binding["kid"],
"url": url,
},
self.client.account_jwk,
self.client.backend.create_mac_key(external_account_binding['alg'], external_account_binding['key'])
self.client.backend.create_mac_key(
external_account_binding["alg"], external_account_binding["key"]
),
)
elif self.client.directory['meta'].get('externalAccountRequired') and allow_creation:
elif (
self.client.directory["meta"].get("externalAccountRequired")
and allow_creation
):
raise ModuleFailException(
'To create an account, an external account binding must be specified. '
'Use the acme_account module with the external_account_binding option.'
"To create an account, an external account binding must be specified. "
"Use the acme_account module with the external_account_binding option."
)
result, info = self.client.send_signed_request(url, new_reg, fail_on_error=False)
result, info = self.client.send_signed_request(
url, new_reg, fail_on_error=False
)
if not isinstance(result, Mapping):
raise ACMEProtocolException(
self.client.module, msg='Invalid account creation reply from ACME server', info=info, content=result)
self.client.module,
msg="Invalid account creation reply from ACME server",
info=info,
content=result,
)
if info['status'] in ([200, 201] if self.client.version == 1 else [201]):
if info["status"] in ([200, 201] if self.client.version == 1 else [201]):
# Account did not exist
if 'location' in info:
self.client.set_account_uri(info['location'])
if "location" in info:
self.client.set_account_uri(info["location"])
return True, result
elif info['status'] == (409 if self.client.version == 1 else 200):
elif info["status"] == (409 if self.client.version == 1 else 200):
# Account did exist
if result.get('status') == 'deactivated':
if result.get("status") == "deactivated":
# A bug in Pebble (https://github.com/letsencrypt/pebble/issues/179) and
# Boulder (https://github.com/letsencrypt/boulder/issues/3971): this should
# not return a valid account object according to
@@ -121,15 +138,23 @@ class ACMEAccount(object):
return False, None
else:
raise ModuleFailException("Account is deactivated")
if 'location' in info:
self.client.set_account_uri(info['location'])
if "location" in info:
self.client.set_account_uri(info["location"])
return False, result
elif info['status'] in (400, 404) and result['type'] == 'urn:ietf:params:acme:error:accountDoesNotExist' and not allow_creation:
elif (
info["status"] in (400, 404)
and result["type"] == "urn:ietf:params:acme:error:accountDoesNotExist"
and not allow_creation
):
# Account does not exist (and we did not try to create it)
# (According to RFC 8555, Section 7.3.1, the HTTP status code MUST be 400.
# Unfortunately Digicert does not care and sends 404 instead.)
return False, None
elif info['status'] == 403 and result['type'] == 'urn:ietf:params:acme:error:unauthorized' and 'deactivated' in (result.get('detail') or ''):
elif (
info["status"] == 403
and result["type"] == "urn:ietf:params:acme:error:unauthorized"
and "deactivated" in (result.get("detail") or "")
):
# Account has been deactivated; currently works for Pebble; has not been
# implemented for Boulder (https://github.com/letsencrypt/boulder/issues/3971),
# might need adjustment in error detection.
@@ -139,47 +164,80 @@ class ACMEAccount(object):
raise ModuleFailException("Account is deactivated")
else:
raise ACMEProtocolException(
self.client.module, msg='Registering ACME account failed', info=info, content_json=result)
self.client.module,
msg="Registering ACME account failed",
info=info,
content_json=result,
)
def get_account_data(self):
'''
"""
Retrieve account information. Can only be called when the account
URI is already known (such as after calling setup_account).
Return None if the account was deactivated, or a dict otherwise.
'''
"""
if self.client.account_uri is None:
raise ModuleFailException("Account URI unknown")
if self.client.version == 1:
data = {}
data['resource'] = 'reg'
result, info = self.client.send_signed_request(self.client.account_uri, data, fail_on_error=False)
data["resource"] = "reg"
result, info = self.client.send_signed_request(
self.client.account_uri, data, fail_on_error=False
)
else:
# try POST-as-GET first (draft-15 or newer)
data = None
result, info = self.client.send_signed_request(self.client.account_uri, data, fail_on_error=False)
result, info = self.client.send_signed_request(
self.client.account_uri, data, fail_on_error=False
)
# check whether that failed with a malformed request error
if info['status'] >= 400 and result.get('type') == 'urn:ietf:params:acme:error:malformed':
if (
info["status"] >= 400
and result.get("type") == "urn:ietf:params:acme:error:malformed"
):
# retry as a regular POST (with no changed data) for pre-draft-15 ACME servers
data = {}
result, info = self.client.send_signed_request(self.client.account_uri, data, fail_on_error=False)
result, info = self.client.send_signed_request(
self.client.account_uri, data, fail_on_error=False
)
if not isinstance(result, Mapping):
raise ACMEProtocolException(
self.client.module, msg='Invalid account data retrieved from ACME server', info=info, content=result)
if info['status'] in (400, 403) and result.get('type') == 'urn:ietf:params:acme:error:unauthorized':
self.client.module,
msg="Invalid account data retrieved from ACME server",
info=info,
content=result,
)
if (
info["status"] in (400, 403)
and result.get("type") == "urn:ietf:params:acme:error:unauthorized"
):
# Returned when account is deactivated
return None
if info['status'] in (400, 404) and result.get('type') == 'urn:ietf:params:acme:error:accountDoesNotExist':
if (
info["status"] in (400, 404)
and result.get("type") == "urn:ietf:params:acme:error:accountDoesNotExist"
):
# Returned when account does not exist
return None
if info['status'] < 200 or info['status'] >= 300:
if info["status"] < 200 or info["status"] >= 300:
raise ACMEProtocolException(
self.client.module, msg='Error retrieving account data', info=info, content_json=result)
self.client.module,
msg="Error retrieving account data",
info=info,
content_json=result,
)
return result
def setup_account(self, contact=None, agreement=None, terms_agreed=False,
allow_creation=True, remove_account_uri_if_not_exists=False,
external_account_binding=None):
'''
def setup_account(
self,
contact=None,
agreement=None,
terms_agreed=False,
allow_creation=True,
remove_account_uri_if_not_exists=False,
external_account_binding=None,
):
"""
Detect or create an account on the ACME server. For ACME v1,
as the only way (without knowing an account URI) to test if an
account exists is to try and create one with the provided account
@@ -203,7 +261,7 @@ class ACMEAccount(object):
(https://tools.ietf.org/html/rfc8555#section-7.3.4).
https://tools.ietf.org/html/rfc8555#section-7.3
'''
"""
if self.client.account_uri is not None:
created = False
@@ -214,7 +272,9 @@ class ACMEAccount(object):
if remove_account_uri_if_not_exists and not allow_creation:
self.client.account_uri = None
else:
raise ModuleFailException("Account is deactivated or does not exist!")
raise ModuleFailException(
"Account is deactivated or does not exist!"
)
else:
created, account_data = self._new_reg(
contact,
@@ -223,15 +283,17 @@ class ACMEAccount(object):
allow_creation=allow_creation and not self.client.module.check_mode,
external_account_binding=external_account_binding,
)
if self.client.module.check_mode and self.client.account_uri is None and allow_creation:
if (
self.client.module.check_mode
and self.client.account_uri is None
and allow_creation
):
created = True
account_data = {
'contact': contact or []
}
account_data = {"contact": contact or []}
return created, account_data
def update_account(self, account_data, contact=None):
'''
"""
Update an account on the ACME server. Check mode is fully respected.
The current account data must be provided as ``account_data``.
@@ -242,11 +304,11 @@ class ACMEAccount(object):
account data.
https://tools.ietf.org/html/rfc8555#section-7.3.2
'''
"""
# Create request
update_request = {}
if contact is not None and account_data.get('contact', []) != contact:
update_request['contact'] = list(contact)
if contact is not None and account_data.get("contact", []) != contact:
update_request["contact"] = list(contact)
# No change?
if not update_request:
@@ -258,10 +320,16 @@ class ACMEAccount(object):
account_data.update(update_request)
else:
if self.client.version == 1:
update_request['resource'] = 'reg'
account_data, info = self.client.send_signed_request(self.client.account_uri, update_request)
update_request["resource"] = "reg"
account_data, info = self.client.send_signed_request(
self.client.account_uri, update_request
)
if not isinstance(account_data, Mapping):
raise ACMEProtocolException(
self.client.module, msg='Invalid account updating reply from ACME server', info=info, content=account_data)
self.client.module,
msg="Invalid account updating reply from ACME server",
info=info,
content=account_data,
)
return True, account_data

View File

@@ -66,72 +66,97 @@ RETRY_COUNT = 10
def _decode_retry(module, response, info, retry_count):
if info['status'] not in RETRY_STATUS_CODES:
if info["status"] not in RETRY_STATUS_CODES:
return False
if retry_count >= RETRY_COUNT:
raise ACMEProtocolException(
module, msg='Giving up after {retry} retries'.format(retry=RETRY_COUNT), info=info, response=response)
module,
msg="Giving up after {retry} retries".format(retry=RETRY_COUNT),
info=info,
response=response,
)
# 429 and 503 should have a Retry-After header (https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
try:
retry_after = min(max(1, int(info.get('retry-after'))), 60)
retry_after = min(max(1, int(info.get("retry-after"))), 60)
except (TypeError, ValueError):
retry_after = 10
module.log('Retrieved a %s HTTP status on %s, retrying in %s seconds' % (format_http_status(info['status']), info['url'], retry_after))
module.log(
"Retrieved a %s HTTP status on %s, retrying in %s seconds"
% (format_http_status(info["status"]), info["url"], retry_after)
)
time.sleep(retry_after)
return True
def _assert_fetch_url_success(module, response, info, allow_redirect=False, allow_client_error=True, allow_server_error=True):
if info['status'] < 0:
raise NetworkException(msg="Failure downloading %s, %s" % (info['url'], info['msg']))
def _assert_fetch_url_success(
module,
response,
info,
allow_redirect=False,
allow_client_error=True,
allow_server_error=True,
):
if info["status"] < 0:
raise NetworkException(
msg="Failure downloading %s, %s" % (info["url"], info["msg"])
)
if (300 <= info['status'] < 400 and not allow_redirect) or \
(400 <= info['status'] < 500 and not allow_client_error) or \
(info['status'] >= 500 and not allow_server_error):
if (
(300 <= info["status"] < 400 and not allow_redirect)
or (400 <= info["status"] < 500 and not allow_client_error)
or (info["status"] >= 500 and not allow_server_error)
):
raise ACMEProtocolException(module, info=info, response=response)
def _is_failed(info, expected_status_codes=None):
if info['status'] < 200 or info['status'] >= 400:
if info["status"] < 200 or info["status"] >= 400:
return True
if expected_status_codes is not None and info['status'] not in expected_status_codes:
if (
expected_status_codes is not None
and info["status"] not in expected_status_codes
):
return True
return False
class ACMEDirectory(object):
'''
"""
The ACME server directory. Gives access to the available resources,
and allows to obtain a Replay-Nonce. The acme_directory URL
needs to support unauthenticated GET requests; ACME endpoints
requiring authentication are not supported.
https://tools.ietf.org/html/rfc8555#section-7.1.1
'''
"""
def __init__(self, module, account):
self.module = module
self.directory_root = module.params['acme_directory']
self.version = module.params['acme_version']
self.directory_root = module.params["acme_directory"]
self.version = module.params["acme_version"]
self.directory, dummy = account.get_request(self.directory_root, get_only=True)
self.request_timeout = module.params['request_timeout']
self.request_timeout = module.params["request_timeout"]
# Check whether self.version matches what we expect
if self.version == 1:
for key in ('new-reg', 'new-authz', 'new-cert'):
for key in ("new-reg", "new-authz", "new-cert"):
if key not in self.directory:
raise ModuleFailException("ACME directory does not seem to follow protocol ACME v1")
raise ModuleFailException(
"ACME directory does not seem to follow protocol ACME v1"
)
if self.version == 2:
for key in ('newNonce', 'newAccount', 'newOrder'):
for key in ("newNonce", "newAccount", "newOrder"):
if key not in self.directory:
raise ModuleFailException("ACME directory does not seem to follow protocol ACME v2")
raise ModuleFailException(
"ACME directory does not seem to follow protocol ACME v2"
)
# Make sure that 'meta' is always available
if 'meta' not in self.directory:
self.directory['meta'] = {}
if "meta" not in self.directory:
self.directory["meta"] = {}
def __getitem__(self, key):
return self.directory[key]
@@ -143,35 +168,48 @@ class ACMEDirectory(object):
return self.directory.get(key, default_value)
def get_nonce(self, resource=None):
url = self.directory_root if self.version == 1 else self.directory['newNonce']
url = self.directory_root if self.version == 1 else self.directory["newNonce"]
if resource is not None:
url = resource
retry_count = 0
while True:
response, info = fetch_url(self.module, url, method='HEAD', timeout=self.request_timeout)
response, info = fetch_url(
self.module, url, method="HEAD", timeout=self.request_timeout
)
if _decode_retry(self.module, response, info, retry_count):
retry_count += 1
continue
if info['status'] not in (200, 204):
raise NetworkException("Failed to get replay-nonce, got status {0}".format(format_http_status(info['status'])))
if 'replay-nonce' in info:
return info['replay-nonce']
if info["status"] not in (200, 204):
raise NetworkException(
"Failed to get replay-nonce, got status {0}".format(
format_http_status(info["status"])
)
)
if "replay-nonce" in info:
return info["replay-nonce"]
self.module.log(
'HEAD to {0} did return status {1}, but no replay-nonce header!'.format(url, format_http_status(info['status'])))
"HEAD to {0} did return status {1}, but no replay-nonce header!".format(
url, format_http_status(info["status"])
)
)
if retry_count >= 5:
raise ACMEProtocolException(
self.module, msg='Was not able to obtain nonce, giving up after 5 retries', info=info, response=response)
self.module,
msg="Was not able to obtain nonce, giving up after 5 retries",
info=info,
response=response,
)
retry_count += 1
def has_renewal_info_endpoint(self):
return 'renewalInfo' in self.directory
return "renewalInfo" in self.directory
class ACMEClient(object):
'''
"""
ACME client object. Handles the authorized communication with the
ACME server.
'''
"""
def __init__(self, module, backend):
# Set to true to enable logging of all signed requests
@@ -179,17 +217,17 @@ class ACMEClient(object):
self.module = module
self.backend = backend
self.version = module.params['acme_version']
self.version = module.params["acme_version"]
# account_key path and content are mutually exclusive
self.account_key_file = module.params.get('account_key_src')
self.account_key_content = module.params.get('account_key_content')
self.account_key_passphrase = module.params.get('account_key_passphrase')
self.account_key_file = module.params.get("account_key_src")
self.account_key_content = module.params.get("account_key_content")
self.account_key_passphrase = module.params.get("account_key_passphrase")
# Grab account URI from module parameters.
# Make sure empty string is treated as None.
self.account_uri = module.params.get('account_uri') or None
self.account_uri = module.params.get("account_uri") or None
self.request_timeout = module.params['request_timeout']
self.request_timeout = module.params["request_timeout"]
self.account_key_data = None
self.account_jwk = None
@@ -199,12 +237,15 @@ class ACMEClient(object):
self.account_key_data = self.parse_key(
key_file=self.account_key_file,
key_content=self.account_key_content,
passphrase=self.account_key_passphrase)
passphrase=self.account_key_passphrase,
)
except KeyParsingError as e:
raise ModuleFailException("Error while parsing account key: {msg}".format(msg=e.msg))
self.account_jwk = self.account_key_data['jwk']
raise ModuleFailException(
"Error while parsing account key: {msg}".format(msg=e.msg)
)
self.account_jwk = self.account_key_data["jwk"]
self.account_jws_header = {
"alg": self.account_key_data['alg'],
"alg": self.account_key_data["alg"],
"jwk": self.account_jwk,
}
if self.account_uri:
@@ -214,56 +255,76 @@ class ACMEClient(object):
self.directory = ACMEDirectory(module, self)
def set_account_uri(self, uri):
'''
"""
Set account URI. For ACME v2, it needs to be used to sending signed
requests.
'''
"""
self.account_uri = uri
if self.version != 1:
self.account_jws_header.pop('jwk')
self.account_jws_header['kid'] = self.account_uri
self.account_jws_header.pop("jwk")
self.account_jws_header["kid"] = self.account_uri
def parse_key(self, key_file=None, key_content=None, passphrase=None):
'''
"""
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
In case of an error, raises KeyParsingError.
'''
"""
if key_file is None and key_content is None:
raise AssertionError('One of key_file and key_content must be specified!')
raise AssertionError("One of key_file and key_content must be specified!")
return self.backend.parse_key(key_file, key_content, passphrase=passphrase)
def sign_request(self, protected, payload, key_data, encode_payload=True):
'''
"""
Signs an ACME request.
'''
"""
try:
if payload is None:
# POST-as-GET
payload64 = ''
payload64 = ""
else:
# POST
if encode_payload:
payload = self.module.jsonify(payload).encode('utf8')
payload = self.module.jsonify(payload).encode("utf8")
payload64 = nopad_b64(to_bytes(payload))
protected64 = nopad_b64(self.module.jsonify(protected).encode('utf8'))
protected64 = nopad_b64(self.module.jsonify(protected).encode("utf8"))
except Exception as e:
raise ModuleFailException("Failed to encode payload / headers as JSON: {0}".format(e))
raise ModuleFailException(
"Failed to encode payload / headers as JSON: {0}".format(e)
)
return self.backend.sign(payload64, protected64, key_data)
def _log(self, msg, data=None):
'''
"""
Write arguments to acme.log when logging is enabled.
'''
"""
if self._debug:
with open('acme.log', 'ab') as f:
f.write('[{0}] {1}\n'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%s'), msg).encode('utf-8'))
with open("acme.log", "ab") as f:
f.write(
"[{0}] {1}\n".format(
datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%s"), msg
).encode("utf-8")
)
if data is not None:
f.write('{0}\n\n'.format(json.dumps(data, indent=2, sort_keys=True)).encode('utf-8'))
f.write(
"{0}\n\n".format(
json.dumps(data, indent=2, sort_keys=True)
).encode("utf-8")
)
def send_signed_request(self, url, payload, key_data=None, jws_header=None, parse_json_result=True,
encode_payload=True, fail_on_error=True, error_msg=None, expected_status_codes=None):
'''
def send_signed_request(
self,
url,
payload,
key_data=None,
jws_header=None,
parse_json_result=True,
encode_payload=True,
fail_on_error=True,
error_msg=None,
expected_status_codes=None,
):
"""
Sends a JWS signed HTTP POST request to the ACME server and returns
the response as dictionary (if parse_json_result is True) or in raw form
(if parse_json_result is False).
@@ -271,7 +332,7 @@ class ACMEClient(object):
If payload is None, a POST-as-GET is performed.
(https://tools.ietf.org/html/rfc8555#section-6.3)
'''
"""
key_data = key_data or self.account_key_data
jws_header = jws_header or self.account_jws_header
failed_tries = 0
@@ -281,21 +342,30 @@ class ACMEClient(object):
if self.version != 1:
protected["url"] = url
self._log('URL', url)
self._log('protected', protected)
self._log('payload', payload)
data = self.sign_request(protected, payload, key_data, encode_payload=encode_payload)
self._log("URL", url)
self._log("protected", protected)
self._log("payload", payload)
data = self.sign_request(
protected, payload, key_data, encode_payload=encode_payload
)
if self.version == 1:
data["header"] = jws_header.copy()
for k, v in protected.items():
data["header"].pop(k, None)
self._log('signed request', data)
self._log("signed request", data)
data = self.module.jsonify(data)
headers = {
'Content-Type': 'application/jose+json',
"Content-Type": "application/jose+json",
}
resp, info = fetch_url(self.module, url, data=data, headers=headers, method='POST', timeout=self.request_timeout)
resp, info = fetch_url(
self.module,
url,
data=data,
headers=headers,
method="POST",
timeout=self.request_timeout,
)
if _decode_retry(self.module, resp, info, failed_tries):
failed_tries += 1
continue
@@ -309,20 +379,26 @@ class ACMEClient(object):
raise TypeError
content = resp.read()
except (AttributeError, TypeError):
content = info.pop('body', None)
content = info.pop("body", None)
if content or not parse_json_result:
if (parse_json_result and info['content-type'].startswith('application/json')) or 400 <= info['status'] < 600:
if (
parse_json_result
and info["content-type"].startswith("application/json")
) or 400 <= info["status"] < 600:
try:
decoded_result = self.module.from_json(content.decode('utf8'))
self._log('parsed result', decoded_result)
decoded_result = self.module.from_json(content.decode("utf8"))
self._log("parsed result", decoded_result)
# In case of badNonce error, try again (up to 5 times)
# (https://tools.ietf.org/html/rfc8555#section-6.7)
if all((
400 <= info['status'] < 600,
decoded_result.get('type') == 'urn:ietf:params:acme:error:badNonce',
failed_tries <= 5,
)):
if all(
(
400 <= info["status"] < 600,
decoded_result.get("type")
== "urn:ietf:params:acme:error:badNonce",
failed_tries <= 5,
)
):
failed_tries += 1
continue
if parse_json_result:
@@ -330,25 +406,46 @@ class ACMEClient(object):
else:
result = content
except ValueError:
raise NetworkException("Failed to parse the ACME response: {0} {1}".format(url, content))
raise NetworkException(
"Failed to parse the ACME response: {0} {1}".format(
url, content
)
)
else:
result = content
if fail_on_error and _is_failed(info, expected_status_codes=expected_status_codes):
if fail_on_error and _is_failed(
info, expected_status_codes=expected_status_codes
):
raise ACMEProtocolException(
self.module, msg=error_msg, info=info, content=content, content_json=result if parse_json_result else None)
self.module,
msg=error_msg,
info=info,
content=content,
content_json=result if parse_json_result else None,
)
return result, info
def get_request(self, uri, parse_json_result=True, headers=None, get_only=False,
fail_on_error=True, error_msg=None, expected_status_codes=None):
'''
def get_request(
self,
uri,
parse_json_result=True,
headers=None,
get_only=False,
fail_on_error=True,
error_msg=None,
expected_status_codes=None,
):
"""
Perform a GET-like request. Will try POST-as-GET for ACMEv2, with fallback
to GET if server replies with a status code of 405.
'''
"""
if not get_only and self.version != 1:
# Try POST-as-GET
content, info = self.send_signed_request(uri, None, parse_json_result=False, fail_on_error=False)
if info['status'] == 405:
content, info = self.send_signed_request(
uri, None, parse_json_result=False, fail_on_error=False
)
if info["status"] == 405:
# Instead, do unauthenticated GET
get_only = True
else:
@@ -359,7 +456,13 @@ class ACMEClient(object):
# Perform unauthenticated GET
retry_count = 0
while True:
resp, info = fetch_url(self.module, uri, method='GET', headers=headers, timeout=self.request_timeout)
resp, info = fetch_url(
self.module,
uri,
method="GET",
headers=headers,
timeout=self.request_timeout,
)
if not _decode_retry(self.module, resp, info, retry_count):
break
retry_count += 1
@@ -373,27 +476,38 @@ class ACMEClient(object):
raise TypeError
content = resp.read()
except (AttributeError, TypeError):
content = info.pop('body', None)
content = info.pop("body", None)
# Process result
parsed_json_result = False
if parse_json_result:
result = {}
if content:
if info['content-type'].startswith('application/json'):
if info["content-type"].startswith("application/json"):
try:
result = self.module.from_json(content.decode('utf8'))
result = self.module.from_json(content.decode("utf8"))
parsed_json_result = True
except ValueError:
raise NetworkException("Failed to parse the ACME response: {0} {1}".format(uri, content))
raise NetworkException(
"Failed to parse the ACME response: {0} {1}".format(
uri, content
)
)
else:
result = content
else:
result = content
if fail_on_error and _is_failed(info, expected_status_codes=expected_status_codes):
if fail_on_error and _is_failed(
info, expected_status_codes=expected_status_codes
):
raise ACMEProtocolException(
self.module, msg=error_msg, info=info, content=content, content_json=result if parsed_json_result else None)
self.module,
msg=error_msg,
info=info,
content=content,
content_json=result if parsed_json_result else None,
)
return result, info
def get_renewal_info(
@@ -406,19 +520,30 @@ class ACMEClient(object):
retry_after_relative_with_timezone=True,
):
if not self.directory.has_renewal_info_endpoint():
raise ModuleFailException('The ACME endpoint does not support ACME Renewal Information retrieval')
raise ModuleFailException(
"The ACME endpoint does not support ACME Renewal Information retrieval"
)
if cert_id is None:
cert_id = compute_cert_id(self.backend, cert_info=cert_info, cert_filename=cert_filename, cert_content=cert_content)
url = '{base}/{cert_id}'.format(base=self.directory.directory['renewalInfo'].rstrip('/'), cert_id=cert_id)
cert_id = compute_cert_id(
self.backend,
cert_info=cert_info,
cert_filename=cert_filename,
cert_content=cert_content,
)
url = "{base}/{cert_id}".format(
base=self.directory.directory["renewalInfo"].rstrip("/"), cert_id=cert_id
)
data, info = self.get_request(url, parse_json_result=True, fail_on_error=True, get_only=True)
data, info = self.get_request(
url, parse_json_result=True, fail_on_error=True, get_only=True
)
# Include Retry-After header if asked for
if include_retry_after and 'retry-after' in info:
if include_retry_after and "retry-after" in info:
try:
data['retryAfter'] = parse_retry_after(
info['retry-after'],
data["retryAfter"] = parse_retry_after(
info["retry-after"],
relative_with_timezone=retry_after_relative_with_timezone,
)
except ValueError:
@@ -427,21 +552,23 @@ class ACMEClient(object):
def get_default_argspec():
'''
"""
Provides default argument spec for the options documented in the acme doc fragment.
DEPRECATED: will be removed in community.crypto 3.0.0
'''
"""
return dict(
acme_directory=dict(type='str', required=True),
acme_version=dict(type='int', required=True, choices=[1, 2]),
validate_certs=dict(type='bool', default=True),
select_crypto_backend=dict(type='str', default='auto', choices=['auto', 'openssl', 'cryptography']),
request_timeout=dict(type='int', default=10),
account_key_src=dict(type='path', aliases=['account_key']),
account_key_content=dict(type='str', no_log=True),
account_key_passphrase=dict(type='str', no_log=True),
account_uri=dict(type='str'),
acme_directory=dict(type="str", required=True),
acme_version=dict(type="int", required=True, choices=[1, 2]),
validate_certs=dict(type="bool", default=True),
select_crypto_backend=dict(
type="str", default="auto", choices=["auto", "openssl", "cryptography"]
),
request_timeout=dict(type="int", default=10),
account_key_src=dict(type="path", aliases=["account_key"]),
account_key_content=dict(type="str", no_log=True),
account_key_passphrase=dict(type="str", no_log=True),
account_uri=dict(type="str"),
)
@@ -450,90 +577,109 @@ def create_default_argspec(
require_account_key=True,
with_certificate=False,
):
'''
"""
Provides default argument spec for the options documented in the acme doc fragment.
'''
"""
result = ArgumentSpec(
argument_spec=dict(
acme_directory=dict(type='str', required=True),
acme_version=dict(type='int', required=True, choices=[1, 2]),
validate_certs=dict(type='bool', default=True),
select_crypto_backend=dict(type='str', default='auto', choices=['auto', 'openssl', 'cryptography']),
request_timeout=dict(type='int', default=10),
acme_directory=dict(type="str", required=True),
acme_version=dict(type="int", required=True, choices=[1, 2]),
validate_certs=dict(type="bool", default=True),
select_crypto_backend=dict(
type="str", default="auto", choices=["auto", "openssl", "cryptography"]
),
request_timeout=dict(type="int", default=10),
),
)
if with_account:
result.update_argspec(
account_key_src=dict(type='path', aliases=['account_key']),
account_key_content=dict(type='str', no_log=True),
account_key_passphrase=dict(type='str', no_log=True),
account_uri=dict(type='str'),
account_key_src=dict(type="path", aliases=["account_key"]),
account_key_content=dict(type="str", no_log=True),
account_key_passphrase=dict(type="str", no_log=True),
account_uri=dict(type="str"),
)
if require_account_key:
result.update(required_one_of=[['account_key_src', 'account_key_content']])
result.update(mutually_exclusive=[['account_key_src', 'account_key_content']])
result.update(required_one_of=[["account_key_src", "account_key_content"]])
result.update(mutually_exclusive=[["account_key_src", "account_key_content"]])
if with_certificate:
result.update_argspec(
csr=dict(type='path'),
csr_content=dict(type='str'),
csr=dict(type="path"),
csr_content=dict(type="str"),
)
result.update(
required_one_of=[['csr', 'csr_content']],
mutually_exclusive=[['csr', 'csr_content']],
required_one_of=[["csr", "csr_content"]],
mutually_exclusive=[["csr", "csr_content"]],
)
return result
def create_backend(module, needs_acme_v2):
if not HAS_IPADDRESS:
module.fail_json(msg=missing_required_lib('ipaddress'), exception=IPADDRESS_IMPORT_ERROR)
module.fail_json(
msg=missing_required_lib("ipaddress"), exception=IPADDRESS_IMPORT_ERROR
)
backend = module.params['select_crypto_backend']
backend = module.params["select_crypto_backend"]
# Backend autodetect
if backend == 'auto':
backend = 'cryptography' if HAS_CURRENT_CRYPTOGRAPHY else 'openssl'
if backend == "auto":
backend = "cryptography" if HAS_CURRENT_CRYPTOGRAPHY else "openssl"
# Create backend object
if backend == 'cryptography':
if backend == "cryptography":
if CRYPTOGRAPHY_ERROR is not None:
# Either we could not import cryptography at all, or there was an unexpected error
if CRYPTOGRAPHY_VERSION is None:
msg = missing_required_lib('cryptography')
msg = missing_required_lib("cryptography")
else:
msg = 'Unexpected error while preparing cryptography: {0}'.format(CRYPTOGRAPHY_ERROR.splitlines()[-1])
msg = "Unexpected error while preparing cryptography: {0}".format(
CRYPTOGRAPHY_ERROR.splitlines()[-1]
)
module.fail_json(msg=msg, exception=CRYPTOGRAPHY_ERROR)
if not HAS_CURRENT_CRYPTOGRAPHY:
# We succeeded importing cryptography, but its version is too old.
module.fail_json(
msg='Found cryptography, but only version {0}. {1}'.format(
msg="Found cryptography, but only version {0}. {1}".format(
CRYPTOGRAPHY_VERSION,
missing_required_lib('cryptography >= {0}'.format(CRYPTOGRAPHY_MINIMAL_VERSION))))
module.debug('Using cryptography backend (library version {0})'.format(CRYPTOGRAPHY_VERSION))
missing_required_lib(
"cryptography >= {0}".format(CRYPTOGRAPHY_MINIMAL_VERSION)
),
)
)
module.debug(
"Using cryptography backend (library version {0})".format(
CRYPTOGRAPHY_VERSION
)
)
module_backend = CryptographyBackend(module)
elif backend == 'openssl':
module.debug('Using OpenSSL binary backend')
elif backend == "openssl":
module.debug("Using OpenSSL binary backend")
module_backend = OpenSSLCLIBackend(module)
else:
module.fail_json(msg='Unknown crypto backend "{0}"!'.format(backend))
# Check common module parameters
if not module.params['validate_certs']:
if not module.params["validate_certs"]:
module.warn(
'Disabling certificate validation for communications with ACME endpoint. '
'This should only be done for testing against a local ACME server for '
'development purposes, but *never* for production purposes.'
"Disabling certificate validation for communications with ACME endpoint. "
"This should only be done for testing against a local ACME server for "
"development purposes, but *never* for production purposes."
)
if needs_acme_v2 and module.params['acme_version'] < 2:
module.fail_json(msg='The {0} module requires the ACME v2 protocol!'.format(module._name))
if needs_acme_v2 and module.params["acme_version"] < 2:
module.fail_json(
msg="The {0} module requires the ACME v2 protocol!".format(module._name)
)
if module.params['acme_version'] == 1:
module.deprecate("The value 1 for 'acme_version' is deprecated. Please switch to ACME v2",
version='3.0.0', collection_name='community.crypto')
if module.params["acme_version"] == 1:
module.deprecate(
"The value 1 for 'acme_version' is deprecated. Please switch to ACME v2",
version="3.0.0",
collection_name="community.crypto",
)
# AnsibleModule() changes the locale, so change it back to C because we rely
# on datetime.datetime.strptime() when parsing certificate dates.
locale.setlocale(locale.LC_ALL, 'C')
locale.setlocale(locale.LC_ALL, "C")
return module_backend

View File

@@ -57,7 +57,7 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
CRYPTOGRAPHY_MINIMAL_VERSION = '1.5'
CRYPTOGRAPHY_MINIMAL_VERSION = "1.5"
CRYPTOGRAPHY_ERROR = None
try:
@@ -78,7 +78,9 @@ except ImportError:
CRYPTOGRAPHY_ERROR = traceback.format_exc()
else:
CRYPTOGRAPHY_VERSION = cryptography.__version__
HAS_CURRENT_CRYPTOGRAPHY = (LooseVersion(CRYPTOGRAPHY_VERSION) >= LooseVersion(CRYPTOGRAPHY_MINIMAL_VERSION))
HAS_CURRENT_CRYPTOGRAPHY = LooseVersion(CRYPTOGRAPHY_VERSION) >= LooseVersion(
CRYPTOGRAPHY_MINIMAL_VERSION
)
try:
if HAS_CURRENT_CRYPTOGRAPHY:
_cryptography_backend = cryptography.hazmat.backends.default_backend()
@@ -91,13 +93,19 @@ class CryptographyChainMatcher(ChainMatcher):
def _parse_key_identifier(key_identifier, name, criterium_idx, module):
if key_identifier:
try:
return binascii.unhexlify(key_identifier.replace(':', ''))
return binascii.unhexlify(key_identifier.replace(":", ""))
except Exception:
if criterium_idx is None:
module.warn('Criterium has invalid {0} value. Ignoring criterium.'.format(name))
module.warn(
"Criterium has invalid {0} value. Ignoring criterium.".format(
name
)
)
else:
module.warn('Criterium {0} in select_chain has invalid {1} value. '
'Ignoring criterium.'.format(criterium_idx, name))
module.warn(
"Criterium {0} in select_chain has invalid {1} value. "
"Ignoring criterium.".format(criterium_idx, name)
)
return None
def __init__(self, criterium, module):
@@ -107,16 +115,26 @@ class CryptographyChainMatcher(ChainMatcher):
self.issuer = []
if criterium.subject:
self.subject = [
(cryptography_name_to_oid(k), to_native(v)) for k, v in parse_name_field(criterium.subject, 'subject')
(cryptography_name_to_oid(k), to_native(v))
for k, v in parse_name_field(criterium.subject, "subject")
]
if criterium.issuer:
self.issuer = [
(cryptography_name_to_oid(k), to_native(v)) for k, v in parse_name_field(criterium.issuer, 'issuer')
(cryptography_name_to_oid(k), to_native(v))
for k, v in parse_name_field(criterium.issuer, "issuer")
]
self.subject_key_identifier = CryptographyChainMatcher._parse_key_identifier(
criterium.subject_key_identifier, 'subject_key_identifier', criterium.index, module)
criterium.subject_key_identifier,
"subject_key_identifier",
criterium.index,
module,
)
self.authority_key_identifier = CryptographyChainMatcher._parse_key_identifier(
criterium.authority_key_identifier, 'authority_key_identifier', criterium.index, module)
criterium.authority_key_identifier,
"authority_key_identifier",
criterium.index,
module,
)
def _match_subject(self, x509_subject, match_subject):
for oid, value in match_subject:
@@ -130,17 +148,19 @@ class CryptographyChainMatcher(ChainMatcher):
return True
def match(self, certificate):
'''
"""
Check whether an alternate chain matches the specified criterium.
'''
"""
chain = certificate.chain
if self.test_certificates == 'last':
if self.test_certificates == "last":
chain = chain[-1:]
elif self.test_certificates == 'first':
elif self.test_certificates == "first":
chain = chain[:1]
for cert in chain:
try:
x509 = cryptography.x509.load_pem_x509_certificate(to_bytes(cert), cryptography.hazmat.backends.default_backend())
x509 = cryptography.x509.load_pem_x509_certificate(
to_bytes(cert), cryptography.hazmat.backends.default_backend()
)
matches = True
if not self._match_subject(x509.subject, self.subject):
matches = False
@@ -148,14 +168,18 @@ class CryptographyChainMatcher(ChainMatcher):
matches = False
if self.subject_key_identifier:
try:
ext = x509.extensions.get_extension_for_class(cryptography.x509.SubjectKeyIdentifier)
ext = x509.extensions.get_extension_for_class(
cryptography.x509.SubjectKeyIdentifier
)
if self.subject_key_identifier != ext.value.digest:
matches = False
except cryptography.x509.ExtensionNotFound:
matches = False
if self.authority_key_identifier:
try:
ext = x509.extensions.get_extension_for_class(cryptography.x509.AuthorityKeyIdentifier)
ext = x509.extensions.get_extension_for_class(
cryptography.x509.AuthorityKeyIdentifier
)
if self.authority_key_identifier != ext.value.key_identifier:
matches = False
except cryptography.x509.ExtensionNotFound:
@@ -163,19 +187,23 @@ class CryptographyChainMatcher(ChainMatcher):
if matches:
return True
except Exception as e:
self.module.warn('Error while loading certificate {0}: {1}'.format(cert, e))
self.module.warn(
"Error while loading certificate {0}: {1}".format(cert, e)
)
return False
class CryptographyBackend(CryptoBackend):
def __init__(self, module):
super(CryptographyBackend, self).__init__(module, with_timezone=CRYPTOGRAPHY_TIMEZONE)
super(CryptographyBackend, self).__init__(
module, with_timezone=CRYPTOGRAPHY_TIMEZONE
)
def parse_key(self, key_file=None, key_content=None, passphrase=None):
'''
"""
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
Raises KeyParsingError in case of errors.
'''
"""
# If key_content is not given, read key_file
if key_content is None:
key_content = read_file(key_file)
@@ -186,84 +214,97 @@ class CryptographyBackend(CryptoBackend):
key = cryptography.hazmat.primitives.serialization.load_pem_private_key(
key_content,
password=to_bytes(passphrase) if passphrase is not None else None,
backend=_cryptography_backend)
backend=_cryptography_backend,
)
except Exception as e:
raise KeyParsingError('error while loading key: {0}'.format(e))
raise KeyParsingError("error while loading key: {0}".format(e))
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
pk = key.public_key().public_numbers()
return {
'key_obj': key,
'type': 'rsa',
'alg': 'RS256',
'jwk': {
"key_obj": key,
"type": "rsa",
"alg": "RS256",
"jwk": {
"kty": "RSA",
"e": nopad_b64(convert_int_to_bytes(pk.e)),
"n": nopad_b64(convert_int_to_bytes(pk.n)),
},
'hash': 'sha256',
"hash": "sha256",
}
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey):
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
):
pk = key.public_key().public_numbers()
if pk.curve.name == 'secp256r1':
if pk.curve.name == "secp256r1":
bits = 256
alg = 'ES256'
hashalg = 'sha256'
alg = "ES256"
hashalg = "sha256"
point_size = 32
curve = 'P-256'
elif pk.curve.name == 'secp384r1':
curve = "P-256"
elif pk.curve.name == "secp384r1":
bits = 384
alg = 'ES384'
hashalg = 'sha384'
alg = "ES384"
hashalg = "sha384"
point_size = 48
curve = 'P-384'
elif pk.curve.name == 'secp521r1':
curve = "P-384"
elif pk.curve.name == "secp521r1":
# Not yet supported on Let's Encrypt side, see
# https://github.com/letsencrypt/boulder/issues/2217
bits = 521
alg = 'ES512'
hashalg = 'sha512'
alg = "ES512"
hashalg = "sha512"
point_size = 66
curve = 'P-521'
curve = "P-521"
else:
raise KeyParsingError('unknown elliptic curve: {0}'.format(pk.curve.name))
raise KeyParsingError(
"unknown elliptic curve: {0}".format(pk.curve.name)
)
num_bytes = (bits + 7) // 8
return {
'key_obj': key,
'type': 'ec',
'alg': alg,
'jwk': {
"key_obj": key,
"type": "ec",
"alg": alg,
"jwk": {
"kty": "EC",
"crv": curve,
"x": nopad_b64(convert_int_to_bytes(pk.x, count=num_bytes)),
"y": nopad_b64(convert_int_to_bytes(pk.y, count=num_bytes)),
},
'hash': hashalg,
'point_size': point_size,
"hash": hashalg,
"point_size": point_size,
}
else:
raise KeyParsingError('unknown key type "{0}"'.format(type(key)))
def sign(self, payload64, protected64, key_data):
sign_payload = "{0}.{1}".format(protected64, payload64).encode('utf8')
if 'mac_obj' in key_data:
mac = key_data['mac_obj']()
sign_payload = "{0}.{1}".format(protected64, payload64).encode("utf8")
if "mac_obj" in key_data:
mac = key_data["mac_obj"]()
mac.update(sign_payload)
signature = mac.finalize()
elif isinstance(key_data['key_obj'], cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
elif isinstance(
key_data["key_obj"],
cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey,
):
padding = cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15()
hashalg = cryptography.hazmat.primitives.hashes.SHA256
signature = key_data['key_obj'].sign(sign_payload, padding, hashalg())
elif isinstance(key_data['key_obj'], cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey):
if key_data['hash'] == 'sha256':
signature = key_data["key_obj"].sign(sign_payload, padding, hashalg())
elif isinstance(
key_data["key_obj"],
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
):
if key_data["hash"] == "sha256":
hashalg = cryptography.hazmat.primitives.hashes.SHA256
elif key_data['hash'] == 'sha384':
elif key_data["hash"] == "sha384":
hashalg = cryptography.hazmat.primitives.hashes.SHA384
elif key_data['hash'] == 'sha512':
elif key_data["hash"] == "sha512":
hashalg = cryptography.hazmat.primitives.hashes.SHA512
ecdsa = cryptography.hazmat.primitives.asymmetric.ec.ECDSA(hashalg())
r, s = cryptography.hazmat.primitives.asymmetric.utils.decode_dss_signature(key_data['key_obj'].sign(sign_payload, ecdsa))
rr = convert_int_to_hex(r, 2 * key_data['point_size'])
ss = convert_int_to_hex(s, 2 * key_data['point_size'])
r, s = cryptography.hazmat.primitives.asymmetric.utils.decode_dss_signature(
key_data["key_obj"].sign(sign_payload, ecdsa)
)
rr = convert_int_to_hex(r, 2 * key_data["point_size"])
ss = convert_int_to_hex(s, 2 * key_data["point_size"])
signature = binascii.unhexlify(rr) + binascii.unhexlify(ss)
return {
@@ -273,44 +314,50 @@ class CryptographyBackend(CryptoBackend):
}
def create_mac_key(self, alg, key):
'''Create a MAC key.'''
if alg == 'HS256':
"""Create a MAC key."""
if alg == "HS256":
hashalg = cryptography.hazmat.primitives.hashes.SHA256
hashbytes = 32
elif alg == 'HS384':
elif alg == "HS384":
hashalg = cryptography.hazmat.primitives.hashes.SHA384
hashbytes = 48
elif alg == 'HS512':
elif alg == "HS512":
hashalg = cryptography.hazmat.primitives.hashes.SHA512
hashbytes = 64
else:
raise BackendException('Unsupported MAC key algorithm for cryptography backend: {0}'.format(alg))
raise BackendException(
"Unsupported MAC key algorithm for cryptography backend: {0}".format(
alg
)
)
key_bytes = base64.urlsafe_b64decode(key)
if len(key_bytes) < hashbytes:
raise BackendException(
'{0} key must be at least {1} bytes long (after Base64 decoding)'.format(alg, hashbytes))
"{0} key must be at least {1} bytes long (after Base64 decoding)".format(
alg, hashbytes
)
)
return {
'mac_obj': lambda: cryptography.hazmat.primitives.hmac.HMAC(
key_bytes,
hashalg(),
_cryptography_backend),
'type': 'hmac',
'alg': alg,
'jwk': {
'kty': 'oct',
'k': key,
"mac_obj": lambda: cryptography.hazmat.primitives.hmac.HMAC(
key_bytes, hashalg(), _cryptography_backend
),
"type": "hmac",
"alg": alg,
"jwk": {
"kty": "oct",
"k": key,
},
}
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
'''
"""
Return a list of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
'dns' or 'ip'.
The list is deduplicated, and if a CNAME is present, it will be returned
as the first element in the result.
'''
"""
if csr_content is None:
csr_content = read_file(csr_filename)
else:
@@ -328,34 +375,43 @@ class CryptographyBackend(CryptoBackend):
for sub in csr.subject:
if sub.oid == cryptography.x509.oid.NameOID.COMMON_NAME:
add_identifier(('dns', sub.value))
add_identifier(("dns", sub.value))
for extension in csr.extensions:
if extension.oid == cryptography.x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME:
if (
extension.oid
== cryptography.x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME
):
for name in extension.value:
if isinstance(name, cryptography.x509.DNSName):
add_identifier(('dns', name.value))
add_identifier(("dns", name.value))
elif isinstance(name, cryptography.x509.IPAddress):
add_identifier(('ip', name.value.compressed))
add_identifier(("ip", name.value.compressed))
else:
raise BackendException('Found unsupported SAN identifier {0}'.format(name))
raise BackendException(
"Found unsupported SAN identifier {0}".format(name)
)
return result
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
'''
"""
Return a set of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
'dns' or 'ip'.
'''
return set(self.get_ordered_csr_identifiers(csr_filename=csr_filename, csr_content=csr_content))
"""
return set(
self.get_ordered_csr_identifiers(
csr_filename=csr_filename, csr_content=csr_content
)
)
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
'''
"""
Return the days the certificate in cert_filename remains valid and -1
if the file was not found. If cert_filename contains more than one
certificate, only the first one will be considered.
If now is not specified, datetime.datetime.now() is used.
'''
"""
if cert_filename is not None:
cert_content = None
if os.path.exists(cert_filename):
@@ -367,14 +423,18 @@ class CryptographyBackend(CryptoBackend):
return -1
# Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf.
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or '')
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
try:
cert = cryptography.x509.load_pem_x509_certificate(cert_content, _cryptography_backend)
cert = cryptography.x509.load_pem_x509_certificate(
cert_content, _cryptography_backend
)
except Exception as e:
if cert_filename is None:
raise BackendException('Cannot parse certificate: {0}'.format(e))
raise BackendException('Cannot parse certificate {0}: {1}'.format(cert_filename, e))
raise BackendException("Cannot parse certificate: {0}".format(e))
raise BackendException(
"Cannot parse certificate {0}: {1}".format(cert_filename, e)
)
if now is None:
now = self.get_now()
@@ -383,40 +443,48 @@ class CryptographyBackend(CryptoBackend):
return (get_not_valid_after(cert) - now).days
def create_chain_matcher(self, criterium):
'''
"""
Given a Criterium object, creates a ChainMatcher object.
'''
"""
return CryptographyChainMatcher(criterium, self.module)
def get_cert_information(self, cert_filename=None, cert_content=None):
'''
"""
Return some information on a X.509 certificate as a CertificateInformation object.
'''
"""
if cert_filename is not None:
cert_content = read_file(cert_filename)
else:
cert_content = to_bytes(cert_content)
# Make sure we have at most one PEM. Otherwise cryptography 36.0.0 will barf.
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or '')
cert_content = to_bytes(extract_first_pem(to_text(cert_content)) or "")
try:
cert = cryptography.x509.load_pem_x509_certificate(cert_content, _cryptography_backend)
cert = cryptography.x509.load_pem_x509_certificate(
cert_content, _cryptography_backend
)
except Exception as e:
if cert_filename is None:
raise BackendException('Cannot parse certificate: {0}'.format(e))
raise BackendException('Cannot parse certificate {0}: {1}'.format(cert_filename, e))
raise BackendException("Cannot parse certificate: {0}".format(e))
raise BackendException(
"Cannot parse certificate {0}: {1}".format(cert_filename, e)
)
ski = None
try:
ext = cert.extensions.get_extension_for_class(cryptography.x509.SubjectKeyIdentifier)
ext = cert.extensions.get_extension_for_class(
cryptography.x509.SubjectKeyIdentifier
)
ski = ext.value.digest
except cryptography.x509.ExtensionNotFound:
pass
aki = None
try:
ext = cert.extensions.get_extension_for_class(cryptography.x509.AuthorityKeyIdentifier)
ext = cert.extensions.get_extension_for_class(
cryptography.x509.AuthorityKeyIdentifier
)
aki = ext.value.key_identifier
except cryptography.x509.ExtensionNotFound:
pass

View File

@@ -45,7 +45,7 @@ except ImportError:
pass
_OPENSSL_ENVIRONMENT_UPDATE = dict(LANG='C', LC_ALL='C', LC_MESSAGES='C', LC_CTYPE='C')
_OPENSSL_ENVIRONMENT_UPDATE = dict(LANG="C", LC_ALL="C", LC_MESSAGES="C", LC_CTYPE="C")
def _extract_date(out_text, name, cert_filename_suffix=""):
@@ -55,11 +55,17 @@ def _extract_date(out_text, name, cert_filename_suffix=""):
# even though the information is there and a supported timezone for all supported
# Python implementations (GMT). So we have to modify the datetime object by
# replacing it by UTC.
return ensure_utc_timezone(datetime.datetime.strptime(date_str, '%b %d %H:%M:%S %Y %Z'))
return ensure_utc_timezone(
datetime.datetime.strptime(date_str, "%b %d %H:%M:%S %Y %Z")
)
except AttributeError:
raise BackendException("No '{0}' date found{1}".format(name, cert_filename_suffix))
raise BackendException(
"No '{0}' date found{1}".format(name, cert_filename_suffix)
)
except ValueError as exc:
raise BackendException("Failed to parse '{0}' date{1}: {2}".format(name, cert_filename_suffix, exc))
raise BackendException(
"Failed to parse '{0}' date{1}: {2}".format(name, cert_filename_suffix, exc)
)
def _decode_octets(octets_text):
@@ -69,7 +75,11 @@ def _decode_octets(octets_text):
def _extract_octets(out_text, name, required=True, potential_prefixes=None):
regexp = r"\s+%s:\s*\n\s+%s([A-Fa-f0-9]{2}(?::[A-Fa-f0-9]{2})*)\s*\n" % (
name,
('(?:%s)' % '|'.join(re.escape(pp) for pp in potential_prefixes)) if potential_prefixes else '',
(
("(?:%s)" % "|".join(re.escape(pp) for pp in potential_prefixes))
if potential_prefixes
else ""
),
)
match = re.search(regexp, out_text, re.MULTILINE | re.DOTALL)
if match is not None:
@@ -83,36 +93,41 @@ class OpenSSLCLIBackend(CryptoBackend):
def __init__(self, module, openssl_binary=None):
super(OpenSSLCLIBackend, self).__init__(module, with_timezone=True)
if openssl_binary is None:
openssl_binary = module.get_bin_path('openssl', True)
openssl_binary = module.get_bin_path("openssl", True)
self.openssl_binary = openssl_binary
def parse_key(self, key_file=None, key_content=None, passphrase=None):
'''
"""
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
Raises KeyParsingError in case of errors.
'''
"""
if passphrase is not None:
raise KeyParsingError('openssl backend does not support key passphrases')
raise KeyParsingError("openssl backend does not support key passphrases")
# If key_file is not given, but key_content, write that to a temporary file
if key_file is None:
fd, tmpsrc = tempfile.mkstemp()
self.module.add_cleanup_file(tmpsrc) # Ansible will delete the file on exit
f = os.fdopen(fd, 'wb')
f = os.fdopen(fd, "wb")
try:
f.write(key_content.encode('utf-8'))
f.write(key_content.encode("utf-8"))
key_file = tmpsrc
except Exception as err:
try:
f.close()
except Exception:
pass
raise KeyParsingError("failed to create temporary content file: %s" % to_native(err), exception=traceback.format_exc())
raise KeyParsingError(
"failed to create temporary content file: %s" % to_native(err),
exception=traceback.format_exc(),
)
f.close()
# Parse key
account_key_type = None
with open(key_file, "rt") as f:
for line in f:
m = re.match(r"^\s*-{5,}BEGIN\s+(EC|RSA)\s+PRIVATE\s+KEY-{5,}\s*$", line)
m = re.match(
r"^\s*-{5,}BEGIN\s+(EC|RSA)\s+PRIVATE\s+KEY-{5,}\s*$", line
)
if m is not None:
account_key_type = m.group(1).lower()
break
@@ -125,111 +140,162 @@ class OpenSSLCLIBackend(CryptoBackend):
if account_key_type not in ("rsa", "ec"):
raise KeyParsingError('unknown key type "%s"' % account_key_type)
openssl_keydump_cmd = [self.openssl_binary, account_key_type, "-in", key_file, "-noout", "-text"]
openssl_keydump_cmd = [
self.openssl_binary,
account_key_type,
"-in",
key_file,
"-noout",
"-text",
]
rc, out, err = self.module.run_command(
openssl_keydump_cmd, check_rc=False, environ_update=_OPENSSL_ENVIRONMENT_UPDATE)
openssl_keydump_cmd,
check_rc=False,
environ_update=_OPENSSL_ENVIRONMENT_UPDATE,
)
if rc != 0:
raise BackendException('Error while running {cmd}: {stderr}'.format(cmd=' '.join(openssl_keydump_cmd), stderr=to_text(err)))
raise BackendException(
"Error while running {cmd}: {stderr}".format(
cmd=" ".join(openssl_keydump_cmd), stderr=to_text(err)
)
)
out_text = to_text(out, errors='surrogate_or_strict')
out_text = to_text(out, errors="surrogate_or_strict")
if account_key_type == 'rsa':
pub_hex = re.search(r"modulus:\n\s+00:([a-f0-9\:\s]+?)\npublicExponent", out_text, re.MULTILINE | re.DOTALL).group(1)
if account_key_type == "rsa":
pub_hex = re.search(
r"modulus:\n\s+00:([a-f0-9\:\s]+?)\npublicExponent",
out_text,
re.MULTILINE | re.DOTALL,
).group(1)
pub_exp = re.search(r"\npublicExponent: ([0-9]+)", out_text, re.MULTILINE | re.DOTALL).group(1)
pub_exp = re.search(
r"\npublicExponent: ([0-9]+)", out_text, re.MULTILINE | re.DOTALL
).group(1)
pub_exp = "{0:x}".format(int(pub_exp))
if len(pub_exp) % 2:
pub_exp = "0{0}".format(pub_exp)
return {
'key_file': key_file,
'type': 'rsa',
'alg': 'RS256',
'jwk': {
"key_file": key_file,
"type": "rsa",
"alg": "RS256",
"jwk": {
"kty": "RSA",
"e": nopad_b64(binascii.unhexlify(pub_exp.encode("utf-8"))),
"n": nopad_b64(_decode_octets(pub_hex)),
},
'hash': 'sha256',
"hash": "sha256",
}
elif account_key_type == 'ec':
elif account_key_type == "ec":
pub_data = re.search(
r"pub:\s*\n\s+04:([a-f0-9\:\s]+?)\nASN1 OID: (\S+)(?:\nNIST CURVE: (\S+))?",
out_text,
re.MULTILINE | re.DOTALL,
)
if pub_data is None:
raise KeyParsingError('cannot parse elliptic curve key')
raise KeyParsingError("cannot parse elliptic curve key")
pub_hex = _decode_octets(pub_data.group(1))
asn1_oid_curve = pub_data.group(2).lower()
nist_curve = pub_data.group(3).lower() if pub_data.group(3) else None
if asn1_oid_curve == 'prime256v1' or nist_curve == 'p-256':
if asn1_oid_curve == "prime256v1" or nist_curve == "p-256":
bits = 256
alg = 'ES256'
hashalg = 'sha256'
alg = "ES256"
hashalg = "sha256"
point_size = 32
curve = 'P-256'
elif asn1_oid_curve == 'secp384r1' or nist_curve == 'p-384':
curve = "P-256"
elif asn1_oid_curve == "secp384r1" or nist_curve == "p-384":
bits = 384
alg = 'ES384'
hashalg = 'sha384'
alg = "ES384"
hashalg = "sha384"
point_size = 48
curve = 'P-384'
elif asn1_oid_curve == 'secp521r1' or nist_curve == 'p-521':
curve = "P-384"
elif asn1_oid_curve == "secp521r1" or nist_curve == "p-521":
# Not yet supported on Let's Encrypt side, see
# https://github.com/letsencrypt/boulder/issues/2217
bits = 521
alg = 'ES512'
hashalg = 'sha512'
alg = "ES512"
hashalg = "sha512"
point_size = 66
curve = 'P-521'
curve = "P-521"
else:
raise KeyParsingError('unknown elliptic curve: %s / %s' % (asn1_oid_curve, nist_curve))
raise KeyParsingError(
"unknown elliptic curve: %s / %s" % (asn1_oid_curve, nist_curve)
)
num_bytes = (bits + 7) // 8
if len(pub_hex) != 2 * num_bytes:
raise KeyParsingError('bad elliptic curve point (%s / %s)' % (asn1_oid_curve, nist_curve))
raise KeyParsingError(
"bad elliptic curve point (%s / %s)" % (asn1_oid_curve, nist_curve)
)
return {
'key_file': key_file,
'type': 'ec',
'alg': alg,
'jwk': {
"key_file": key_file,
"type": "ec",
"alg": alg,
"jwk": {
"kty": "EC",
"crv": curve,
"x": nopad_b64(pub_hex[:num_bytes]),
"y": nopad_b64(pub_hex[num_bytes:]),
},
'hash': hashalg,
'point_size': point_size,
"hash": hashalg,
"point_size": point_size,
}
def sign(self, payload64, protected64, key_data):
sign_payload = "{0}.{1}".format(protected64, payload64).encode('utf8')
if key_data['type'] == 'hmac':
hex_key = to_native(binascii.hexlify(base64.urlsafe_b64decode(key_data['jwk']['k'])))
cmd_postfix = ["-mac", "hmac", "-macopt", "hexkey:{0}".format(hex_key), "-binary"]
sign_payload = "{0}.{1}".format(protected64, payload64).encode("utf8")
if key_data["type"] == "hmac":
hex_key = to_native(
binascii.hexlify(base64.urlsafe_b64decode(key_data["jwk"]["k"]))
)
cmd_postfix = [
"-mac",
"hmac",
"-macopt",
"hexkey:{0}".format(hex_key),
"-binary",
]
else:
cmd_postfix = ["-sign", key_data['key_file']]
openssl_sign_cmd = [self.openssl_binary, "dgst", "-{0}".format(key_data['hash'])] + cmd_postfix
cmd_postfix = ["-sign", key_data["key_file"]]
openssl_sign_cmd = [
self.openssl_binary,
"dgst",
"-{0}".format(key_data["hash"]),
] + cmd_postfix
rc, out, err = self.module.run_command(
openssl_sign_cmd, data=sign_payload, check_rc=False, binary_data=True, environ_update=_OPENSSL_ENVIRONMENT_UPDATE)
openssl_sign_cmd,
data=sign_payload,
check_rc=False,
binary_data=True,
environ_update=_OPENSSL_ENVIRONMENT_UPDATE,
)
if rc != 0:
raise BackendException('Error while running {cmd}: {stderr}'.format(cmd=' '.join(openssl_sign_cmd), stderr=to_text(err)))
raise BackendException(
"Error while running {cmd}: {stderr}".format(
cmd=" ".join(openssl_sign_cmd), stderr=to_text(err)
)
)
if key_data['type'] == 'ec':
if key_data["type"] == "ec":
dummy, der_out, dummy = self.module.run_command(
[self.openssl_binary, "asn1parse", "-inform", "DER"],
data=out, binary_data=True, environ_update=_OPENSSL_ENVIRONMENT_UPDATE)
expected_len = 2 * key_data['point_size']
data=out,
binary_data=True,
environ_update=_OPENSSL_ENVIRONMENT_UPDATE,
)
expected_len = 2 * key_data["point_size"]
sig = re.findall(
r"prim:\s+INTEGER\s+:([0-9A-F]{1,%s})\n" % expected_len,
to_text(der_out, errors='surrogate_or_strict'))
to_text(der_out, errors="surrogate_or_strict"),
)
if len(sig) != 2:
raise BackendException(
"failed to generate Elliptic Curve signature; cannot parse DER output: {0}".format(
to_text(der_out, errors='surrogate_or_strict')))
sig[0] = (expected_len - len(sig[0])) * '0' + sig[0]
sig[1] = (expected_len - len(sig[1])) * '0' + sig[1]
to_text(der_out, errors="surrogate_or_strict")
)
)
sig[0] = (expected_len - len(sig[0])) * "0" + sig[0]
sig[1] = (expected_len - len(sig[1])) * "0" + sig[1]
out = binascii.unhexlify(sig[0]) + binascii.unhexlify(sig[1])
return {
@@ -239,30 +305,35 @@ class OpenSSLCLIBackend(CryptoBackend):
}
def create_mac_key(self, alg, key):
'''Create a MAC key.'''
if alg == 'HS256':
hashalg = 'sha256'
"""Create a MAC key."""
if alg == "HS256":
hashalg = "sha256"
hashbytes = 32
elif alg == 'HS384':
hashalg = 'sha384'
elif alg == "HS384":
hashalg = "sha384"
hashbytes = 48
elif alg == 'HS512':
hashalg = 'sha512'
elif alg == "HS512":
hashalg = "sha512"
hashbytes = 64
else:
raise BackendException('Unsupported MAC key algorithm for OpenSSL backend: {0}'.format(alg))
raise BackendException(
"Unsupported MAC key algorithm for OpenSSL backend: {0}".format(alg)
)
key_bytes = base64.urlsafe_b64decode(key)
if len(key_bytes) < hashbytes:
raise BackendException(
'{0} key must be at least {1} bytes long (after Base64 decoding)'.format(alg, hashbytes))
"{0} key must be at least {1} bytes long (after Base64 decoding)".format(
alg, hashbytes
)
)
return {
'type': 'hmac',
'alg': alg,
'jwk': {
'kty': 'oct',
'k': key,
"type": "hmac",
"alg": alg,
"jwk": {
"kty": "oct",
"k": key,
},
'hash': hashalg,
"hash": hashalg,
}
@staticmethod
@@ -274,25 +345,41 @@ class OpenSSLCLIBackend(CryptoBackend):
return ip
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
'''
"""
Return a list of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
'dns' or 'ip'.
The list is deduplicated, and if a CNAME is present, it will be returned
as the first element in the result.
'''
"""
filename = csr_filename
data = None
if csr_content is not None:
filename = '/dev/stdin'
data = csr_content.encode('utf-8')
filename = "/dev/stdin"
data = csr_content.encode("utf-8")
openssl_csr_cmd = [self.openssl_binary, "req", "-in", filename, "-noout", "-text"]
openssl_csr_cmd = [
self.openssl_binary,
"req",
"-in",
filename,
"-noout",
"-text",
]
rc, out, err = self.module.run_command(
openssl_csr_cmd, data=data, check_rc=False, binary_data=True, environ_update=_OPENSSL_ENVIRONMENT_UPDATE)
openssl_csr_cmd,
data=data,
check_rc=False,
binary_data=True,
environ_update=_OPENSSL_ENVIRONMENT_UPDATE,
)
if rc != 0:
raise BackendException('Error while running {cmd}: {stderr}'.format(cmd=' '.join(openssl_csr_cmd), stderr=to_text(err)))
raise BackendException(
"Error while running {cmd}: {stderr}".format(
cmd=" ".join(openssl_csr_cmd), stderr=to_text(err)
)
)
identifiers = set()
result = []
@@ -303,61 +390,90 @@ class OpenSSLCLIBackend(CryptoBackend):
identifiers.add(identifier)
result.append(identifier)
common_name = re.search(r"Subject:.* CN\s?=\s?([^\s,;/]+)", to_text(out, errors='surrogate_or_strict'))
common_name = re.search(
r"Subject:.* CN\s?=\s?([^\s,;/]+)",
to_text(out, errors="surrogate_or_strict"),
)
if common_name is not None:
add_identifier(('dns', common_name.group(1)))
add_identifier(("dns", common_name.group(1)))
subject_alt_names = re.search(
r"X509v3 Subject Alternative Name: (?:critical)?\n +([^\n]+)\n",
to_text(out, errors='surrogate_or_strict'), re.MULTILINE | re.DOTALL)
to_text(out, errors="surrogate_or_strict"),
re.MULTILINE | re.DOTALL,
)
if subject_alt_names is not None:
for san in subject_alt_names.group(1).split(", "):
if san.lower().startswith("dns:"):
add_identifier(('dns', san[4:]))
add_identifier(("dns", san[4:]))
elif san.lower().startswith("ip:"):
add_identifier(('ip', self._normalize_ip(san[3:])))
add_identifier(("ip", self._normalize_ip(san[3:])))
elif san.lower().startswith("ip address:"):
add_identifier(('ip', self._normalize_ip(san[11:])))
add_identifier(("ip", self._normalize_ip(san[11:])))
else:
raise BackendException('Found unsupported SAN identifier "{0}"'.format(san))
raise BackendException(
'Found unsupported SAN identifier "{0}"'.format(san)
)
return result
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
'''
"""
Return a set of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
'dns' or 'ip'.
'''
return set(self.get_ordered_csr_identifiers(csr_filename=csr_filename, csr_content=csr_content))
"""
return set(
self.get_ordered_csr_identifiers(
csr_filename=csr_filename, csr_content=csr_content
)
)
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
'''
"""
Return the days the certificate in cert_filename remains valid and -1
if the file was not found. If cert_filename contains more than one
certificate, only the first one will be considered.
If now is not specified, datetime.datetime.now() is used.
'''
"""
filename = cert_filename
data = None
if cert_content is not None:
filename = '/dev/stdin'
data = cert_content.encode('utf-8')
cert_filename_suffix = ''
filename = "/dev/stdin"
data = cert_content.encode("utf-8")
cert_filename_suffix = ""
elif cert_filename is not None:
if not os.path.exists(cert_filename):
return -1
cert_filename_suffix = ' in {0}'.format(cert_filename)
cert_filename_suffix = " in {0}".format(cert_filename)
else:
return -1
openssl_cert_cmd = [self.openssl_binary, "x509", "-in", filename, "-noout", "-text"]
openssl_cert_cmd = [
self.openssl_binary,
"x509",
"-in",
filename,
"-noout",
"-text",
]
rc, out, err = self.module.run_command(
openssl_cert_cmd, data=data, check_rc=False, binary_data=True, environ_update=_OPENSSL_ENVIRONMENT_UPDATE)
openssl_cert_cmd,
data=data,
check_rc=False,
binary_data=True,
environ_update=_OPENSSL_ENVIRONMENT_UPDATE,
)
if rc != 0:
raise BackendException('Error while running {cmd}: {stderr}'.format(cmd=' '.join(openssl_cert_cmd), stderr=to_text(err)))
raise BackendException(
"Error while running {cmd}: {stderr}".format(
cmd=" ".join(openssl_cert_cmd), stderr=to_text(err)
)
)
out_text = to_text(out, errors='surrogate_or_strict')
not_after = _extract_date(out_text, 'Not After', cert_filename_suffix=cert_filename_suffix)
out_text = to_text(out, errors="surrogate_or_strict")
not_after = _extract_date(
out_text, "Not After", cert_filename_suffix=cert_filename_suffix
)
if now is None:
now = self.get_now()
else:
@@ -365,45 +481,76 @@ class OpenSSLCLIBackend(CryptoBackend):
return (not_after - now).days
def create_chain_matcher(self, criterium):
'''
"""
Given a Criterium object, creates a ChainMatcher object.
'''
raise BackendException('Alternate chain matching can only be used with the "cryptography" backend.')
"""
raise BackendException(
'Alternate chain matching can only be used with the "cryptography" backend.'
)
def get_cert_information(self, cert_filename=None, cert_content=None):
'''
"""
Return some information on a X.509 certificate as a CertificateInformation object.
'''
"""
filename = cert_filename
data = None
if cert_filename is not None:
cert_filename_suffix = ' in {0}'.format(cert_filename)
cert_filename_suffix = " in {0}".format(cert_filename)
else:
filename = '/dev/stdin'
filename = "/dev/stdin"
data = to_bytes(cert_content)
cert_filename_suffix = ''
cert_filename_suffix = ""
openssl_cert_cmd = [self.openssl_binary, "x509", "-in", filename, "-noout", "-text"]
openssl_cert_cmd = [
self.openssl_binary,
"x509",
"-in",
filename,
"-noout",
"-text",
]
rc, out, err = self.module.run_command(
openssl_cert_cmd, data=data, check_rc=False, binary_data=True, environ_update=_OPENSSL_ENVIRONMENT_UPDATE)
openssl_cert_cmd,
data=data,
check_rc=False,
binary_data=True,
environ_update=_OPENSSL_ENVIRONMENT_UPDATE,
)
if rc != 0:
raise BackendException('Error while running {cmd}: {stderr}'.format(cmd=' '.join(openssl_cert_cmd), stderr=to_text(err)))
raise BackendException(
"Error while running {cmd}: {stderr}".format(
cmd=" ".join(openssl_cert_cmd), stderr=to_text(err)
)
)
out_text = to_text(out, errors='surrogate_or_strict')
out_text = to_text(out, errors="surrogate_or_strict")
not_after = _extract_date(out_text, 'Not After', cert_filename_suffix=cert_filename_suffix)
not_before = _extract_date(out_text, 'Not Before', cert_filename_suffix=cert_filename_suffix)
not_after = _extract_date(
out_text, "Not After", cert_filename_suffix=cert_filename_suffix
)
not_before = _extract_date(
out_text, "Not Before", cert_filename_suffix=cert_filename_suffix
)
sn = re.search(
r" Serial Number: ([0-9]+)",
to_text(out, errors='surrogate_or_strict'), re.MULTILINE | re.DOTALL)
to_text(out, errors="surrogate_or_strict"),
re.MULTILINE | re.DOTALL,
)
if sn:
serial = int(sn.group(1))
else:
serial = convert_bytes_to_int(_extract_octets(out_text, 'Serial Number', required=True))
serial = convert_bytes_to_int(
_extract_octets(out_text, "Serial Number", required=True)
)
ski = _extract_octets(out_text, 'X509v3 Subject Key Identifier', required=False)
aki = _extract_octets(out_text, 'X509v3 Authority Key Identifier', required=False, potential_prefixes=['keyid:', ''])
ski = _extract_octets(out_text, "X509v3 Subject Key Identifier", required=False)
aki = _extract_octets(
out_text,
"X509v3 Authority Key Identifier",
required=False,
potential_prefixes=["keyid:", ""],
)
return CertificateInformation(
not_valid_after=not_after,

View File

@@ -36,18 +36,20 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
CertificateInformation = namedtuple(
'CertificateInformation',
"CertificateInformation",
(
'not_valid_after',
'not_valid_before',
'serial_number',
'subject_key_identifier',
'authority_key_identifier',
"not_valid_after",
"not_valid_before",
"serial_number",
"subject_key_identifier",
"authority_key_identifier",
),
)
_FRACTIONAL_MATCHER = re.compile(r'^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(|\.\d+)(Z|[+-]\d{2}:?\d{2}.*)$')
_FRACTIONAL_MATCHER = re.compile(
r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(|\.\d+)(Z|[+-]\d{2}:?\d{2}.*)$"
)
def _reduce_fractional_digits(timestamp_str):
@@ -57,13 +59,15 @@ def _reduce_fractional_digits(timestamp_str):
# RFC 3339 (https://www.rfc-editor.org/info/rfc3339)
m = _FRACTIONAL_MATCHER.match(timestamp_str)
if not m:
raise BackendException('Cannot parse ISO 8601 timestamp {0!r}'.format(timestamp_str))
raise BackendException(
"Cannot parse ISO 8601 timestamp {0!r}".format(timestamp_str)
)
timestamp, fractional, timezone = m.groups()
if len(fractional) > 7:
# Python does not support anything smaller than microseconds
# (Golang supports nanoseconds, Boulder often emits more fractional digits, which Python chokes on)
fractional = fractional[:7]
return '%s%s%s' % (timestamp, fractional, timezone)
return "%s%s%s" % (timestamp, fractional, timezone)
def _parse_acme_timestamp(timestamp_str, with_timezone):
@@ -72,15 +76,26 @@ def _parse_acme_timestamp(timestamp_str, with_timezone):
"""
# RFC 3339 (https://www.rfc-editor.org/info/rfc3339)
timestamp_str = _reduce_fractional_digits(timestamp_str)
for format in ('%Y-%m-%dT%H:%M:%SZ', '%Y-%m-%dT%H:%M:%S.%fZ', '%Y-%m-%dT%H:%M:%S%z', '%Y-%m-%dT%H:%M:%S.%f%z'):
for format in (
"%Y-%m-%dT%H:%M:%SZ",
"%Y-%m-%dT%H:%M:%S.%fZ",
"%Y-%m-%dT%H:%M:%S%z",
"%Y-%m-%dT%H:%M:%S.%f%z",
):
# Note that %z will not work with Python 2... https://stackoverflow.com/a/27829491
try:
result = datetime.datetime.strptime(timestamp_str, format)
except ValueError:
pass
else:
return ensure_utc_timezone(result) if with_timezone else remove_timezone(result)
raise BackendException('Cannot parse ISO 8601 timestamp {0!r}'.format(timestamp_str))
return (
ensure_utc_timezone(result)
if with_timezone
else remove_timezone(result)
)
raise BackendException(
"Cannot parse ISO 8601 timestamp {0!r}".format(timestamp_str)
)
@six.add_metaclass(abc.ABCMeta)
@@ -98,30 +113,34 @@ class CryptoBackend(object):
def parse_module_parameter(self, value, name):
try:
return get_relative_time_option(value, name, backend='cryptography', with_timezone=self._with_timezone)
return get_relative_time_option(
value, name, backend="cryptography", with_timezone=self._with_timezone
)
except OpenSSLObjectError as exc:
raise BackendException(to_native(exc))
def interpolate_timestamp(self, timestamp_start, timestamp_end, percentage):
start = get_epoch_seconds(timestamp_start)
end = get_epoch_seconds(timestamp_end)
return from_epoch_seconds(start + percentage * (end - start), with_timezone=self._with_timezone)
return from_epoch_seconds(
start + percentage * (end - start), with_timezone=self._with_timezone
)
def get_utc_datetime(self, *args, **kwargs):
kwargs_ext = dict(kwargs)
if self._with_timezone and ('tzinfo' not in kwargs_ext and len(args) < 8):
kwargs_ext['tzinfo'] = UTC
if self._with_timezone and ("tzinfo" not in kwargs_ext and len(args) < 8):
kwargs_ext["tzinfo"] = UTC
result = datetime.datetime(*args, **kwargs_ext)
if self._with_timezone and ('tzinfo' in kwargs or len(args) >= 8):
if self._with_timezone and ("tzinfo" in kwargs or len(args) >= 8):
result = ensure_utc_timezone(result)
return result
@abc.abstractmethod
def parse_key(self, key_file=None, key_content=None, passphrase=None):
'''
"""
Parses an RSA or Elliptic Curve key file in PEM format and returns key_data.
Raises KeyParsingError in case of errors.
'''
"""
@abc.abstractmethod
def sign(self, payload64, protected64, key_data):
@@ -129,54 +148,56 @@ class CryptoBackend(object):
@abc.abstractmethod
def create_mac_key(self, alg, key):
'''Create a MAC key.'''
"""Create a MAC key."""
def get_ordered_csr_identifiers(self, csr_filename=None, csr_content=None):
'''
"""
Return a list of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
'dns' or 'ip'.
The list is deduplicated, and if a CNAME is present, it will be returned
as the first element in the result.
'''
"""
self.module.deprecate(
"Every backend must override the get_ordered_csr_identifiers() method."
" The default implementation will be removed in 3.0.0 and this method will be marked as `abstractmethod` by then.",
version='3.0.0',
collection_name='community.crypto',
version="3.0.0",
collection_name="community.crypto",
)
return sorted(
self.get_csr_identifiers(csr_filename=csr_filename, csr_content=csr_content)
)
return sorted(self.get_csr_identifiers(csr_filename=csr_filename, csr_content=csr_content))
@abc.abstractmethod
def get_csr_identifiers(self, csr_filename=None, csr_content=None):
'''
"""
Return a set of requested identifiers (CN and SANs) for the CSR.
Each identifier is a pair (type, identifier), where type is either
'dns' or 'ip'.
'''
"""
@abc.abstractmethod
def get_cert_days(self, cert_filename=None, cert_content=None, now=None):
'''
"""
Return the days the certificate in cert_filename remains valid and -1
if the file was not found. If cert_filename contains more than one
certificate, only the first one will be considered.
If now is not specified, datetime.datetime.now() is used.
'''
"""
@abc.abstractmethod
def create_chain_matcher(self, criterium):
'''
"""
Given a Criterium object, creates a ChainMatcher object.
'''
"""
def get_cert_information(self, cert_filename=None, cert_content=None):
'''
"""
Return some information on a X.509 certificate as a CertificateInformation object.
'''
"""
# Not implementing this method in a backend is DEPRECATED and will be
# disallowed in community.crypto 3.0.0. This method will be marked as
# @abstractmethod by then.
raise BackendException('This backend does not support get_cert_information()')
raise BackendException("This backend does not support get_cert_information()")

View File

@@ -37,38 +37,44 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.utils import
class ACMECertificateClient(object):
'''
"""
ACME v2 client class. Uses an ACME account object and a CSR to
start and validate ACME challenges and download the respective
certificates.
'''
"""
def __init__(self, module, backend, client=None, account=None):
self.module = module
self.version = module.params['acme_version']
self.csr = module.params.get('csr')
self.csr_content = module.params.get('csr_content')
self.version = module.params["acme_version"]
self.csr = module.params.get("csr")
self.csr_content = module.params.get("csr_content")
if client is None:
client = ACMEClient(module, backend)
self.client = client
if account is None:
account = ACMEAccount(self.client)
self.account = account
self.order_uri = module.params.get('order_uri')
self.order_creation_error_strategy = module.params.get('order_creation_error_strategy', 'auto')
self.order_creation_max_retries = module.params.get('order_creation_max_retries', 3)
self.order_uri = module.params.get("order_uri")
self.order_creation_error_strategy = module.params.get(
"order_creation_error_strategy", "auto"
)
self.order_creation_max_retries = module.params.get(
"order_creation_max_retries", 3
)
# Make sure account exists
dummy, account_data = self.account.setup_account(allow_creation=False)
if account_data is None:
raise ModuleFailException(msg='Account does not exist or is deactivated.')
raise ModuleFailException(msg="Account does not exist or is deactivated.")
if self.csr is not None and not os.path.exists(self.csr):
raise ModuleFailException("CSR %s not found" % (self.csr))
# Extract list of identifiers from CSR
if self.csr is not None or self.csr_content is not None:
self.identifiers = self.client.backend.get_ordered_csr_identifiers(csr_filename=self.csr, csr_content=self.csr_content)
self.identifiers = self.client.backend.get_ordered_csr_identifiers(
csr_filename=self.csr, csr_content=self.csr_content
)
else:
self.identifiers = None
@@ -78,24 +84,31 @@ class ACMECertificateClient(object):
for criterium_idx, criterium in enumerate(select_chain):
try:
select_chain_matcher.append(
self.client.backend.create_chain_matcher(Criterium(criterium, index=criterium_idx)))
self.client.backend.create_chain_matcher(
Criterium(criterium, index=criterium_idx)
)
)
except ValueError as exc:
self.module.warn('Error while parsing criterium: {error}. Ignoring criterium.'.format(error=exc))
self.module.warn(
"Error while parsing criterium: {error}. Ignoring criterium.".format(
error=exc
)
)
return select_chain_matcher
def load_order(self):
if not self.order_uri:
raise ModuleFailException('The order URI has not been provided')
raise ModuleFailException("The order URI has not been provided")
order = Order.from_url(self.client, self.order_uri)
order.load_authorizations(self.client)
return order
def create_order(self, replaces_cert_id=None, profile=None):
'''
"""
Create a new order.
'''
"""
if self.identifiers is None:
raise ModuleFailException('No identifiers have been provided')
raise ModuleFailException("No identifiers have been provided")
order = Order.create_with_error_handling(
self.client,
self.identifiers,
@@ -110,64 +123,78 @@ class ACMECertificateClient(object):
return order
def get_challenges_data(self, order):
'''
"""
Get challenge details.
Return a tuple of generic challenge details, and specialized DNS challenge details.
'''
"""
# Get general challenge data
data = []
for authz in order.authorizations.values():
# Skip valid authentications: their challenges are already valid
# and do not need to be returned
if authz.status == 'valid':
if authz.status == "valid":
continue
data.append(dict(
identifier=authz.identifier,
identifier_type=authz.identifier_type,
challenges=authz.get_challenge_data(self.client),
))
data.append(
dict(
identifier=authz.identifier,
identifier_type=authz.identifier_type,
challenges=authz.get_challenge_data(self.client),
)
)
# Get DNS challenge data
data_dns = {}
dns_challenge_type = 'dns-01'
dns_challenge_type = "dns-01"
for entry in data:
dns_challenge = entry['challenges'].get(dns_challenge_type)
dns_challenge = entry["challenges"].get(dns_challenge_type)
if dns_challenge:
values = data_dns.get(dns_challenge['record'])
values = data_dns.get(dns_challenge["record"])
if values is None:
values = []
data_dns[dns_challenge['record']] = values
values.append(dns_challenge['resource_value'])
data_dns[dns_challenge["record"]] = values
values.append(dns_challenge["resource_value"])
return data, data_dns
def check_that_authorizations_can_be_used(self, order):
bad_authzs = []
for authz in order.authorizations.values():
if authz.status not in ('valid', 'pending'):
bad_authzs.append('{authz} (status={status!r})'.format(
authz=authz.combined_identifier,
status=authz.status,
))
if authz.status not in ("valid", "pending"):
bad_authzs.append(
"{authz} (status={status!r})".format(
authz=authz.combined_identifier,
status=authz.status,
)
)
if bad_authzs:
raise ModuleFailException(
'Some of the authorizations for the order are in a bad state, so the order'
' can no longer be satisfied: {bad_authzs}'.format(
bad_authzs=', '.join(sorted(bad_authzs)),
"Some of the authorizations for the order are in a bad state, so the order"
" can no longer be satisfied: {bad_authzs}".format(
bad_authzs=", ".join(sorted(bad_authzs)),
),
)
def collect_invalid_authzs(self, order):
return [authz for authz in order.authorizations.values() if authz.status == 'invalid']
return [
authz
for authz in order.authorizations.values()
if authz.status == "invalid"
]
def collect_pending_authzs(self, order):
return [authz for authz in order.authorizations.values() if authz.status == 'pending']
return [
authz
for authz in order.authorizations.values()
if authz.status == "pending"
]
def call_validate(self, pending_authzs, get_challenge, wait=True):
authzs_with_challenges_to_wait_for = []
for authz in pending_authzs:
challenge_type = get_challenge(authz)
authz.call_validate(self.client, challenge_type, wait=wait)
authzs_with_challenges_to_wait_for.append((authz, challenge_type, authz.find_challenge(challenge_type)))
authzs_with_challenges_to_wait_for.append(
(authz, challenge_type, authz.find_challenge(challenge_type))
)
return authzs_with_challenges_to_wait_for
def wait_for_validation(self, authzs_to_wait_for):
@@ -179,27 +206,45 @@ class ACMECertificateClient(object):
try:
alt_cert = CertificateChain.download(self.client, alternate)
except ModuleFailException as e:
self.module.warn('Error while downloading alternative certificate {0}: {1}'.format(alternate, e))
self.module.warn(
"Error while downloading alternative certificate {0}: {1}".format(
alternate, e
)
)
continue
if alt_cert.cert is not None:
alternate_chains.append(alt_cert)
else:
self.module.warn('Error while downloading alternative certificate {0}: no certificate found'.format(alternate))
self.module.warn(
"Error while downloading alternative certificate {0}: no certificate found".format(
alternate
)
)
return alternate_chains
def download_certificate(self, order, download_all_chains=True):
'''
"""
Download certificate from a valid oder.
'''
if order.status != 'valid':
raise ModuleFailException('The order must be valid, but has state {state!r}!'.format(state=order.state))
"""
if order.status != "valid":
raise ModuleFailException(
"The order must be valid, but has state {state!r}!".format(
state=order.state
)
)
if not order.certificate_uri:
raise ModuleFailException("Order's crtificate URL {url!r} is empty!".format(url=order.certificate_uri))
raise ModuleFailException(
"Order's crtificate URL {url!r} is empty!".format(
url=order.certificate_uri
)
)
cert = CertificateChain.download(self.client, order.certificate_uri)
if cert.cert is None:
raise ModuleFailException('Certificate at {url} is empty!'.format(url=order.certificate_uri))
raise ModuleFailException(
"Certificate at {url} is empty!".format(url=order.certificate_uri)
)
alternate_chains = None
if download_all_chains:
@@ -208,15 +253,18 @@ class ACMECertificateClient(object):
return cert, alternate_chains
def get_certificate(self, order, download_all_chains=True):
'''
"""
Request a new certificate and downloads it, and optionally all certificate chains.
First verifies whether all authorizations are valid; if not, aborts with an error.
'''
"""
if self.csr is None and self.csr_content is None:
raise ModuleFailException('No CSR has been provided')
raise ModuleFailException("No CSR has been provided")
for identifier, authz in order.authorizations.items():
if authz.status != 'valid':
authz.raise_error('Status is {status!r} and not "valid"'.format(status=authz.status), module=self.module)
if authz.status != "valid":
authz.raise_error(
'Status is {status!r} and not "valid"'.format(status=authz.status),
module=self.module,
)
order.finalize(self.client, pem_to_der(self.csr, self.csr_content))
@@ -226,30 +274,40 @@ class ACMECertificateClient(object):
for criterium_idx, matcher in enumerate(select_chain_matcher):
for chain in chains:
if matcher.match(chain):
self.module.debug('Found matching chain for criterium {0}'.format(criterium_idx))
self.module.debug(
"Found matching chain for criterium {0}".format(criterium_idx)
)
return chain
return None
def write_cert_chain(self, cert, cert_dest=None, fullchain_dest=None, chain_dest=None):
def write_cert_chain(
self, cert, cert_dest=None, fullchain_dest=None, chain_dest=None
):
changed = False
if cert_dest and write_file(self.module, cert_dest, cert.cert.encode('utf8')):
if cert_dest and write_file(self.module, cert_dest, cert.cert.encode("utf8")):
changed = True
if fullchain_dest and write_file(self.module, fullchain_dest, (cert.cert + "\n".join(cert.chain)).encode('utf8')):
if fullchain_dest and write_file(
self.module,
fullchain_dest,
(cert.cert + "\n".join(cert.chain)).encode("utf8"),
):
changed = True
if chain_dest and write_file(self.module, chain_dest, ("\n".join(cert.chain)).encode('utf8')):
if chain_dest and write_file(
self.module, chain_dest, ("\n".join(cert.chain)).encode("utf8")
):
changed = True
return changed
def deactivate_authzs(self, order):
'''
"""
Deactivates all valid authz's. Does not raise exceptions.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
https://tools.ietf.org/html/rfc8555#section-7.5.2
'''
"""
if len(order.authorization_uris) > len(order.authorizations):
for authz_uri in order.authorization_uris:
authz = None
@@ -258,8 +316,12 @@ class ACMECertificateClient(object):
except Exception:
# ignore errors
pass
if authz is None or authz.status != 'deactivated':
self.module.warn(warning='Could not deactivate authz object {0}.'.format(authz_uri))
if authz is None or authz.status != "deactivated":
self.module.warn(
warning="Could not deactivate authz object {0}.".format(
authz_uri
)
)
else:
for authz in order.authorizations.values():
try:
@@ -267,5 +329,9 @@ class ACMECertificateClient(object):
except Exception:
# ignore errors
pass
if authz.status != 'deactivated':
self.module.warn(warning='Could not deactivate authz object {0}.'.format(authz.url))
if authz.status != "deactivated":
self.module.warn(
warning="Could not deactivate authz object {0}.".format(
authz.url
)
)

View File

@@ -28,10 +28,10 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.pem import
class CertificateChain(object):
'''
"""
Download and parse the certificate chain.
https://tools.ietf.org/html/rfc8555#section-7.4.2
'''
"""
def __init__(self, url):
self.url = url
@@ -41,86 +41,106 @@ class CertificateChain(object):
@classmethod
def download(cls, client, url):
content, info = client.get_request(url, parse_json_result=False, headers={'Accept': 'application/pem-certificate-chain'})
content, info = client.get_request(
url,
parse_json_result=False,
headers={"Accept": "application/pem-certificate-chain"},
)
if not content or not info['content-type'].startswith('application/pem-certificate-chain'):
if not content or not info["content-type"].startswith(
"application/pem-certificate-chain"
):
raise ModuleFailException(
"Cannot download certificate chain from {0}, as content type is not application/pem-certificate-chain: {1} (headers: {2})".format(
url, content, info))
url, content, info
)
)
result = cls(url)
# Parse data
certs = split_pem_list(content.decode('utf-8'), keep_inbetween=True)
certs = split_pem_list(content.decode("utf-8"), keep_inbetween=True)
if certs:
result.cert = certs[0]
result.chain = certs[1:]
process_links(info, lambda link, relation: result._process_links(client, link, relation))
process_links(
info, lambda link, relation: result._process_links(client, link, relation)
)
if result.cert is None:
raise ModuleFailException("Failed to parse certificate chain download from {0}: {1} (headers: {2})".format(url, content, info))
raise ModuleFailException(
"Failed to parse certificate chain download from {0}: {1} (headers: {2})".format(
url, content, info
)
)
return result
def _process_links(self, client, link, relation):
if relation == 'up':
if relation == "up":
# Process link-up headers if there was no chain in reply
if not self.chain:
chain_result, chain_info = client.get_request(link, parse_json_result=False)
if chain_info['status'] in [200, 201]:
chain_result, chain_info = client.get_request(
link, parse_json_result=False
)
if chain_info["status"] in [200, 201]:
self.chain.append(der_to_pem(chain_result))
elif relation == 'alternate':
elif relation == "alternate":
self.alternates.append(link)
def to_json(self):
cert = self.cert.encode('utf8')
chain = ('\n'.join(self.chain)).encode('utf8')
cert = self.cert.encode("utf8")
chain = ("\n".join(self.chain)).encode("utf8")
return {
'cert': cert,
'chain': chain,
'full_chain': cert + chain,
"cert": cert,
"chain": chain,
"full_chain": cert + chain,
}
class Criterium(object):
def __init__(self, criterium, index=None):
self.index = index
self.test_certificates = criterium['test_certificates']
self.subject = criterium['subject']
self.issuer = criterium['issuer']
self.subject_key_identifier = criterium['subject_key_identifier']
self.authority_key_identifier = criterium['authority_key_identifier']
self.test_certificates = criterium["test_certificates"]
self.subject = criterium["subject"]
self.issuer = criterium["issuer"]
self.subject_key_identifier = criterium["subject_key_identifier"]
self.authority_key_identifier = criterium["authority_key_identifier"]
@six.add_metaclass(abc.ABCMeta)
class ChainMatcher(object):
@abc.abstractmethod
def match(self, certificate):
'''
"""
Check whether a certificate chain (CertificateChain instance) matches.
'''
"""
def retrieve_acme_v1_certificate(client, csr_der):
'''
"""
Create a new certificate based on the CSR (ACME v1 protocol).
Return the certificate object as dict
https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-6.5
'''
"""
new_cert = {
"resource": "new-cert",
"csr": nopad_b64(csr_der),
}
result, info = client.send_signed_request(
client.directory['new-cert'], new_cert, error_msg='Failed to receive certificate', expected_status_codes=[200, 201])
cert = CertificateChain(info['location'])
client.directory["new-cert"],
new_cert,
error_msg="Failed to receive certificate",
expected_status_codes=[200, 201],
)
cert = CertificateChain(info["location"])
cert.cert = der_to_pem(result)
def f(link, relation):
if relation == 'up':
if relation == "up":
chain_result, chain_info = client.get_request(link, parse_json_result=False)
if chain_info['status'] in [200, 201]:
if chain_info["status"] in [200, 201]:
del cert.chain[:]
cert.chain.append(der_to_pem(chain_result))

View File

@@ -35,17 +35,19 @@ except ImportError:
def create_key_authorization(client, token):
'''
"""
Returns the key authorization for the given token
https://tools.ietf.org/html/rfc8555#section-8.1
'''
accountkey_json = json.dumps(client.account_jwk, sort_keys=True, separators=(',', ':'))
thumbprint = nopad_b64(hashlib.sha256(accountkey_json.encode('utf8')).digest())
"""
accountkey_json = json.dumps(
client.account_jwk, sort_keys=True, separators=(",", ":")
)
thumbprint = nopad_b64(hashlib.sha256(accountkey_json.encode("utf8")).digest())
return "{0}.{1}".format(token, thumbprint)
def combine_identifier(identifier_type, identifier):
return '{type}:{identifier}'.format(type=identifier_type, identifier=identifier)
return "{type}:{identifier}".format(type=identifier_type, identifier=identifier)
def normalize_combined_identifier(identifier):
@@ -56,10 +58,13 @@ def normalize_combined_identifier(identifier):
def split_identifier(identifier):
parts = identifier.split(':', 1)
parts = identifier.split(":", 1)
if len(parts) != 2:
raise ModuleFailException(
'Identifier "{identifier}" is not of the form <type>:<identifier>'.format(identifier=identifier))
'Identifier "{identifier}" is not of the form <type>:<identifier>'.format(
identifier=identifier
)
)
return parts
@@ -67,27 +72,27 @@ class Challenge(object):
def __init__(self, data, url):
self.data = data
self.type = data['type']
self.type = data["type"]
self.url = url
self.status = data['status']
self.token = data.get('token')
self.status = data["status"]
self.token = data.get("token")
@classmethod
def from_json(cls, client, data, url=None):
return cls(data, url or (data['uri'] if client.version == 1 else data['url']))
return cls(data, url or (data["uri"] if client.version == 1 else data["url"]))
def call_validate(self, client):
challenge_response = {}
if client.version == 1:
token = re.sub(r"[^A-Za-z0-9_\-]", "_", self.token)
key_authorization = create_key_authorization(client, token)
challenge_response['resource'] = 'challenge'
challenge_response['keyAuthorization'] = key_authorization
challenge_response['type'] = self.type
challenge_response["resource"] = "challenge"
challenge_response["keyAuthorization"] = key_authorization
challenge_response["type"] = self.type
client.send_signed_request(
self.url,
challenge_response,
error_msg='Failed to validate challenge',
error_msg="Failed to validate challenge",
expected_status_codes=[200, 202],
)
@@ -98,40 +103,44 @@ class Challenge(object):
token = re.sub(r"[^A-Za-z0-9_\-]", "_", self.token)
key_authorization = create_key_authorization(client, token)
if self.type == 'http-01':
if self.type == "http-01":
# https://tools.ietf.org/html/rfc8555#section-8.3
return {
'resource': '.well-known/acme-challenge/{token}'.format(token=token),
'resource_value': key_authorization,
"resource": ".well-known/acme-challenge/{token}".format(token=token),
"resource_value": key_authorization,
}
if self.type == 'dns-01':
if identifier_type != 'dns':
if self.type == "dns-01":
if identifier_type != "dns":
return None
# https://tools.ietf.org/html/rfc8555#section-8.4
resource = '_acme-challenge'
resource = "_acme-challenge"
value = nopad_b64(hashlib.sha256(to_bytes(key_authorization)).digest())
record = '{0}.{1}'.format(resource, identifier[2:] if identifier.startswith('*.') else identifier)
record = "{0}.{1}".format(
resource, identifier[2:] if identifier.startswith("*.") else identifier
)
return {
'resource': resource,
'resource_value': value,
'record': record,
"resource": resource,
"resource_value": value,
"record": record,
}
if self.type == 'tls-alpn-01':
if self.type == "tls-alpn-01":
# https://www.rfc-editor.org/rfc/rfc8737.html#section-3
if identifier_type == 'ip':
if identifier_type == "ip":
# IPv4/IPv6 address: use reverse mapping (RFC1034, RFC3596)
resource = ipaddress.ip_address(identifier).reverse_pointer
if not resource.endswith('.'):
resource += '.'
if not resource.endswith("."):
resource += "."
else:
resource = identifier
value = base64.b64encode(hashlib.sha256(to_bytes(key_authorization)).digest())
value = base64.b64encode(
hashlib.sha256(to_bytes(key_authorization)).digest()
)
return {
'resource': resource,
'resource_original': combine_identifier(identifier_type, identifier),
'resource_value': value,
"resource": resource,
"resource_original": combine_identifier(identifier_type, identifier),
"resource_value": value,
}
# Unknown challenge type: ignore
@@ -140,25 +149,28 @@ class Challenge(object):
class Authorization(object):
def _setup(self, client, data):
data['uri'] = self.url
data["uri"] = self.url
self.data = data
# While 'challenges' is a required field, apparently not every CA cares
# (https://github.com/ansible-collections/community.crypto/issues/824)
if data.get('challenges'):
self.challenges = [Challenge.from_json(client, challenge) for challenge in data['challenges']]
if data.get("challenges"):
self.challenges = [
Challenge.from_json(client, challenge)
for challenge in data["challenges"]
]
else:
self.challenges = []
if client.version == 1 and 'status' not in data:
if client.version == 1 and "status" not in data:
# https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-6.1.2
# "status (required, string): ...
# If this field is missing, then the default value is "pending"."
self.status = 'pending'
self.status = "pending"
else:
self.status = data['status']
self.identifier = data['identifier']['value']
self.identifier_type = data['identifier']['type']
if data.get('wildcard', False):
self.identifier = '*.{0}'.format(self.identifier)
self.status = data["status"]
self.identifier = data["identifier"]["value"]
self.identifier_type = data["identifier"]["type"]
if data.get("wildcard", False):
self.identifier = "*.{0}".format(self.identifier)
def __init__(self, url):
self.url = url
@@ -183,11 +195,11 @@ class Authorization(object):
@classmethod
def create(cls, client, identifier_type, identifier):
'''
"""
Create a new authorization for the given identifier.
Return the authorization object of the new authorization
https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-6.4
'''
"""
new_authz = {
"identifier": {
"type": identifier_type,
@@ -195,16 +207,22 @@ class Authorization(object):
},
}
if client.version == 1:
url = client.directory['new-authz']
url = client.directory["new-authz"]
new_authz["resource"] = "new-authz"
else:
if 'newAuthz' not in client.directory.directory:
raise ACMEProtocolException(client.module, 'ACME endpoint does not support pre-authorization')
url = client.directory['newAuthz']
if "newAuthz" not in client.directory.directory:
raise ACMEProtocolException(
client.module, "ACME endpoint does not support pre-authorization"
)
url = client.directory["newAuthz"]
result, info = client.send_signed_request(
url, new_authz, error_msg='Failed to request challenges', expected_status_codes=[200, 201])
return cls.from_json(client, result, info['location'])
url,
new_authz,
error_msg="Failed to request challenges",
expected_status_codes=[200, 201],
)
return cls.from_json(client, result, info["location"])
@property
def combined_identifier(self):
@@ -220,39 +238,44 @@ class Authorization(object):
return changed
def get_challenge_data(self, client):
'''
"""
Returns a dict with the data for all proposed (and supported) challenges
of the given authorization.
'''
"""
data = {}
for challenge in self.challenges:
validation_data = challenge.get_validation_data(client, self.identifier_type, self.identifier)
validation_data = challenge.get_validation_data(
client, self.identifier_type, self.identifier
)
if validation_data is not None:
data[challenge.type] = validation_data
return data
def raise_error(self, error_msg, module=None):
'''
"""
Aborts with a specific error for a challenge.
'''
"""
error_details = []
# multiple challenges could have failed at this point, gather error
# details for all of them before failing
for challenge in self.challenges:
if challenge.status == 'invalid':
msg = 'Challenge {type}'.format(type=challenge.type)
if 'error' in challenge.data:
msg = '{msg}: {problem}'.format(
if challenge.status == "invalid":
msg = "Challenge {type}".format(type=challenge.type)
if "error" in challenge.data:
msg = "{msg}: {problem}".format(
msg=msg,
problem=format_error_problem(challenge.data['error'], subproblem_prefix='{0}.'.format(challenge.type)),
problem=format_error_problem(
challenge.data["error"],
subproblem_prefix="{0}.".format(challenge.type),
),
)
error_details.append(msg)
raise ACMEProtocolException(
module,
'Failed to validate challenge for {identifier}: {error}. {details}'.format(
"Failed to validate challenge for {identifier}: {error}. {details}".format(
identifier=self.combined_identifier,
error=error_msg,
details='; '.join(error_details),
details="; ".join(error_details),
),
extras=dict(
identifier=self.combined_identifier,
@@ -269,88 +292,90 @@ class Authorization(object):
def wait_for_validation(self, client, callenge_type):
while True:
self.refresh(client)
if self.status in ['valid', 'invalid', 'revoked']:
if self.status in ["valid", "invalid", "revoked"]:
break
time.sleep(2)
if self.status == 'invalid':
if self.status == "invalid":
self.raise_error('Status is "invalid"', module=client.module)
return self.status == 'valid'
return self.status == "valid"
def call_validate(self, client, challenge_type, wait=True):
'''
"""
Validate the authorization provided in the auth dict. Returns True
when the validation was successful and False when it was not.
'''
"""
challenge = self.find_challenge(challenge_type)
if challenge is None:
raise ModuleFailException('Found no challenge of type "{challenge}" for identifier {identifier}!'.format(
challenge=challenge_type,
identifier=self.combined_identifier,
))
raise ModuleFailException(
'Found no challenge of type "{challenge}" for identifier {identifier}!'.format(
challenge=challenge_type,
identifier=self.combined_identifier,
)
)
challenge.call_validate(client)
if not wait:
return self.status == 'valid'
return self.status == "valid"
return self.wait_for_validation(client, challenge_type)
def can_deactivate(self):
'''
"""
Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
https://tools.ietf.org/html/rfc8555#section-7.5.2
'''
return self.status in ('valid', 'pending')
"""
return self.status in ("valid", "pending")
def deactivate(self, client):
'''
"""
Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
https://tools.ietf.org/html/rfc8555#section-7.5.2
'''
"""
if not self.can_deactivate():
return
authz_deactivate = {
'status': 'deactivated'
}
authz_deactivate = {"status": "deactivated"}
if client.version == 1:
authz_deactivate['resource'] = 'authz'
result, info = client.send_signed_request(self.url, authz_deactivate, fail_on_error=False)
if 200 <= info['status'] < 300 and result.get('status') == 'deactivated':
self.status = 'deactivated'
authz_deactivate["resource"] = "authz"
result, info = client.send_signed_request(
self.url, authz_deactivate, fail_on_error=False
)
if 200 <= info["status"] < 300 and result.get("status") == "deactivated":
self.status = "deactivated"
return True
return False
@classmethod
def deactivate_url(cls, client, url):
'''
"""
Deactivates this authorization.
https://community.letsencrypt.org/t/authorization-deactivation/19860/2
https://tools.ietf.org/html/rfc8555#section-7.5.2
'''
"""
authz = cls(url)
authz_deactivate = {
'status': 'deactivated'
}
authz_deactivate = {"status": "deactivated"}
if client.version == 1:
authz_deactivate['resource'] = 'authz'
result, info = client.send_signed_request(url, authz_deactivate, fail_on_error=True)
authz_deactivate["resource"] = "authz"
result, info = client.send_signed_request(
url, authz_deactivate, fail_on_error=True
)
authz._setup(client, result)
return authz
def wait_for_validation(authzs, client):
'''
"""
Wait until a list of authz is valid. Fail if at least one of them is invalid or revoked.
'''
"""
while authzs:
authzs_next = []
for authz in authzs:
authz.refresh(client)
if authz.status in ['valid', 'invalid', 'revoked']:
if authz.status != 'valid':
if authz.status in ["valid", "invalid", "revoked"]:
if authz.status != "valid":
authz.raise_error('Status is not "valid"', module=client.module)
else:
authzs_next.append(authz)

View File

@@ -19,37 +19,42 @@ def format_http_status(status_code):
expl = http_responses.get(status_code)
if not expl:
return str(status_code)
return '%d %s' % (status_code, expl)
return "%d %s" % (status_code, expl)
def format_error_problem(problem, subproblem_prefix=''):
error_type = problem.get('type', 'about:blank') # https://www.rfc-editor.org/rfc/rfc7807#section-3.1
if 'title' in problem:
def format_error_problem(problem, subproblem_prefix=""):
error_type = problem.get(
"type", "about:blank"
) # https://www.rfc-editor.org/rfc/rfc7807#section-3.1
if "title" in problem:
msg = 'Error "{title}" ({type})'.format(
type=error_type,
title=problem['title'],
title=problem["title"],
)
else:
msg = 'Error {type}'.format(type=error_type)
if 'detail' in problem:
msg += ': "{detail}"'.format(detail=problem['detail'])
subproblems = problem.get('subproblems')
msg = "Error {type}".format(type=error_type)
if "detail" in problem:
msg += ': "{detail}"'.format(detail=problem["detail"])
subproblems = problem.get("subproblems")
if subproblems is not None:
msg = '{msg} Subproblems:'.format(msg=msg)
msg = "{msg} Subproblems:".format(msg=msg)
for index, problem in enumerate(subproblems):
index_str = '{prefix}{index}'.format(prefix=subproblem_prefix, index=index)
msg = '{msg}\n({index}) {problem}'.format(
index_str = "{prefix}{index}".format(prefix=subproblem_prefix, index=index)
msg = "{msg}\n({index}) {problem}".format(
msg=msg,
index=index_str,
problem=format_error_problem(problem, subproblem_prefix='{0}.'.format(index_str)),
problem=format_error_problem(
problem, subproblem_prefix="{0}.".format(index_str)
),
)
return msg
class ModuleFailException(Exception):
'''
"""
If raised, module.fail_json() will be called with the given parameters after cleanup.
'''
"""
def __init__(self, msg, **args):
super(ModuleFailException, self).__init__(self, msg)
self.msg = msg
@@ -60,7 +65,16 @@ class ModuleFailException(Exception):
class ACMEProtocolException(ModuleFailException):
def __init__(self, module, msg=None, info=None, response=None, content=None, content_json=None, extras=None):
def __init__(
self,
module,
msg=None,
info=None,
response=None,
content=None,
content_json=None,
extras=None,
):
# Try to get hold of content, if response is given and content is not provided
if content is None and content_json is None and response is not None:
try:
@@ -70,7 +84,7 @@ class ACMEProtocolException(ModuleFailException):
raise TypeError
content = response.read()
except (AttributeError, TypeError):
content = info.pop('body', None)
content = info.pop("body", None)
# Make sure that content_json is None or a dictionary
if content_json is not None and not isinstance(content_json, dict):
@@ -90,53 +104,71 @@ class ACMEProtocolException(ModuleFailException):
error_type = None
if msg is None:
msg = 'ACME request failed'
add_msg = ''
msg = "ACME request failed"
add_msg = ""
if info is not None:
url = info['url']
code = info['status']
extras['http_url'] = url
extras['http_status'] = code
url = info["url"]
code = info["status"]
extras["http_url"] = url
extras["http_status"] = code
error_code = code
if code is not None and code >= 400 and content_json is not None and 'type' in content_json:
error_type = content_json['type']
if 'status' in content_json and content_json['status'] != code:
code_msg = 'status {problem_code} (HTTP status: {http_code})'.format(
http_code=format_http_status(code), problem_code=content_json['status'])
if (
code is not None
and code >= 400
and content_json is not None
and "type" in content_json
):
error_type = content_json["type"]
if "status" in content_json and content_json["status"] != code:
code_msg = (
"status {problem_code} (HTTP status: {http_code})".format(
http_code=format_http_status(code),
problem_code=content_json["status"],
)
)
else:
code_msg = 'status {problem_code}'.format(problem_code=format_http_status(code))
if code == -1 and info.get('msg'):
code_msg = 'error: {msg}'.format(msg=info['msg'])
subproblems = content_json.pop('subproblems', None)
add_msg = ' {problem}.'.format(problem=format_error_problem(content_json))
extras['problem'] = content_json
extras['subproblems'] = subproblems or []
code_msg = "status {problem_code}".format(
problem_code=format_http_status(code)
)
if code == -1 and info.get("msg"):
code_msg = "error: {msg}".format(msg=info["msg"])
subproblems = content_json.pop("subproblems", None)
add_msg = " {problem}.".format(
problem=format_error_problem(content_json)
)
extras["problem"] = content_json
extras["subproblems"] = subproblems or []
if subproblems is not None:
add_msg = '{add_msg} Subproblems:'.format(add_msg=add_msg)
add_msg = "{add_msg} Subproblems:".format(add_msg=add_msg)
for index, problem in enumerate(subproblems):
add_msg = '{add_msg}\n({index}) {problem}.'.format(
add_msg = "{add_msg}\n({index}) {problem}.".format(
add_msg=add_msg,
index=index,
problem=format_error_problem(problem, subproblem_prefix='{0}.'.format(index)),
problem=format_error_problem(
problem, subproblem_prefix="{0}.".format(index)
),
)
else:
code_msg = 'HTTP status {code}'.format(code=format_http_status(code))
if code == -1 and info.get('msg'):
code_msg = 'error: {msg}'.format(msg=info['msg'])
code_msg = "HTTP status {code}".format(code=format_http_status(code))
if code == -1 and info.get("msg"):
code_msg = "error: {msg}".format(msg=info["msg"])
if content_json is not None:
add_msg = ' The JSON error result: {content}'.format(content=content_json)
add_msg = " The JSON error result: {content}".format(
content=content_json
)
elif content is not None:
add_msg = ' The raw error result: {content}'.format(content=to_text(content))
msg = '{msg} for {url} with {code}'.format(msg=msg, url=url, code=code_msg)
add_msg = " The raw error result: {content}".format(
content=to_text(content)
)
msg = "{msg} for {url} with {code}".format(msg=msg, url=url, code=code_msg)
elif content_json is not None:
add_msg = ' The JSON result: {content}'.format(content=content_json)
add_msg = " The JSON result: {content}".format(content=content_json)
elif content is not None:
add_msg = ' The raw result: {content}'.format(content=to_text(content))
add_msg = " The raw result: {content}".format(content=to_text(content))
super(ACMEProtocolException, self).__init__(
'{msg}.{add_msg}'.format(msg=msg, add_msg=add_msg),
**extras
"{msg}.{add_msg}".format(msg=msg, add_msg=add_msg), **extras
)
self.problem = {}
self.subproblems = []

View File

@@ -23,9 +23,9 @@ from ansible_collections.community.crypto.plugins.module_utils.acme.errors impor
)
def read_file(fn, mode='b'):
def read_file(fn, mode="b"):
try:
with open(fn, 'r' + mode) as f:
with open(fn, "r" + mode) as f:
return f.read()
except Exception as e:
raise ModuleFailException('Error while reading file "{0}": {1}'.format(fn, e))
@@ -33,14 +33,14 @@ def read_file(fn, mode='b'):
# This function was adapted from an earlier version of https://github.com/ansible/ansible/blob/devel/lib/ansible/modules/uri.py
def write_file(module, dest, content):
'''
"""
Write content to destination file dest, only if the content
has changed.
'''
"""
changed = False
# create a tempfile
fd, tmpsrc = tempfile.mkstemp(text=False)
f = os.fdopen(fd, 'wb')
f = os.fdopen(fd, "wb")
try:
f.write(content)
except Exception as err:
@@ -49,7 +49,10 @@ def write_file(module, dest, content):
except Exception:
pass
os.remove(tmpsrc)
raise ModuleFailException("failed to create temporary content file: %s" % to_native(err), exception=traceback.format_exc())
raise ModuleFailException(
"failed to create temporary content file: %s" % to_native(err),
exception=traceback.format_exc(),
)
f.close()
checksum_src = None
checksum_dest = None
@@ -75,7 +78,7 @@ def write_file(module, dest, content):
raise ModuleFailException("Destination %s not readable" % (dest))
checksum_dest = module.sha1(dest)
else:
dirname = os.path.dirname(dest) or '.'
dirname = os.path.dirname(dest) or "."
if not os.access(dirname, os.W_OK):
os.remove(tmpsrc)
raise ModuleFailException("Destination dir %s not writable" % (dirname))
@@ -85,6 +88,9 @@ def write_file(module, dest, content):
changed = True
except Exception as err:
os.remove(tmpsrc)
raise ModuleFailException("failed to copy %s to %s: %s" % (tmpsrc, dest, to_native(err)), exception=traceback.format_exc())
raise ModuleFailException(
"failed to copy %s to %s: %s" % (tmpsrc, dest, to_native(err)),
exception=traceback.format_exc(),
)
os.remove(tmpsrc)
return changed

View File

@@ -29,14 +29,14 @@ class Order(object):
def _setup(self, client, data):
self.data = data
self.status = data['status']
self.status = data["status"]
self.identifiers = []
for identifier in data['identifiers']:
self.identifiers.append((identifier['type'], identifier['value']))
self.replaces_cert_id = data.get('replaces')
self.finalize_uri = data.get('finalize')
self.certificate_uri = data.get('certificate')
self.authorization_uris = data['authorizations']
for identifier in data["identifiers"]:
self.identifiers.append((identifier["type"], identifier["value"]))
self.replaces_cert_id = data.get("replaces")
self.finalize_uri = data.get("finalize")
self.certificate_uri = data.get("certificate")
self.authorization_uris = data["authorizations"]
self.authorizations = {}
def __init__(self, url):
@@ -66,33 +66,37 @@ class Order(object):
@classmethod
def create(cls, client, identifiers, replaces_cert_id=None, profile=None):
'''
"""
Start a new certificate order (ACME v2 protocol).
https://tools.ietf.org/html/rfc8555#section-7.4
'''
"""
acme_identifiers = []
for identifier_type, identifier in identifiers:
acme_identifiers.append({
'type': identifier_type,
'value': identifier,
})
new_order = {
"identifiers": acme_identifiers
}
acme_identifiers.append(
{
"type": identifier_type,
"value": identifier,
}
)
new_order = {"identifiers": acme_identifiers}
if replaces_cert_id is not None:
new_order["replaces"] = replaces_cert_id
if profile is not None:
new_order["profile"] = profile
result, info = client.send_signed_request(
client.directory['newOrder'], new_order, error_msg='Failed to start new order', expected_status_codes=[201])
return cls.from_json(client, result, info['location'])
client.directory["newOrder"],
new_order,
error_msg="Failed to start new order",
expected_status_codes=[201],
)
return cls.from_json(client, result, info["location"])
@classmethod
def create_with_error_handling(
cls,
client,
identifiers,
error_strategy='auto',
error_strategy="auto",
error_max_retries=3,
replaces_cert_id=None,
profile=None,
@@ -113,20 +117,29 @@ class Order(object):
while True:
tries += 1
try:
return cls.create(client, identifiers, replaces_cert_id=replaces_cert_id, profile=profile)
return cls.create(
client,
identifiers,
replaces_cert_id=replaces_cert_id,
profile=profile,
)
except ACMEProtocolException as exc:
if tries <= error_max_retries + 1 and error_strategy != 'fail':
if error_strategy == 'always':
if tries <= error_max_retries + 1 and error_strategy != "fail":
if error_strategy == "always":
continue
if (
error_strategy in ('auto', 'retry_without_replaces_cert_id') and
replaces_cert_id is not None and
not (exc.error_code == 409 and exc.error_type == 'urn:ietf:params:acme:error:alreadyReplaced')
error_strategy in ("auto", "retry_without_replaces_cert_id")
and replaces_cert_id is not None
and not (
exc.error_code == 409
and exc.error_type
== "urn:ietf:params:acme:error:alreadyReplaced"
)
):
if message_callback:
message_callback(
'Stop passing `replaces={replaces}` due to error {code} {type} when creating ACME order'.format(
"Stop passing `replaces={replaces}` due to error {code} {type} when creating ACME order".format(
code=exc.error_code,
type=exc.error_type,
replaces=replaces_cert_id,
@@ -146,32 +159,41 @@ class Order(object):
def load_authorizations(self, client):
for auth_uri in self.authorization_uris:
authz = Authorization.from_url(client, auth_uri)
self.authorizations[normalize_combined_identifier(authz.combined_identifier)] = authz
self.authorizations[
normalize_combined_identifier(authz.combined_identifier)
] = authz
def wait_for_finalization(self, client):
while True:
self.refresh(client)
if self.status in ['valid', 'invalid', 'pending', 'ready']:
if self.status in ["valid", "invalid", "pending", "ready"]:
break
time.sleep(2)
if self.status != 'valid':
if self.status != "valid":
raise ACMEProtocolException(
client.module,
'Failed to wait for order to complete; got status "{status}"'.format(status=self.status),
content_json=self.data)
'Failed to wait for order to complete; got status "{status}"'.format(
status=self.status
),
content_json=self.data,
)
def finalize(self, client, csr_der, wait=True):
'''
"""
Create a new certificate based on the csr.
Return the certificate object as dict
https://tools.ietf.org/html/rfc8555#section-7.4
'''
"""
new_cert = {
"csr": nopad_b64(csr_der),
}
result, info = client.send_signed_request(
self.finalize_uri, new_cert, error_msg='Failed to finalizing order', expected_status_codes=[200])
self.finalize_uri,
new_cert,
error_msg="Failed to finalizing order",
expected_status_codes=[200],
)
# It is not clear from the RFC whether the finalize call returns the order object or not.
# Instead of using the result, we call self.refresh(client) below.
@@ -179,9 +201,12 @@ class Order(object):
self.wait_for_finalization(client)
else:
self.refresh(client)
if self.status not in ['procesing', 'valid', 'invalid']:
if self.status not in ["procesing", "valid", "invalid"]:
raise ACMEProtocolException(
client.module,
'Failed to finalize order; got status "{status}"'.format(status=self.status),
'Failed to finalize order; got status "{status}"'.format(
status=self.status
),
info=info,
content_json=result)
content_json=result,
)

View File

@@ -31,23 +31,24 @@ from ansible_collections.community.crypto.plugins.module_utils.time import (
def nopad_b64(data):
return base64.urlsafe_b64encode(data).decode('utf8').replace("=", "")
return base64.urlsafe_b64encode(data).decode("utf8").replace("=", "")
def der_to_pem(der_cert):
'''
"""
Convert the DER format certificate in der_cert to a PEM format certificate and return it.
'''
"""
return """-----BEGIN CERTIFICATE-----\n{0}\n-----END CERTIFICATE-----\n""".format(
"\n".join(textwrap.wrap(base64.b64encode(der_cert).decode('utf8'), 64)))
"\n".join(textwrap.wrap(base64.b64encode(der_cert).decode("utf8"), 64))
)
def pem_to_der(pem_filename=None, pem_content=None):
'''
"""
Load PEM file, or use PEM file's content, and convert to DER.
If PEM contains multiple entities, the first entity will be used.
'''
"""
certificate_lines = []
if pem_content is not None:
lines = pem_content.splitlines()
@@ -56,12 +57,17 @@ def pem_to_der(pem_filename=None, pem_content=None):
with open(pem_filename, "rt") as f:
lines = list(f)
except Exception as err:
raise ModuleFailException("cannot load PEM file {0}: {1}".format(pem_filename, to_native(err)), exception=traceback.format_exc())
raise ModuleFailException(
"cannot load PEM file {0}: {1}".format(pem_filename, to_native(err)),
exception=traceback.format_exc(),
)
else:
raise ModuleFailException('One of pem_filename and pem_content must be provided')
raise ModuleFailException(
"One of pem_filename and pem_content must be provided"
)
header_line_count = 0
for line in lines:
if line.startswith('-----'):
if line.startswith("-----"):
header_line_count += 1
if header_line_count == 2:
# If certificate file contains other certs appended
@@ -69,27 +75,27 @@ def pem_to_der(pem_filename=None, pem_content=None):
break
continue
certificate_lines.append(line.strip())
return base64.b64decode(''.join(certificate_lines))
return base64.b64decode("".join(certificate_lines))
def process_links(info, callback):
'''
"""
Process link header, calls callback for every link header with the URL and relation as options.
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
'''
if 'link' in info:
link = info['link']
"""
if "link" in info:
link = info["link"]
for url, relation in re.findall(r'<([^>]+)>;\s*rel="(\w+)"', link):
callback(unquote(url), relation)
def parse_retry_after(value, relative_with_timezone=True, now=None):
'''
"""
Parse the value of a Retry-After header and return a timestamp.
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
'''
"""
# First try a number of seconds
try:
delta = datetime.timedelta(seconds=int(value))
@@ -100,11 +106,11 @@ def parse_retry_after(value, relative_with_timezone=True, now=None):
pass
try:
return datetime.datetime.strptime(value, '%a, %d %b %Y %H:%M:%S GMT')
return datetime.datetime.strptime(value, "%a, %d %b %Y %H:%M:%S GMT")
except ValueError:
pass
raise ValueError('Cannot parse Retry-After header value %s' % repr(value))
raise ValueError("Cannot parse Retry-After header value %s" % repr(value))
def compute_cert_id(
@@ -116,20 +122,26 @@ def compute_cert_id(
):
# Obtain certificate info if not provided
if cert_info is None:
cert_info = backend.get_cert_information(cert_filename=cert_filename, cert_content=cert_content)
cert_info = backend.get_cert_information(
cert_filename=cert_filename, cert_content=cert_content
)
# Convert Authority Key Identifier to string
if cert_info.authority_key_identifier is None:
if none_if_required_information_is_missing:
return None
raise ModuleFailException('Certificate has no Authority Key Identifier extension')
aki = to_native(base64.urlsafe_b64encode(cert_info.authority_key_identifier)).replace('=', '')
raise ModuleFailException(
"Certificate has no Authority Key Identifier extension"
)
aki = to_native(
base64.urlsafe_b64encode(cert_info.authority_key_identifier)
).replace("=", "")
# Convert serial number to string
serial_bytes = convert_int_to_bytes(cert_info.serial_number)
if ord(serial_bytes[:1]) >= 128:
serial_bytes = b'\x00' + serial_bytes
serial = to_native(base64.urlsafe_b64encode(serial_bytes)).replace('=', '')
serial_bytes = b"\x00" + serial_bytes
serial = to_native(base64.urlsafe_b64encode(serial_bytes)).replace("=", "")
# Compose cert ID
return '{aki}.{serial}'.format(aki=aki, serial=serial)
return "{aki}.{serial}".format(aki=aki, serial=serial)

View File

@@ -20,7 +20,15 @@ def _ensure_list(value):
class ArgumentSpec:
def __init__(self, argument_spec=None, mutually_exclusive=None, required_together=None, required_one_of=None, required_if=None, required_by=None):
def __init__(
self,
argument_spec=None,
mutually_exclusive=None,
required_together=None,
required_one_of=None,
required_if=None,
required_by=None,
):
self.argument_spec = argument_spec or {}
self.mutually_exclusive = _ensure_list(mutually_exclusive)
self.required_together = _ensure_list(required_together)
@@ -32,7 +40,14 @@ class ArgumentSpec:
self.argument_spec.update(kwargs)
return self
def update(self, mutually_exclusive=None, required_together=None, required_one_of=None, required_if=None, required_by=None):
def update(
self,
mutually_exclusive=None,
required_together=None,
required_one_of=None,
required_if=None,
required_by=None,
):
if mutually_exclusive:
self.mutually_exclusive.extend(mutually_exclusive)
if required_together:
@@ -68,10 +83,11 @@ class ArgumentSpec:
required_one_of=self.required_one_of,
required_if=self.required_if,
required_by=self.required_by,
**kwargs)
**kwargs
)
def create_ansible_module(self, **kwargs):
return self.create_ansible_module_helper(AnsibleModule, (), **kwargs)
__all__ = ('ArgumentSpec', )
__all__ = ("ArgumentSpec",)

View File

@@ -31,8 +31,10 @@ type:
value:
The value to encode, the format of this value depends on the <type> specified.
"""
ASN1_STRING_REGEX = re.compile(r'^((?P<tag_type>IMPLICIT|EXPLICIT):(?P<tag_number>\d+)(?P<tag_class>U|A|P|C)?,)?'
r'(?P<value_type>[\w\d]+):(?P<value>.*)')
ASN1_STRING_REGEX = re.compile(
r"^((?P<tag_type>IMPLICIT|EXPLICIT):(?P<tag_number>\d+)(?P<tag_class>U|A|P|C)?,)?"
r"(?P<value_type>[\w\d]+):(?P<value>.*)"
)
class TagClass:
@@ -48,7 +50,7 @@ class TagNumber:
def _pack_octet_integer(value):
""" Packs an integer value into 1 or multiple octets. """
"""Packs an integer value into 1 or multiple octets."""
# NOTE: This is *NOT* the same as packing an ASN.1 INTEGER like value.
octets = bytearray()
@@ -70,37 +72,41 @@ def _pack_octet_integer(value):
def serialize_asn1_string_as_der(value):
""" Deserializes an ASN.1 string to a DER encoded byte string. """
"""Deserializes an ASN.1 string to a DER encoded byte string."""
asn1_match = ASN1_STRING_REGEX.match(value)
if not asn1_match:
raise ValueError("The ASN.1 serialized string must be in the format [modifier,]type[:value]")
raise ValueError(
"The ASN.1 serialized string must be in the format [modifier,]type[:value]"
)
tag_type = asn1_match.group('tag_type')
tag_number = asn1_match.group('tag_number')
tag_class = asn1_match.group('tag_class') or 'C'
value_type = asn1_match.group('value_type')
asn1_value = asn1_match.group('value')
tag_type = asn1_match.group("tag_type")
tag_number = asn1_match.group("tag_number")
tag_class = asn1_match.group("tag_class") or "C"
value_type = asn1_match.group("value_type")
asn1_value = asn1_match.group("value")
if value_type != 'UTF8':
raise ValueError('The ASN.1 serialized string is not a known type "{0}", only UTF8 types are '
'supported'.format(value_type))
if value_type != "UTF8":
raise ValueError(
'The ASN.1 serialized string is not a known type "{0}", only UTF8 types are '
"supported".format(value_type)
)
b_value = to_bytes(asn1_value, encoding='utf-8', errors='surrogate_or_strict')
b_value = to_bytes(asn1_value, encoding="utf-8", errors="surrogate_or_strict")
# We should only do a universal type tag if not IMPLICITLY tagged or the tag class is not universal.
if not tag_type or (tag_type == 'EXPLICIT' and tag_class != 'U'):
if not tag_type or (tag_type == "EXPLICIT" and tag_class != "U"):
b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value)
if tag_type:
tag_class = {
'U': TagClass.universal,
'A': TagClass.application,
'P': TagClass.private,
'C': TagClass.context_specific,
"U": TagClass.universal,
"A": TagClass.application,
"P": TagClass.private,
"C": TagClass.context_specific,
}[tag_class]
# When adding support for more types this should be looked into further. For now it works with UTF8Strings.
constructed = tag_type == 'EXPLICIT' and tag_class != TagClass.universal
constructed = tag_type == "EXPLICIT" and tag_class != TagClass.universal
b_value = pack_asn1(tag_class, constructed, int(tag_number), b_value)
return b_value
@@ -121,7 +127,7 @@ def pack_asn1(tag_class, constructed, tag_number, b_data):
# Bit 8 and 7 denotes the class.
identifier_octets = tag_class << 6
# Bit 6 denotes whether the value is primitive or constructed.
identifier_octets |= ((1 if constructed else 0) << 5)
identifier_octets |= (1 if constructed else 0) << 5
# Bits 5-1 contain the tag number, if it cannot be encoded in these 5 bits
# then they are set and another octet(s) is used to denote the tag number.

View File

@@ -36,6 +36,7 @@ __metaclass__ = type
# It must **ONLY** be used in compatibility code for older
# cryptography versions!
def obj2txt(openssl_lib, openssl_ffi, obj):
# Set to 80 on the recommendation of
# https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values

View File

@@ -21,17 +21,19 @@ for dotted, names in OID_MAP.items():
for name in names:
if name in NORMALIZE_NAMES and OID_LOOKUP[name] != dotted:
raise AssertionError(
'Name collision during setup: "{0}" for OIDs {1} and {2}'
.format(name, dotted, OID_LOOKUP[name])
'Name collision during setup: "{0}" for OIDs {1} and {2}'.format(
name, dotted, OID_LOOKUP[name]
)
)
NORMALIZE_NAMES[name] = names[0]
NORMALIZE_NAMES_SHORT[name] = names[-1]
OID_LOOKUP[name] = dotted
for alias, original in [('userID', 'userId')]:
for alias, original in [("userID", "userId")]:
if alias in NORMALIZE_NAMES:
raise AssertionError(
'Name collision during adding aliases: "{0}" (alias for "{1}") is already mapped to OID {2}'
.format(alias, original, OID_LOOKUP[alias])
'Name collision during adding aliases: "{0}" (alias for "{1}") is already mapped to OID {2}'.format(
alias, original, OID_LOOKUP[alias]
)
)
NORMALIZE_NAMES[alias] = original
NORMALIZE_NAMES_SHORT[alias] = NORMALIZE_NAMES_SHORT[original]

File diff suppressed because it is too large Load Diff

View File

@@ -27,7 +27,7 @@ try:
# actually doing that in x509_certificate, and potentially in other code,
# we need to monkey-patch __hash__ for these classes to make sure our code
# works fine.
if LooseVersion(cryptography.__version__) < LooseVersion('2.1'):
if LooseVersion(cryptography.__version__) < LooseVersion("2.1"):
# A very simply hash function which relies on the representation
# of an object to be implemented. This is the case since at least
# cryptography 1.0, see
@@ -44,7 +44,7 @@ try:
x509.OtherName.__hash__ = simple_hash
x509.RegisteredID.__hash__ = simple_hash
if LooseVersion(cryptography.__version__) < LooseVersion('1.2'):
if LooseVersion(cryptography.__version__) < LooseVersion("1.2"):
# The hash functions for the following types were added for cryptography 1.2:
# https://github.com/pyca/cryptography/commit/b642deed88a8696e5f01ce6855ccf89985fc35d0
# https://github.com/pyca/cryptography/commit/d1b5681f6db2bde7a14625538bd7907b08dfb486
@@ -55,6 +55,7 @@ try:
try:
# added in 0.5 - https://cryptography.io/en/latest/hazmat/primitives/asymmetric/dsa/
import cryptography.hazmat.primitives.asymmetric.dsa
CRYPTOGRAPHY_HAS_DSA = True
try:
# added later in 1.5
@@ -68,6 +69,7 @@ try:
try:
# added in 2.6 - https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ed25519/
import cryptography.hazmat.primitives.asymmetric.ed25519
CRYPTOGRAPHY_HAS_ED25519 = True
try:
# added with the primitive in 2.6
@@ -81,6 +83,7 @@ try:
try:
# added in 2.6 - https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ed448/
import cryptography.hazmat.primitives.asymmetric.ed448
CRYPTOGRAPHY_HAS_ED448 = True
try:
# added with the primitive in 2.6
@@ -94,6 +97,7 @@ try:
try:
# added in 0.5 - https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
import cryptography.hazmat.primitives.asymmetric.ec
CRYPTOGRAPHY_HAS_EC = True
try:
# added later in 1.5
@@ -107,6 +111,7 @@ try:
try:
# added in 0.5 - https://cryptography.io/en/latest/hazmat/primitives/asymmetric/rsa/
import cryptography.hazmat.primitives.asymmetric.rsa
CRYPTOGRAPHY_HAS_RSA = True
try:
# added later in 1.4
@@ -120,6 +125,7 @@ try:
try:
# added in 2.0 - https://cryptography.io/en/latest/hazmat/primitives/asymmetric/x25519/
import cryptography.hazmat.primitives.asymmetric.x25519
CRYPTOGRAPHY_HAS_X25519 = True
try:
# added later in 2.5
@@ -133,6 +139,7 @@ try:
try:
# added in 2.5 - https://cryptography.io/en/latest/hazmat/primitives/asymmetric/x448/
import cryptography.hazmat.primitives.asymmetric.x448
CRYPTOGRAPHY_HAS_X448 = True
except ImportError:
CRYPTOGRAPHY_HAS_X448 = False

View File

@@ -32,23 +32,25 @@ from .cryptography_support import CRYPTOGRAPHY_TIMEZONE, cryptography_decode_nam
# (https://github.com/pyca/cryptography/issues/10818)
CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE = False
if HAS_CRYPTOGRAPHY:
CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE = _LooseVersion(cryptography.__version__) >= _LooseVersion('43.0.0')
CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE = _LooseVersion(
cryptography.__version__
) >= _LooseVersion("43.0.0")
TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
if HAS_CRYPTOGRAPHY:
REVOCATION_REASON_MAP = {
'unspecified': x509.ReasonFlags.unspecified,
'key_compromise': x509.ReasonFlags.key_compromise,
'ca_compromise': x509.ReasonFlags.ca_compromise,
'affiliation_changed': x509.ReasonFlags.affiliation_changed,
'superseded': x509.ReasonFlags.superseded,
'cessation_of_operation': x509.ReasonFlags.cessation_of_operation,
'certificate_hold': x509.ReasonFlags.certificate_hold,
'privilege_withdrawn': x509.ReasonFlags.privilege_withdrawn,
'aa_compromise': x509.ReasonFlags.aa_compromise,
'remove_from_crl': x509.ReasonFlags.remove_from_crl,
"unspecified": x509.ReasonFlags.unspecified,
"key_compromise": x509.ReasonFlags.key_compromise,
"ca_compromise": x509.ReasonFlags.ca_compromise,
"affiliation_changed": x509.ReasonFlags.affiliation_changed,
"superseded": x509.ReasonFlags.superseded,
"cessation_of_operation": x509.ReasonFlags.cessation_of_operation,
"certificate_hold": x509.ReasonFlags.certificate_hold,
"privilege_withdrawn": x509.ReasonFlags.privilege_withdrawn,
"aa_compromise": x509.ReasonFlags.aa_compromise,
"remove_from_crl": x509.ReasonFlags.remove_from_crl,
}
REVOCATION_REASON_MAP_INVERSE = dict()
for k, v in REVOCATION_REASON_MAP.items():
@@ -61,50 +63,61 @@ else:
def cryptography_decode_revoked_certificate(cert):
result = {
'serial_number': cert.serial_number,
'revocation_date': get_revocation_date(cert),
'issuer': None,
'issuer_critical': False,
'reason': None,
'reason_critical': False,
'invalidity_date': None,
'invalidity_date_critical': False,
"serial_number": cert.serial_number,
"revocation_date": get_revocation_date(cert),
"issuer": None,
"issuer_critical": False,
"reason": None,
"reason_critical": False,
"invalidity_date": None,
"invalidity_date_critical": False,
}
try:
ext = cert.extensions.get_extension_for_class(x509.CertificateIssuer)
result['issuer'] = list(ext.value)
result['issuer_critical'] = ext.critical
result["issuer"] = list(ext.value)
result["issuer_critical"] = ext.critical
except x509.ExtensionNotFound:
pass
try:
ext = cert.extensions.get_extension_for_class(x509.CRLReason)
result['reason'] = ext.value.reason
result['reason_critical'] = ext.critical
result["reason"] = ext.value.reason
result["reason_critical"] = ext.critical
except x509.ExtensionNotFound:
pass
try:
ext = cert.extensions.get_extension_for_class(x509.InvalidityDate)
result['invalidity_date'] = get_invalidity_date(ext.value)
result['invalidity_date_critical'] = ext.critical
result["invalidity_date"] = get_invalidity_date(ext.value)
result["invalidity_date_critical"] = ext.critical
except x509.ExtensionNotFound:
pass
return result
def cryptography_dump_revoked(entry, idn_rewrite='ignore'):
def cryptography_dump_revoked(entry, idn_rewrite="ignore"):
return {
'serial_number': entry['serial_number'],
'revocation_date': entry['revocation_date'].strftime(TIMESTAMP_FORMAT),
'issuer':
[cryptography_decode_name(issuer, idn_rewrite=idn_rewrite) for issuer in entry['issuer']]
if entry['issuer'] is not None else None,
'issuer_critical': entry['issuer_critical'],
'reason': REVOCATION_REASON_MAP_INVERSE.get(entry['reason']) if entry['reason'] is not None else None,
'reason_critical': entry['reason_critical'],
'invalidity_date':
entry['invalidity_date'].strftime(TIMESTAMP_FORMAT)
if entry['invalidity_date'] is not None else None,
'invalidity_date_critical': entry['invalidity_date_critical'],
"serial_number": entry["serial_number"],
"revocation_date": entry["revocation_date"].strftime(TIMESTAMP_FORMAT),
"issuer": (
[
cryptography_decode_name(issuer, idn_rewrite=idn_rewrite)
for issuer in entry["issuer"]
]
if entry["issuer"] is not None
else None
),
"issuer_critical": entry["issuer_critical"],
"reason": (
REVOCATION_REASON_MAP_INVERSE.get(entry["reason"])
if entry["reason"] is not None
else None
),
"reason_critical": entry["reason_critical"],
"invalidity_date": (
entry["invalidity_date"].strftime(TIMESTAMP_FORMAT)
if entry["invalidity_date"] is not None
else None
),
"invalidity_date_critical": entry["invalidity_date_critical"],
}
@@ -114,9 +127,7 @@ def cryptography_get_signature_algorithm_oid_from_crl(crl):
except AttributeError:
# Older cryptography versions do not have signature_algorithm_oid yet
dotted = obj2txt(
crl._backend._lib,
crl._backend._ffi,
crl._x509_crl.sig_alg.algorithm
crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm
)
return x509.oid.ObjectIdentifier(dotted)

View File

@@ -38,6 +38,7 @@ try:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
_HAS_CRYPTOGRAPHY = True
except ImportError:
_HAS_CRYPTOGRAPHY = False
@@ -112,10 +113,12 @@ from .basic import (
CRYPTOGRAPHY_TIMEZONE = False
if _HAS_CRYPTOGRAPHY:
CRYPTOGRAPHY_TIMEZONE = LooseVersion(cryptography.__version__) >= LooseVersion('42.0.0')
CRYPTOGRAPHY_TIMEZONE = LooseVersion(cryptography.__version__) >= LooseVersion(
"42.0.0"
)
DOTTED_OID = re.compile(r'^\d+(?:\.\d+)+$')
DOTTED_OID = re.compile(r"^\d+(?:\.\d+)+$")
def cryptography_get_extensions_from_cert(cert):
@@ -150,7 +153,11 @@ def cryptography_get_extensions_from_cert(cert):
value=to_native(base64.b64encode(der)),
)
try:
oid = obj2txt(backend._lib, backend._ffi, backend._lib.X509_EXTENSION_get_object(ext))
oid = obj2txt(
backend._lib,
backend._ffi,
backend._lib.X509_EXTENSION_get_object(ext),
)
except AttributeError:
oid = exts[i].oid.dotted_string
result[oid] = entry
@@ -189,8 +196,10 @@ def cryptography_get_extensions_from_csr(csr):
extensions,
lambda ext: backend._lib.sk_X509_EXTENSION_pop_free(
ext,
backend._ffi.addressof(backend._lib._original_lib, "X509_EXTENSION_free")
)
backend._ffi.addressof(
backend._lib._original_lib, "X509_EXTENSION_free"
),
),
)
# With cryptography 35.0.0, we can no longer use obj2txt. Unfortunately it still does
@@ -210,7 +219,11 @@ def cryptography_get_extensions_from_csr(csr):
value=to_native(base64.b64encode(der)),
)
try:
oid = obj2txt(backend._lib, backend._ffi, backend._lib.X509_EXTENSION_get_object(ext))
oid = obj2txt(
backend._lib,
backend._ffi,
backend._lib.X509_EXTENSION_get_object(ext),
)
except AttributeError:
oid = exts[i].oid.dotted_string
result[oid] = entry
@@ -246,7 +259,7 @@ def cryptography_oid_to_name(oid, short=False):
name = names[0]
else:
name = oid._name
if name == 'Unknown OID':
if name == "Unknown OID":
name = dotted_string
if short:
return NORMALIZE_NAMES_SHORT.get(name, name)
@@ -258,104 +271,128 @@ def _get_hex(bytesstr):
if bytesstr is None:
return bytesstr
data = binascii.hexlify(bytesstr)
data = to_text(b':'.join(data[i:i + 2] for i in range(0, len(data), 2)))
data = to_text(b":".join(data[i : i + 2] for i in range(0, len(data), 2)))
return data
def _parse_hex(bytesstr):
if bytesstr is None:
return bytesstr
data = ''.join([('0' * (2 - len(p)) + p) if len(p) < 2 else p for p in to_text(bytesstr).split(':')])
data = "".join(
[
("0" * (2 - len(p)) + p) if len(p) < 2 else p
for p in to_text(bytesstr).split(":")
]
)
data = binascii.unhexlify(data)
return data
DN_COMPONENT_START_RE = re.compile(b'^ *([a-zA-z0-9.]+) *= *')
DN_HEX_LETTER = b'0123456789abcdef'
DN_COMPONENT_START_RE = re.compile(b"^ *([a-zA-z0-9.]+) *= *")
DN_HEX_LETTER = b"0123456789abcdef"
if sys.version_info[0] < 3:
_int_to_byte = chr
else:
def _int_to_byte(value):
return bytes((value, ))
return bytes((value,))
def _parse_dn_component(name, sep=b',', decode_remainder=True):
def _parse_dn_component(name, sep=b",", decode_remainder=True):
m = DN_COMPONENT_START_RE.match(name)
if not m:
raise OpenSSLObjectError(u'cannot start part in "{0}"'.format(to_text(name)))
oid = cryptography_name_to_oid(to_text(m.group(1)))
idx = len(m.group(0))
decoded_name = []
sep_str = sep + b'\\'
sep_str = sep + b"\\"
if decode_remainder:
length = len(name)
if length > idx and name[idx:idx + 1] == b'#':
if length > idx and name[idx : idx + 1] == b"#":
# Decoding a hex string
idx += 1
while idx + 1 < length:
ch1 = name[idx:idx + 1]
ch2 = name[idx + 1:idx + 2]
ch1 = name[idx : idx + 1]
ch2 = name[idx + 1 : idx + 2]
idx1 = DN_HEX_LETTER.find(ch1.lower())
idx2 = DN_HEX_LETTER.find(ch2.lower())
if idx1 < 0 or idx2 < 0:
raise OpenSSLObjectError(u'Invalid hex sequence entry "{0}"'.format(to_text(ch1 + ch2)))
raise OpenSSLObjectError(
u'Invalid hex sequence entry "{0}"'.format(to_text(ch1 + ch2))
)
idx += 2
decoded_name.append(_int_to_byte(idx1 * 16 + idx2))
else:
# Decoding a regular string
while idx < length:
i = idx
while i < length and name[i:i + 1] not in sep_str:
while i < length and name[i : i + 1] not in sep_str:
i += 1
if i > idx:
decoded_name.append(name[idx:i])
idx = i
while idx + 1 < length and name[idx:idx + 1] == b'\\':
ch = name[idx + 1:idx + 2]
while idx + 1 < length and name[idx : idx + 1] == b"\\":
ch = name[idx + 1 : idx + 2]
idx1 = DN_HEX_LETTER.find(ch.lower())
if idx1 >= 0:
if idx + 2 >= length:
raise OpenSSLObjectError(u'Hex escape sequence "\\{0}" incomplete at end of string'.format(to_text(ch)))
ch2 = name[idx + 2:idx + 3]
raise OpenSSLObjectError(
u'Hex escape sequence "\\{0}" incomplete at end of string'.format(
to_text(ch)
)
)
ch2 = name[idx + 2 : idx + 3]
idx2 = DN_HEX_LETTER.find(ch2.lower())
if idx2 < 0:
raise OpenSSLObjectError(u'Hex escape sequence "\\{0}" has invalid second letter'.format(to_text(ch + ch2)))
raise OpenSSLObjectError(
u'Hex escape sequence "\\{0}" has invalid second letter'.format(
to_text(ch + ch2)
)
)
ch = _int_to_byte(idx1 * 16 + idx2)
idx += 1
idx += 2
decoded_name.append(ch)
if idx < length and name[idx:idx + 1] == sep:
if idx < length and name[idx : idx + 1] == sep:
break
else:
decoded_name.append(name[idx:])
idx = len(name)
return x509.NameAttribute(oid, to_text(b''.join(decoded_name))), name[idx:]
return x509.NameAttribute(oid, to_text(b"".join(decoded_name))), name[idx:]
def _parse_dn(name):
'''
"""
Parse a Distinguished Name.
Can be of the form ``CN=Test, O = Something`` or ``CN = Test,O= Something``.
'''
"""
original_name = name
name = name.lstrip()
sep = b','
if name.startswith(b'/'):
sep = b'/'
sep = b","
if name.startswith(b"/"):
sep = b"/"
name = name[1:]
result = []
while name:
try:
attribute, name = _parse_dn_component(name, sep=sep)
except OpenSSLObjectError as e:
raise OpenSSLObjectError(u'Error while parsing distinguished name "{0}": {1}'.format(to_text(original_name), e))
raise OpenSSLObjectError(
u'Error while parsing distinguished name "{0}": {1}'.format(
to_text(original_name), e
)
)
result.append(attribute)
if name:
if name[0:1] != sep or len(name) < 2:
raise OpenSSLObjectError(u'Error while parsing distinguished name "{0}": unexpected end of string'.format(to_text(original_name)))
raise OpenSSLObjectError(
u'Error while parsing distinguished name "{0}": unexpected end of string'.format(
to_text(original_name)
)
)
name = name[1:]
return result
@@ -366,12 +403,16 @@ def cryptography_parse_relative_distinguished_name(rdn):
try:
names.append(_parse_dn_component(to_bytes(part), decode_remainder=False)[0])
except OpenSSLObjectError as e:
raise OpenSSLObjectError(u'Error while parsing relative distinguished name "{0}": {1}'.format(part, e))
raise OpenSSLObjectError(
u'Error while parsing relative distinguished name "{0}": {1}'.format(
part, e
)
)
return cryptography.x509.RelativeDistinguishedName(names)
def _is_ascii(value):
'''Check whether the Unicode string `value` contains only ASCII characters.'''
"""Check whether the Unicode string `value` contains only ASCII characters."""
try:
value.encode("ascii")
return True
@@ -380,195 +421,244 @@ def _is_ascii(value):
def _adjust_idn(value, idn_rewrite):
if idn_rewrite == 'ignore' or not value:
if idn_rewrite == "ignore" or not value:
return value
if idn_rewrite == 'idna' and _is_ascii(value):
if idn_rewrite == "idna" and _is_ascii(value):
return value
if idn_rewrite not in ('idna', 'unicode'):
if idn_rewrite not in ("idna", "unicode"):
raise ValueError('Invalid value for idn_rewrite: "{0}"'.format(idn_rewrite))
if not HAS_IDNA:
raise OpenSSLObjectError(
missing_required_lib('idna', reason='to transform {what} DNS name "{name}" to {dest}'.format(
name=value,
what='IDNA' if idn_rewrite == 'unicode' else 'Unicode',
dest='Unicode' if idn_rewrite == 'unicode' else 'IDNA',
)))
missing_required_lib(
"idna",
reason='to transform {what} DNS name "{name}" to {dest}'.format(
name=value,
what="IDNA" if idn_rewrite == "unicode" else "Unicode",
dest="Unicode" if idn_rewrite == "unicode" else "IDNA",
),
)
)
# Since IDNA does not like '*' or empty labels (except one empty label at the end),
# we split and let IDNA only handle labels that are neither empty or '*'.
parts = value.split(u'.')
parts = value.split(u".")
for index, part in enumerate(parts):
if part in (u'', u'*'):
if part in (u"", u"*"):
continue
try:
if idn_rewrite == 'idna':
parts[index] = idna.encode(part).decode('ascii')
elif idn_rewrite == 'unicode' and part.startswith(u'xn--'):
if idn_rewrite == "idna":
parts[index] = idna.encode(part).decode("ascii")
elif idn_rewrite == "unicode" and part.startswith(u"xn--"):
parts[index] = idna.decode(part)
except idna.IDNAError as exc2008:
try:
if idn_rewrite == 'idna':
parts[index] = part.encode('idna').decode('ascii')
elif idn_rewrite == 'unicode' and part.startswith(u'xn--'):
parts[index] = part.encode('ascii').decode('idna')
if idn_rewrite == "idna":
parts[index] = part.encode("idna").decode("ascii")
elif idn_rewrite == "unicode" and part.startswith(u"xn--"):
parts[index] = part.encode("ascii").decode("idna")
except Exception as exc2003:
raise OpenSSLObjectError(
u'Error while transforming part "{part}" of {what} DNS name "{name}" to {dest}.'
u' IDNA2008 transformation resulted in "{exc2008}", IDNA2003 transformation resulted in "{exc2003}".'.format(
part=part,
name=value,
what='IDNA' if idn_rewrite == 'unicode' else 'Unicode',
dest='Unicode' if idn_rewrite == 'unicode' else 'IDNA',
what="IDNA" if idn_rewrite == "unicode" else "Unicode",
dest="Unicode" if idn_rewrite == "unicode" else "IDNA",
exc2003=exc2003,
exc2008=exc2008,
))
return u'.'.join(parts)
)
)
return u".".join(parts)
def _adjust_idn_email(value, idn_rewrite):
idx = value.find(u'@')
idx = value.find(u"@")
if idx < 0:
return value
return u'{0}@{1}'.format(value[:idx], _adjust_idn(value[idx + 1:], idn_rewrite))
return u"{0}@{1}".format(value[:idx], _adjust_idn(value[idx + 1 :], idn_rewrite))
def _adjust_idn_url(value, idn_rewrite):
url = urlparse(value)
host = _adjust_idn(url.hostname, idn_rewrite)
if url.username is not None and url.password is not None:
host = u'{0}:{1}@{2}'.format(url.username, url.password, host)
host = u"{0}:{1}@{2}".format(url.username, url.password, host)
elif url.username is not None:
host = u'{0}@{1}'.format(url.username, host)
host = u"{0}@{1}".format(url.username, host)
if url.port is not None:
host = u'{0}:{1}'.format(host, url.port)
host = u"{0}:{1}".format(host, url.port)
return urlunparse(
ParseResult(scheme=url.scheme, netloc=host, path=url.path, params=url.params, query=url.query, fragment=url.fragment))
ParseResult(
scheme=url.scheme,
netloc=host,
path=url.path,
params=url.params,
query=url.query,
fragment=url.fragment,
)
)
def cryptography_get_name(name, what='Subject Alternative Name'):
'''
def cryptography_get_name(name, what="Subject Alternative Name"):
"""
Given a name string, returns a cryptography x509.GeneralName object.
Raises an OpenSSLObjectError if the name is unknown or cannot be parsed.
'''
"""
try:
if name.startswith('DNS:'):
return x509.DNSName(_adjust_idn(to_text(name[4:]), 'idna'))
if name.startswith('IP:'):
if name.startswith("DNS:"):
return x509.DNSName(_adjust_idn(to_text(name[4:]), "idna"))
if name.startswith("IP:"):
address = to_text(name[3:])
if '/' in address:
if "/" in address:
return x509.IPAddress(ipaddress.ip_network(address))
return x509.IPAddress(ipaddress.ip_address(address))
if name.startswith('email:'):
return x509.RFC822Name(_adjust_idn_email(to_text(name[6:]), 'idna'))
if name.startswith('URI:'):
return x509.UniformResourceIdentifier(_adjust_idn_url(to_text(name[4:]), 'idna'))
if name.startswith('RID:'):
m = re.match(r'^([0-9]+(?:\.[0-9]+)*)$', to_text(name[4:]))
if name.startswith("email:"):
return x509.RFC822Name(_adjust_idn_email(to_text(name[6:]), "idna"))
if name.startswith("URI:"):
return x509.UniformResourceIdentifier(
_adjust_idn_url(to_text(name[4:]), "idna")
)
if name.startswith("RID:"):
m = re.match(r"^([0-9]+(?:\.[0-9]+)*)$", to_text(name[4:]))
if not m:
raise OpenSSLObjectError('Cannot parse {what} "{name}"'.format(name=name, what=what))
raise OpenSSLObjectError(
'Cannot parse {what} "{name}"'.format(name=name, what=what)
)
return x509.RegisteredID(x509.oid.ObjectIdentifier(m.group(1)))
if name.startswith('otherName:'):
if name.startswith("otherName:"):
# otherName can either be a raw ASN.1 hex string or in the format that OpenSSL works with.
m = re.match(r'^([0-9]+(?:\.[0-9]+)*);([0-9a-fA-F]{1,2}(?::[0-9a-fA-F]{1,2})*)$', to_text(name[10:]))
m = re.match(
r"^([0-9]+(?:\.[0-9]+)*);([0-9a-fA-F]{1,2}(?::[0-9a-fA-F]{1,2})*)$",
to_text(name[10:]),
)
if m:
return x509.OtherName(x509.oid.ObjectIdentifier(m.group(1)), _parse_hex(m.group(2)))
return x509.OtherName(
x509.oid.ObjectIdentifier(m.group(1)), _parse_hex(m.group(2))
)
# See https://www.openssl.org/docs/man1.0.2/man5/x509v3_config.html - Subject Alternative Name for more
# defailts on the format expected.
name = to_text(name[10:], errors='surrogate_or_strict')
if ';' not in name:
raise OpenSSLObjectError('Cannot parse {what} otherName "{name}", must be in the '
'format "otherName:<OID>;<ASN.1 OpenSSL Encoded String>" or '
'"otherName:<OID>;<hex string>"'.format(name=name, what=what))
name = to_text(name[10:], errors="surrogate_or_strict")
if ";" not in name:
raise OpenSSLObjectError(
'Cannot parse {what} otherName "{name}", must be in the '
'format "otherName:<OID>;<ASN.1 OpenSSL Encoded String>" or '
'"otherName:<OID>;<hex string>"'.format(name=name, what=what)
)
oid, value = name.split(';', 1)
oid, value = name.split(";", 1)
b_value = serialize_asn1_string_as_der(value)
return x509.OtherName(x509.ObjectIdentifier(oid), b_value)
if name.startswith('dirName:'):
return x509.DirectoryName(x509.Name(reversed(_parse_dn(to_bytes(name[8:])))))
if name.startswith("dirName:"):
return x509.DirectoryName(
x509.Name(reversed(_parse_dn(to_bytes(name[8:]))))
)
except Exception as e:
raise OpenSSLObjectError('Cannot parse {what} "{name}": {error}'.format(name=name, what=what, error=e))
if ':' not in name:
raise OpenSSLObjectError('Cannot parse {what} "{name}" (forgot "DNS:" prefix?)'.format(name=name, what=what))
raise OpenSSLObjectError('Cannot parse {what} "{name}" (potentially unsupported by cryptography backend)'.format(name=name, what=what))
raise OpenSSLObjectError(
'Cannot parse {what} "{name}": {error}'.format(
name=name, what=what, error=e
)
)
if ":" not in name:
raise OpenSSLObjectError(
'Cannot parse {what} "{name}" (forgot "DNS:" prefix?)'.format(
name=name, what=what
)
)
raise OpenSSLObjectError(
'Cannot parse {what} "{name}" (potentially unsupported by cryptography backend)'.format(
name=name, what=what
)
)
def _dn_escape_value(value):
'''
"""
Escape Distinguished Name's attribute value.
'''
value = value.replace(u'\\', u'\\\\')
for ch in [u',', u'+', u'<', u'>', u';', u'"']:
value = value.replace(ch, u'\\%s' % ch)
value = value.replace(u'\0', u'\\00')
if value.startswith((u' ', u'#')):
value = u'\\%s' % value[0] + value[1:]
if value.endswith(u' '):
value = value[:-1] + u'\\ '
"""
value = value.replace(u"\\", u"\\\\")
for ch in [u",", u"+", u"<", u">", u";", u'"']:
value = value.replace(ch, u"\\%s" % ch)
value = value.replace(u"\0", u"\\00")
if value.startswith((u" ", u"#")):
value = u"\\%s" % value[0] + value[1:]
if value.endswith(u" "):
value = value[:-1] + u"\\ "
return value
def cryptography_decode_name(name, idn_rewrite='ignore'):
'''
def cryptography_decode_name(name, idn_rewrite="ignore"):
"""
Given a cryptography x509.GeneralName object, returns a string.
Raises an OpenSSLObjectError if the name is not supported.
'''
if idn_rewrite not in ('ignore', 'idna', 'unicode'):
raise AssertionError('idn_rewrite must be one of "ignore", "idna", or "unicode"')
"""
if idn_rewrite not in ("ignore", "idna", "unicode"):
raise AssertionError(
'idn_rewrite must be one of "ignore", "idna", or "unicode"'
)
if isinstance(name, x509.DNSName):
return u'DNS:{0}'.format(_adjust_idn(name.value, idn_rewrite))
return u"DNS:{0}".format(_adjust_idn(name.value, idn_rewrite))
if isinstance(name, x509.IPAddress):
if isinstance(name.value, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
return u'IP:{0}/{1}'.format(name.value.network_address.compressed, name.value.prefixlen)
return u'IP:{0}'.format(name.value.compressed)
return u"IP:{0}/{1}".format(
name.value.network_address.compressed, name.value.prefixlen
)
return u"IP:{0}".format(name.value.compressed)
if isinstance(name, x509.RFC822Name):
return u'email:{0}'.format(_adjust_idn_email(name.value, idn_rewrite))
return u"email:{0}".format(_adjust_idn_email(name.value, idn_rewrite))
if isinstance(name, x509.UniformResourceIdentifier):
return u'URI:{0}'.format(_adjust_idn_url(name.value, idn_rewrite))
return u"URI:{0}".format(_adjust_idn_url(name.value, idn_rewrite))
if isinstance(name, x509.DirectoryName):
# According to https://datatracker.ietf.org/doc/html/rfc4514.html#section-2.1 the
# list needs to be reversed, and joined by commas
return u'dirName:' + ','.join([
u'{0}={1}'.format(to_text(cryptography_oid_to_name(attribute.oid, short=True)), _dn_escape_value(attribute.value))
for attribute in reversed(list(name.value))
])
return u"dirName:" + ",".join(
[
u"{0}={1}".format(
to_text(cryptography_oid_to_name(attribute.oid, short=True)),
_dn_escape_value(attribute.value),
)
for attribute in reversed(list(name.value))
]
)
if isinstance(name, x509.RegisteredID):
return u'RID:{0}'.format(name.value.dotted_string)
return u"RID:{0}".format(name.value.dotted_string)
if isinstance(name, x509.OtherName):
return u'otherName:{0};{1}'.format(name.type_id.dotted_string, _get_hex(name.value))
return u"otherName:{0};{1}".format(
name.type_id.dotted_string, _get_hex(name.value)
)
raise OpenSSLObjectError('Cannot decode name "{0}"'.format(name))
def _cryptography_get_keyusage(usage):
'''
"""
Given a key usage identifier string, returns the parameter name used by cryptography's x509.KeyUsage().
Raises an OpenSSLObjectError if the identifier is unknown.
'''
if usage in ('Digital Signature', 'digitalSignature'):
return 'digital_signature'
if usage in ('Non Repudiation', 'nonRepudiation'):
return 'content_commitment'
if usage in ('Key Encipherment', 'keyEncipherment'):
return 'key_encipherment'
if usage in ('Data Encipherment', 'dataEncipherment'):
return 'data_encipherment'
if usage in ('Key Agreement', 'keyAgreement'):
return 'key_agreement'
if usage in ('Certificate Sign', 'keyCertSign'):
return 'key_cert_sign'
if usage in ('CRL Sign', 'cRLSign'):
return 'crl_sign'
if usage in ('Encipher Only', 'encipherOnly'):
return 'encipher_only'
if usage in ('Decipher Only', 'decipherOnly'):
return 'decipher_only'
"""
if usage in ("Digital Signature", "digitalSignature"):
return "digital_signature"
if usage in ("Non Repudiation", "nonRepudiation"):
return "content_commitment"
if usage in ("Key Encipherment", "keyEncipherment"):
return "key_encipherment"
if usage in ("Data Encipherment", "dataEncipherment"):
return "data_encipherment"
if usage in ("Key Agreement", "keyAgreement"):
return "key_agreement"
if usage in ("Certificate Sign", "keyCertSign"):
return "key_cert_sign"
if usage in ("CRL Sign", "cRLSign"):
return "crl_sign"
if usage in ("Encipher Only", "encipherOnly"):
return "encipher_only"
if usage in ("Decipher Only", "decipherOnly"):
return "decipher_only"
raise OpenSSLObjectError('Unknown key usage "{0}"'.format(usage))
def cryptography_parse_key_usage_params(usages):
'''
"""
Given a list of key usage identifier strings, returns the parameters for cryptography's x509.KeyUsage().
Raises an OpenSSLObjectError if an identifier is unknown.
'''
"""
params = dict(
digital_signature=False,
content_commitment=False,
@@ -586,40 +676,52 @@ def cryptography_parse_key_usage_params(usages):
def cryptography_get_basic_constraints(constraints):
'''
"""
Given a list of constraints, returns a tuple (ca, path_length).
Raises an OpenSSLObjectError if a constraint is unknown or cannot be parsed.
'''
"""
ca = False
path_length = None
if constraints:
for constraint in constraints:
if constraint.startswith('CA:'):
if constraint == 'CA:TRUE':
if constraint.startswith("CA:"):
if constraint == "CA:TRUE":
ca = True
elif constraint == 'CA:FALSE':
elif constraint == "CA:FALSE":
ca = False
else:
raise OpenSSLObjectError('Unknown basic constraint value "{0}" for CA'.format(constraint[3:]))
elif constraint.startswith('pathlen:'):
v = constraint[len('pathlen:'):]
raise OpenSSLObjectError(
'Unknown basic constraint value "{0}" for CA'.format(
constraint[3:]
)
)
elif constraint.startswith("pathlen:"):
v = constraint[len("pathlen:") :]
try:
path_length = int(v)
except Exception as e:
raise OpenSSLObjectError('Cannot parse path length constraint "{0}" ({1})'.format(v, e))
raise OpenSSLObjectError(
'Cannot parse path length constraint "{0}" ({1})'.format(v, e)
)
else:
raise OpenSSLObjectError('Unknown basic constraint "{0}"'.format(constraint))
raise OpenSSLObjectError(
'Unknown basic constraint "{0}"'.format(constraint)
)
return ca, path_length
def cryptography_key_needs_digest_for_signing(key):
'''Tests whether the given private key requires a digest algorithm for signing.
"""Tests whether the given private key requires a digest algorithm for signing.
Ed25519 and Ed448 keys do not; they need None to be passed as the digest algorithm.
'''
if CRYPTOGRAPHY_HAS_ED25519 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey):
"""
if CRYPTOGRAPHY_HAS_ED25519 and isinstance(
key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey
):
return False
if CRYPTOGRAPHY_HAS_ED448 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey):
if CRYPTOGRAPHY_HAS_ED448 and isinstance(
key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey
):
return False
return True
@@ -637,16 +739,22 @@ def _compare_public_keys(key1, key2, clazz):
def cryptography_compare_public_keys(key1, key2):
'''Tests whether two public keys are the same.
"""Tests whether two public keys are the same.
Needs special logic for Ed25519 and Ed448 keys, since they do not have public_numbers().
'''
"""
if CRYPTOGRAPHY_HAS_ED25519:
res = _compare_public_keys(key1, key2, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey)
res = _compare_public_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
)
if res is not None:
return res
if CRYPTOGRAPHY_HAS_ED448:
res = _compare_public_keys(key1, key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey)
res = _compare_public_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey
)
if res is not None:
return res
return key1.public_numbers() == key2.public_numbers()
@@ -663,41 +771,61 @@ def _compare_private_keys(key1, key2, clazz, has_no_private_bytes=False):
# We do not have the private_bytes() function - compare associated public keys
return cryptography_compare_public_keys(a.public_key(), b.public_key())
encryption_algorithm = cryptography.hazmat.primitives.serialization.NoEncryption()
a = key1.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, encryption_algorithm=encryption_algorithm)
b = key2.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, encryption_algorithm=encryption_algorithm)
a = key1.private_bytes(
serialization.Encoding.Raw,
serialization.PrivateFormat.Raw,
encryption_algorithm=encryption_algorithm,
)
b = key2.private_bytes(
serialization.Encoding.Raw,
serialization.PrivateFormat.Raw,
encryption_algorithm=encryption_algorithm,
)
return a == b
def cryptography_compare_private_keys(key1, key2):
'''Tests whether two private keys are the same.
"""Tests whether two private keys are the same.
Needs special logic for Ed25519, X25519, and Ed448 keys, since they do not have private_numbers().
'''
"""
if CRYPTOGRAPHY_HAS_ED25519:
res = _compare_private_keys(key1, key2, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey)
res = _compare_private_keys(
key1,
key2,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
)
if res is not None:
return res
if CRYPTOGRAPHY_HAS_X25519:
res = _compare_private_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey, has_no_private_bytes=not CRYPTOGRAPHY_HAS_X25519_FULL)
key1,
key2,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
has_no_private_bytes=not CRYPTOGRAPHY_HAS_X25519_FULL,
)
if res is not None:
return res
if CRYPTOGRAPHY_HAS_ED448:
res = _compare_private_keys(key1, key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey)
res = _compare_private_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey
)
if res is not None:
return res
if CRYPTOGRAPHY_HAS_X448:
res = _compare_private_keys(key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey)
res = _compare_private_keys(
key1, key2, cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey
)
if res is not None:
return res
return key1.private_numbers() == key2.private_numbers()
def cryptography_serial_number_of_cert(cert):
'''Returns cert.serial_number.
"""Returns cert.serial_number.
Also works for old versions of cryptography.
'''
"""
try:
return cert.serial_number
except AttributeError:
@@ -706,10 +834,11 @@ def cryptography_serial_number_of_cert(cert):
def parse_pkcs12(pkcs12_bytes, passphrase=None):
'''Returns a tuple (private_key, certificate, additional_certificates, friendly_name).
'''
"""Returns a tuple (private_key, certificate, additional_certificates, friendly_name)."""
if _load_pkcs12 is None and _load_key_and_certificates is None:
raise ValueError('neither load_pkcs12() nor load_key_and_certificates() present in the current cryptography version')
raise ValueError(
"neither load_pkcs12() nor load_key_and_certificates() present in the current cryptography version"
)
if passphrase is not None:
passphrase = to_bytes(passphrase)
@@ -718,7 +847,7 @@ def parse_pkcs12(pkcs12_bytes, passphrase=None):
if _load_pkcs12 is not None:
return _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase)
if LooseVersion(cryptography.__version__) >= LooseVersion('35.0'):
if LooseVersion(cryptography.__version__) >= LooseVersion("35.0"):
return _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase)
return _parse_pkcs12_legacy(pkcs12_bytes, passphrase)
@@ -739,7 +868,9 @@ def _parse_pkcs12_36_0_0(pkcs12_bytes, passphrase=None):
def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None):
# Backwards compatibility code for cryptography 35.x
private_key, certificate, additional_certificates = _load_key_and_certificates(pkcs12_bytes, passphrase)
private_key, certificate, additional_certificates = _load_key_and_certificates(
pkcs12_bytes, passphrase
)
friendly_name = None
if certificate:
@@ -749,18 +880,26 @@ def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None):
# This code basically does what load_key_and_certificates() does, but without error-checking.
# Since load_key_and_certificates succeeded, it should not fail.
pkcs12 = backend._ffi.gc(
backend._lib.d2i_PKCS12_bio(backend._bytes_to_bio(pkcs12_bytes).bio, backend._ffi.NULL),
backend._lib.PKCS12_free)
backend._lib.d2i_PKCS12_bio(
backend._bytes_to_bio(pkcs12_bytes).bio, backend._ffi.NULL
),
backend._lib.PKCS12_free,
)
certificate_x509_ptr = backend._ffi.new("X509 **")
with backend._zeroed_null_terminated_buf(to_bytes(passphrase) if passphrase is not None else None) as passphrase_buffer:
with backend._zeroed_null_terminated_buf(
to_bytes(passphrase) if passphrase is not None else None
) as passphrase_buffer:
backend._lib.PKCS12_parse(
pkcs12,
passphrase_buffer,
backend._ffi.new("EVP_PKEY **"),
certificate_x509_ptr,
backend._ffi.new("Cryptography_STACK_OF_X509 **"))
backend._ffi.new("Cryptography_STACK_OF_X509 **"),
)
if certificate_x509_ptr[0] != backend._ffi.NULL:
maybe_name = backend._lib.X509_alias_get0(certificate_x509_ptr[0], backend._ffi.NULL)
maybe_name = backend._lib.X509_alias_get0(
certificate_x509_ptr[0], backend._ffi.NULL
)
if maybe_name != backend._ffi.NULL:
friendly_name = backend._ffi.string(maybe_name)
@@ -769,7 +908,9 @@ def _parse_pkcs12_35_0_0(pkcs12_bytes, passphrase=None):
def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None):
# Backwards compatibility code for cryptography < 35.0.0
private_key, certificate, additional_certificates = _load_key_and_certificates(pkcs12_bytes, passphrase)
private_key, certificate, additional_certificates = _load_key_and_certificates(
pkcs12_bytes, passphrase
)
friendly_name = None
if certificate:
@@ -782,39 +923,62 @@ def _parse_pkcs12_legacy(pkcs12_bytes, passphrase=None):
def cryptography_verify_signature(signature, data, hash_algorithm, signer_public_key):
'''
"""
Check whether the given signature of the given data was signed by the given public key object.
'''
"""
try:
if CRYPTOGRAPHY_HAS_RSA_SIGN and isinstance(signer_public_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey):
signer_public_key.verify(signature, data, padding.PKCS1v15(), hash_algorithm)
if CRYPTOGRAPHY_HAS_RSA_SIGN and isinstance(
signer_public_key,
cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey,
):
signer_public_key.verify(
signature, data, padding.PKCS1v15(), hash_algorithm
)
return True
if CRYPTOGRAPHY_HAS_EC_SIGN and isinstance(signer_public_key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey):
signer_public_key.verify(signature, data, cryptography.hazmat.primitives.asymmetric.ec.ECDSA(hash_algorithm))
if CRYPTOGRAPHY_HAS_EC_SIGN and isinstance(
signer_public_key,
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey,
):
signer_public_key.verify(
signature,
data,
cryptography.hazmat.primitives.asymmetric.ec.ECDSA(hash_algorithm),
)
return True
if CRYPTOGRAPHY_HAS_DSA_SIGN and isinstance(signer_public_key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey):
if CRYPTOGRAPHY_HAS_DSA_SIGN and isinstance(
signer_public_key,
cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey,
):
signer_public_key.verify(signature, data, hash_algorithm)
return True
if CRYPTOGRAPHY_HAS_ED25519_SIGN and isinstance(signer_public_key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey):
if CRYPTOGRAPHY_HAS_ED25519_SIGN and isinstance(
signer_public_key,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey,
):
signer_public_key.verify(signature, data)
return True
if CRYPTOGRAPHY_HAS_ED448_SIGN and isinstance(signer_public_key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey):
if CRYPTOGRAPHY_HAS_ED448_SIGN and isinstance(
signer_public_key,
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey,
):
signer_public_key.verify(signature, data)
return True
raise OpenSSLObjectError(u'Unsupported public key type {0}'.format(type(signer_public_key)))
raise OpenSSLObjectError(
u"Unsupported public key type {0}".format(type(signer_public_key))
)
except InvalidSignature:
return False
def cryptography_verify_certificate_signature(certificate, signer_public_key):
'''
"""
Check whether the given X509 certificate object was signed by the given public key object.
'''
"""
return cryptography_verify_signature(
certificate.signature,
certificate.tbs_certificate_bytes,
certificate.signature_hash_algorithm,
signer_public_key
signer_public_key,
)

View File

@@ -14,7 +14,7 @@ import sys
def binary_exp_mod(f, e, m):
'''Computes f^e mod m in O(log e) multiplications modulo m.'''
"""Computes f^e mod m in O(log e) multiplications modulo m."""
# Compute len_e = floor(log_2(e))
len_e = -1
x = e
@@ -31,18 +31,18 @@ def binary_exp_mod(f, e, m):
def simple_gcd(a, b):
'''Compute GCD of its two inputs.'''
"""Compute GCD of its two inputs."""
while b != 0:
a, b = b, a % b
return a
def quick_is_not_prime(n):
'''Does some quick checks to see if we can poke a hole into the primality of n.
"""Does some quick checks to see if we can poke a hole into the primality of n.
A result of `False` does **not** mean that the number is prime; it just means
that we could not detect quickly whether it is not prime.
'''
"""
if n <= 2:
return n < 2
# The constant in the next line is the product of all primes < 200
@@ -52,9 +52,52 @@ def quick_is_not_prime(n):
if n < 200 and gcd == n:
# Explicitly check for all primes < 200
return n not in (
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83,
89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179,
181, 191, 193, 197, 199,
2,
3,
5,
7,
11,
13,
17,
19,
23,
29,
31,
37,
41,
43,
47,
53,
59,
61,
67,
71,
73,
79,
83,
89,
97,
101,
103,
107,
109,
113,
127,
131,
137,
139,
149,
151,
157,
163,
167,
173,
179,
181,
191,
193,
197,
199,
)
return True
# TODO: maybe do some iterations of Miller-Rabin to increase confidence
@@ -83,6 +126,7 @@ if python_version >= (2, 7) or python_version >= (3, 1):
if no == 0:
return 0
return no.bit_length()
else:
# Slow, but works
def count_bytes(no):
@@ -107,25 +151,27 @@ else:
count += 1
return count
if sys.version_info[0] >= 3:
# Python 3 (and newer)
def _convert_int_to_bytes(count, no):
return no.to_bytes(count, byteorder='big')
return no.to_bytes(count, byteorder="big")
def _convert_bytes_to_int(data):
return int.from_bytes(data, byteorder='big', signed=False)
return int.from_bytes(data, byteorder="big", signed=False)
def _to_hex(no):
return hex(no)[2:]
else:
# Python 2
def _convert_int_to_bytes(count, n):
if n == 0 and count == 0:
return ''
h = '%x' % n
return ""
h = "%x" % n
if len(h) > 2 * count:
raise Exception('Number {1} needs more than {0} bytes!'.format(count, n))
return ('0' * (2 * count - len(h)) + h).decode('hex')
raise Exception("Number {1} needs more than {0} bytes!".format(count, n))
return ("0" * (2 * count - len(h)) + h).decode("hex")
def _convert_bytes_to_int(data):
v = 0
@@ -134,7 +180,7 @@ else:
return v
def _to_hex(no):
return '%x' % no
return "%x" % no
def convert_int_to_bytes(no, count=None):
@@ -164,7 +210,7 @@ def convert_int_to_hex(no, digits=None):
no = abs(no)
value = _to_hex(no)
if digits is not None and len(value) < digits:
value = '0' * (digits - len(value)) + value
value = "0" * (digits - len(value)) + value
return value

View File

@@ -41,13 +41,14 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
MINIMAL_CRYPTOGRAPHY_VERSION = '1.6'
MINIMAL_CRYPTOGRAPHY_VERSION = "1.6"
CRYPTOGRAPHY_IMP_ERR = None
CRYPTOGRAPHY_VERSION = None
try:
import cryptography
from cryptography import x509
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
except ImportError:
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -66,21 +67,21 @@ class CertificateBackend(object):
self.module = module
self.backend = backend
self.force = module.params['force']
self.ignore_timestamps = module.params['ignore_timestamps']
self.privatekey_path = module.params['privatekey_path']
self.privatekey_content = module.params['privatekey_content']
self.force = module.params["force"]
self.ignore_timestamps = module.params["ignore_timestamps"]
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
if self.privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode('utf-8')
self.privatekey_passphrase = module.params['privatekey_passphrase']
self.csr_path = module.params['csr_path']
self.csr_content = module.params['csr_content']
self.privatekey_content = self.privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.csr_path = module.params["csr_path"]
self.csr_content = module.params["csr_content"]
if self.csr_content is not None:
self.csr_content = self.csr_content.encode('utf-8')
self.csr_content = self.csr_content.encode("utf-8")
# The following are default values which make sure check() works as
# before if providers do not explicitly change these properties.
self.create_subject_key_identifier = 'never_create'
self.create_subject_key_identifier = "never_create"
self.create_authority_key_identifier = False
self.privatekey = None
@@ -99,8 +100,10 @@ class CertificateBackend(object):
if data is None:
return dict()
try:
result = get_certificate_info(self.module, self.backend, data, prefer_one_fingerprint=True)
result['can_parse_certificate'] = True
result = get_certificate_info(
self.module, self.backend, data, prefer_one_fingerprint=True
)
result["can_parse_certificate"] = True
return result
except Exception:
return dict(can_parse_certificate=False)
@@ -118,7 +121,9 @@ class CertificateBackend(object):
def set_existing(self, certificate_bytes):
"""Set existing certificate bytes. None indicates that the key does not exist."""
self.existing_certificate_bytes = certificate_bytes
self.diff_after = self.diff_before = self._get_info(self.existing_certificate_bytes)
self.diff_after = self.diff_before = self._get_info(
self.existing_certificate_bytes
)
def has_existing(self):
"""Query whether an existing certificate is/has been there."""
@@ -166,33 +171,60 @@ class CertificateBackend(object):
def _check_privatekey(self):
"""Check whether provided parameters match, assuming self.existing_certificate and self.privatekey have been populated."""
if self.backend == 'cryptography':
return cryptography_compare_public_keys(self.existing_certificate.public_key(), self.privatekey.public_key())
if self.backend == "cryptography":
return cryptography_compare_public_keys(
self.existing_certificate.public_key(), self.privatekey.public_key()
)
def _check_csr(self):
"""Check whether provided parameters match, assuming self.existing_certificate and self.csr have been populated."""
if self.backend == 'cryptography':
if self.backend == "cryptography":
# Verify that CSR is signed by certificate's private key
if not self.csr.is_signature_valid:
return False
if not cryptography_compare_public_keys(self.csr.public_key(), self.existing_certificate.public_key()):
if not cryptography_compare_public_keys(
self.csr.public_key(), self.existing_certificate.public_key()
):
return False
# Check subject
if self.check_csr_subject and self.csr.subject != self.existing_certificate.subject:
if (
self.check_csr_subject
and self.csr.subject != self.existing_certificate.subject
):
return False
# Check extensions
if not self.check_csr_extensions:
return True
cert_exts = list(self.existing_certificate.extensions)
csr_exts = list(self.csr.extensions)
if self.create_subject_key_identifier != 'never_create':
if self.create_subject_key_identifier != "never_create":
# Filter out SubjectKeyIdentifier extension before comparison
cert_exts = list(filter(lambda x: not isinstance(x.value, x509.SubjectKeyIdentifier), cert_exts))
csr_exts = list(filter(lambda x: not isinstance(x.value, x509.SubjectKeyIdentifier), csr_exts))
cert_exts = list(
filter(
lambda x: not isinstance(x.value, x509.SubjectKeyIdentifier),
cert_exts,
)
)
csr_exts = list(
filter(
lambda x: not isinstance(x.value, x509.SubjectKeyIdentifier),
csr_exts,
)
)
if self.create_authority_key_identifier:
# Filter out AuthorityKeyIdentifier extension before comparison
cert_exts = list(filter(lambda x: not isinstance(x.value, x509.AuthorityKeyIdentifier), cert_exts))
csr_exts = list(filter(lambda x: not isinstance(x.value, x509.AuthorityKeyIdentifier), csr_exts))
cert_exts = list(
filter(
lambda x: not isinstance(x.value, x509.AuthorityKeyIdentifier),
cert_exts,
)
)
csr_exts = list(
filter(
lambda x: not isinstance(x.value, x509.AuthorityKeyIdentifier),
csr_exts,
)
)
if len(cert_exts) != len(csr_exts):
return False
for cert_ext in cert_exts:
@@ -208,19 +240,28 @@ class CertificateBackend(object):
"""Check whether Subject Key Identifier matches, assuming self.existing_certificate has been populated."""
# Get hold of certificate's SKI
try:
ext = self.existing_certificate.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
ext = self.existing_certificate.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
except cryptography.x509.ExtensionNotFound:
return False
# Get hold of CSR's SKI for 'create_if_not_provided'
csr_ext = None
if self.create_subject_key_identifier == 'create_if_not_provided':
if self.create_subject_key_identifier == "create_if_not_provided":
try:
csr_ext = self.csr.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
csr_ext = self.csr.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
except cryptography.x509.ExtensionNotFound:
pass
if csr_ext is None:
# If CSR had no SKI, or we chose to ignore it ('always_create'), compare with created SKI
if ext.value.digest != x509.SubjectKeyIdentifier.from_public_key(self.existing_certificate.public_key()).digest:
if (
ext.value.digest
!= x509.SubjectKeyIdentifier.from_public_key(
self.existing_certificate.public_key()
).digest
):
return False
else:
# If CSR had SKI and we did not ignore it ('create_if_not_provided'), compare SKIs
@@ -249,7 +290,10 @@ class CertificateBackend(object):
return True
# Check SubjectKeyIdentifier
if self.create_subject_key_identifier != 'never_create' and not self._check_subject_key_identifier():
if (
self.create_subject_key_identifier != "never_create"
and not self._check_subject_key_identifier()
):
return True
# Check not before
@@ -265,10 +309,7 @@ class CertificateBackend(object):
def dump(self, include_certificate):
"""Serialize the object into a dictionary."""
result = {
'privatekey': self.privatekey_path,
'csr': self.csr_path
}
result = {"privatekey": self.privatekey_path, "csr": self.csr_path}
# Get hold of certificate bytes
certificate_bytes = self.existing_certificate_bytes
if self.cert is not None:
@@ -276,9 +317,11 @@ class CertificateBackend(object):
self.diff_after = self._get_info(certificate_bytes)
if include_certificate:
# Store result
result['certificate'] = certificate_bytes.decode('utf-8') if certificate_bytes else None
result["certificate"] = (
certificate_bytes.decode("utf-8") if certificate_bytes else None
)
result['diff'] = dict(
result["diff"] = dict(
before=self.diff_before,
after=self.diff_after,
)
@@ -311,26 +354,38 @@ def select_backend(module, backend, provider):
"""
provider.validate_module_args(module)
backend = module.params['select_crypto_backend']
if backend == 'auto':
backend = module.params["select_crypto_backend"]
if backend == "auto":
# Detect what backend we can use
can_use_cryptography = CRYPTOGRAPHY_FOUND and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
can_use_cryptography = (
CRYPTOGRAPHY_FOUND
and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
)
# If cryptography is available we'll use it
if can_use_cryptography:
backend = 'cryptography'
backend = "cryptography"
# Fail if no backend has been found
if backend == 'auto':
module.fail_json(msg=("Cannot detect the required Python library "
"cryptography (>= {0})").format(MINIMAL_CRYPTOGRAPHY_VERSION))
if backend == "auto":
module.fail_json(
msg=(
"Cannot detect the required Python library " "cryptography (>= {0})"
).format(MINIMAL_CRYPTOGRAPHY_VERSION)
)
if backend == 'cryptography':
if backend == "cryptography":
if not CRYPTOGRAPHY_FOUND:
module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
exception=CRYPTOGRAPHY_IMP_ERR)
module.fail_json(
msg=missing_required_lib(
"cryptography >= {0}".format(MINIMAL_CRYPTOGRAPHY_VERSION)
),
exception=CRYPTOGRAPHY_IMP_ERR,
)
if provider.needs_version_two_certs(module):
module.fail_json(msg='The cryptography backend does not support v2 certificates')
module.fail_json(
msg="The cryptography backend does not support v2 certificates"
)
return provider.create_backend(module, backend)
@@ -338,20 +393,26 @@ def select_backend(module, backend, provider):
def get_certificate_argument_spec():
return ArgumentSpec(
argument_spec=dict(
provider=dict(type='str', choices=[]), # choices will be filled by add_XXX_provider_to_argument_spec() in certificate_xxx.py
force=dict(type='bool', default=False,),
csr_path=dict(type='path'),
csr_content=dict(type='str'),
ignore_timestamps=dict(type='bool', default=True),
select_crypto_backend=dict(type='str', default='auto', choices=['auto', 'cryptography']),
provider=dict(
type="str", choices=[]
), # choices will be filled by add_XXX_provider_to_argument_spec() in certificate_xxx.py
force=dict(
type="bool",
default=False,
),
csr_path=dict(type="path"),
csr_content=dict(type="str"),
ignore_timestamps=dict(type="bool", default=True),
select_crypto_backend=dict(
type="str", default="auto", choices=["auto", "cryptography"]
),
# General properties of a certificate
privatekey_path=dict(type='path'),
privatekey_content=dict(type='str', no_log=True),
privatekey_passphrase=dict(type='str', no_log=True),
privatekey_path=dict(type="path"),
privatekey_content=dict(type="str", no_log=True),
privatekey_passphrase=dict(type="str", no_log=True),
),
mutually_exclusive=[
['csr_path', 'csr_content'],
['privatekey_path', 'privatekey_content'],
["csr_path", "csr_content"],
["privatekey_path", "privatekey_content"],
],
)

View File

@@ -26,44 +26,44 @@ from ansible_collections.community.crypto.plugins.module_utils.crypto.module_bac
class AcmeCertificateBackend(CertificateBackend):
def __init__(self, module, backend):
super(AcmeCertificateBackend, self).__init__(module, backend)
self.accountkey_path = module.params['acme_accountkey_path']
self.challenge_path = module.params['acme_challenge_path']
self.use_chain = module.params['acme_chain']
self.acme_directory = module.params['acme_directory']
self.accountkey_path = module.params["acme_accountkey_path"]
self.challenge_path = module.params["acme_challenge_path"]
self.use_chain = module.params["acme_chain"]
self.acme_directory = module.params["acme_directory"]
if self.csr_content is None and self.csr_path is None:
raise CertificateError(
'csr_path or csr_content is required for ownca provider'
"csr_path or csr_content is required for ownca provider"
)
if self.csr_content is None and not os.path.exists(self.csr_path):
raise CertificateError(
'The certificate signing request file %s does not exist' % self.csr_path
"The certificate signing request file %s does not exist" % self.csr_path
)
if not os.path.exists(self.accountkey_path):
raise CertificateError(
'The account key %s does not exist' % self.accountkey_path
"The account key %s does not exist" % self.accountkey_path
)
if not os.path.exists(self.challenge_path):
raise CertificateError(
'The challenge path %s does not exist' % self.challenge_path
"The challenge path %s does not exist" % self.challenge_path
)
self.acme_tiny_path = self.module.get_bin_path('acme-tiny', required=True)
self.acme_tiny_path = self.module.get_bin_path("acme-tiny", required=True)
def generate_certificate(self):
"""(Re-)Generate certificate."""
command = [self.acme_tiny_path]
if self.use_chain:
command.append('--chain')
command.extend(['--account-key', self.accountkey_path])
command.append("--chain")
command.extend(["--account-key", self.accountkey_path])
if self.csr_content is not None:
# We need to temporarily write the CSR to disk
fd, tmpsrc = tempfile.mkstemp()
self.module.add_cleanup_file(tmpsrc) # Ansible will delete the file on exit
f = os.fdopen(fd, 'wb')
f = os.fdopen(fd, "wb")
try:
f.write(self.csr_content)
except Exception as err:
@@ -73,14 +73,14 @@ class AcmeCertificateBackend(CertificateBackend):
pass
self.module.fail_json(
msg="failed to create temporary CSR file: %s" % to_native(err),
exception=traceback.format_exc()
exception=traceback.format_exc(),
)
f.close()
command.extend(['--csr', tmpsrc])
command.extend(["--csr", tmpsrc])
else:
command.extend(['--csr', self.csr_path])
command.extend(['--acme-dir', self.challenge_path])
command.extend(['--directory-url', self.acme_directory])
command.extend(["--csr", self.csr_path])
command.extend(["--acme-dir", self.challenge_path])
command.extend(["--directory-url", self.acme_directory])
try:
self.cert = to_bytes(self.module.run_command(command, check_rc=True)[1])
@@ -93,16 +93,20 @@ class AcmeCertificateBackend(CertificateBackend):
def dump(self, include_certificate):
result = super(AcmeCertificateBackend, self).dump(include_certificate)
result['accountkey'] = self.accountkey_path
result["accountkey"] = self.accountkey_path
return result
class AcmeCertificateProvider(CertificateProvider):
def validate_module_args(self, module):
if module.params['acme_accountkey_path'] is None:
module.fail_json(msg='The acme_accountkey_path option must be specified for the acme provider.')
if module.params['acme_challenge_path'] is None:
module.fail_json(msg='The acme_challenge_path option must be specified for the acme provider.')
if module.params["acme_accountkey_path"] is None:
module.fail_json(
msg="The acme_accountkey_path option must be specified for the acme provider."
)
if module.params["acme_challenge_path"] is None:
module.fail_json(
msg="The acme_challenge_path option must be specified for the acme provider."
)
def needs_version_two_certs(self, module):
return False
@@ -112,10 +116,14 @@ class AcmeCertificateProvider(CertificateProvider):
def add_acme_provider_to_argument_spec(argument_spec):
argument_spec.argument_spec['provider']['choices'].append('acme')
argument_spec.argument_spec.update(dict(
acme_accountkey_path=dict(type='path'),
acme_challenge_path=dict(type='path'),
acme_chain=dict(type='bool', default=False),
acme_directory=dict(type='str', default="https://acme-v02.api.letsencrypt.org/directory"),
))
argument_spec.argument_spec["provider"]["choices"].append("acme")
argument_spec.argument_spec.update(
dict(
acme_accountkey_path=dict(type="path"),
acme_challenge_path=dict(type="path"),
acme_chain=dict(type="bool", default=False),
acme_directory=dict(
type="str", default="https://acme-v02.api.letsencrypt.org/directory"
),
)
)

View File

@@ -50,19 +50,21 @@ class EntrustCertificateBackend(CertificateBackend):
super(EntrustCertificateBackend, self).__init__(module, backend)
self.trackingId = None
self.notAfter = get_relative_time_option(
module.params['entrust_not_after'],
'entrust_not_after',
module.params["entrust_not_after"],
"entrust_not_after",
backend=self.backend,
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
if self.csr_content is None and self.csr_path is None:
raise CertificateError(
'csr_path or csr_content is required for entrust provider'
"csr_path or csr_content is required for entrust provider"
)
if self.csr_content is None and not os.path.exists(self.csr_path):
raise CertificateError(
'The certificate signing request file {0} does not exist'.format(self.csr_path)
"The certificate signing request file {0} does not exist".format(
self.csr_path
)
)
self._ensure_csr_loaded()
@@ -71,28 +73,42 @@ class EntrustCertificateBackend(CertificateBackend):
# We want to always force behavior of trying to use the organization provided in the CSR.
# To that end we need to parse out the organization from the CSR.
self.csr_org = None
if self.backend == 'cryptography':
csr_subject_orgs = self.csr.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)
if self.backend == "cryptography":
csr_subject_orgs = self.csr.subject.get_attributes_for_oid(
NameOID.ORGANIZATION_NAME
)
if len(csr_subject_orgs) == 1:
self.csr_org = csr_subject_orgs[0].value
elif len(csr_subject_orgs) > 1:
self.module.fail_json(msg=("Entrust provider does not currently support multiple validated organizations. Multiple organizations found in "
"Subject DN: '{0}'. ".format(self.csr.subject)))
self.module.fail_json(
msg=(
"Entrust provider does not currently support multiple validated organizations. Multiple organizations found in "
"Subject DN: '{0}'. ".format(self.csr.subject)
)
)
# If no organization in the CSR, explicitly tell ECS that it should be blank in issued cert, not defaulted to
# organization tied to the account.
if self.csr_org is None:
self.csr_org = ''
self.csr_org = ""
try:
self.ecs_client = ECSClient(
entrust_api_user=self.module.params['entrust_api_user'],
entrust_api_key=self.module.params['entrust_api_key'],
entrust_api_cert=self.module.params['entrust_api_client_cert_path'],
entrust_api_cert_key=self.module.params['entrust_api_client_cert_key_path'],
entrust_api_specification_path=self.module.params['entrust_api_specification_path']
entrust_api_user=self.module.params["entrust_api_user"],
entrust_api_key=self.module.params["entrust_api_key"],
entrust_api_cert=self.module.params["entrust_api_client_cert_path"],
entrust_api_cert_key=self.module.params[
"entrust_api_client_cert_key_path"
],
entrust_api_specification_path=self.module.params[
"entrust_api_specification_path"
],
)
except SessionConfigurationException as e:
module.fail_json(msg='Failed to initialize Entrust Provider: {0}'.format(to_native(e.message)))
module.fail_json(
msg="Failed to initialize Entrust Provider: {0}".format(
to_native(e.message)
)
)
def generate_certificate(self):
"""(Re-)Generate certificate."""
@@ -101,12 +117,12 @@ class EntrustCertificateBackend(CertificateBackend):
# Read the CSR that was generated for us
if self.csr_content is not None:
# csr_content contains bytes
body['csr'] = to_native(self.csr_content)
body["csr"] = to_native(self.csr_content)
else:
with open(self.csr_path, 'r') as csr_file:
body['csr'] = csr_file.read()
with open(self.csr_path, "r") as csr_file:
body["csr"] = csr_file.read()
body['certType'] = self.module.params['entrust_cert_type']
body["certType"] = self.module.params["entrust_cert_type"]
# Handle expiration (30 days if not specified)
expiry = self.notAfter
@@ -115,22 +131,28 @@ class EntrustCertificateBackend(CertificateBackend):
expiry = gmt_now + datetime.timedelta(days=365)
expiry_iso3339 = expiry.strftime("%Y-%m-%dT%H:%M:%S.00Z")
body['certExpiryDate'] = expiry_iso3339
body['org'] = self.csr_org
body['tracking'] = {
'requesterName': self.module.params['entrust_requester_name'],
'requesterEmail': self.module.params['entrust_requester_email'],
'requesterPhone': self.module.params['entrust_requester_phone'],
body["certExpiryDate"] = expiry_iso3339
body["org"] = self.csr_org
body["tracking"] = {
"requesterName": self.module.params["entrust_requester_name"],
"requesterEmail": self.module.params["entrust_requester_email"],
"requesterPhone": self.module.params["entrust_requester_phone"],
}
try:
result = self.ecs_client.NewCertRequest(Body=body)
self.trackingId = result.get('trackingId')
self.trackingId = result.get("trackingId")
except RestOperationException as e:
self.module.fail_json(msg='Failed to request new certificate from Entrust Certificate Services (ECS): {0}'.format(to_native(e.message)))
self.module.fail_json(
msg="Failed to request new certificate from Entrust Certificate Services (ECS): {0}".format(
to_native(e.message)
)
)
self.cert_bytes = to_bytes(result.get('endEntityCert'))
self.cert = load_certificate(path=None, content=self.cert_bytes, backend=self.backend)
self.cert_bytes = to_bytes(result.get("endEntityCert"))
self.cert = load_certificate(
path=None, content=self.cert_bytes, backend=self.backend
)
def get_certificate_data(self):
"""Return bytes for self.cert."""
@@ -142,15 +164,23 @@ class EntrustCertificateBackend(CertificateBackend):
try:
cert_details = self._get_cert_details()
except RestOperationException as e:
self.module.fail_json(msg='Failed to get status of existing certificate from Entrust Certificate Services (ECS): {0}.'.format(to_native(e.message)))
self.module.fail_json(
msg="Failed to get status of existing certificate from Entrust Certificate Services (ECS): {0}.".format(
to_native(e.message)
)
)
# Always issue a new certificate if the certificate is expired, suspended or revoked
status = cert_details.get('status', False)
if status == 'EXPIRED' or status == 'SUSPENDED' or status == 'REVOKED':
status = cert_details.get("status", False)
if status == "EXPIRED" or status == "SUSPENDED" or status == "REVOKED":
return True
# If the requested cert type was specified and it is for a different certificate type than the initial certificate, a new one is needed
if self.module.params['entrust_cert_type'] and cert_details.get('certType') and self.module.params['entrust_cert_type'] != cert_details.get('certType'):
if (
self.module.params["entrust_cert_type"]
and cert_details.get("certType")
and self.module.params["entrust_cert_type"] != cert_details.get("certType")
):
return True
return parent_check
@@ -164,27 +194,33 @@ class EntrustCertificateBackend(CertificateBackend):
if self.existing_certificate:
serial_number = None
expiry = None
if self.backend == 'cryptography':
serial_number = "{0:X}".format(cryptography_serial_number_of_cert(self.existing_certificate))
if self.backend == "cryptography":
serial_number = "{0:X}".format(
cryptography_serial_number_of_cert(self.existing_certificate)
)
expiry = get_not_valid_after(self.existing_certificate)
# get some information about the expiry of this certificate
expiry_iso3339 = expiry.strftime("%Y-%m-%dT%H:%M:%S.00Z")
cert_details['expiresAfter'] = expiry_iso3339
cert_details["expiresAfter"] = expiry_iso3339
# If a trackingId is not already defined (from the result of a generate)
# use the serial number to identify the tracking Id
if self.trackingId is None and serial_number is not None:
cert_results = self.ecs_client.GetCertificates(serialNumber=serial_number).get('certificates', {})
cert_results = self.ecs_client.GetCertificates(
serialNumber=serial_number
).get("certificates", {})
# Finding 0 or more than 1 result is a very unlikely use case, it simply means we cannot perform additional checks
# on the 'state' as returned by Entrust Certificate Services (ECS). The general certificate validity is
# still checked as it is in the rest of the module.
if len(cert_results) == 1:
self.trackingId = cert_results[0].get('trackingId')
self.trackingId = cert_results[0].get("trackingId")
if self.trackingId is not None:
cert_details.update(self.ecs_client.GetCertificate(trackingId=self.trackingId))
cert_details.update(
self.ecs_client.GetCertificate(trackingId=self.trackingId)
)
return cert_details
@@ -201,23 +237,51 @@ class EntrustCertificateProvider(CertificateProvider):
def add_entrust_provider_to_argument_spec(argument_spec):
argument_spec.argument_spec['provider']['choices'].append('entrust')
argument_spec.argument_spec.update(dict(
entrust_cert_type=dict(type='str', default='STANDARD_SSL',
choices=['STANDARD_SSL', 'ADVANTAGE_SSL', 'UC_SSL', 'EV_SSL', 'WILDCARD_SSL',
'PRIVATE_SSL', 'PD_SSL', 'CDS_ENT_LITE', 'CDS_ENT_PRO', 'SMIME_ENT']),
entrust_requester_email=dict(type='str'),
entrust_requester_name=dict(type='str'),
entrust_requester_phone=dict(type='str'),
entrust_api_user=dict(type='str'),
entrust_api_key=dict(type='str', no_log=True),
entrust_api_client_cert_path=dict(type='path'),
entrust_api_client_cert_key_path=dict(type='path', no_log=True),
entrust_api_specification_path=dict(type='path', default='https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml'),
entrust_not_after=dict(type='str', default='+365d'),
))
argument_spec.required_if.append(
['provider', 'entrust', ['entrust_requester_email', 'entrust_requester_name', 'entrust_requester_phone',
'entrust_api_user', 'entrust_api_key', 'entrust_api_client_cert_path',
'entrust_api_client_cert_key_path']]
argument_spec.argument_spec["provider"]["choices"].append("entrust")
argument_spec.argument_spec.update(
dict(
entrust_cert_type=dict(
type="str",
default="STANDARD_SSL",
choices=[
"STANDARD_SSL",
"ADVANTAGE_SSL",
"UC_SSL",
"EV_SSL",
"WILDCARD_SSL",
"PRIVATE_SSL",
"PD_SSL",
"CDS_ENT_LITE",
"CDS_ENT_PRO",
"SMIME_ENT",
],
),
entrust_requester_email=dict(type="str"),
entrust_requester_name=dict(type="str"),
entrust_requester_phone=dict(type="str"),
entrust_api_user=dict(type="str"),
entrust_api_key=dict(type="str", no_log=True),
entrust_api_client_cert_path=dict(type="path"),
entrust_api_client_cert_key_path=dict(type="path", no_log=True),
entrust_api_specification_path=dict(
type="path",
default="https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml",
),
entrust_not_after=dict(type="str", default="+365d"),
)
)
argument_spec.required_if.append(
[
"provider",
"entrust",
[
"entrust_requester_email",
"entrust_requester_name",
"entrust_requester_phone",
"entrust_api_user",
"entrust_api_key",
"entrust_api_client_cert_path",
"entrust_api_client_cert_key_path",
],
]
)

View File

@@ -43,13 +43,14 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
MINIMAL_CRYPTOGRAPHY_VERSION = '1.6'
MINIMAL_CRYPTOGRAPHY_VERSION = "1.6"
CRYPTOGRAPHY_IMP_ERR = None
try:
import cryptography
from cryptography import x509
from cryptography.hazmat.primitives import serialization
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
except ImportError:
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -151,75 +152,97 @@ class CertificateInfoRetrieval(object):
def get_info(self, prefer_one_fingerprint=False, der_support_enabled=False):
result = dict()
self.cert = load_certificate(None, content=self.content, backend=self.backend, der_support_enabled=der_support_enabled)
self.cert = load_certificate(
None,
content=self.content,
backend=self.backend,
der_support_enabled=der_support_enabled,
)
result['signature_algorithm'] = self._get_signature_algorithm()
result["signature_algorithm"] = self._get_signature_algorithm()
subject = self._get_subject_ordered()
issuer = self._get_issuer_ordered()
result['subject'] = dict()
result["subject"] = dict()
for k, v in subject:
result['subject'][k] = v
result['subject_ordered'] = subject
result['issuer'] = dict()
result["subject"][k] = v
result["subject_ordered"] = subject
result["issuer"] = dict()
for k, v in issuer:
result['issuer'][k] = v
result['issuer_ordered'] = issuer
result['version'] = self._get_version()
result['key_usage'], result['key_usage_critical'] = self._get_key_usage()
result['extended_key_usage'], result['extended_key_usage_critical'] = self._get_extended_key_usage()
result['basic_constraints'], result['basic_constraints_critical'] = self._get_basic_constraints()
result['ocsp_must_staple'], result['ocsp_must_staple_critical'] = self._get_ocsp_must_staple()
result['subject_alt_name'], result['subject_alt_name_critical'] = self._get_subject_alt_name()
result["issuer"][k] = v
result["issuer_ordered"] = issuer
result["version"] = self._get_version()
result["key_usage"], result["key_usage_critical"] = self._get_key_usage()
result["extended_key_usage"], result["extended_key_usage_critical"] = (
self._get_extended_key_usage()
)
result["basic_constraints"], result["basic_constraints_critical"] = (
self._get_basic_constraints()
)
result["ocsp_must_staple"], result["ocsp_must_staple_critical"] = (
self._get_ocsp_must_staple()
)
result["subject_alt_name"], result["subject_alt_name_critical"] = (
self._get_subject_alt_name()
)
not_before = self.get_not_before()
not_after = self.get_not_after()
result['not_before'] = not_before.strftime(TIMESTAMP_FORMAT)
result['not_after'] = not_after.strftime(TIMESTAMP_FORMAT)
result['expired'] = not_after < get_now_datetime(with_timezone=CRYPTOGRAPHY_TIMEZONE)
result["not_before"] = not_before.strftime(TIMESTAMP_FORMAT)
result["not_after"] = not_after.strftime(TIMESTAMP_FORMAT)
result["expired"] = not_after < get_now_datetime(
with_timezone=CRYPTOGRAPHY_TIMEZONE
)
result['public_key'] = to_native(self._get_public_key_pem())
result["public_key"] = to_native(self._get_public_key_pem())
public_key_info = get_publickey_info(
self.module,
self.backend,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint)
result.update({
'public_key_type': public_key_info['type'],
'public_key_data': public_key_info['public_data'],
'public_key_fingerprints': public_key_info['fingerprints'],
})
prefer_one_fingerprint=prefer_one_fingerprint,
)
result.update(
{
"public_key_type": public_key_info["type"],
"public_key_data": public_key_info["public_data"],
"public_key_fingerprints": public_key_info["fingerprints"],
}
)
result['fingerprints'] = get_fingerprint_of_bytes(
self._get_der_bytes(), prefer_one=prefer_one_fingerprint)
result["fingerprints"] = get_fingerprint_of_bytes(
self._get_der_bytes(), prefer_one=prefer_one_fingerprint
)
ski = self._get_subject_key_identifier()
if ski is not None:
ski = to_native(binascii.hexlify(ski))
ski = ':'.join([ski[i:i + 2] for i in range(0, len(ski), 2)])
result['subject_key_identifier'] = ski
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
result["subject_key_identifier"] = ski
aki, aci, acsn = self._get_authority_key_identifier()
if aki is not None:
aki = to_native(binascii.hexlify(aki))
aki = ':'.join([aki[i:i + 2] for i in range(0, len(aki), 2)])
result['authority_key_identifier'] = aki
result['authority_cert_issuer'] = aci
result['authority_cert_serial_number'] = acsn
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
result['serial_number'] = self._get_serial_number()
result['extensions_by_oid'] = self._get_all_extensions()
result['ocsp_uri'] = self._get_ocsp_uri()
result['issuer_uri'] = self._get_issuer_uri()
result["serial_number"] = self._get_serial_number()
result["extensions_by_oid"] = self._get_all_extensions()
result["ocsp_uri"] = self._get_ocsp_uri()
result["issuer_uri"] = self._get_issuer_uri()
return result
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
"""Validate the supplied cert, using the cryptography backend"""
def __init__(self, module, content):
super(CertificateInfoRetrievalCryptography, self).__init__(module, 'cryptography', content)
self.name_encoding = module.params.get('name_encoding', 'ignore')
super(CertificateInfoRetrievalCryptography, self).__init__(
module, "cryptography", content
)
self.name_encoding = module.params.get("name_encoding", "ignore")
def _get_der_bytes(self):
return self.cert.public_bytes(serialization.Encoding.DER)
@@ -248,7 +271,9 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
def _get_key_usage(self):
try:
current_key_ext = self.cert.extensions.get_extension_for_class(x509.KeyUsage)
current_key_ext = self.cert.extensions.get_extension_for_class(
x509.KeyUsage
)
current_key_usage = current_key_ext.value
key_usage = dict(
digital_signature=current_key_usage.digital_signature,
@@ -261,45 +286,63 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
encipher_only=False,
decipher_only=False,
)
if key_usage['key_agreement']:
key_usage.update(dict(
encipher_only=current_key_usage.encipher_only,
decipher_only=current_key_usage.decipher_only
))
if key_usage["key_agreement"]:
key_usage.update(
dict(
encipher_only=current_key_usage.encipher_only,
decipher_only=current_key_usage.decipher_only,
)
)
key_usage_names = dict(
digital_signature='Digital Signature',
content_commitment='Non Repudiation',
key_encipherment='Key Encipherment',
data_encipherment='Data Encipherment',
key_agreement='Key Agreement',
key_cert_sign='Certificate Sign',
crl_sign='CRL Sign',
encipher_only='Encipher Only',
decipher_only='Decipher Only',
digital_signature="Digital Signature",
content_commitment="Non Repudiation",
key_encipherment="Key Encipherment",
data_encipherment="Data Encipherment",
key_agreement="Key Agreement",
key_cert_sign="Certificate Sign",
crl_sign="CRL Sign",
encipher_only="Encipher Only",
decipher_only="Decipher Only",
)
return (
sorted(
[
key_usage_names[name]
for name, value in key_usage.items()
if value
]
),
current_key_ext.critical,
)
return sorted([
key_usage_names[name] for name, value in key_usage.items() if value
]), current_key_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_extended_key_usage(self):
try:
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage)
return sorted([
cryptography_oid_to_name(eku) for eku in ext_keyusage_ext.value
]), ext_keyusage_ext.critical
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
x509.ExtendedKeyUsage
)
return (
sorted(
[cryptography_oid_to_name(eku) for eku in ext_keyusage_ext.value]
),
ext_keyusage_ext.critical,
)
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_basic_constraints(self):
try:
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(x509.BasicConstraints)
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
x509.BasicConstraints
)
result = []
result.append('CA:{0}'.format('TRUE' if ext_keyusage_ext.value.ca else 'FALSE'))
result.append(
"CA:{0}".format("TRUE" if ext_keyusage_ext.value.ca else "FALSE")
)
if ext_keyusage_ext.value.path_length is not None:
result.append('pathlen:{0}'.format(ext_keyusage_ext.value.path_length))
result.append("pathlen:{0}".format(ext_keyusage_ext.value.path_length))
return sorted(result), ext_keyusage_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
@@ -308,8 +351,13 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
try:
try:
# This only works with cryptography >= 2.1
tlsfeature_ext = self.cert.extensions.get_extension_for_class(x509.TLSFeature)
value = cryptography.x509.TLSFeatureType.status_request in tlsfeature_ext.value
tlsfeature_ext = self.cert.extensions.get_extension_for_class(
x509.TLSFeature
)
value = (
cryptography.x509.TLSFeatureType.status_request
in tlsfeature_ext.value
)
except AttributeError:
# Fallback for cryptography < 2.1
oid = x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.1.24")
@@ -321,8 +369,13 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
def _get_subject_alt_name(self):
try:
san_ext = self.cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
result = [cryptography_decode_name(san, idn_rewrite=self.name_encoding) for san in san_ext.value]
san_ext = self.cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
result = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in san_ext.value
]
return result, san_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
@@ -344,18 +397,29 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
def _get_subject_key_identifier(self):
try:
ext = self.cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
ext = self.cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
return ext.value.digest
except cryptography.x509.ExtensionNotFound:
return None
def _get_authority_key_identifier(self):
try:
ext = self.cert.extensions.get_extension_for_class(x509.AuthorityKeyIdentifier)
ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier
)
issuer = None
if ext.value.authority_cert_issuer is not None:
issuer = [cryptography_decode_name(san, idn_rewrite=self.name_encoding) for san in ext.value.authority_cert_issuer]
return ext.value.key_identifier, issuer, ext.value.authority_cert_serial_number
issuer = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in ext.value.authority_cert_issuer
]
return (
ext.value.key_identifier,
issuer,
ext.value.authority_cert_serial_number,
)
except cryptography.x509.ExtensionNotFound:
return None, None, None
@@ -367,7 +431,9 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
def _get_ocsp_uri(self):
try:
ext = self.cert.extensions.get_extension_for_class(x509.AuthorityInformationAccess)
ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityInformationAccess
)
for desc in ext.value:
if desc.access_method == x509.oid.AuthorityInformationAccessOID.OCSP:
if isinstance(desc.access_location, x509.UniformResourceIdentifier):
@@ -378,9 +444,14 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
def _get_issuer_uri(self):
try:
ext = self.cert.extensions.get_extension_for_class(x509.AuthorityInformationAccess)
ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityInformationAccess
)
for desc in ext.value:
if desc.access_method == x509.oid.AuthorityInformationAccessOID.CA_ISSUERS:
if (
desc.access_method
== x509.oid.AuthorityInformationAccessOID.CA_ISSUERS
):
if isinstance(desc.access_location, x509.UniformResourceIdentifier):
return desc.access_location.value
except x509.ExtensionNotFound:
@@ -389,29 +460,40 @@ class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
def get_certificate_info(module, backend, content, prefer_one_fingerprint=False):
if backend == 'cryptography':
if backend == "cryptography":
info = CertificateInfoRetrievalCryptography(module, content)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(module, backend, content):
if backend == 'auto':
if backend == "auto":
# Detection what is possible
can_use_cryptography = CRYPTOGRAPHY_FOUND and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
can_use_cryptography = (
CRYPTOGRAPHY_FOUND
and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
)
# Try cryptography
if can_use_cryptography:
backend = 'cryptography'
backend = "cryptography"
# Success?
if backend == 'auto':
module.fail_json(msg=("Cannot detect any of the required Python libraries "
"cryptography (>= {0})").format(MINIMAL_CRYPTOGRAPHY_VERSION))
if backend == "auto":
module.fail_json(
msg=(
"Cannot detect any of the required Python libraries "
"cryptography (>= {0})"
).format(MINIMAL_CRYPTOGRAPHY_VERSION)
)
if backend == 'cryptography':
if backend == "cryptography":
if not CRYPTOGRAPHY_FOUND:
module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
exception=CRYPTOGRAPHY_IMP_ERR)
module.fail_json(
msg=missing_required_lib(
"cryptography >= {0}".format(MINIMAL_CRYPTOGRAPHY_VERSION)
),
exception=CRYPTOGRAPHY_IMP_ERR,
)
return backend, CertificateInfoRetrievalCryptography(module, content)
else:
raise ValueError('Unsupported value for backend: {0}'.format(backend))
raise ValueError("Unsupported value for backend: {0}".format(backend))

View File

@@ -58,75 +58,90 @@ except ImportError:
class OwnCACertificateBackendCryptography(CertificateBackend):
def __init__(self, module):
super(OwnCACertificateBackendCryptography, self).__init__(module, 'cryptography')
super(OwnCACertificateBackendCryptography, self).__init__(
module, "cryptography"
)
self.create_subject_key_identifier = module.params['ownca_create_subject_key_identifier']
self.create_authority_key_identifier = module.params['ownca_create_authority_key_identifier']
self.create_subject_key_identifier = module.params[
"ownca_create_subject_key_identifier"
]
self.create_authority_key_identifier = module.params[
"ownca_create_authority_key_identifier"
]
self.notBefore = get_relative_time_option(
module.params['ownca_not_before'],
'ownca_not_before',
module.params["ownca_not_before"],
"ownca_not_before",
backend=self.backend,
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.notAfter = get_relative_time_option(
module.params['ownca_not_after'],
'ownca_not_after',
module.params["ownca_not_after"],
"ownca_not_after",
backend=self.backend,
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params['ownca_digest'])
self.version = module.params['ownca_version']
self.digest = select_message_digest(module.params["ownca_digest"])
self.version = module.params["ownca_version"]
self.serial_number = x509.random_serial_number()
self.ca_cert_path = module.params['ownca_path']
self.ca_cert_content = module.params['ownca_content']
self.ca_cert_path = module.params["ownca_path"]
self.ca_cert_content = module.params["ownca_content"]
if self.ca_cert_content is not None:
self.ca_cert_content = self.ca_cert_content.encode('utf-8')
self.ca_privatekey_path = module.params['ownca_privatekey_path']
self.ca_privatekey_content = module.params['ownca_privatekey_content']
self.ca_cert_content = self.ca_cert_content.encode("utf-8")
self.ca_privatekey_path = module.params["ownca_privatekey_path"]
self.ca_privatekey_content = module.params["ownca_privatekey_content"]
if self.ca_privatekey_content is not None:
self.ca_privatekey_content = self.ca_privatekey_content.encode('utf-8')
self.ca_privatekey_passphrase = module.params['ownca_privatekey_passphrase']
self.ca_privatekey_content = self.ca_privatekey_content.encode("utf-8")
self.ca_privatekey_passphrase = module.params["ownca_privatekey_passphrase"]
if self.csr_content is None and self.csr_path is None:
raise CertificateError(
'csr_path or csr_content is required for ownca provider'
"csr_path or csr_content is required for ownca provider"
)
if self.csr_content is None and not os.path.exists(self.csr_path):
raise CertificateError(
'The certificate signing request file {0} does not exist'.format(self.csr_path)
"The certificate signing request file {0} does not exist".format(
self.csr_path
)
)
if self.ca_cert_content is None and not os.path.exists(self.ca_cert_path):
raise CertificateError(
'The CA certificate file {0} does not exist'.format(self.ca_cert_path)
"The CA certificate file {0} does not exist".format(self.ca_cert_path)
)
if self.ca_privatekey_content is None and not os.path.exists(self.ca_privatekey_path):
if self.ca_privatekey_content is None and not os.path.exists(
self.ca_privatekey_path
):
raise CertificateError(
'The CA private key file {0} does not exist'.format(self.ca_privatekey_path)
"The CA private key file {0} does not exist".format(
self.ca_privatekey_path
)
)
self._ensure_csr_loaded()
self.ca_cert = load_certificate(
path=self.ca_cert_path,
content=self.ca_cert_content,
backend=self.backend
path=self.ca_cert_path, content=self.ca_cert_content, backend=self.backend
)
try:
self.ca_private_key = load_privatekey(
path=self.ca_privatekey_path,
content=self.ca_privatekey_content,
passphrase=self.ca_privatekey_passphrase,
backend=self.backend
backend=self.backend,
)
except OpenSSLBadPassphraseError as exc:
module.fail_json(msg=str(exc))
if not cryptography_compare_public_keys(self.ca_cert.public_key(), self.ca_private_key.public_key()):
raise CertificateError('The CA private key does not belong to the CA certificate')
if not cryptography_compare_public_keys(
self.ca_cert.public_key(), self.ca_private_key.public_key()
):
raise CertificateError(
"The CA private key does not belong to the CA certificate"
)
if cryptography_key_needs_digest_for_signing(self.ca_private_key):
if self.digest is None:
raise CertificateError(
'The digest %s is not supported with the cryptography backend' % module.params['ownca_digest']
"The digest %s is not supported with the cryptography backend"
% module.params["ownca_digest"]
)
else:
self.digest = None
@@ -143,40 +158,60 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
has_ski = False
for extension in self.csr.extensions:
if isinstance(extension.value, x509.SubjectKeyIdentifier):
if self.create_subject_key_identifier == 'always_create':
if self.create_subject_key_identifier == "always_create":
continue
has_ski = True
if self.create_authority_key_identifier and isinstance(extension.value, x509.AuthorityKeyIdentifier):
if self.create_authority_key_identifier and isinstance(
extension.value, x509.AuthorityKeyIdentifier
):
continue
cert_builder = cert_builder.add_extension(extension.value, critical=extension.critical)
if not has_ski and self.create_subject_key_identifier != 'never_create':
cert_builder = cert_builder.add_extension(
extension.value, critical=extension.critical
)
if not has_ski and self.create_subject_key_identifier != "never_create":
cert_builder = cert_builder.add_extension(
x509.SubjectKeyIdentifier.from_public_key(self.csr.public_key()),
critical=False
critical=False,
)
if self.create_authority_key_identifier:
try:
ext = self.ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
ext = self.ca_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
cert_builder = cert_builder.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(ext.value)
if CRYPTOGRAPHY_VERSION >= LooseVersion('2.7') else
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(ext),
critical=False
(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ext.value
)
if CRYPTOGRAPHY_VERSION >= LooseVersion("2.7")
else x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ext
)
),
critical=False,
)
except cryptography.x509.ExtensionNotFound:
cert_builder = cert_builder.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_public_key(self.ca_cert.public_key()),
critical=False
x509.AuthorityKeyIdentifier.from_issuer_public_key(
self.ca_cert.public_key()
),
critical=False,
)
try:
certificate = cert_builder.sign(
private_key=self.ca_private_key, algorithm=self.digest,
backend=default_backend()
private_key=self.ca_private_key,
algorithm=self.digest,
backend=default_backend(),
)
except TypeError as e:
if str(e) == 'Algorithm must be a registered hash algorithm.' and self.digest is None:
self.module.fail_json(msg='Signing with Ed25519 and Ed448 keys requires cryptography 2.8 or newer.')
if (
str(e) == "Algorithm must be a registered hash algorithm."
and self.digest is None
):
self.module.fail_json(
msg="Signing with Ed25519 and Ed448 keys requires cryptography 2.8 or newer."
)
raise
self.cert = certificate
@@ -186,13 +221,17 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
return self.cert.public_bytes(Encoding.PEM)
def needs_regeneration(self):
if super(OwnCACertificateBackendCryptography, self).needs_regeneration(not_before=self.notBefore, not_after=self.notAfter):
if super(OwnCACertificateBackendCryptography, self).needs_regeneration(
not_before=self.notBefore, not_after=self.notAfter
):
return True
self._ensure_existing_certificate_loaded()
# Check whether certificate is signed by CA certificate
if not cryptography_verify_certificate_signature(self.existing_certificate, self.ca_cert.public_key()):
if not cryptography_verify_certificate_signature(
self.existing_certificate, self.ca_cert.public_key()
):
return True
# Check subject
@@ -202,17 +241,27 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
# Check AuthorityKeyIdentifier
if self.create_authority_key_identifier:
try:
ext = self.ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
ext = self.ca_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
expected_ext = (
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(ext.value)
if CRYPTOGRAPHY_VERSION >= LooseVersion('2.7') else
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(ext)
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ext.value
)
if CRYPTOGRAPHY_VERSION >= LooseVersion("2.7")
else x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ext
)
)
except cryptography.x509.ExtensionNotFound:
expected_ext = x509.AuthorityKeyIdentifier.from_issuer_public_key(self.ca_cert.public_key())
expected_ext = x509.AuthorityKeyIdentifier.from_issuer_public_key(
self.ca_cert.public_key()
)
try:
ext = self.existing_certificate.extensions.get_extension_for_class(x509.AuthorityKeyIdentifier)
ext = self.existing_certificate.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier
)
if ext.value != expected_ext:
return True
except cryptography.x509.ExtensionNotFound:
@@ -221,26 +270,38 @@ class OwnCACertificateBackendCryptography(CertificateBackend):
return False
def dump(self, include_certificate):
result = super(OwnCACertificateBackendCryptography, self).dump(include_certificate)
result.update({
'ca_cert': self.ca_cert_path,
'ca_privatekey': self.ca_privatekey_path,
})
result = super(OwnCACertificateBackendCryptography, self).dump(
include_certificate
)
result.update(
{
"ca_cert": self.ca_cert_path,
"ca_privatekey": self.ca_privatekey_path,
}
)
if self.module.check_mode:
result.update({
'notBefore': self.notBefore.strftime("%Y%m%d%H%M%SZ"),
'notAfter': self.notAfter.strftime("%Y%m%d%H%M%SZ"),
'serial_number': self.serial_number,
})
result.update(
{
"notBefore": self.notBefore.strftime("%Y%m%d%H%M%SZ"),
"notAfter": self.notAfter.strftime("%Y%m%d%H%M%SZ"),
"serial_number": self.serial_number,
}
)
else:
if self.cert is None:
self.cert = self.existing_certificate
result.update({
'notBefore': get_not_valid_before(self.cert).strftime("%Y%m%d%H%M%SZ"),
'notAfter': get_not_valid_after(self.cert).strftime("%Y%m%d%H%M%SZ"),
'serial_number': cryptography_serial_number_of_cert(self.cert),
})
result.update(
{
"notBefore": get_not_valid_before(self.cert).strftime(
"%Y%m%d%H%M%SZ"
),
"notAfter": get_not_valid_after(self.cert).strftime(
"%Y%m%d%H%M%SZ"
),
"serial_number": cryptography_serial_number_of_cert(self.cert),
}
)
return result
@@ -255,39 +316,53 @@ def generate_serial_number():
class OwnCACertificateProvider(CertificateProvider):
def validate_module_args(self, module):
if module.params['ownca_path'] is None and module.params['ownca_content'] is None:
module.fail_json(msg='One of ownca_path and ownca_content must be specified for the ownca provider.')
if module.params['ownca_privatekey_path'] is None and module.params['ownca_privatekey_content'] is None:
module.fail_json(msg='One of ownca_privatekey_path and ownca_privatekey_content must be specified for the ownca provider.')
if (
module.params["ownca_path"] is None
and module.params["ownca_content"] is None
):
module.fail_json(
msg="One of ownca_path and ownca_content must be specified for the ownca provider."
)
if (
module.params["ownca_privatekey_path"] is None
and module.params["ownca_privatekey_content"] is None
):
module.fail_json(
msg="One of ownca_privatekey_path and ownca_privatekey_content must be specified for the ownca provider."
)
def needs_version_two_certs(self, module):
return module.params['ownca_version'] == 2
return module.params["ownca_version"] == 2
def create_backend(self, module, backend):
if backend == 'cryptography':
if backend == "cryptography":
return OwnCACertificateBackendCryptography(module)
def add_ownca_provider_to_argument_spec(argument_spec):
argument_spec.argument_spec['provider']['choices'].append('ownca')
argument_spec.argument_spec.update(dict(
ownca_path=dict(type='path'),
ownca_content=dict(type='str'),
ownca_privatekey_path=dict(type='path'),
ownca_privatekey_content=dict(type='str', no_log=True),
ownca_privatekey_passphrase=dict(type='str', no_log=True),
ownca_digest=dict(type='str', default='sha256'),
ownca_version=dict(type='int', default=3),
ownca_not_before=dict(type='str', default='+0s'),
ownca_not_after=dict(type='str', default='+3650d'),
ownca_create_subject_key_identifier=dict(
type='str',
default='create_if_not_provided',
choices=['create_if_not_provided', 'always_create', 'never_create']
),
ownca_create_authority_key_identifier=dict(type='bool', default=True),
))
argument_spec.mutually_exclusive.extend([
['ownca_path', 'ownca_content'],
['ownca_privatekey_path', 'ownca_privatekey_content'],
])
argument_spec.argument_spec["provider"]["choices"].append("ownca")
argument_spec.argument_spec.update(
dict(
ownca_path=dict(type="path"),
ownca_content=dict(type="str"),
ownca_privatekey_path=dict(type="path"),
ownca_privatekey_content=dict(type="str", no_log=True),
ownca_privatekey_passphrase=dict(type="str", no_log=True),
ownca_digest=dict(type="str", default="sha256"),
ownca_version=dict(type="int", default=3),
ownca_not_before=dict(type="str", default="+0s"),
ownca_not_after=dict(type="str", default="+3650d"),
ownca_create_subject_key_identifier=dict(
type="str",
default="create_if_not_provided",
choices=["create_if_not_provided", "always_create", "never_create"],
),
ownca_create_authority_key_identifier=dict(type="bool", default=True),
)
)
argument_spec.mutually_exclusive.extend(
[
["ownca_path", "ownca_content"],
["ownca_privatekey_path", "ownca_privatekey_content"],
]
)

View File

@@ -48,32 +48,38 @@ except ImportError:
class SelfSignedCertificateBackendCryptography(CertificateBackend):
def __init__(self, module):
super(SelfSignedCertificateBackendCryptography, self).__init__(module, 'cryptography')
super(SelfSignedCertificateBackendCryptography, self).__init__(
module, "cryptography"
)
self.create_subject_key_identifier = module.params['selfsigned_create_subject_key_identifier']
self.create_subject_key_identifier = module.params[
"selfsigned_create_subject_key_identifier"
]
self.notBefore = get_relative_time_option(
module.params['selfsigned_not_before'],
'selfsigned_not_before',
module.params["selfsigned_not_before"],
"selfsigned_not_before",
backend=self.backend,
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.notAfter = get_relative_time_option(
module.params['selfsigned_not_after'],
'selfsigned_not_after',
module.params["selfsigned_not_after"],
"selfsigned_not_after",
backend=self.backend,
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params['selfsigned_digest'])
self.version = module.params['selfsigned_version']
self.digest = select_message_digest(module.params["selfsigned_digest"])
self.version = module.params["selfsigned_version"]
self.serial_number = x509.random_serial_number()
if self.csr_path is not None and not os.path.exists(self.csr_path):
raise CertificateError(
'The certificate signing request file {0} does not exist'.format(self.csr_path)
"The certificate signing request file {0} does not exist".format(
self.csr_path
)
)
if self.privatekey_content is None and not os.path.exists(self.privatekey_path):
raise CertificateError(
'The private key file {0} does not exist'.format(self.privatekey_path)
"The private key file {0} does not exist".format(self.privatekey_path)
)
self._module = module
@@ -89,18 +95,28 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
if cryptography_key_needs_digest_for_signing(self.privatekey):
digest = self.digest
if digest is None:
self.module.fail_json(msg='Unsupported digest "{0}"'.format(module.params['selfsigned_digest']))
self.module.fail_json(
msg='Unsupported digest "{0}"'.format(
module.params["selfsigned_digest"]
)
)
try:
self.csr = csr.sign(self.privatekey, digest, default_backend())
except TypeError as e:
if str(e) == 'Algorithm must be a registered hash algorithm.' and digest is None:
self.module.fail_json(msg='Signing with Ed25519 and Ed448 keys requires cryptography 2.8 or newer.')
if (
str(e) == "Algorithm must be a registered hash algorithm."
and digest is None
):
self.module.fail_json(
msg="Signing with Ed25519 and Ed448 keys requires cryptography 2.8 or newer."
)
raise
if cryptography_key_needs_digest_for_signing(self.privatekey):
if self.digest is None:
raise CertificateError(
'The digest %s is not supported with the cryptography backend' % module.params['selfsigned_digest']
"The digest %s is not supported with the cryptography backend"
% module.params["selfsigned_digest"]
)
else:
self.digest = None
@@ -118,26 +134,36 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
has_ski = False
for extension in self.csr.extensions:
if isinstance(extension.value, x509.SubjectKeyIdentifier):
if self.create_subject_key_identifier == 'always_create':
if self.create_subject_key_identifier == "always_create":
continue
has_ski = True
cert_builder = cert_builder.add_extension(extension.value, critical=extension.critical)
if not has_ski and self.create_subject_key_identifier != 'never_create':
cert_builder = cert_builder.add_extension(
x509.SubjectKeyIdentifier.from_public_key(self.privatekey.public_key()),
critical=False
extension.value, critical=extension.critical
)
if not has_ski and self.create_subject_key_identifier != "never_create":
cert_builder = cert_builder.add_extension(
x509.SubjectKeyIdentifier.from_public_key(
self.privatekey.public_key()
),
critical=False,
)
except ValueError as e:
raise CertificateError(str(e))
try:
certificate = cert_builder.sign(
private_key=self.privatekey, algorithm=self.digest,
backend=default_backend()
private_key=self.privatekey,
algorithm=self.digest,
backend=default_backend(),
)
except TypeError as e:
if str(e) == 'Algorithm must be a registered hash algorithm.' and self.digest is None:
self.module.fail_json(msg='Signing with Ed25519 and Ed448 keys requires cryptography 2.8 or newer.')
if (
str(e) == "Algorithm must be a registered hash algorithm."
and self.digest is None
):
self.module.fail_json(
msg="Signing with Ed25519 and Ed448 keys requires cryptography 2.8 or newer."
)
raise
self.cert = certificate
@@ -147,34 +173,48 @@ class SelfSignedCertificateBackendCryptography(CertificateBackend):
return self.cert.public_bytes(Encoding.PEM)
def needs_regeneration(self):
if super(SelfSignedCertificateBackendCryptography, self).needs_regeneration(not_before=self.notBefore, not_after=self.notAfter):
if super(SelfSignedCertificateBackendCryptography, self).needs_regeneration(
not_before=self.notBefore, not_after=self.notAfter
):
return True
self._ensure_existing_certificate_loaded()
# Check whether certificate is signed by private key
if not cryptography_verify_certificate_signature(self.existing_certificate, self.privatekey.public_key()):
if not cryptography_verify_certificate_signature(
self.existing_certificate, self.privatekey.public_key()
):
return True
return False
def dump(self, include_certificate):
result = super(SelfSignedCertificateBackendCryptography, self).dump(include_certificate)
result = super(SelfSignedCertificateBackendCryptography, self).dump(
include_certificate
)
if self.module.check_mode:
result.update({
'notBefore': self.notBefore.strftime("%Y%m%d%H%M%SZ"),
'notAfter': self.notAfter.strftime("%Y%m%d%H%M%SZ"),
'serial_number': self.serial_number,
})
result.update(
{
"notBefore": self.notBefore.strftime("%Y%m%d%H%M%SZ"),
"notAfter": self.notAfter.strftime("%Y%m%d%H%M%SZ"),
"serial_number": self.serial_number,
}
)
else:
if self.cert is None:
self.cert = self.existing_certificate
result.update({
'notBefore': get_not_valid_before(self.cert).strftime("%Y%m%d%H%M%SZ"),
'notAfter': get_not_valid_after(self.cert).strftime("%Y%m%d%H%M%SZ"),
'serial_number': cryptography_serial_number_of_cert(self.cert),
})
result.update(
{
"notBefore": get_not_valid_before(self.cert).strftime(
"%Y%m%d%H%M%SZ"
),
"notAfter": get_not_valid_after(self.cert).strftime(
"%Y%m%d%H%M%SZ"
),
"serial_number": cryptography_serial_number_of_cert(self.cert),
}
)
return result
@@ -189,27 +229,38 @@ def generate_serial_number():
class SelfSignedCertificateProvider(CertificateProvider):
def validate_module_args(self, module):
if module.params['privatekey_path'] is None and module.params['privatekey_content'] is None:
module.fail_json(msg='One of privatekey_path and privatekey_content must be specified for the selfsigned provider.')
if (
module.params["privatekey_path"] is None
and module.params["privatekey_content"] is None
):
module.fail_json(
msg="One of privatekey_path and privatekey_content must be specified for the selfsigned provider."
)
def needs_version_two_certs(self, module):
return module.params['selfsigned_version'] == 2
return module.params["selfsigned_version"] == 2
def create_backend(self, module, backend):
if backend == 'cryptography':
if backend == "cryptography":
return SelfSignedCertificateBackendCryptography(module)
def add_selfsigned_provider_to_argument_spec(argument_spec):
argument_spec.argument_spec['provider']['choices'].append('selfsigned')
argument_spec.argument_spec.update(dict(
selfsigned_version=dict(type='int', default=3),
selfsigned_digest=dict(type='str', default='sha256'),
selfsigned_not_before=dict(type='str', default='+0s', aliases=['selfsigned_notBefore']),
selfsigned_not_after=dict(type='str', default='+3650d', aliases=['selfsigned_notAfter']),
selfsigned_create_subject_key_identifier=dict(
type='str',
default='create_if_not_provided',
choices=['create_if_not_provided', 'always_create', 'never_create']
),
))
argument_spec.argument_spec["provider"]["choices"].append("selfsigned")
argument_spec.argument_spec.update(
dict(
selfsigned_version=dict(type="int", default=3),
selfsigned_digest=dict(type="str", default="sha256"),
selfsigned_not_before=dict(
type="str", default="+0s", aliases=["selfsigned_notBefore"]
),
selfsigned_not_after=dict(
type="str", default="+3650d", aliases=["selfsigned_notAfter"]
),
selfsigned_create_subject_key_identifier=dict(
type="str",
default="create_if_not_provided",
choices=["create_if_not_provided", "always_create", "never_create"],
),
)
)

View File

@@ -18,14 +18,16 @@ from ansible_collections.community.crypto.plugins.module_utils.argspec import (
class ArgumentSpec(_ArgumentSpec):
def create_ansible_module_helper(self, clazz, args, **kwargs):
result = super(ArgumentSpec, self).create_ansible_module_helper(clazz, args, **kwargs)
result = super(ArgumentSpec, self).create_ansible_module_helper(
clazz, args, **kwargs
)
result.deprecate(
"The crypto.module_backends.common module utils is deprecated and will be removed from community.crypto 3.0.0."
" Use the argspec module utils from community.crypto instead.",
version='3.0.0',
collection_name='community.crypto',
version="3.0.0",
collection_name="community.crypto",
)
return result
__all__ = ('AnsibleModule', 'ArgumentSpec')
__all__ = ("AnsibleModule", "ArgumentSpec")

View File

@@ -32,13 +32,14 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
# crypto_utils
MINIMAL_CRYPTOGRAPHY_VERSION = '1.2'
MINIMAL_CRYPTOGRAPHY_VERSION = "1.2"
CRYPTOGRAPHY_IMP_ERR = None
try:
import cryptography
from cryptography import x509
from cryptography.hazmat.backends import default_backend
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
except ImportError:
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -53,7 +54,7 @@ class CRLInfoRetrieval(object):
self.module = module
self.content = content
self.list_revoked_certificates = list_revoked_certificates
self.name_encoding = module.params.get('name_encoding', 'ignore')
self.name_encoding = module.params.get("name_encoding", "ignore")
def get_info(self):
self.crl_pem = identify_pem_format(self.content)
@@ -63,41 +64,51 @@ class CRLInfoRetrieval(object):
else:
self.crl = x509.load_der_x509_crl(self.content, default_backend())
except ValueError as e:
self.module.fail_json(msg='Error while decoding CRL: {0}'.format(e))
self.module.fail_json(msg="Error while decoding CRL: {0}".format(e))
result = {
'changed': False,
'format': 'pem' if self.crl_pem else 'der',
'last_update': None,
'next_update': None,
'digest': None,
'issuer_ordered': None,
'issuer': None,
"changed": False,
"format": "pem" if self.crl_pem else "der",
"last_update": None,
"next_update": None,
"digest": None,
"issuer_ordered": None,
"issuer": None,
}
result['last_update'] = self.crl.last_update.strftime(TIMESTAMP_FORMAT)
result['next_update'] = self.crl.next_update.strftime(TIMESTAMP_FORMAT)
result['digest'] = cryptography_oid_to_name(cryptography_get_signature_algorithm_oid_from_crl(self.crl))
result["last_update"] = self.crl.last_update.strftime(TIMESTAMP_FORMAT)
result["next_update"] = self.crl.next_update.strftime(TIMESTAMP_FORMAT)
result["digest"] = cryptography_oid_to_name(
cryptography_get_signature_algorithm_oid_from_crl(self.crl)
)
issuer = []
for attribute in self.crl.issuer:
issuer.append([cryptography_oid_to_name(attribute.oid), attribute.value])
result['issuer_ordered'] = issuer
result['issuer'] = {}
result["issuer_ordered"] = issuer
result["issuer"] = {}
for k, v in issuer:
result['issuer'][k] = v
result["issuer"][k] = v
if self.list_revoked_certificates:
result['revoked_certificates'] = []
result["revoked_certificates"] = []
for cert in self.crl:
entry = cryptography_decode_revoked_certificate(cert)
result['revoked_certificates'].append(cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding))
result["revoked_certificates"].append(
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
)
return result
def get_crl_info(module, content, list_revoked_certificates=True):
if not CRYPTOGRAPHY_FOUND:
module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
exception=CRYPTOGRAPHY_IMP_ERR)
module.fail_json(
msg=missing_required_lib(
"cryptography >= {0}".format(MINIMAL_CRYPTOGRAPHY_VERSION)
),
exception=CRYPTOGRAPHY_IMP_ERR,
)
info = CRLInfoRetrieval(module, content, list_revoked_certificates=list_revoked_certificates)
info = CRLInfoRetrieval(
module, content, list_revoked_certificates=list_revoked_certificates
)
return info.get_info()

View File

@@ -51,7 +51,7 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
MINIMAL_CRYPTOGRAPHY_VERSION = '1.3'
MINIMAL_CRYPTOGRAPHY_VERSION = "1.3"
CRYPTOGRAPHY_IMP_ERR = None
try:
@@ -62,13 +62,16 @@ try:
import cryptography.hazmat.primitives.serialization
import cryptography.x509
import cryptography.x509.oid
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
except ImportError:
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
CRYPTOGRAPHY_FOUND = False
else:
CRYPTOGRAPHY_FOUND = True
CRYPTOGRAPHY_MUST_STAPLE_NAME = cryptography.x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.1.24")
CRYPTOGRAPHY_MUST_STAPLE_NAME = cryptography.x509.oid.ObjectIdentifier(
"1.3.6.1.5.5.7.1.24"
)
CRYPTOGRAPHY_MUST_STAPLE_VALUE = b"\x30\x03\x02\x01\x05"
@@ -88,80 +91,107 @@ class CertificateSigningRequestBackend(object):
def __init__(self, module, backend):
self.module = module
self.backend = backend
self.digest = module.params['digest']
self.privatekey_path = module.params['privatekey_path']
self.privatekey_content = module.params['privatekey_content']
self.digest = module.params["digest"]
self.privatekey_path = module.params["privatekey_path"]
self.privatekey_content = module.params["privatekey_content"]
if self.privatekey_content is not None:
self.privatekey_content = self.privatekey_content.encode('utf-8')
self.privatekey_passphrase = module.params['privatekey_passphrase']
self.version = module.params['version']
self.subjectAltName = module.params['subject_alt_name']
self.subjectAltName_critical = module.params['subject_alt_name_critical']
self.keyUsage = module.params['key_usage']
self.keyUsage_critical = module.params['key_usage_critical']
self.extendedKeyUsage = module.params['extended_key_usage']
self.extendedKeyUsage_critical = module.params['extended_key_usage_critical']
self.basicConstraints = module.params['basic_constraints']
self.basicConstraints_critical = module.params['basic_constraints_critical']
self.ocspMustStaple = module.params['ocsp_must_staple']
self.ocspMustStaple_critical = module.params['ocsp_must_staple_critical']
self.name_constraints_permitted = module.params['name_constraints_permitted'] or []
self.name_constraints_excluded = module.params['name_constraints_excluded'] or []
self.name_constraints_critical = module.params['name_constraints_critical']
self.create_subject_key_identifier = module.params['create_subject_key_identifier']
self.subject_key_identifier = module.params['subject_key_identifier']
self.authority_key_identifier = module.params['authority_key_identifier']
self.authority_cert_issuer = module.params['authority_cert_issuer']
self.authority_cert_serial_number = module.params['authority_cert_serial_number']
self.crl_distribution_points = module.params['crl_distribution_points']
self.privatekey_content = self.privatekey_content.encode("utf-8")
self.privatekey_passphrase = module.params["privatekey_passphrase"]
self.version = module.params["version"]
self.subjectAltName = module.params["subject_alt_name"]
self.subjectAltName_critical = module.params["subject_alt_name_critical"]
self.keyUsage = module.params["key_usage"]
self.keyUsage_critical = module.params["key_usage_critical"]
self.extendedKeyUsage = module.params["extended_key_usage"]
self.extendedKeyUsage_critical = module.params["extended_key_usage_critical"]
self.basicConstraints = module.params["basic_constraints"]
self.basicConstraints_critical = module.params["basic_constraints_critical"]
self.ocspMustStaple = module.params["ocsp_must_staple"]
self.ocspMustStaple_critical = module.params["ocsp_must_staple_critical"]
self.name_constraints_permitted = (
module.params["name_constraints_permitted"] or []
)
self.name_constraints_excluded = (
module.params["name_constraints_excluded"] or []
)
self.name_constraints_critical = module.params["name_constraints_critical"]
self.create_subject_key_identifier = module.params[
"create_subject_key_identifier"
]
self.subject_key_identifier = module.params["subject_key_identifier"]
self.authority_key_identifier = module.params["authority_key_identifier"]
self.authority_cert_issuer = module.params["authority_cert_issuer"]
self.authority_cert_serial_number = module.params[
"authority_cert_serial_number"
]
self.crl_distribution_points = module.params["crl_distribution_points"]
self.csr = None
self.privatekey = None
if self.create_subject_key_identifier and self.subject_key_identifier is not None:
module.fail_json(msg='subject_key_identifier cannot be specified if create_subject_key_identifier is true')
if (
self.create_subject_key_identifier
and self.subject_key_identifier is not None
):
module.fail_json(
msg="subject_key_identifier cannot be specified if create_subject_key_identifier is true"
)
self.ordered_subject = False
self.subject = [
('C', module.params['country_name']),
('ST', module.params['state_or_province_name']),
('L', module.params['locality_name']),
('O', module.params['organization_name']),
('OU', module.params['organizational_unit_name']),
('CN', module.params['common_name']),
('emailAddress', module.params['email_address']),
("C", module.params["country_name"]),
("ST", module.params["state_or_province_name"]),
("L", module.params["locality_name"]),
("O", module.params["organization_name"]),
("OU", module.params["organizational_unit_name"]),
("CN", module.params["common_name"]),
("emailAddress", module.params["email_address"]),
]
self.subject = [(entry[0], entry[1]) for entry in self.subject if entry[1]]
try:
if module.params['subject']:
self.subject = self.subject + parse_name_field(module.params['subject'], 'subject')
if module.params['subject_ordered']:
if module.params["subject"]:
self.subject = self.subject + parse_name_field(
module.params["subject"], "subject"
)
if module.params["subject_ordered"]:
if self.subject:
raise CertificateSigningRequestError('subject_ordered cannot be combined with any other subject field')
self.subject = parse_ordered_name_field(module.params['subject_ordered'], 'subject_ordered')
raise CertificateSigningRequestError(
"subject_ordered cannot be combined with any other subject field"
)
self.subject = parse_ordered_name_field(
module.params["subject_ordered"], "subject_ordered"
)
self.ordered_subject = True
except ValueError as exc:
raise CertificateSigningRequestError(to_native(exc))
self.using_common_name_for_san = False
if not self.subjectAltName and module.params['use_common_name_for_san']:
if not self.subjectAltName and module.params["use_common_name_for_san"]:
for sub in self.subject:
if sub[0] in ('commonName', 'CN'):
self.subjectAltName = ['DNS:%s' % sub[1]]
if sub[0] in ("commonName", "CN"):
self.subjectAltName = ["DNS:%s" % sub[1]]
self.using_common_name_for_san = True
break
if self.subject_key_identifier is not None:
try:
self.subject_key_identifier = binascii.unhexlify(self.subject_key_identifier.replace(':', ''))
self.subject_key_identifier = binascii.unhexlify(
self.subject_key_identifier.replace(":", "")
)
except Exception as e:
raise CertificateSigningRequestError('Cannot parse subject_key_identifier: {0}'.format(e))
raise CertificateSigningRequestError(
"Cannot parse subject_key_identifier: {0}".format(e)
)
if self.authority_key_identifier is not None:
try:
self.authority_key_identifier = binascii.unhexlify(self.authority_key_identifier.replace(':', ''))
self.authority_key_identifier = binascii.unhexlify(
self.authority_key_identifier.replace(":", "")
)
except Exception as e:
raise CertificateSigningRequestError('Cannot parse authority_key_identifier: {0}'.format(e))
raise CertificateSigningRequestError(
"Cannot parse authority_key_identifier: {0}".format(e)
)
self.existing_csr = None
self.existing_csr_bytes = None
@@ -174,8 +204,13 @@ class CertificateSigningRequestBackend(object):
return dict()
try:
result = get_csr_info(
self.module, self.backend, data, validate_signature=False, prefer_one_fingerprint=True)
result['can_parse_csr'] = True
self.module,
self.backend,
data,
validate_signature=False,
prefer_one_fingerprint=True,
)
result["can_parse_csr"] = True
return result
except Exception:
return dict(can_parse_csr=False)
@@ -223,7 +258,9 @@ class CertificateSigningRequestBackend(object):
if self.existing_csr_bytes is None:
return True
try:
self.existing_csr = load_certificate_request(None, content=self.existing_csr_bytes, backend=self.backend)
self.existing_csr = load_certificate_request(
None, content=self.existing_csr_bytes, backend=self.backend
)
except Exception:
return True
self._ensure_private_key_loaded()
@@ -232,15 +269,15 @@ class CertificateSigningRequestBackend(object):
def dump(self, include_csr):
"""Serialize the object into a dictionary."""
result = {
'privatekey': self.privatekey_path,
'subject': self.subject,
'subjectAltName': self.subjectAltName,
'keyUsage': self.keyUsage,
'extendedKeyUsage': self.extendedKeyUsage,
'basicConstraints': self.basicConstraints,
'ocspMustStaple': self.ocspMustStaple,
'name_constraints_permitted': self.name_constraints_permitted,
'name_constraints_excluded': self.name_constraints_excluded,
"privatekey": self.privatekey_path,
"subject": self.subject,
"subjectAltName": self.subjectAltName,
"keyUsage": self.keyUsage,
"extendedKeyUsage": self.extendedKeyUsage,
"basicConstraints": self.basicConstraints,
"ocspMustStaple": self.ocspMustStaple,
"name_constraints_permitted": self.name_constraints_permitted,
"name_constraints_excluded": self.name_constraints_excluded,
}
# Get hold of CSR bytes
csr_bytes = self.existing_csr_bytes
@@ -249,9 +286,9 @@ class CertificateSigningRequestBackend(object):
self.diff_after = self._get_info(csr_bytes)
if include_csr:
# Store result
result['csr'] = csr_bytes.decode('utf-8') if csr_bytes else None
result["csr"] = csr_bytes.decode("utf-8") if csr_bytes else None
result['diff'] = dict(
result["diff"] = dict(
before=self.diff_before,
after=self.diff_after,
)
@@ -268,45 +305,67 @@ def parse_crl_distribution_points(module, crl_distribution_points):
crl_issuer=None,
reasons=None,
)
if parse_crl_distribution_point['full_name'] is not None:
if not parse_crl_distribution_point['full_name']:
raise OpenSSLObjectError('full_name must not be empty')
params['full_name'] = [cryptography_get_name(name, 'full name') for name in parse_crl_distribution_point['full_name']]
if parse_crl_distribution_point['relative_name'] is not None:
if not parse_crl_distribution_point['relative_name']:
raise OpenSSLObjectError('relative_name must not be empty')
if parse_crl_distribution_point["full_name"] is not None:
if not parse_crl_distribution_point["full_name"]:
raise OpenSSLObjectError("full_name must not be empty")
params["full_name"] = [
cryptography_get_name(name, "full name")
for name in parse_crl_distribution_point["full_name"]
]
if parse_crl_distribution_point["relative_name"] is not None:
if not parse_crl_distribution_point["relative_name"]:
raise OpenSSLObjectError("relative_name must not be empty")
try:
params['relative_name'] = cryptography_parse_relative_distinguished_name(parse_crl_distribution_point['relative_name'])
params["relative_name"] = (
cryptography_parse_relative_distinguished_name(
parse_crl_distribution_point["relative_name"]
)
)
except Exception:
# If cryptography's version is < 1.6, the error is probably caused by that
if CRYPTOGRAPHY_VERSION < LooseVersion('1.6'):
raise OpenSSLObjectError('Cannot specify relative_name for cryptography < 1.6')
if CRYPTOGRAPHY_VERSION < LooseVersion("1.6"):
raise OpenSSLObjectError(
"Cannot specify relative_name for cryptography < 1.6"
)
raise
if parse_crl_distribution_point['crl_issuer'] is not None:
if not parse_crl_distribution_point['crl_issuer']:
raise OpenSSLObjectError('crl_issuer must not be empty')
params['crl_issuer'] = [cryptography_get_name(name, 'CRL issuer') for name in parse_crl_distribution_point['crl_issuer']]
if parse_crl_distribution_point['reasons'] is not None:
if parse_crl_distribution_point["crl_issuer"] is not None:
if not parse_crl_distribution_point["crl_issuer"]:
raise OpenSSLObjectError("crl_issuer must not be empty")
params["crl_issuer"] = [
cryptography_get_name(name, "CRL issuer")
for name in parse_crl_distribution_point["crl_issuer"]
]
if parse_crl_distribution_point["reasons"] is not None:
reasons = []
for reason in parse_crl_distribution_point['reasons']:
for reason in parse_crl_distribution_point["reasons"]:
reasons.append(REVOCATION_REASON_MAP[reason])
params['reasons'] = frozenset(reasons)
params["reasons"] = frozenset(reasons)
result.append(cryptography.x509.DistributionPoint(**params))
except (OpenSSLObjectError, ValueError) as e:
raise OpenSSLObjectError('Error while parsing CRL distribution point #{index}: {error}'.format(index=index, error=e))
raise OpenSSLObjectError(
"Error while parsing CRL distribution point #{index}: {error}".format(
index=index, error=e
)
)
return result
# Implementation with using cryptography
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
def __init__(self, module):
super(CertificateSigningRequestCryptographyBackend, self).__init__(module, 'cryptography')
super(CertificateSigningRequestCryptographyBackend, self).__init__(
module, "cryptography"
)
self.cryptography_backend = cryptography.hazmat.backends.default_backend()
if self.version != 1:
module.warn('The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)')
module.warn(
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
)
if self.crl_distribution_points:
self.crl_distribution_points = parse_crl_distribution_points(module, self.crl_distribution_points)
self.crl_distribution_points = parse_crl_distribution_points(
module, self.crl_distribution_points
)
def generate_csr(self):
"""(Re-)Generate CSR."""
@@ -314,82 +373,145 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
csr = cryptography.x509.CertificateSigningRequestBuilder()
try:
csr = csr.subject_name(cryptography.x509.Name([
cryptography.x509.NameAttribute(cryptography_name_to_oid(entry[0]), to_text(entry[1])) for entry in self.subject
]))
csr = csr.subject_name(
cryptography.x509.Name(
[
cryptography.x509.NameAttribute(
cryptography_name_to_oid(entry[0]), to_text(entry[1])
)
for entry in self.subject
]
)
)
except ValueError as e:
raise CertificateSigningRequestError(e)
if self.subjectAltName:
csr = csr.add_extension(cryptography.x509.SubjectAlternativeName([
cryptography_get_name(name) for name in self.subjectAltName
]), critical=self.subjectAltName_critical)
csr = csr.add_extension(
cryptography.x509.SubjectAlternativeName(
[cryptography_get_name(name) for name in self.subjectAltName]
),
critical=self.subjectAltName_critical,
)
if self.keyUsage:
params = cryptography_parse_key_usage_params(self.keyUsage)
csr = csr.add_extension(cryptography.x509.KeyUsage(**params), critical=self.keyUsage_critical)
csr = csr.add_extension(
cryptography.x509.KeyUsage(**params), critical=self.keyUsage_critical
)
if self.extendedKeyUsage:
usages = [cryptography_name_to_oid(usage) for usage in self.extendedKeyUsage]
csr = csr.add_extension(cryptography.x509.ExtendedKeyUsage(usages), critical=self.extendedKeyUsage_critical)
usages = [
cryptography_name_to_oid(usage) for usage in self.extendedKeyUsage
]
csr = csr.add_extension(
cryptography.x509.ExtendedKeyUsage(usages),
critical=self.extendedKeyUsage_critical,
)
if self.basicConstraints:
params = {}
ca, path_length = cryptography_get_basic_constraints(self.basicConstraints)
csr = csr.add_extension(cryptography.x509.BasicConstraints(ca, path_length), critical=self.basicConstraints_critical)
csr = csr.add_extension(
cryptography.x509.BasicConstraints(ca, path_length),
critical=self.basicConstraints_critical,
)
if self.ocspMustStaple:
try:
# This only works with cryptography >= 2.1
csr = csr.add_extension(cryptography.x509.TLSFeature([cryptography.x509.TLSFeatureType.status_request]), critical=self.ocspMustStaple_critical)
csr = csr.add_extension(
cryptography.x509.TLSFeature(
[cryptography.x509.TLSFeatureType.status_request]
),
critical=self.ocspMustStaple_critical,
)
except AttributeError:
csr = csr.add_extension(
cryptography.x509.UnrecognizedExtension(CRYPTOGRAPHY_MUST_STAPLE_NAME, CRYPTOGRAPHY_MUST_STAPLE_VALUE),
critical=self.ocspMustStaple_critical
cryptography.x509.UnrecognizedExtension(
CRYPTOGRAPHY_MUST_STAPLE_NAME, CRYPTOGRAPHY_MUST_STAPLE_VALUE
),
critical=self.ocspMustStaple_critical,
)
if self.name_constraints_permitted or self.name_constraints_excluded:
try:
csr = csr.add_extension(cryptography.x509.NameConstraints(
[cryptography_get_name(name, 'name constraints permitted') for name in self.name_constraints_permitted] or None,
[cryptography_get_name(name, 'name constraints excluded') for name in self.name_constraints_excluded] or None,
), critical=self.name_constraints_critical)
csr = csr.add_extension(
cryptography.x509.NameConstraints(
[
cryptography_get_name(name, "name constraints permitted")
for name in self.name_constraints_permitted
]
or None,
[
cryptography_get_name(name, "name constraints excluded")
for name in self.name_constraints_excluded
]
or None,
),
critical=self.name_constraints_critical,
)
except TypeError as e:
raise OpenSSLObjectError('Error while parsing name constraint: {0}'.format(e))
raise OpenSSLObjectError(
"Error while parsing name constraint: {0}".format(e)
)
if self.create_subject_key_identifier:
csr = csr.add_extension(
cryptography.x509.SubjectKeyIdentifier.from_public_key(self.privatekey.public_key()),
critical=False
cryptography.x509.SubjectKeyIdentifier.from_public_key(
self.privatekey.public_key()
),
critical=False,
)
elif self.subject_key_identifier is not None:
csr = csr.add_extension(cryptography.x509.SubjectKeyIdentifier(self.subject_key_identifier), critical=False)
csr = csr.add_extension(
cryptography.x509.SubjectKeyIdentifier(self.subject_key_identifier),
critical=False,
)
if self.authority_key_identifier is not None or self.authority_cert_issuer is not None or self.authority_cert_serial_number is not None:
if (
self.authority_key_identifier is not None
or self.authority_cert_issuer is not None
or self.authority_cert_serial_number is not None
):
issuers = None
if self.authority_cert_issuer is not None:
issuers = [cryptography_get_name(n, 'authority cert issuer') for n in self.authority_cert_issuer]
issuers = [
cryptography_get_name(n, "authority cert issuer")
for n in self.authority_cert_issuer
]
csr = csr.add_extension(
cryptography.x509.AuthorityKeyIdentifier(self.authority_key_identifier, issuers, self.authority_cert_serial_number),
critical=False
cryptography.x509.AuthorityKeyIdentifier(
self.authority_key_identifier,
issuers,
self.authority_cert_serial_number,
),
critical=False,
)
if self.crl_distribution_points:
csr = csr.add_extension(
cryptography.x509.CRLDistributionPoints(self.crl_distribution_points),
critical=False
critical=False,
)
digest = None
if cryptography_key_needs_digest_for_signing(self.privatekey):
digest = select_message_digest(self.digest)
if digest is None:
raise CertificateSigningRequestError('Unsupported digest "{0}"'.format(self.digest))
raise CertificateSigningRequestError(
'Unsupported digest "{0}"'.format(self.digest)
)
try:
self.csr = csr.sign(self.privatekey, digest, self.cryptography_backend)
except TypeError as e:
if str(e) == 'Algorithm must be a registered hash algorithm.' and digest is None:
self.module.fail_json(msg='Signing with Ed25519 and Ed448 keys requires cryptography 2.8 or newer.')
if (
str(e) == "Algorithm must be a registered hash algorithm."
and digest is None
):
self.module.fail_json(
msg="Signing with Ed25519 and Ed448 keys requires cryptography 2.8 or newer."
)
raise
except UnicodeError as e:
# This catches IDNAErrors, which happens when a bad name is passed as a SAN
@@ -402,20 +524,32 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
# https://github.com/kjd/idna/commit/ebefacd3134d0f5da4745878620a6a1cba86d130
# and then
# https://github.com/kjd/idna/commit/ea03c7b5db7d2a99af082e0239da2b68aeea702a).
msg = 'Error while creating CSR: {0}\n'.format(e)
msg = "Error while creating CSR: {0}\n".format(e)
if self.using_common_name_for_san:
self.module.fail_json(msg=msg + 'This is probably caused because the Common Name is used as a SAN.'
' Specifying use_common_name_for_san=false might fix this.')
self.module.fail_json(msg=msg + 'This is probably caused by an invalid Subject Alternative DNS Name.')
self.module.fail_json(
msg=msg
+ "This is probably caused because the Common Name is used as a SAN."
" Specifying use_common_name_for_san=false might fix this."
)
self.module.fail_json(
msg=msg
+ "This is probably caused by an invalid Subject Alternative DNS Name."
)
def get_csr_data(self):
"""Return bytes for self.csr."""
return self.csr.public_bytes(cryptography.hazmat.primitives.serialization.Encoding.PEM)
return self.csr.public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM
)
def _check_csr(self):
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
def _check_subject(csr):
subject = [(cryptography_name_to_oid(entry[0]), to_text(entry[1])) for entry in self.subject]
subject = [
(cryptography_name_to_oid(entry[0]), to_text(entry[1]))
for entry in self.subject
]
current_subject = [(sub.oid, sub.value) for sub in csr.subject]
if self.ordered_subject:
return subject == current_subject
@@ -424,14 +558,26 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
def _find_extension(extensions, exttype):
return next(
(ext for ext in extensions if isinstance(ext.value, exttype)),
None
(ext for ext in extensions if isinstance(ext.value, exttype)), None
)
def _check_subjectAltName(extensions):
current_altnames_ext = _find_extension(extensions, cryptography.x509.SubjectAlternativeName)
current_altnames = [to_text(altname) for altname in current_altnames_ext.value] if current_altnames_ext else []
altnames = [to_text(cryptography_get_name(altname)) for altname in self.subjectAltName] if self.subjectAltName else []
current_altnames_ext = _find_extension(
extensions, cryptography.x509.SubjectAlternativeName
)
current_altnames = (
[to_text(altname) for altname in current_altnames_ext.value]
if current_altnames_ext
else []
)
altnames = (
[
to_text(cryptography_get_name(altname))
for altname in self.subjectAltName
]
if self.subjectAltName
else []
)
if set(altnames) != set(current_altnames):
return False
if altnames:
@@ -440,23 +586,38 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
return True
def _check_keyUsage(extensions):
current_keyusage_ext = _find_extension(extensions, cryptography.x509.KeyUsage)
current_keyusage_ext = _find_extension(
extensions, cryptography.x509.KeyUsage
)
if not self.keyUsage:
return current_keyusage_ext is None
elif current_keyusage_ext is None:
return False
params = cryptography_parse_key_usage_params(self.keyUsage)
for param in params:
if getattr(current_keyusage_ext.value, '_' + param) != params[param]:
if getattr(current_keyusage_ext.value, "_" + param) != params[param]:
return False
if current_keyusage_ext.critical != self.keyUsage_critical:
return False
return True
def _check_extenededKeyUsage(extensions):
current_usages_ext = _find_extension(extensions, cryptography.x509.ExtendedKeyUsage)
current_usages = [str(usage) for usage in current_usages_ext.value] if current_usages_ext else []
usages = [str(cryptography_name_to_oid(usage)) for usage in self.extendedKeyUsage] if self.extendedKeyUsage else []
current_usages_ext = _find_extension(
extensions, cryptography.x509.ExtendedKeyUsage
)
current_usages = (
[str(usage) for usage in current_usages_ext.value]
if current_usages_ext
else []
)
usages = (
[
str(cryptography_name_to_oid(usage))
for usage in self.extendedKeyUsage
]
if self.extendedKeyUsage
else []
)
if set(current_usages) != set(usages):
return False
if usages:
@@ -477,38 +638,77 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
return False
# Check criticality
if self.basicConstraints:
return bc_ext is not None and bc_ext.critical == self.basicConstraints_critical
return (
bc_ext is not None
and bc_ext.critical == self.basicConstraints_critical
)
else:
return bc_ext is None
def _check_ocspMustStaple(extensions):
try:
# This only works with cryptography >= 2.1
tlsfeature_ext = _find_extension(extensions, cryptography.x509.TLSFeature)
tlsfeature_ext = _find_extension(
extensions, cryptography.x509.TLSFeature
)
has_tlsfeature = True
except AttributeError:
tlsfeature_ext = next(
(ext for ext in extensions if ext.value.oid == CRYPTOGRAPHY_MUST_STAPLE_NAME),
None
(
ext
for ext in extensions
if ext.value.oid == CRYPTOGRAPHY_MUST_STAPLE_NAME
),
None,
)
has_tlsfeature = False
if self.ocspMustStaple:
if not tlsfeature_ext or tlsfeature_ext.critical != self.ocspMustStaple_critical:
if (
not tlsfeature_ext
or tlsfeature_ext.critical != self.ocspMustStaple_critical
):
return False
if has_tlsfeature:
return cryptography.x509.TLSFeatureType.status_request in tlsfeature_ext.value
return (
cryptography.x509.TLSFeatureType.status_request
in tlsfeature_ext.value
)
else:
return tlsfeature_ext.value.value == CRYPTOGRAPHY_MUST_STAPLE_VALUE
else:
return tlsfeature_ext is None
def _check_nameConstraints(extensions):
current_nc_ext = _find_extension(extensions, cryptography.x509.NameConstraints)
current_nc_perm = [to_text(altname) for altname in current_nc_ext.value.permitted_subtrees or []] if current_nc_ext else []
current_nc_excl = [to_text(altname) for altname in current_nc_ext.value.excluded_subtrees or []] if current_nc_ext else []
nc_perm = [to_text(cryptography_get_name(altname, 'name constraints permitted')) for altname in self.name_constraints_permitted]
nc_excl = [to_text(cryptography_get_name(altname, 'name constraints excluded')) for altname in self.name_constraints_excluded]
if set(nc_perm) != set(current_nc_perm) or set(nc_excl) != set(current_nc_excl):
current_nc_ext = _find_extension(
extensions, cryptography.x509.NameConstraints
)
current_nc_perm = (
[
to_text(altname)
for altname in current_nc_ext.value.permitted_subtrees or []
]
if current_nc_ext
else []
)
current_nc_excl = (
[
to_text(altname)
for altname in current_nc_ext.value.excluded_subtrees or []
]
if current_nc_ext
else []
)
nc_perm = [
to_text(cryptography_get_name(altname, "name constraints permitted"))
for altname in self.name_constraints_permitted
]
nc_excl = [
to_text(cryptography_get_name(altname, "name constraints excluded"))
for altname in self.name_constraints_excluded
]
if set(nc_perm) != set(current_nc_perm) or set(nc_excl) != set(
current_nc_excl
):
return False
if nc_perm or nc_excl:
if current_nc_ext.critical != self.name_constraints_critical:
@@ -517,11 +717,16 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
def _check_subject_key_identifier(extensions):
ext = _find_extension(extensions, cryptography.x509.SubjectKeyIdentifier)
if self.create_subject_key_identifier or self.subject_key_identifier is not None:
if (
self.create_subject_key_identifier
or self.subject_key_identifier is not None
):
if not ext or ext.critical:
return False
if self.create_subject_key_identifier:
digest = cryptography.x509.SubjectKeyIdentifier.from_public_key(self.privatekey.public_key()).digest
digest = cryptography.x509.SubjectKeyIdentifier.from_public_key(
self.privatekey.public_key()
).digest
return ext.value.digest == digest
else:
return ext.value.digest == self.subject_key_identifier
@@ -530,18 +735,28 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
def _check_authority_key_identifier(extensions):
ext = _find_extension(extensions, cryptography.x509.AuthorityKeyIdentifier)
if self.authority_key_identifier is not None or self.authority_cert_issuer is not None or self.authority_cert_serial_number is not None:
if (
self.authority_key_identifier is not None
or self.authority_cert_issuer is not None
or self.authority_cert_serial_number is not None
):
if not ext or ext.critical:
return False
aci = None
csr_aci = None
if self.authority_cert_issuer is not None:
aci = [to_text(cryptography_get_name(n, 'authority cert issuer')) for n in self.authority_cert_issuer]
aci = [
to_text(cryptography_get_name(n, "authority cert issuer"))
for n in self.authority_cert_issuer
]
if ext.value.authority_cert_issuer is not None:
csr_aci = [to_text(n) for n in ext.value.authority_cert_issuer]
return (ext.value.key_identifier == self.authority_key_identifier
and csr_aci == aci
and ext.value.authority_cert_serial_number == self.authority_cert_serial_number)
return (
ext.value.key_identifier == self.authority_key_identifier
and csr_aci == aci
and ext.value.authority_cert_serial_number
== self.authority_cert_serial_number
)
else:
return ext is None
@@ -555,11 +770,17 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
def _check_extensions(csr):
extensions = csr.extensions
return (_check_subjectAltName(extensions) and _check_keyUsage(extensions) and
_check_extenededKeyUsage(extensions) and _check_basicConstraints(extensions) and
_check_ocspMustStaple(extensions) and _check_subject_key_identifier(extensions) and
_check_authority_key_identifier(extensions) and _check_nameConstraints(extensions) and
_check_crl_distribution_points(extensions))
return (
_check_subjectAltName(extensions)
and _check_keyUsage(extensions)
and _check_extenededKeyUsage(extensions)
and _check_basicConstraints(extensions)
and _check_ocspMustStaple(extensions)
and _check_subject_key_identifier(extensions)
and _check_authority_key_identifier(extensions)
and _check_nameConstraints(extensions)
and _check_crl_distribution_points(extensions)
)
def _check_signature(csr):
if not csr.is_signature_valid:
@@ -568,107 +789,154 @@ class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBack
# encode both public keys and compare PEMs.
key_a = csr.public_key().public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
)
key_b = self.privatekey.public_key().public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
)
return key_a == key_b
return _check_subject(self.existing_csr) and _check_extensions(self.existing_csr) and _check_signature(self.existing_csr)
return (
_check_subject(self.existing_csr)
and _check_extensions(self.existing_csr)
and _check_signature(self.existing_csr)
)
def select_backend(module, backend):
if backend == 'auto':
if backend == "auto":
# Detection what is possible
can_use_cryptography = CRYPTOGRAPHY_FOUND and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
can_use_cryptography = (
CRYPTOGRAPHY_FOUND
and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
)
# Try cryptography
if can_use_cryptography:
backend = 'cryptography'
backend = "cryptography"
# Success?
if backend == 'auto':
module.fail_json(msg=("Cannot detect any of the required Python libraries "
"cryptography (>= {0})").format(MINIMAL_CRYPTOGRAPHY_VERSION))
if backend == "auto":
module.fail_json(
msg=(
"Cannot detect any of the required Python libraries "
"cryptography (>= {0})"
).format(MINIMAL_CRYPTOGRAPHY_VERSION)
)
if backend == 'cryptography':
if backend == "cryptography":
if not CRYPTOGRAPHY_FOUND:
module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
exception=CRYPTOGRAPHY_IMP_ERR)
module.fail_json(
msg=missing_required_lib(
"cryptography >= {0}".format(MINIMAL_CRYPTOGRAPHY_VERSION)
),
exception=CRYPTOGRAPHY_IMP_ERR,
)
return backend, CertificateSigningRequestCryptographyBackend(module)
else:
raise Exception('Unsupported value for backend: {0}'.format(backend))
raise Exception("Unsupported value for backend: {0}".format(backend))
def get_csr_argument_spec():
return ArgumentSpec(
argument_spec=dict(
digest=dict(type='str', default='sha256'),
privatekey_path=dict(type='path'),
privatekey_content=dict(type='str', no_log=True),
privatekey_passphrase=dict(type='str', no_log=True),
version=dict(type='int', default=1, choices=[1]),
subject=dict(type='dict'),
subject_ordered=dict(type='list', elements='dict'),
country_name=dict(type='str', aliases=['C', 'countryName']),
state_or_province_name=dict(type='str', aliases=['ST', 'stateOrProvinceName']),
locality_name=dict(type='str', aliases=['L', 'localityName']),
organization_name=dict(type='str', aliases=['O', 'organizationName']),
organizational_unit_name=dict(type='str', aliases=['OU', 'organizationalUnitName']),
common_name=dict(type='str', aliases=['CN', 'commonName']),
email_address=dict(type='str', aliases=['E', 'emailAddress']),
subject_alt_name=dict(type='list', elements='str', aliases=['subjectAltName']),
subject_alt_name_critical=dict(type='bool', default=False, aliases=['subjectAltName_critical']),
use_common_name_for_san=dict(type='bool', default=True, aliases=['useCommonNameForSAN']),
key_usage=dict(type='list', elements='str', aliases=['keyUsage']),
key_usage_critical=dict(type='bool', default=False, aliases=['keyUsage_critical']),
extended_key_usage=dict(type='list', elements='str', aliases=['extKeyUsage', 'extendedKeyUsage']),
extended_key_usage_critical=dict(type='bool', default=False, aliases=['extKeyUsage_critical', 'extendedKeyUsage_critical']),
basic_constraints=dict(type='list', elements='str', aliases=['basicConstraints']),
basic_constraints_critical=dict(type='bool', default=False, aliases=['basicConstraints_critical']),
ocsp_must_staple=dict(type='bool', default=False, aliases=['ocspMustStaple']),
ocsp_must_staple_critical=dict(type='bool', default=False, aliases=['ocspMustStaple_critical']),
name_constraints_permitted=dict(type='list', elements='str'),
name_constraints_excluded=dict(type='list', elements='str'),
name_constraints_critical=dict(type='bool', default=False),
create_subject_key_identifier=dict(type='bool', default=False),
subject_key_identifier=dict(type='str'),
authority_key_identifier=dict(type='str'),
authority_cert_issuer=dict(type='list', elements='str'),
authority_cert_serial_number=dict(type='int'),
crl_distribution_points=dict(
type='list',
elements='dict',
options=dict(
full_name=dict(type='list', elements='str'),
relative_name=dict(type='list', elements='str'),
crl_issuer=dict(type='list', elements='str'),
reasons=dict(type='list', elements='str', choices=[
'key_compromise',
'ca_compromise',
'affiliation_changed',
'superseded',
'cessation_of_operation',
'certificate_hold',
'privilege_withdrawn',
'aa_compromise',
]),
),
mutually_exclusive=[('full_name', 'relative_name')],
required_one_of=[('full_name', 'relative_name', 'crl_issuer')],
digest=dict(type="str", default="sha256"),
privatekey_path=dict(type="path"),
privatekey_content=dict(type="str", no_log=True),
privatekey_passphrase=dict(type="str", no_log=True),
version=dict(type="int", default=1, choices=[1]),
subject=dict(type="dict"),
subject_ordered=dict(type="list", elements="dict"),
country_name=dict(type="str", aliases=["C", "countryName"]),
state_or_province_name=dict(
type="str", aliases=["ST", "stateOrProvinceName"]
),
locality_name=dict(type="str", aliases=["L", "localityName"]),
organization_name=dict(type="str", aliases=["O", "organizationName"]),
organizational_unit_name=dict(
type="str", aliases=["OU", "organizationalUnitName"]
),
common_name=dict(type="str", aliases=["CN", "commonName"]),
email_address=dict(type="str", aliases=["E", "emailAddress"]),
subject_alt_name=dict(
type="list", elements="str", aliases=["subjectAltName"]
),
subject_alt_name_critical=dict(
type="bool", default=False, aliases=["subjectAltName_critical"]
),
use_common_name_for_san=dict(
type="bool", default=True, aliases=["useCommonNameForSAN"]
),
key_usage=dict(type="list", elements="str", aliases=["keyUsage"]),
key_usage_critical=dict(
type="bool", default=False, aliases=["keyUsage_critical"]
),
extended_key_usage=dict(
type="list", elements="str", aliases=["extKeyUsage", "extendedKeyUsage"]
),
extended_key_usage_critical=dict(
type="bool",
default=False,
aliases=["extKeyUsage_critical", "extendedKeyUsage_critical"],
),
basic_constraints=dict(
type="list", elements="str", aliases=["basicConstraints"]
),
basic_constraints_critical=dict(
type="bool", default=False, aliases=["basicConstraints_critical"]
),
ocsp_must_staple=dict(
type="bool", default=False, aliases=["ocspMustStaple"]
),
ocsp_must_staple_critical=dict(
type="bool", default=False, aliases=["ocspMustStaple_critical"]
),
name_constraints_permitted=dict(type="list", elements="str"),
name_constraints_excluded=dict(type="list", elements="str"),
name_constraints_critical=dict(type="bool", default=False),
create_subject_key_identifier=dict(type="bool", default=False),
subject_key_identifier=dict(type="str"),
authority_key_identifier=dict(type="str"),
authority_cert_issuer=dict(type="list", elements="str"),
authority_cert_serial_number=dict(type="int"),
crl_distribution_points=dict(
type="list",
elements="dict",
options=dict(
full_name=dict(type="list", elements="str"),
relative_name=dict(type="list", elements="str"),
crl_issuer=dict(type="list", elements="str"),
reasons=dict(
type="list",
elements="str",
choices=[
"key_compromise",
"ca_compromise",
"affiliation_changed",
"superseded",
"cessation_of_operation",
"certificate_hold",
"privilege_withdrawn",
"aa_compromise",
],
),
),
mutually_exclusive=[("full_name", "relative_name")],
required_one_of=[("full_name", "relative_name", "crl_issuer")],
),
select_crypto_backend=dict(
type="str", default="auto", choices=["auto", "cryptography"]
),
select_crypto_backend=dict(type='str', default='auto', choices=['auto', 'cryptography']),
),
required_together=[
['authority_cert_issuer', 'authority_cert_serial_number'],
["authority_cert_issuer", "authority_cert_serial_number"],
],
mutually_exclusive=[
['privatekey_path', 'privatekey_content'],
['subject', 'subject_ordered'],
["privatekey_path", "privatekey_content"],
["subject", "subject_ordered"],
],
required_one_of=[
['privatekey_path', 'privatekey_content'],
["privatekey_path", "privatekey_content"],
],
)

View File

@@ -35,13 +35,14 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
MINIMAL_CRYPTOGRAPHY_VERSION = '1.3'
MINIMAL_CRYPTOGRAPHY_VERSION = "1.3"
CRYPTOGRAPHY_IMP_ERR = None
try:
import cryptography
from cryptography import x509
from cryptography.hazmat.primitives import serialization
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
except ImportError:
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -116,67 +117,80 @@ class CSRInfoRetrieval(object):
def get_info(self, prefer_one_fingerprint=False):
result = dict()
self.csr = load_certificate_request(None, content=self.content, backend=self.backend)
self.csr = load_certificate_request(
None, content=self.content, backend=self.backend
)
subject = self._get_subject_ordered()
result['subject'] = dict()
result["subject"] = dict()
for k, v in subject:
result['subject'][k] = v
result['subject_ordered'] = subject
result['key_usage'], result['key_usage_critical'] = self._get_key_usage()
result['extended_key_usage'], result['extended_key_usage_critical'] = self._get_extended_key_usage()
result['basic_constraints'], result['basic_constraints_critical'] = self._get_basic_constraints()
result['ocsp_must_staple'], result['ocsp_must_staple_critical'] = self._get_ocsp_must_staple()
result['subject_alt_name'], result['subject_alt_name_critical'] = self._get_subject_alt_name()
result["subject"][k] = v
result["subject_ordered"] = subject
result["key_usage"], result["key_usage_critical"] = self._get_key_usage()
result["extended_key_usage"], result["extended_key_usage_critical"] = (
self._get_extended_key_usage()
)
result["basic_constraints"], result["basic_constraints_critical"] = (
self._get_basic_constraints()
)
result["ocsp_must_staple"], result["ocsp_must_staple_critical"] = (
self._get_ocsp_must_staple()
)
result["subject_alt_name"], result["subject_alt_name_critical"] = (
self._get_subject_alt_name()
)
(
result['name_constraints_permitted'],
result['name_constraints_excluded'],
result['name_constraints_critical'],
result["name_constraints_permitted"],
result["name_constraints_excluded"],
result["name_constraints_critical"],
) = self._get_name_constraints()
result['public_key'] = to_native(self._get_public_key_pem())
result["public_key"] = to_native(self._get_public_key_pem())
public_key_info = get_publickey_info(
self.module,
self.backend,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint)
result.update({
'public_key_type': public_key_info['type'],
'public_key_data': public_key_info['public_data'],
'public_key_fingerprints': public_key_info['fingerprints'],
})
prefer_one_fingerprint=prefer_one_fingerprint,
)
result.update(
{
"public_key_type": public_key_info["type"],
"public_key_data": public_key_info["public_data"],
"public_key_fingerprints": public_key_info["fingerprints"],
}
)
ski = self._get_subject_key_identifier()
if ski is not None:
ski = to_native(binascii.hexlify(ski))
ski = ':'.join([ski[i:i + 2] for i in range(0, len(ski), 2)])
result['subject_key_identifier'] = ski
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
result["subject_key_identifier"] = ski
aki, aci, acsn = self._get_authority_key_identifier()
if aki is not None:
aki = to_native(binascii.hexlify(aki))
aki = ':'.join([aki[i:i + 2] for i in range(0, len(aki), 2)])
result['authority_key_identifier'] = aki
result['authority_cert_issuer'] = aci
result['authority_cert_serial_number'] = acsn
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
result['extensions_by_oid'] = self._get_all_extensions()
result["extensions_by_oid"] = self._get_all_extensions()
result['signature_valid'] = self._is_signature_valid()
if self.validate_signature and not result['signature_valid']:
self.module.fail_json(
msg='CSR signature is invalid!',
**result
)
result["signature_valid"] = self._is_signature_valid()
if self.validate_signature and not result["signature_valid"]:
self.module.fail_json(msg="CSR signature is invalid!", **result)
return result
class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
"""Validate the supplied CSR, using the cryptography backend"""
def __init__(self, module, content, validate_signature):
super(CSRInfoRetrievalCryptography, self).__init__(module, 'cryptography', content, validate_signature)
self.name_encoding = module.params.get('name_encoding', 'ignore')
super(CSRInfoRetrievalCryptography, self).__init__(
module, "cryptography", content, validate_signature
)
self.name_encoding = module.params.get("name_encoding", "ignore")
def _get_subject_ordered(self):
result = []
@@ -199,44 +213,60 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
encipher_only=False,
decipher_only=False,
)
if key_usage['key_agreement']:
key_usage.update(dict(
encipher_only=current_key_usage.encipher_only,
decipher_only=current_key_usage.decipher_only
))
if key_usage["key_agreement"]:
key_usage.update(
dict(
encipher_only=current_key_usage.encipher_only,
decipher_only=current_key_usage.decipher_only,
)
)
key_usage_names = dict(
digital_signature='Digital Signature',
content_commitment='Non Repudiation',
key_encipherment='Key Encipherment',
data_encipherment='Data Encipherment',
key_agreement='Key Agreement',
key_cert_sign='Certificate Sign',
crl_sign='CRL Sign',
encipher_only='Encipher Only',
decipher_only='Decipher Only',
digital_signature="Digital Signature",
content_commitment="Non Repudiation",
key_encipherment="Key Encipherment",
data_encipherment="Data Encipherment",
key_agreement="Key Agreement",
key_cert_sign="Certificate Sign",
crl_sign="CRL Sign",
encipher_only="Encipher Only",
decipher_only="Decipher Only",
)
return (
sorted(
[
key_usage_names[name]
for name, value in key_usage.items()
if value
]
),
current_key_ext.critical,
)
return sorted([
key_usage_names[name] for name, value in key_usage.items() if value
]), current_key_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_extended_key_usage(self):
try:
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(x509.ExtendedKeyUsage)
return sorted([
cryptography_oid_to_name(eku) for eku in ext_keyusage_ext.value
]), ext_keyusage_ext.critical
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
x509.ExtendedKeyUsage
)
return (
sorted(
[cryptography_oid_to_name(eku) for eku in ext_keyusage_ext.value]
),
ext_keyusage_ext.critical,
)
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_basic_constraints(self):
try:
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(x509.BasicConstraints)
result = ['CA:{0}'.format('TRUE' if ext_keyusage_ext.value.ca else 'FALSE')]
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
x509.BasicConstraints
)
result = ["CA:{0}".format("TRUE" if ext_keyusage_ext.value.ca else "FALSE")]
if ext_keyusage_ext.value.path_length is not None:
result.append('pathlen:{0}'.format(ext_keyusage_ext.value.path_length))
result.append("pathlen:{0}".format(ext_keyusage_ext.value.path_length))
return sorted(result), ext_keyusage_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
@@ -245,8 +275,13 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
try:
try:
# This only works with cryptography >= 2.1
tlsfeature_ext = self.csr.extensions.get_extension_for_class(x509.TLSFeature)
value = cryptography.x509.TLSFeatureType.status_request in tlsfeature_ext.value
tlsfeature_ext = self.csr.extensions.get_extension_for_class(
x509.TLSFeature
)
value = (
cryptography.x509.TLSFeatureType.status_request
in tlsfeature_ext.value
)
except AttributeError:
# Fallback for cryptography < 2.1
oid = x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.1.24")
@@ -258,8 +293,13 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
def _get_subject_alt_name(self):
try:
san_ext = self.csr.extensions.get_extension_for_class(x509.SubjectAlternativeName)
result = [cryptography_decode_name(san, idn_rewrite=self.name_encoding) for san in san_ext.value]
san_ext = self.csr.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
result = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in san_ext.value
]
return result, san_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
@@ -267,8 +307,14 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
def _get_name_constraints(self):
try:
nc_ext = self.csr.extensions.get_extension_for_class(x509.NameConstraints)
permitted = [cryptography_decode_name(san, idn_rewrite=self.name_encoding) for san in nc_ext.value.permitted_subtrees or []]
excluded = [cryptography_decode_name(san, idn_rewrite=self.name_encoding) for san in nc_ext.value.excluded_subtrees or []]
permitted = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in nc_ext.value.permitted_subtrees or []
]
excluded = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in nc_ext.value.excluded_subtrees or []
]
return permitted, excluded, nc_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, None, False
@@ -291,11 +337,20 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
def _get_authority_key_identifier(self):
try:
ext = self.csr.extensions.get_extension_for_class(x509.AuthorityKeyIdentifier)
ext = self.csr.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier
)
issuer = None
if ext.value.authority_cert_issuer is not None:
issuer = [cryptography_decode_name(san, idn_rewrite=self.name_encoding) for san in ext.value.authority_cert_issuer]
return ext.value.key_identifier, issuer, ext.value.authority_cert_serial_number
issuer = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in ext.value.authority_cert_issuer
]
return (
ext.value.key_identifier,
issuer,
ext.value.authority_cert_serial_number,
)
except cryptography.x509.ExtensionNotFound:
return None, None, None
@@ -306,30 +361,46 @@ class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
return self.csr.is_signature_valid
def get_csr_info(module, backend, content, validate_signature=True, prefer_one_fingerprint=False):
if backend == 'cryptography':
info = CSRInfoRetrievalCryptography(module, content, validate_signature=validate_signature)
def get_csr_info(
module, backend, content, validate_signature=True, prefer_one_fingerprint=False
):
if backend == "cryptography":
info = CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(module, backend, content, validate_signature=True):
if backend == 'auto':
if backend == "auto":
# Detection what is possible
can_use_cryptography = CRYPTOGRAPHY_FOUND and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
can_use_cryptography = (
CRYPTOGRAPHY_FOUND
and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
)
# Try cryptography
if can_use_cryptography:
backend = 'cryptography'
backend = "cryptography"
# Success?
if backend == 'auto':
module.fail_json(msg=("Cannot detect the required Python library "
"cryptography (>= {0})").format(MINIMAL_CRYPTOGRAPHY_VERSION))
if backend == "auto":
module.fail_json(
msg=(
"Cannot detect the required Python library " "cryptography (>= {0})"
).format(MINIMAL_CRYPTOGRAPHY_VERSION)
)
if backend == 'cryptography':
if backend == "cryptography":
if not CRYPTOGRAPHY_FOUND:
module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
exception=CRYPTOGRAPHY_IMP_ERR)
return backend, CSRInfoRetrievalCryptography(module, content, validate_signature=validate_signature)
module.fail_json(
msg=missing_required_lib(
"cryptography >= {0}".format(MINIMAL_CRYPTOGRAPHY_VERSION)
),
exception=CRYPTOGRAPHY_IMP_ERR,
)
return backend, CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
)
else:
raise ValueError('Unsupported value for backend: {0}'.format(backend))
raise ValueError("Unsupported value for backend: {0}".format(backend))

View File

@@ -45,7 +45,7 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
MINIMAL_CRYPTOGRAPHY_VERSION = '1.2.3'
MINIMAL_CRYPTOGRAPHY_VERSION = "1.2.3"
CRYPTOGRAPHY_IMP_ERR = None
try:
@@ -57,6 +57,7 @@ try:
import cryptography.hazmat.primitives.asymmetric.rsa
import cryptography.hazmat.primitives.asymmetric.utils
import cryptography.hazmat.primitives.serialization
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
except ImportError:
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -80,14 +81,14 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyBackend:
def __init__(self, module, backend):
self.module = module
self.type = module.params['type']
self.size = module.params['size']
self.curve = module.params['curve']
self.passphrase = module.params['passphrase']
self.cipher = module.params['cipher']
self.format = module.params['format']
self.format_mismatch = module.params.get('format_mismatch', 'regenerate')
self.regenerate = module.params.get('regenerate', 'full_idempotence')
self.type = module.params["type"]
self.size = module.params["size"]
self.curve = module.params["curve"]
self.passphrase = module.params["passphrase"]
self.cipher = module.params["cipher"]
self.format = module.params["format"]
self.format_mismatch = module.params.get("format_mismatch", "regenerate")
self.regenerate = module.params.get("regenerate", "full_idempotence")
self.backend = backend
self.private_key = None
@@ -103,9 +104,16 @@ class PrivateKeyBackend:
return dict()
result = dict(can_parse_key=False)
try:
result.update(get_privatekey_info(
self.module, self.backend, data, passphrase=self.passphrase,
return_private_key_data=False, prefer_one_fingerprint=True))
result.update(
get_privatekey_info(
self.module,
self.backend,
data,
passphrase=self.passphrase,
return_private_key_data=False,
prefer_one_fingerprint=True,
)
)
except PrivateKeyConsistencyError as exc:
result.update(exc.result)
except PrivateKeyParseError as exc:
@@ -137,7 +145,9 @@ class PrivateKeyBackend:
def set_existing(self, privatekey_bytes):
"""Set existing private key bytes. None indicates that the key does not exist."""
self.existing_private_key_bytes = privatekey_bytes
self.diff_after = self.diff_before = self._get_info(self.existing_private_key_bytes)
self.diff_after = self.diff_before = self._get_info(
self.existing_private_key_bytes
)
def has_existing(self):
"""Query whether an existing private key is/has been there."""
@@ -165,41 +175,51 @@ class PrivateKeyBackend:
def needs_regeneration(self):
"""Check whether a regeneration is necessary."""
if self.regenerate == 'always':
if self.regenerate == "always":
return True
if not self.has_existing():
# key does not exist
return True
if not self._check_passphrase():
if self.regenerate == 'full_idempotence':
if self.regenerate == "full_idempotence":
return True
self.module.fail_json(msg='Unable to read the key. The key is protected with a another passphrase / no passphrase or broken.'
' Will not proceed. To force regeneration, call the module with `generate`'
' set to `full_idempotence` or `always`, or with `force=true`.')
self.module.fail_json(
msg="Unable to read the key. The key is protected with a another passphrase / no passphrase or broken."
" Will not proceed. To force regeneration, call the module with `generate`"
" set to `full_idempotence` or `always`, or with `force=true`."
)
self._ensure_existing_private_key_loaded()
if self.regenerate != 'never':
if self.regenerate != "never":
if not self._check_size_and_type():
if self.regenerate in ('partial_idempotence', 'full_idempotence'):
if self.regenerate in ("partial_idempotence", "full_idempotence"):
return True
self.module.fail_json(msg='Key has wrong type and/or size.'
' Will not proceed. To force regeneration, call the module with `generate`'
' set to `partial_idempotence`, `full_idempotence` or `always`, or with `force=true`.')
self.module.fail_json(
msg="Key has wrong type and/or size."
" Will not proceed. To force regeneration, call the module with `generate`"
" set to `partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
)
# During generation step, regenerate if format does not match and format_mismatch == 'regenerate'
if self.format_mismatch == 'regenerate' and self.regenerate != 'never':
if self.format_mismatch == "regenerate" and self.regenerate != "never":
if not self._check_format():
if self.regenerate in ('partial_idempotence', 'full_idempotence'):
if self.regenerate in ("partial_idempotence", "full_idempotence"):
return True
self.module.fail_json(msg='Key has wrong format.'
' Will not proceed. To force regeneration, call the module with `generate`'
' set to `partial_idempotence`, `full_idempotence` or `always`, or with `force=true`.'
' To convert the key, set `format_mismatch` to `convert`.')
self.module.fail_json(
msg="Key has wrong format."
" Will not proceed. To force regeneration, call the module with `generate`"
" set to `partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
" To convert the key, set `format_mismatch` to `convert`."
)
return False
def needs_conversion(self):
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
# During conversion step, convert if format does not match and format_mismatch == 'convert'
self._ensure_existing_private_key_loaded()
return self.has_existing() and self.format_mismatch == 'convert' and not self._check_format()
return (
self.has_existing()
and self.format_mismatch == "convert"
and not self._check_format()
)
def _get_fingerprint(self):
if self.private_key:
@@ -210,7 +230,9 @@ class PrivateKeyBackend:
# Ignore errors
pass
if self.existing_private_key:
return get_fingerprint_of_privatekey(self.existing_private_key, backend=self.backend)
return get_fingerprint_of_privatekey(
self.existing_private_key, backend=self.backend
)
def dump(self, include_key):
"""Serialize the object into a dictionary."""
@@ -222,12 +244,12 @@ class PrivateKeyBackend:
# Ignore errors
pass
result = {
'type': self.type,
'size': self.size,
'fingerprint': self._get_fingerprint(),
"type": self.type,
"size": self.size,
"fingerprint": self._get_fingerprint(),
}
if self.type == 'ECC':
result['curve'] = self.curve
if self.type == "ECC":
result["curve"] = self.curve
# Get hold of private key bytes
pk_bytes = self.existing_private_key_bytes
if self.private_key is not None:
@@ -236,14 +258,14 @@ class PrivateKeyBackend:
if include_key:
# Store result
if pk_bytes:
if identify_private_key_format(pk_bytes) == 'raw':
result['privatekey'] = base64.b64encode(pk_bytes)
if identify_private_key_format(pk_bytes) == "raw":
result["privatekey"] = base64.b64encode(pk_bytes)
else:
result['privatekey'] = pk_bytes.decode('utf-8')
result["privatekey"] = pk_bytes.decode("utf-8")
else:
result['privatekey'] = None
result["privatekey"] = None
result['diff'] = dict(
result["diff"] = dict(
before=self.diff_before,
after=self.diff_after,
)
@@ -256,7 +278,9 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
def _get_ec_class(self, ectype):
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(ectype)
if ecclass is None:
self.module.fail_json(msg='Your cryptography version does not support {0}'.format(ectype))
self.module.fail_json(
msg="Your cryptography version does not support {0}".format(ectype)
)
return ecclass
def _add_curve(self, name, ectype, deprecated=False):
@@ -266,90 +290,123 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
def verify(privatekey):
ecclass = self._get_ec_class(ectype)
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
return isinstance(
privatekey.private_numbers().public_numbers.curve, ecclass
)
self.curves[name] = {
'create': create,
'verify': verify,
'deprecated': deprecated,
"create": create,
"verify": verify,
"deprecated": deprecated,
}
def __init__(self, module):
super(PrivateKeyCryptographyBackend, self).__init__(module=module, backend='cryptography')
super(PrivateKeyCryptographyBackend, self).__init__(
module=module, backend="cryptography"
)
self.curves = dict()
self._add_curve('secp224r1', 'SECP224R1')
self._add_curve('secp256k1', 'SECP256K1')
self._add_curve('secp256r1', 'SECP256R1')
self._add_curve('secp384r1', 'SECP384R1')
self._add_curve('secp521r1', 'SECP521R1')
self._add_curve('secp192r1', 'SECP192R1', deprecated=True)
self._add_curve('sect163k1', 'SECT163K1', deprecated=True)
self._add_curve('sect163r2', 'SECT163R2', deprecated=True)
self._add_curve('sect233k1', 'SECT233K1', deprecated=True)
self._add_curve('sect233r1', 'SECT233R1', deprecated=True)
self._add_curve('sect283k1', 'SECT283K1', deprecated=True)
self._add_curve('sect283r1', 'SECT283R1', deprecated=True)
self._add_curve('sect409k1', 'SECT409K1', deprecated=True)
self._add_curve('sect409r1', 'SECT409R1', deprecated=True)
self._add_curve('sect571k1', 'SECT571K1', deprecated=True)
self._add_curve('sect571r1', 'SECT571R1', deprecated=True)
self._add_curve('brainpoolP256r1', 'BrainpoolP256R1', deprecated=True)
self._add_curve('brainpoolP384r1', 'BrainpoolP384R1', deprecated=True)
self._add_curve('brainpoolP512r1', 'BrainpoolP512R1', deprecated=True)
self._add_curve("secp224r1", "SECP224R1")
self._add_curve("secp256k1", "SECP256K1")
self._add_curve("secp256r1", "SECP256R1")
self._add_curve("secp384r1", "SECP384R1")
self._add_curve("secp521r1", "SECP521R1")
self._add_curve("secp192r1", "SECP192R1", deprecated=True)
self._add_curve("sect163k1", "SECT163K1", deprecated=True)
self._add_curve("sect163r2", "SECT163R2", deprecated=True)
self._add_curve("sect233k1", "SECT233K1", deprecated=True)
self._add_curve("sect233r1", "SECT233R1", deprecated=True)
self._add_curve("sect283k1", "SECT283K1", deprecated=True)
self._add_curve("sect283r1", "SECT283R1", deprecated=True)
self._add_curve("sect409k1", "SECT409K1", deprecated=True)
self._add_curve("sect409r1", "SECT409R1", deprecated=True)
self._add_curve("sect571k1", "SECT571K1", deprecated=True)
self._add_curve("sect571r1", "SECT571R1", deprecated=True)
self._add_curve("brainpoolP256r1", "BrainpoolP256R1", deprecated=True)
self._add_curve("brainpoolP384r1", "BrainpoolP384R1", deprecated=True)
self._add_curve("brainpoolP512r1", "BrainpoolP512R1", deprecated=True)
self.cryptography_backend = cryptography.hazmat.backends.default_backend()
if not CRYPTOGRAPHY_HAS_X25519 and self.type == 'X25519':
self.module.fail_json(msg='Your cryptography version does not support X25519')
if not CRYPTOGRAPHY_HAS_X25519_FULL and self.type == 'X25519':
self.module.fail_json(msg='Your cryptography version does not support X25519 serialization')
if not CRYPTOGRAPHY_HAS_X448 and self.type == 'X448':
self.module.fail_json(msg='Your cryptography version does not support X448')
if not CRYPTOGRAPHY_HAS_ED25519 and self.type == 'Ed25519':
self.module.fail_json(msg='Your cryptography version does not support Ed25519')
if not CRYPTOGRAPHY_HAS_ED448 and self.type == 'Ed448':
self.module.fail_json(msg='Your cryptography version does not support Ed448')
if not CRYPTOGRAPHY_HAS_X25519 and self.type == "X25519":
self.module.fail_json(
msg="Your cryptography version does not support X25519"
)
if not CRYPTOGRAPHY_HAS_X25519_FULL and self.type == "X25519":
self.module.fail_json(
msg="Your cryptography version does not support X25519 serialization"
)
if not CRYPTOGRAPHY_HAS_X448 and self.type == "X448":
self.module.fail_json(msg="Your cryptography version does not support X448")
if not CRYPTOGRAPHY_HAS_ED25519 and self.type == "Ed25519":
self.module.fail_json(
msg="Your cryptography version does not support Ed25519"
)
if not CRYPTOGRAPHY_HAS_ED448 and self.type == "Ed448":
self.module.fail_json(
msg="Your cryptography version does not support Ed448"
)
def _get_wanted_format(self):
if self.format not in ('auto', 'auto_ignore'):
if self.format not in ("auto", "auto_ignore"):
return self.format
if self.type in ('X25519', 'X448', 'Ed25519', 'Ed448'):
return 'pkcs8'
if self.type in ("X25519", "X448", "Ed25519", "Ed448"):
return "pkcs8"
else:
return 'pkcs1'
return "pkcs1"
def generate_private_key(self):
"""(Re-)Generate private key."""
try:
if self.type == 'RSA':
self.private_key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
public_exponent=65537, # OpenSSL always uses this
key_size=self.size,
backend=self.cryptography_backend
if self.type == "RSA":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
public_exponent=65537, # OpenSSL always uses this
key_size=self.size,
backend=self.cryptography_backend,
)
)
if self.type == 'DSA':
self.private_key = cryptography.hazmat.primitives.asymmetric.dsa.generate_private_key(
key_size=self.size,
backend=self.cryptography_backend
if self.type == "DSA":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.dsa.generate_private_key(
key_size=self.size, backend=self.cryptography_backend
)
)
if CRYPTOGRAPHY_HAS_X25519_FULL and self.type == 'X25519':
self.private_key = cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.generate()
if CRYPTOGRAPHY_HAS_X448 and self.type == 'X448':
self.private_key = cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.generate()
if CRYPTOGRAPHY_HAS_ED25519 and self.type == 'Ed25519':
self.private_key = cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.generate()
if CRYPTOGRAPHY_HAS_ED448 and self.type == 'Ed448':
self.private_key = cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate()
if self.type == 'ECC' and self.curve in self.curves:
if self.curves[self.curve]['deprecated']:
self.module.warn('Elliptic curves of type {0} should not be used for new keys!'.format(self.curve))
self.private_key = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
curve=self.curves[self.curve]['create'](self.size),
backend=self.cryptography_backend
if CRYPTOGRAPHY_HAS_X25519_FULL and self.type == "X25519":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.generate()
)
if CRYPTOGRAPHY_HAS_X448 and self.type == "X448":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.generate()
)
if CRYPTOGRAPHY_HAS_ED25519 and self.type == "Ed25519":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.generate()
)
if CRYPTOGRAPHY_HAS_ED448 and self.type == "Ed448":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate()
)
if self.type == "ECC" and self.curve in self.curves:
if self.curves[self.curve]["deprecated"]:
self.module.warn(
"Elliptic curves of type {0} should not be used for new keys!".format(
self.curve
)
)
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
curve=self.curves[self.curve]["create"](self.size),
backend=self.cryptography_backend,
)
)
except cryptography.exceptions.UnsupportedAlgorithm:
self.module.fail_json(msg='Cryptography backend does not support the algorithm required for {0}'.format(self.type))
self.module.fail_json(
msg="Cryptography backend does not support the algorithm required for {0}".format(
self.type
)
)
def get_private_key_data(self):
"""Return bytes for self.private_key"""
@@ -357,40 +414,62 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
try:
export_format = self._get_wanted_format()
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
if export_format == 'pkcs1':
if export_format == "pkcs1":
# "TraditionalOpenSSL" format is PKCS1
export_format = cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
elif export_format == 'pkcs8':
export_format = cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
elif export_format == 'raw':
export_format = cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.Raw
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
)
elif export_format == "pkcs8":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
)
elif export_format == "raw":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
)
export_encoding = (
cryptography.hazmat.primitives.serialization.Encoding.Raw
)
except AttributeError:
self.module.fail_json(msg='Cryptography backend does not support the selected output format "{0}"'.format(self.format))
self.module.fail_json(
msg='Cryptography backend does not support the selected output format "{0}"'.format(
self.format
)
)
# Select key encryption
encryption_algorithm = cryptography.hazmat.primitives.serialization.NoEncryption()
encryption_algorithm = (
cryptography.hazmat.primitives.serialization.NoEncryption()
)
if self.cipher and self.passphrase:
if self.cipher == 'auto':
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(to_bytes(self.passphrase))
if self.cipher == "auto":
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
to_bytes(self.passphrase)
)
else:
self.module.fail_json(msg='Cryptography backend can only use "auto" for cipher option.')
self.module.fail_json(
msg='Cryptography backend can only use "auto" for cipher option.'
)
# Serialize key
try:
return self.private_key.private_bytes(
encoding=export_encoding,
format=export_format,
encryption_algorithm=encryption_algorithm
encryption_algorithm=encryption_algorithm,
)
except ValueError:
self.module.fail_json(
msg='Cryptography backend cannot serialize the private key in the required format "{0}"'.format(self.format)
msg='Cryptography backend cannot serialize the private key in the required format "{0}"'.format(
self.format
)
)
except Exception:
self.module.fail_json(
msg='Error while serializing the private key in the required format "{0}"'.format(self.format),
exception=traceback.format_exc()
msg='Error while serializing the private key in the required format "{0}"'.format(
self.format
),
exception=traceback.format_exc(),
)
def _load_privatekey(self):
@@ -398,27 +477,45 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
try:
# Interpret bytes depending on format.
format = identify_private_key_format(data)
if format == 'raw':
if format == "raw":
if len(data) == 56 and CRYPTOGRAPHY_HAS_X448:
return cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.from_private_bytes(data)
return cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.from_private_bytes(
data
)
if len(data) == 57 and CRYPTOGRAPHY_HAS_ED448:
return cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.from_private_bytes(data)
return cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.from_private_bytes(
data
)
if len(data) == 32:
if CRYPTOGRAPHY_HAS_X25519 and (self.type == 'X25519' or not CRYPTOGRAPHY_HAS_ED25519):
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(data)
if CRYPTOGRAPHY_HAS_ED25519 and (self.type == 'Ed25519' or not CRYPTOGRAPHY_HAS_X25519):
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(data)
if CRYPTOGRAPHY_HAS_X25519 and (
self.type == "X25519" or not CRYPTOGRAPHY_HAS_ED25519
):
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
)
if CRYPTOGRAPHY_HAS_ED25519 and (
self.type == "Ed25519" or not CRYPTOGRAPHY_HAS_X25519
):
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
)
if CRYPTOGRAPHY_HAS_X25519 and CRYPTOGRAPHY_HAS_ED25519:
try:
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(data)
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
)
except Exception:
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(data)
raise PrivateKeyError('Cannot load raw key')
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
)
raise PrivateKeyError("Cannot load raw key")
else:
return cryptography.hazmat.primitives.serialization.load_pem_private_key(
data,
None if self.passphrase is None else to_bytes(self.passphrase),
backend=self.cryptography_backend
return (
cryptography.hazmat.primitives.serialization.load_pem_private_key(
data,
None if self.passphrase is None else to_bytes(self.passphrase),
backend=self.cryptography_backend,
)
)
except Exception as e:
raise PrivateKeyError(e)
@@ -430,7 +527,7 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
def _check_passphrase(self):
try:
format = identify_private_key_format(self.existing_private_key_bytes)
if format == 'raw':
if format == "raw":
# Raw keys cannot be encrypted. To avoid incompatibilities, we try to
# actually load the key (and return False when this fails).
self._load_privatekey()
@@ -438,38 +535,65 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
# provided.
return self.passphrase is None
else:
return cryptography.hazmat.primitives.serialization.load_pem_private_key(
self.existing_private_key_bytes,
None if self.passphrase is None else to_bytes(self.passphrase),
backend=self.cryptography_backend
return (
cryptography.hazmat.primitives.serialization.load_pem_private_key(
self.existing_private_key_bytes,
None if self.passphrase is None else to_bytes(self.passphrase),
backend=self.cryptography_backend,
)
)
except Exception:
return False
def _check_size_and_type(self):
if isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
return self.type == 'RSA' and self.size == self.existing_private_key.key_size
if isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
return self.type == 'DSA' and self.size == self.existing_private_key.key_size
if CRYPTOGRAPHY_HAS_X25519 and isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey):
return self.type == 'X25519'
if CRYPTOGRAPHY_HAS_X448 and isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey):
return self.type == 'X448'
if CRYPTOGRAPHY_HAS_ED25519 and isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey):
return self.type == 'Ed25519'
if CRYPTOGRAPHY_HAS_ED448 and isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey):
return self.type == 'Ed448'
if isinstance(self.existing_private_key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey):
if self.type != 'ECC':
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey,
):
return (
self.type == "RSA" and self.size == self.existing_private_key.key_size
)
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey,
):
return (
self.type == "DSA" and self.size == self.existing_private_key.key_size
)
if CRYPTOGRAPHY_HAS_X25519 and isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
):
return self.type == "X25519"
if CRYPTOGRAPHY_HAS_X448 and isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
):
return self.type == "X448"
if CRYPTOGRAPHY_HAS_ED25519 and isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
):
return self.type == "Ed25519"
if CRYPTOGRAPHY_HAS_ED448 and isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey,
):
return self.type == "Ed448"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
):
if self.type != "ECC":
return False
if self.curve not in self.curves:
return False
return self.curves[self.curve]['verify'](self.existing_private_key)
return self.curves[self.curve]["verify"](self.existing_private_key)
return False
def _check_format(self):
if self.format == 'auto_ignore':
if self.format == "auto_ignore":
return True
try:
format = identify_private_key_format(self.existing_private_key_bytes)
@@ -479,52 +603,96 @@ class PrivateKeyCryptographyBackend(PrivateKeyBackend):
def select_backend(module, backend):
if backend == 'auto':
if backend == "auto":
# Detection what is possible
can_use_cryptography = CRYPTOGRAPHY_FOUND and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
can_use_cryptography = (
CRYPTOGRAPHY_FOUND
and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
)
# Decision
if can_use_cryptography:
backend = 'cryptography'
backend = "cryptography"
# Success?
if backend == 'auto':
module.fail_json(msg=("Cannot detect the required Python library "
"cryptography (>= {0})").format(MINIMAL_CRYPTOGRAPHY_VERSION))
if backend == 'cryptography':
if backend == "auto":
module.fail_json(
msg=(
"Cannot detect the required Python library " "cryptography (>= {0})"
).format(MINIMAL_CRYPTOGRAPHY_VERSION)
)
if backend == "cryptography":
if not CRYPTOGRAPHY_FOUND:
module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
exception=CRYPTOGRAPHY_IMP_ERR)
module.fail_json(
msg=missing_required_lib(
"cryptography >= {0}".format(MINIMAL_CRYPTOGRAPHY_VERSION)
),
exception=CRYPTOGRAPHY_IMP_ERR,
)
return backend, PrivateKeyCryptographyBackend(module)
else:
raise Exception('Unsupported value for backend: {0}'.format(backend))
raise Exception("Unsupported value for backend: {0}".format(backend))
def get_privatekey_argument_spec():
return ArgumentSpec(
argument_spec=dict(
size=dict(type='int', default=4096),
type=dict(type='str', default='RSA', choices=[
'DSA', 'ECC', 'Ed25519', 'Ed448', 'RSA', 'X25519', 'X448'
]),
curve=dict(type='str', choices=[
'secp224r1', 'secp256k1', 'secp256r1', 'secp384r1', 'secp521r1',
'secp192r1', 'brainpoolP256r1', 'brainpoolP384r1', 'brainpoolP512r1',
'sect163k1', 'sect163r2', 'sect233k1', 'sect233r1', 'sect283k1',
'sect283r1', 'sect409k1', 'sect409r1', 'sect571k1', 'sect571r1',
]),
passphrase=dict(type='str', no_log=True),
cipher=dict(type='str', default='auto'),
format=dict(type='str', default='auto_ignore', choices=['pkcs1', 'pkcs8', 'raw', 'auto', 'auto_ignore']),
format_mismatch=dict(type='str', default='regenerate', choices=['regenerate', 'convert']),
select_crypto_backend=dict(type='str', choices=['auto', 'cryptography'], default='auto'),
size=dict(type="int", default=4096),
type=dict(
type="str",
default="RSA",
choices=["DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"],
),
curve=dict(
type="str",
choices=[
"secp224r1",
"secp256k1",
"secp256r1",
"secp384r1",
"secp521r1",
"secp192r1",
"brainpoolP256r1",
"brainpoolP384r1",
"brainpoolP512r1",
"sect163k1",
"sect163r2",
"sect233k1",
"sect233r1",
"sect283k1",
"sect283r1",
"sect409k1",
"sect409r1",
"sect571k1",
"sect571r1",
],
),
passphrase=dict(type="str", no_log=True),
cipher=dict(type="str", default="auto"),
format=dict(
type="str",
default="auto_ignore",
choices=["pkcs1", "pkcs8", "raw", "auto", "auto_ignore"],
),
format_mismatch=dict(
type="str", default="regenerate", choices=["regenerate", "convert"]
),
select_crypto_backend=dict(
type="str", choices=["auto", "cryptography"], default="auto"
),
regenerate=dict(
type='str',
default='full_idempotence',
choices=['never', 'fail', 'partial_idempotence', 'full_idempotence', 'always']
type="str",
default="full_idempotence",
choices=[
"never",
"fail",
"partial_idempotence",
"full_idempotence",
"always",
],
),
),
required_if=[
['type', 'ECC', ['curve']],
["type", "ECC", ["curve"]],
],
)

View File

@@ -38,7 +38,7 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
MINIMAL_CRYPTOGRAPHY_VERSION = '1.2.3'
MINIMAL_CRYPTOGRAPHY_VERSION = "1.2.3"
CRYPTOGRAPHY_IMP_ERR = None
try:
@@ -50,6 +50,7 @@ try:
import cryptography.hazmat.primitives.asymmetric.rsa
import cryptography.hazmat.primitives.asymmetric.utils
import cryptography.hazmat.primitives.serialization
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
except ImportError:
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -73,18 +74,18 @@ class PrivateKeyError(OpenSSLObjectError):
class PrivateKeyConvertBackend:
def __init__(self, module, backend):
self.module = module
self.src_path = module.params['src_path']
self.src_content = module.params['src_content']
self.src_passphrase = module.params['src_passphrase']
self.format = module.params['format']
self.dest_passphrase = module.params['dest_passphrase']
self.src_path = module.params["src_path"]
self.src_content = module.params["src_content"]
self.src_passphrase = module.params["src_passphrase"]
self.format = module.params["format"]
self.dest_passphrase = module.params["dest_passphrase"]
self.backend = backend
self.src_private_key = None
if self.src_path is not None:
self.src_private_key_bytes = load_file(self.src_path, module)
else:
self.src_private_key_bytes = self.src_content.encode('utf-8')
self.src_private_key_bytes = self.src_content.encode("utf-8")
self.dest_private_key = None
self.dest_private_key_bytes = None
@@ -109,17 +110,25 @@ class PrivateKeyConvertBackend:
def needs_conversion(self):
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
dummy, self.src_private_key = self._load_private_key(self.src_private_key_bytes, self.src_passphrase)
dummy, self.src_private_key = self._load_private_key(
self.src_private_key_bytes, self.src_passphrase
)
if not self.has_existing_destination():
return True
try:
format, self.dest_private_key = self._load_private_key(self.dest_private_key_bytes, self.dest_passphrase, current_hint=self.src_private_key)
format, self.dest_private_key = self._load_private_key(
self.dest_private_key_bytes,
self.dest_passphrase,
current_hint=self.src_private_key,
)
except Exception:
return True
return format != self.format or not cryptography_compare_private_keys(self.dest_private_key, self.src_private_key)
return format != self.format or not cryptography_compare_private_keys(
self.dest_private_key, self.src_private_key
)
def dump(self):
"""Serialize the object into a dictionary."""
@@ -129,7 +138,9 @@ class PrivateKeyConvertBackend:
# Implementation with using cryptography
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def __init__(self, module):
super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module, backend='cryptography')
super(PrivateKeyConvertCryptographyBackend, self).__init__(
module=module, backend="cryptography"
)
self.cryptography_backend = cryptography.hazmat.backends.default_backend()
@@ -138,72 +149,140 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
# Select export format and encoding
try:
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
if self.format == 'pkcs1':
if self.format == "pkcs1":
# "TraditionalOpenSSL" format is PKCS1
export_format = cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
elif self.format == 'pkcs8':
export_format = cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
elif self.format == 'raw':
export_format = cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.Raw
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
)
elif self.format == "pkcs8":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
)
elif self.format == "raw":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
)
export_encoding = (
cryptography.hazmat.primitives.serialization.Encoding.Raw
)
except AttributeError:
self.module.fail_json(msg='Cryptography backend does not support the selected output format "{0}"'.format(self.format))
self.module.fail_json(
msg='Cryptography backend does not support the selected output format "{0}"'.format(
self.format
)
)
# Select key encryption
encryption_algorithm = cryptography.hazmat.primitives.serialization.NoEncryption()
encryption_algorithm = (
cryptography.hazmat.primitives.serialization.NoEncryption()
)
if self.dest_passphrase:
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(to_bytes(self.dest_passphrase))
encryption_algorithm = (
cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
to_bytes(self.dest_passphrase)
)
)
# Serialize key
try:
return self.src_private_key.private_bytes(
encoding=export_encoding,
format=export_format,
encryption_algorithm=encryption_algorithm
encryption_algorithm=encryption_algorithm,
)
except ValueError:
self.module.fail_json(
msg='Cryptography backend cannot serialize the private key in the required format "{0}"'.format(self.format)
msg='Cryptography backend cannot serialize the private key in the required format "{0}"'.format(
self.format
)
)
except Exception:
self.module.fail_json(
msg='Error while serializing the private key in the required format "{0}"'.format(self.format),
exception=traceback.format_exc()
msg='Error while serializing the private key in the required format "{0}"'.format(
self.format
),
exception=traceback.format_exc(),
)
def _load_private_key(self, data, passphrase, current_hint=None):
try:
# Interpret bytes depending on format.
format = identify_private_key_format(data)
if format == 'raw':
if format == "raw":
if passphrase is not None:
raise PrivateKeyError('Cannot load raw key with passphrase')
raise PrivateKeyError("Cannot load raw key with passphrase")
if len(data) == 56 and CRYPTOGRAPHY_HAS_X448:
return format, cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.from_private_bytes(data)
return (
format,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.from_private_bytes(
data
),
)
if len(data) == 57 and CRYPTOGRAPHY_HAS_ED448:
return format, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.from_private_bytes(data)
return (
format,
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.from_private_bytes(
data
),
)
if len(data) == 32:
if CRYPTOGRAPHY_HAS_X25519 and not CRYPTOGRAPHY_HAS_ED25519:
return format, cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(data)
return (
format,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
),
)
if CRYPTOGRAPHY_HAS_ED25519 and not CRYPTOGRAPHY_HAS_X25519:
return format, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(data)
return (
format,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
),
)
if CRYPTOGRAPHY_HAS_X25519 and CRYPTOGRAPHY_HAS_ED25519:
if isinstance(current_hint, cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey):
if isinstance(
current_hint,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
):
try:
return format, cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(data)
return (
format,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
),
)
except Exception:
return format, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(data)
return (
format,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
),
)
else:
try:
return format, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(data)
return (
format,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
),
)
except Exception:
return format, cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(data)
raise PrivateKeyError('Cannot load raw key')
return (
format,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
),
)
raise PrivateKeyError("Cannot load raw key")
else:
return format, cryptography.hazmat.primitives.serialization.load_pem_private_key(
data,
None if passphrase is None else to_bytes(passphrase),
backend=self.cryptography_backend
return (
format,
cryptography.hazmat.primitives.serialization.load_pem_private_key(
data,
None if passphrase is None else to_bytes(passphrase),
backend=self.cryptography_backend,
),
)
except Exception as e:
raise PrivateKeyError(e)
@@ -211,24 +290,28 @@ class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def select_backend(module):
if not CRYPTOGRAPHY_FOUND:
module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
exception=CRYPTOGRAPHY_IMP_ERR)
module.fail_json(
msg=missing_required_lib(
"cryptography >= {0}".format(MINIMAL_CRYPTOGRAPHY_VERSION)
),
exception=CRYPTOGRAPHY_IMP_ERR,
)
return PrivateKeyConvertCryptographyBackend(module)
def get_privatekey_argument_spec():
return ArgumentSpec(
argument_spec=dict(
src_path=dict(type='path'),
src_content=dict(type='str'),
src_passphrase=dict(type='str', no_log=True),
dest_passphrase=dict(type='str', no_log=True),
format=dict(type='str', required=True, choices=['pkcs1', 'pkcs8', 'raw']),
src_path=dict(type="path"),
src_content=dict(type="str"),
src_passphrase=dict(type="str", no_log=True),
dest_passphrase=dict(type="str", no_log=True),
format=dict(type="str", required=True, choices=["pkcs1", "pkcs8", "raw"]),
),
mutually_exclusive=[
['src_path', 'src_content'],
["src_path", "src_content"],
],
required_one_of=[
['src_path', 'src_content'],
["src_path", "src_content"],
],
)

View File

@@ -39,12 +39,13 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
MINIMAL_CRYPTOGRAPHY_VERSION = '1.2.3'
MINIMAL_CRYPTOGRAPHY_VERSION = "1.2.3"
CRYPTOGRAPHY_IMP_ERR = None
try:
import cryptography
from cryptography.hazmat.primitives import serialization
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
except ImportError:
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -52,7 +53,7 @@ except ImportError:
else:
CRYPTOGRAPHY_FOUND = True
SIGNATURE_TEST_DATA = b'1234'
SIGNATURE_TEST_DATA = b"1234"
def _get_cryptography_private_key_info(key, need_private_key_data=False):
@@ -61,25 +62,29 @@ def _get_cryptography_private_key_info(key, need_private_key_data=False):
if need_private_key_data:
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
private_numbers = key.private_numbers()
key_private_data['p'] = private_numbers.p
key_private_data['q'] = private_numbers.q
key_private_data['exponent'] = private_numbers.d
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
key_private_data["p"] = private_numbers.p
key_private_data["q"] = private_numbers.q
key_private_data["exponent"] = private_numbers.d
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey
):
private_numbers = key.private_numbers()
key_private_data['x'] = private_numbers.x
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey):
key_private_data["x"] = private_numbers.x
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
):
private_numbers = key.private_numbers()
key_private_data['multiplier'] = private_numbers.private_value
key_private_data["multiplier"] = private_numbers.private_value
return key_type, key_public_data, key_private_data
def _check_dsa_consistency(key_public_data, key_private_data):
# Get parameters
p = key_public_data.get('p')
q = key_public_data.get('q')
g = key_public_data.get('g')
y = key_public_data.get('y')
x = key_private_data.get('x')
p = key_public_data.get("p")
q = key_public_data.get("q")
g = key_public_data.get("g")
y = key_public_data.get("y")
x = key_private_data.get("x")
for v in (p, q, g, y, x):
if v is None:
return None
@@ -104,10 +109,12 @@ def _check_dsa_consistency(key_public_data, key_private_data):
return True
def _is_cryptography_key_consistent(key, key_public_data, key_private_data, warn_func=None):
def _is_cryptography_key_consistent(
key, key_public_data, key_private_data, warn_func=None
):
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
# key._backend was removed in cryptography 42.0.0
backend = getattr(key, '_backend', None)
backend = getattr(key, "_backend", None)
if backend is not None:
return bool(backend._lib.RSA_check_key(key._rsa_cdata))
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
@@ -115,7 +122,9 @@ def _is_cryptography_key_consistent(key, key_public_data, key_private_data, warn
if result is not None:
return result
try:
signature = key.sign(SIGNATURE_TEST_DATA, cryptography.hazmat.primitives.hashes.SHA256())
signature = key.sign(
SIGNATURE_TEST_DATA, cryptography.hazmat.primitives.hashes.SHA256()
)
except AttributeError:
# sign() was added in cryptography 1.5, but we support older versions
return None
@@ -123,16 +132,20 @@ def _is_cryptography_key_consistent(key, key_public_data, key_private_data, warn
key.public_key().verify(
signature,
SIGNATURE_TEST_DATA,
cryptography.hazmat.primitives.hashes.SHA256()
cryptography.hazmat.primitives.hashes.SHA256(),
)
return True
except cryptography.exceptions.InvalidSignature:
return False
if isinstance(key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey):
if isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
):
try:
signature = key.sign(
SIGNATURE_TEST_DATA,
cryptography.hazmat.primitives.asymmetric.ec.ECDSA(cryptography.hazmat.primitives.hashes.SHA256())
cryptography.hazmat.primitives.asymmetric.ec.ECDSA(
cryptography.hazmat.primitives.hashes.SHA256()
),
)
except AttributeError:
# sign() was added in cryptography 1.5, but we support older versions
@@ -141,15 +154,21 @@ def _is_cryptography_key_consistent(key, key_public_data, key_private_data, warn
key.public_key().verify(
signature,
SIGNATURE_TEST_DATA,
cryptography.hazmat.primitives.asymmetric.ec.ECDSA(cryptography.hazmat.primitives.hashes.SHA256())
cryptography.hazmat.primitives.asymmetric.ec.ECDSA(
cryptography.hazmat.primitives.hashes.SHA256()
),
)
return True
except cryptography.exceptions.InvalidSignature:
return False
has_simple_sign_function = False
if CRYPTOGRAPHY_HAS_ED25519 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey):
if CRYPTOGRAPHY_HAS_ED25519 and isinstance(
key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey
):
has_simple_sign_function = True
if CRYPTOGRAPHY_HAS_ED448 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey):
if CRYPTOGRAPHY_HAS_ED448 and isinstance(
key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey
):
has_simple_sign_function = True
if has_simple_sign_function:
signature = key.sign(SIGNATURE_TEST_DATA)
@@ -160,7 +179,7 @@ def _is_cryptography_key_consistent(key, key_public_data, key_private_data, warn
return False
# For X25519 and X448, there's no test yet.
if warn_func is not None:
warn_func('Cannot determine consistency for key of type %s' % type(key))
warn_func("Cannot determine consistency for key of type %s" % type(key))
return None
@@ -180,7 +199,15 @@ class PrivateKeyParseError(OpenSSLObjectError):
@six.add_metaclass(abc.ABCMeta)
class PrivateKeyInfoRetrieval(object):
def __init__(self, module, backend, content, passphrase=None, return_private_key_data=False, check_consistency=False):
def __init__(
self,
module,
backend,
content,
passphrase=None,
return_private_key_data=False,
check_consistency=False,
):
# content must be a bytes string
self.module = module
self.backend = backend
@@ -211,28 +238,38 @@ class PrivateKeyInfoRetrieval(object):
self.key = load_privatekey(
path=None,
content=priv_key_detail,
passphrase=to_bytes(self.passphrase) if self.passphrase is not None else self.passphrase,
backend=self.backend
passphrase=(
to_bytes(self.passphrase)
if self.passphrase is not None
else self.passphrase
),
backend=self.backend,
)
result['can_parse_key'] = True
result["can_parse_key"] = True
except OpenSSLObjectError as exc:
raise PrivateKeyParseError(to_native(exc), result)
result['public_key'] = to_native(self._get_public_key(binary=False))
result["public_key"] = to_native(self._get_public_key(binary=False))
pk = self._get_public_key(binary=True)
result['public_key_fingerprints'] = get_fingerprint_of_bytes(
pk, prefer_one=prefer_one_fingerprint) if pk is not None else dict()
result["public_key_fingerprints"] = (
get_fingerprint_of_bytes(pk, prefer_one=prefer_one_fingerprint)
if pk is not None
else dict()
)
key_type, key_public_data, key_private_data = self._get_key_info(
need_private_key_data=self.return_private_key_data or self.check_consistency)
result['type'] = key_type
result['public_data'] = key_public_data
need_private_key_data=self.return_private_key_data or self.check_consistency
)
result["type"] = key_type
result["public_data"] = key_public_data
if self.return_private_key_data:
result['private_data'] = key_private_data
result["private_data"] = key_private_data
if self.check_consistency:
result['key_is_consistent'] = self._is_key_consistent(key_public_data, key_private_data)
if result['key_is_consistent'] is False:
result["key_is_consistent"] = self._is_key_consistent(
key_public_data, key_private_data
)
if result["key_is_consistent"] is False:
# Only fail when it is False, to avoid to fail on None (which means "we do not know")
msg = (
"Private key is not consistent! (See "
@@ -244,48 +281,88 @@ class PrivateKeyInfoRetrieval(object):
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
"""Validate the supplied private key, using the cryptography backend"""
def __init__(self, module, content, **kwargs):
super(PrivateKeyInfoRetrievalCryptography, self).__init__(module, 'cryptography', content, **kwargs)
super(PrivateKeyInfoRetrievalCryptography, self).__init__(
module, "cryptography", content, **kwargs
)
def _get_public_key(self, binary):
return self.key.public_key().public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(self, need_private_key_data=False):
return _get_cryptography_private_key_info(self.key, need_private_key_data=need_private_key_data)
return _get_cryptography_private_key_info(
self.key, need_private_key_data=need_private_key_data
)
def _is_key_consistent(self, key_public_data, key_private_data):
return _is_cryptography_key_consistent(self.key, key_public_data, key_private_data, warn_func=self.module.warn)
return _is_cryptography_key_consistent(
self.key, key_public_data, key_private_data, warn_func=self.module.warn
)
def get_privatekey_info(module, backend, content, passphrase=None, return_private_key_data=False, prefer_one_fingerprint=False):
if backend == 'cryptography':
def get_privatekey_info(
module,
backend,
content,
passphrase=None,
return_private_key_data=False,
prefer_one_fingerprint=False,
):
if backend == "cryptography":
info = PrivateKeyInfoRetrievalCryptography(
module, content, passphrase=passphrase, return_private_key_data=return_private_key_data)
module,
content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(module, backend, content, passphrase=None, return_private_key_data=False, check_consistency=False):
if backend == 'auto':
def select_backend(
module,
backend,
content,
passphrase=None,
return_private_key_data=False,
check_consistency=False,
):
if backend == "auto":
# Detection what is possible
can_use_cryptography = CRYPTOGRAPHY_FOUND and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
can_use_cryptography = (
CRYPTOGRAPHY_FOUND
and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
)
# Try cryptography
if can_use_cryptography:
backend = 'cryptography'
backend = "cryptography"
# Success?
if backend == 'auto':
module.fail_json(msg=("Cannot detect the required Python library "
"cryptography (>= {0})").format(MINIMAL_CRYPTOGRAPHY_VERSION))
if backend == "auto":
module.fail_json(
msg=(
"Cannot detect the required Python library " "cryptography (>= {0})"
).format(MINIMAL_CRYPTOGRAPHY_VERSION)
)
if backend == 'cryptography':
if backend == "cryptography":
if not CRYPTOGRAPHY_FOUND:
module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
exception=CRYPTOGRAPHY_IMP_ERR)
module.fail_json(
msg=missing_required_lib(
"cryptography >= {0}".format(MINIMAL_CRYPTOGRAPHY_VERSION)
),
exception=CRYPTOGRAPHY_IMP_ERR,
)
return backend, PrivateKeyInfoRetrievalCryptography(
module, content, passphrase=passphrase, return_private_key_data=return_private_key_data, check_consistency=check_consistency)
module,
content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
check_consistency=check_consistency,
)
else:
raise ValueError('Unsupported value for backend: {0}'.format(backend))
raise ValueError("Unsupported value for backend: {0}".format(backend))

View File

@@ -32,12 +32,13 @@ from ansible_collections.community.crypto.plugins.module_utils.version import (
)
MINIMAL_CRYPTOGRAPHY_VERSION = '1.2.3'
MINIMAL_CRYPTOGRAPHY_VERSION = "1.2.3"
CRYPTOGRAPHY_IMP_ERR = None
try:
import cryptography
from cryptography.hazmat.primitives import serialization
CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
except ImportError:
CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
@@ -49,37 +50,47 @@ else:
def _get_cryptography_public_key_info(key):
key_public_data = dict()
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey):
key_type = 'RSA'
key_type = "RSA"
public_numbers = key.public_numbers()
key_public_data['size'] = key.key_size
key_public_data['modulus'] = public_numbers.n
key_public_data['exponent'] = public_numbers.e
key_public_data["size"] = key.key_size
key_public_data["modulus"] = public_numbers.n
key_public_data["exponent"] = public_numbers.e
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey):
key_type = 'DSA'
key_type = "DSA"
parameter_numbers = key.parameters().parameter_numbers()
public_numbers = key.public_numbers()
key_public_data['size'] = key.key_size
key_public_data['p'] = parameter_numbers.p
key_public_data['q'] = parameter_numbers.q
key_public_data['g'] = parameter_numbers.g
key_public_data['y'] = public_numbers.y
elif CRYPTOGRAPHY_HAS_X25519 and isinstance(key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey):
key_type = 'X25519'
elif CRYPTOGRAPHY_HAS_X448 and isinstance(key, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey):
key_type = 'X448'
elif CRYPTOGRAPHY_HAS_ED25519 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey):
key_type = 'Ed25519'
elif CRYPTOGRAPHY_HAS_ED448 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey):
key_type = 'Ed448'
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey):
key_type = 'ECC'
key_public_data["size"] = key.key_size
key_public_data["p"] = parameter_numbers.p
key_public_data["q"] = parameter_numbers.q
key_public_data["g"] = parameter_numbers.g
key_public_data["y"] = public_numbers.y
elif CRYPTOGRAPHY_HAS_X25519 and isinstance(
key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey
):
key_type = "X25519"
elif CRYPTOGRAPHY_HAS_X448 and isinstance(
key, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey
):
key_type = "X448"
elif CRYPTOGRAPHY_HAS_ED25519 and isinstance(
key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey
):
key_type = "Ed25519"
elif CRYPTOGRAPHY_HAS_ED448 and isinstance(
key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey
):
key_type = "Ed448"
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey
):
key_type = "ECC"
public_numbers = key.public_numbers()
key_public_data['curve'] = key.curve.name
key_public_data['x'] = public_numbers.x
key_public_data['y'] = public_numbers.y
key_public_data['exponent_size'] = key.curve.key_size
key_public_data["curve"] = key.curve.name
key_public_data["x"] = public_numbers.x
key_public_data["y"] = public_numbers.y
key_public_data["exponent_size"] = key.curve.key_size
else:
key_type = 'unknown ({0})'.format(type(key))
key_type = "unknown ({0})".format(type(key))
return key_type, key_public_data
@@ -116,54 +127,75 @@ class PublicKeyInfoRetrieval(object):
raise PublicKeyParseError(to_native(e), {})
pk = self._get_public_key(binary=True)
result['fingerprints'] = get_fingerprint_of_bytes(
pk, prefer_one=prefer_one_fingerprint) if pk is not None else dict()
result["fingerprints"] = (
get_fingerprint_of_bytes(pk, prefer_one=prefer_one_fingerprint)
if pk is not None
else dict()
)
key_type, key_public_data = self._get_key_info()
result['type'] = key_type
result['public_data'] = key_public_data
result["type"] = key_type
result["public_data"] = key_public_data
return result
class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
"""Validate the supplied public key, using the cryptography backend"""
def __init__(self, module, content=None, key=None):
super(PublicKeyInfoRetrievalCryptography, self).__init__(module, 'cryptography', content=content, key=key)
super(PublicKeyInfoRetrievalCryptography, self).__init__(
module, "cryptography", content=content, key=key
)
def _get_public_key(self, binary):
return self.key.public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(self):
return _get_cryptography_public_key_info(self.key)
def get_publickey_info(module, backend, content=None, key=None, prefer_one_fingerprint=False):
if backend == 'cryptography':
def get_publickey_info(
module, backend, content=None, key=None, prefer_one_fingerprint=False
):
if backend == "cryptography":
info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(module, backend, content=None, key=None):
if backend == 'auto':
if backend == "auto":
# Detection what is possible
can_use_cryptography = CRYPTOGRAPHY_FOUND and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
can_use_cryptography = (
CRYPTOGRAPHY_FOUND
and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
)
# Try cryptography
if can_use_cryptography:
backend = 'cryptography'
backend = "cryptography"
# Success?
if backend == 'auto':
module.fail_json(msg=("Cannot detect any of the required Python libraries "
"cryptography (>= {0})").format(MINIMAL_CRYPTOGRAPHY_VERSION))
if backend == "auto":
module.fail_json(
msg=(
"Cannot detect any of the required Python libraries "
"cryptography (>= {0})"
).format(MINIMAL_CRYPTOGRAPHY_VERSION)
)
if backend == 'cryptography':
if backend == "cryptography":
if not CRYPTOGRAPHY_FOUND:
module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
exception=CRYPTOGRAPHY_IMP_ERR)
return backend, PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
module.fail_json(
msg=missing_required_lib(
"cryptography >= {0}".format(MINIMAL_CRYPTOGRAPHY_VERSION)
),
exception=CRYPTOGRAPHY_IMP_ERR,
)
return backend, PublicKeyInfoRetrievalCryptography(
module, content=content, key=key
)
else:
raise ValueError('Unsupported value for backend: {0}'.format(backend))
raise ValueError("Unsupported value for backend: {0}".format(backend))

View File

@@ -10,29 +10,33 @@ from __future__ import absolute_import, division, print_function
__metaclass__ = type
PEM_START = '-----BEGIN '
PEM_END_START = '-----END '
PEM_END = '-----'
PKCS8_PRIVATEKEY_NAMES = ('PRIVATE KEY', 'ENCRYPTED PRIVATE KEY')
PKCS1_PRIVATEKEY_SUFFIX = ' PRIVATE KEY'
PEM_START = "-----BEGIN "
PEM_END_START = "-----END "
PEM_END = "-----"
PKCS8_PRIVATEKEY_NAMES = ("PRIVATE KEY", "ENCRYPTED PRIVATE KEY")
PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY"
def identify_pem_format(content, encoding='utf-8'):
'''Given the contents of a binary file, tests whether this could be a PEM file.'''
def identify_pem_format(content, encoding="utf-8"):
"""Given the contents of a binary file, tests whether this could be a PEM file."""
try:
first_pem = extract_first_pem(content.decode(encoding))
if first_pem is None:
return False
lines = first_pem.splitlines(False)
if lines[0].startswith(PEM_START) and lines[0].endswith(PEM_END) and len(lines[0]) > len(PEM_START) + len(PEM_END):
if (
lines[0].startswith(PEM_START)
and lines[0].endswith(PEM_END)
and len(lines[0]) > len(PEM_START) + len(PEM_END)
):
return True
except UnicodeDecodeError:
pass
return False
def identify_private_key_format(content, encoding='utf-8'):
'''Given the contents of a private key file, identifies its format.'''
def identify_private_key_format(content, encoding="utf-8"):
"""Given the contents of a private key file, identifies its format."""
# See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85
# (PEM_read_bio_PrivateKey)
# and https://github.com/openssl/openssl/blob/master/include/openssl/pem.h#L46-L47
@@ -40,42 +44,48 @@ def identify_private_key_format(content, encoding='utf-8'):
try:
first_pem = extract_first_pem(content.decode(encoding))
if first_pem is None:
return 'raw'
return "raw"
lines = first_pem.splitlines(False)
if lines[0].startswith(PEM_START) and lines[0].endswith(PEM_END) and len(lines[0]) > len(PEM_START) + len(PEM_END):
name = lines[0][len(PEM_START):-len(PEM_END)]
if (
lines[0].startswith(PEM_START)
and lines[0].endswith(PEM_END)
and len(lines[0]) > len(PEM_START) + len(PEM_END)
):
name = lines[0][len(PEM_START) : -len(PEM_END)]
if name in PKCS8_PRIVATEKEY_NAMES:
return 'pkcs8'
if len(name) > len(PKCS1_PRIVATEKEY_SUFFIX) and name.endswith(PKCS1_PRIVATEKEY_SUFFIX):
return 'pkcs1'
return 'unknown-pem'
return "pkcs8"
if len(name) > len(PKCS1_PRIVATEKEY_SUFFIX) and name.endswith(
PKCS1_PRIVATEKEY_SUFFIX
):
return "pkcs1"
return "unknown-pem"
except UnicodeDecodeError:
pass
return 'raw'
return "raw"
def split_pem_list(text, keep_inbetween=False):
'''
"""
Split concatenated PEM objects into a list of strings, where each is one PEM object.
'''
"""
result = []
current = [] if keep_inbetween else None
for line in text.splitlines(True):
if line.strip():
if not keep_inbetween and line.startswith('-----BEGIN '):
if not keep_inbetween and line.startswith("-----BEGIN "):
current = []
if current is not None:
current.append(line)
if line.startswith('-----END '):
result.append(''.join(current))
if line.startswith("-----END "):
result.append("".join(current))
current = [] if keep_inbetween else None
return result
def extract_first_pem(text):
'''
"""
Given one PEM or multiple concatenated PEM objects, return only the first one, or None if there is none.
'''
"""
all_pems = split_pem_list(text)
if not all_pems:
return None
@@ -87,24 +97,42 @@ def _extract_type(line, start=PEM_START):
return None
if not line.endswith(PEM_END):
return None
return line[len(start):-len(PEM_END)]
return line[len(start) : -len(PEM_END)]
def extract_pem(content, strict=False):
lines = content.splitlines()
if len(lines) < 3:
raise ValueError('PEM must have at least 3 lines, have only {count}'.format(count=len(lines)))
raise ValueError(
"PEM must have at least 3 lines, have only {count}".format(count=len(lines))
)
header_type = _extract_type(lines[0])
if header_type is None:
raise ValueError('First line is not of format {start}...{end}: {line!r}'.format(start=PEM_START, end=PEM_END, line=lines[0]))
raise ValueError(
"First line is not of format {start}...{end}: {line!r}".format(
start=PEM_START, end=PEM_END, line=lines[0]
)
)
footer_type = _extract_type(lines[-1], start=PEM_END_START)
if strict:
if header_type != footer_type:
raise ValueError('Header type ({header}) is different from footer type ({footer})'.format(header=header_type, footer=footer_type))
raise ValueError(
"Header type ({header}) is different from footer type ({footer})".format(
header=header_type, footer=footer_type
)
)
for idx, line in enumerate(lines[1:-2]):
if len(line) != 64:
raise ValueError('Line {idx} has length {len} instead of 64'.format(idx=idx, len=len(line)))
raise ValueError(
"Line {idx} has length {len} instead of 64".format(
idx=idx, len=len(line)
)
)
if not (0 < len(lines[-2]) <= 64):
raise ValueError('Last line has length {len}, should be in (0, 64]'.format(len=len(lines[-2])))
raise ValueError(
"Last line has length {len}, should be in (0, 64]".format(
len=len(lines[-2])
)
)
content = lines[1:-1]
return header_type, ''.join(content)
return header_type, "".join(content)

View File

@@ -32,6 +32,7 @@ from ansible_collections.community.crypto.plugins.module_utils.time import ( #
try:
from OpenSSL import crypto
HAS_PYOPENSSL = True
except (ImportError, AttributeError):
# Error handled in the calling module.
@@ -52,7 +53,14 @@ from .basic import OpenSSLBadPassphraseError, OpenSSLObjectError
# This list of preferred fingerprints is used when prefer_one=True is supplied to the
# fingerprinting methods.
PREFERRED_FINGERPRINTS = (
'sha256', 'sha3_256', 'sha512', 'sha3_512', 'sha384', 'sha3_384', 'sha1', 'md5'
"sha256",
"sha3_256",
"sha512",
"sha3_512",
"sha384",
"sha3_384",
"sha1",
"md5",
)
@@ -71,8 +79,16 @@ def get_fingerprint_of_bytes(source, prefer_one=False):
if prefer_one:
# Sort algorithms to have the ones in PREFERRED_FINGERPRINTS at the beginning
prefered_algorithms = [algorithm for algorithm in PREFERRED_FINGERPRINTS if algorithm in algorithms]
prefered_algorithms += sorted([algorithm for algorithm in algorithms if algorithm not in PREFERRED_FINGERPRINTS])
prefered_algorithms = [
algorithm for algorithm in PREFERRED_FINGERPRINTS if algorithm in algorithms
]
prefered_algorithms += sorted(
[
algorithm
for algorithm in algorithms
if algorithm not in PREFERRED_FINGERPRINTS
]
)
algorithms = prefered_algorithms
for algo in algorithms:
@@ -88,34 +104,47 @@ def get_fingerprint_of_bytes(source, prefer_one=False):
pubkey_digest = h.hexdigest()
except TypeError:
pubkey_digest = h.hexdigest(32)
fingerprint[algo] = ':'.join(pubkey_digest[i:i + 2] for i in range(0, len(pubkey_digest), 2))
fingerprint[algo] = ":".join(
pubkey_digest[i : i + 2] for i in range(0, len(pubkey_digest), 2)
)
if prefer_one:
break
return fingerprint
def get_fingerprint_of_privatekey(privatekey, backend='cryptography', prefer_one=False):
"""Generate the fingerprint of the public key. """
def get_fingerprint_of_privatekey(privatekey, backend="cryptography", prefer_one=False):
"""Generate the fingerprint of the public key."""
if backend == 'cryptography':
if backend == "cryptography":
publickey = privatekey.public_key().public_bytes(
serialization.Encoding.DER,
serialization.PublicFormat.SubjectPublicKeyInfo
serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo
)
return get_fingerprint_of_bytes(publickey, prefer_one=prefer_one)
def get_fingerprint(path, passphrase=None, content=None, backend='cryptography', prefer_one=False):
"""Generate the fingerprint of the public key. """
def get_fingerprint(
path, passphrase=None, content=None, backend="cryptography", prefer_one=False
):
"""Generate the fingerprint of the public key."""
privatekey = load_privatekey(path, passphrase=passphrase, content=content, check_passphrase=False, backend=backend)
privatekey = load_privatekey(
path,
passphrase=passphrase,
content=content,
check_passphrase=False,
backend=backend,
)
return get_fingerprint_of_privatekey(privatekey, backend=backend, prefer_one=prefer_one)
return get_fingerprint_of_privatekey(
privatekey, backend=backend, prefer_one=prefer_one
)
def load_privatekey(path, passphrase=None, check_passphrase=True, content=None, backend='cryptography'):
def load_privatekey(
path, passphrase=None, check_passphrase=True, content=None, backend="cryptography"
):
"""Load the specified OpenSSL private key.
The content can also be specified via content; in that case,
@@ -124,58 +153,72 @@ def load_privatekey(path, passphrase=None, check_passphrase=True, content=None,
try:
if content is None:
with open(path, 'rb') as b_priv_key_fh:
with open(path, "rb") as b_priv_key_fh:
priv_key_detail = b_priv_key_fh.read()
else:
priv_key_detail = content
except (IOError, OSError) as exc:
raise OpenSSLObjectError(exc)
if backend == 'pyopenssl':
if backend == "pyopenssl":
# First try: try to load with real passphrase (resp. empty string)
# Will work if this is the correct passphrase, or the key is not
# password-protected.
try:
result = crypto.load_privatekey(crypto.FILETYPE_PEM,
priv_key_detail,
to_bytes(passphrase or ''))
result = crypto.load_privatekey(
crypto.FILETYPE_PEM, priv_key_detail, to_bytes(passphrase or "")
)
except crypto.Error as e:
if len(e.args) > 0 and len(e.args[0]) > 0:
if e.args[0][0][2] in ('bad decrypt', 'bad password read'):
if e.args[0][0][2] in ("bad decrypt", "bad password read"):
# This happens in case we have the wrong passphrase.
if passphrase is not None:
raise OpenSSLBadPassphraseError('Wrong passphrase provided for private key!')
raise OpenSSLBadPassphraseError(
"Wrong passphrase provided for private key!"
)
else:
raise OpenSSLBadPassphraseError('No passphrase provided, but private key is password-protected!')
raise OpenSSLObjectError('Error while deserializing key: {0}'.format(e))
raise OpenSSLBadPassphraseError(
"No passphrase provided, but private key is password-protected!"
)
raise OpenSSLObjectError("Error while deserializing key: {0}".format(e))
if check_passphrase:
# Next we want to make sure that the key is actually protected by
# a passphrase (in case we did try the empty string before, make
# sure that the key is not protected by the empty string)
try:
crypto.load_privatekey(crypto.FILETYPE_PEM,
priv_key_detail,
to_bytes('y' if passphrase == 'x' else 'x'))
crypto.load_privatekey(
crypto.FILETYPE_PEM,
priv_key_detail,
to_bytes("y" if passphrase == "x" else "x"),
)
if passphrase is not None:
# Since we can load the key without an exception, the
# key is not password-protected
raise OpenSSLBadPassphraseError('Passphrase provided, but private key is not password-protected!')
raise OpenSSLBadPassphraseError(
"Passphrase provided, but private key is not password-protected!"
)
except crypto.Error as e:
if passphrase is None and len(e.args) > 0 and len(e.args[0]) > 0:
if e.args[0][0][2] in ('bad decrypt', 'bad password read'):
if e.args[0][0][2] in ("bad decrypt", "bad password read"):
# The key is obviously protected by the empty string.
# Do not do this at home (if it is possible at all)...
raise OpenSSLBadPassphraseError('No passphrase provided, but private key is password-protected!')
elif backend == 'cryptography':
raise OpenSSLBadPassphraseError(
"No passphrase provided, but private key is password-protected!"
)
elif backend == "cryptography":
try:
result = load_pem_private_key(priv_key_detail,
None if passphrase is None else to_bytes(passphrase),
cryptography_backend())
result = load_pem_private_key(
priv_key_detail,
None if passphrase is None else to_bytes(passphrase),
cryptography_backend(),
)
except TypeError:
raise OpenSSLBadPassphraseError('Wrong or empty passphrase provided for private key')
raise OpenSSLBadPassphraseError(
"Wrong or empty passphrase provided for private key"
)
except ValueError:
raise OpenSSLBadPassphraseError('Wrong passphrase provided for private key')
raise OpenSSLBadPassphraseError("Wrong passphrase provided for private key")
return result
@@ -183,60 +226,72 @@ def load_privatekey(path, passphrase=None, check_passphrase=True, content=None,
def load_publickey(path=None, content=None, backend=None):
if content is None:
if path is None:
raise OpenSSLObjectError('Must provide either path or content')
raise OpenSSLObjectError("Must provide either path or content")
try:
with open(path, 'rb') as b_priv_key_fh:
with open(path, "rb") as b_priv_key_fh:
content = b_priv_key_fh.read()
except (IOError, OSError) as exc:
raise OpenSSLObjectError(exc)
if backend == 'cryptography':
if backend == "cryptography":
try:
return serialization.load_pem_public_key(content, backend=cryptography_backend())
return serialization.load_pem_public_key(
content, backend=cryptography_backend()
)
except Exception as e:
raise OpenSSLObjectError('Error while deserializing key: {0}'.format(e))
raise OpenSSLObjectError("Error while deserializing key: {0}".format(e))
def load_certificate(path, content=None, backend='cryptography', der_support_enabled=False):
def load_certificate(
path, content=None, backend="cryptography", der_support_enabled=False
):
"""Load the specified certificate."""
try:
if content is None:
with open(path, 'rb') as cert_fh:
with open(path, "rb") as cert_fh:
cert_content = cert_fh.read()
else:
cert_content = content
except (IOError, OSError) as exc:
raise OpenSSLObjectError(exc)
if backend == 'pyopenssl':
if backend == "pyopenssl":
if der_support_enabled is False or identify_pem_format(cert_content):
return crypto.load_certificate(crypto.FILETYPE_PEM, cert_content)
elif der_support_enabled:
raise OpenSSLObjectError('Certificate in DER format is not supported by the pyopenssl backend.')
elif backend == 'cryptography':
raise OpenSSLObjectError(
"Certificate in DER format is not supported by the pyopenssl backend."
)
elif backend == "cryptography":
if der_support_enabled is False or identify_pem_format(cert_content):
try:
return x509.load_pem_x509_certificate(cert_content, cryptography_backend())
return x509.load_pem_x509_certificate(
cert_content, cryptography_backend()
)
except ValueError as exc:
raise OpenSSLObjectError(exc)
elif der_support_enabled:
try:
return x509.load_der_x509_certificate(cert_content, cryptography_backend())
return x509.load_der_x509_certificate(
cert_content, cryptography_backend()
)
except ValueError as exc:
raise OpenSSLObjectError('Cannot parse DER certificate: {0}'.format(exc))
raise OpenSSLObjectError(
"Cannot parse DER certificate: {0}".format(exc)
)
def load_certificate_request(path, content=None, backend='cryptography'):
def load_certificate_request(path, content=None, backend="cryptography"):
"""Load the specified certificate signing request."""
try:
if content is None:
with open(path, 'rb') as csr_fh:
with open(path, "rb") as csr_fh:
csr_content = csr_fh.read()
else:
csr_content = content
except (IOError, OSError) as exc:
raise OpenSSLObjectError(exc)
if backend == 'cryptography':
if backend == "cryptography":
try:
return x509.load_pem_x509_csr(csr_content, cryptography_backend())
except ValueError as exc:
@@ -245,23 +300,40 @@ def load_certificate_request(path, content=None, backend='cryptography'):
def parse_name_field(input_dict, name_field_name=None):
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
error_str = '{key}' if name_field_name is None else '{key} in {name}'
error_str = "{key}" if name_field_name is None else "{key} in {name}"
result = []
for key, value in input_dict.items():
if isinstance(value, list):
for entry in value:
if not isinstance(entry, six.string_types):
raise TypeError(('Values %s must be strings' % error_str).format(key=key, name=name_field_name))
raise TypeError(
("Values %s must be strings" % error_str).format(
key=key, name=name_field_name
)
)
if not entry:
raise ValueError(('Values for %s must not be empty strings' % error_str).format(key=key))
raise ValueError(
("Values for %s must not be empty strings" % error_str).format(
key=key
)
)
result.append((key, entry))
elif isinstance(value, six.string_types):
if not value:
raise ValueError(('Value for %s must not be an empty string' % error_str).format(key=key))
raise ValueError(
("Value for %s must not be an empty string" % error_str).format(
key=key
)
)
result.append((key, value))
else:
raise TypeError(('Value for %s must be either a string or a list of strings' % error_str).format(key=key))
raise TypeError(
(
"Value for %s must be either a string or a list of strings"
% error_str
).format(key=key)
)
return result
@@ -272,28 +344,32 @@ def parse_ordered_name_field(input_list, name_field_name):
for index, entry in enumerate(input_list):
if len(entry) != 1:
raise ValueError(
'Entry #{index} in {name} must be a dictionary with exactly one key-value pair'.format(
name=name_field_name, index=index + 1))
"Entry #{index} in {name} must be a dictionary with exactly one key-value pair".format(
name=name_field_name, index=index + 1
)
)
try:
result.extend(parse_name_field(entry, name_field_name=name_field_name))
except (TypeError, ValueError) as exc:
raise ValueError(
'Error while processing entry #{index} in {name}: {error}'.format(
name=name_field_name, index=index + 1, error=exc))
"Error while processing entry #{index} in {name}: {error}".format(
name=name_field_name, index=index + 1, error=exc
)
)
return result
def select_message_digest(digest_string):
digest = None
if digest_string == 'sha256':
if digest_string == "sha256":
digest = hashes.SHA256()
elif digest_string == 'sha384':
elif digest_string == "sha384":
digest = hashes.SHA384()
elif digest_string == 'sha512':
elif digest_string == "sha512":
digest = hashes.SHA512()
elif digest_string == 'sha1':
elif digest_string == "sha1":
digest = hashes.SHA1()
elif digest_string == 'md5':
elif digest_string == "md5":
digest = hashes.MD5()
return digest
@@ -317,7 +393,7 @@ class OpenSSLObject(object):
def _check_perms(module):
file_args = module.load_file_common_arguments(module.params)
if module.check_file_absent_if_check_mode(file_args['path']):
if module.check_file_absent_if_check_mode(file_args["path"]):
return False
return not module.set_fs_attributes_if_different(file_args, False)

View File

@@ -41,22 +41,25 @@ valid_file_format = re.compile(r".*(\.)(yml|yaml|json)$")
def ecs_client_argument_spec():
return dict(
entrust_api_user=dict(type='str', required=True),
entrust_api_key=dict(type='str', required=True, no_log=True),
entrust_api_client_cert_path=dict(type='path', required=True),
entrust_api_client_cert_key_path=dict(type='path', required=True, no_log=True),
entrust_api_specification_path=dict(type='path', default='https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml'),
entrust_api_user=dict(type="str", required=True),
entrust_api_key=dict(type="str", required=True, no_log=True),
entrust_api_client_cert_path=dict(type="path", required=True),
entrust_api_client_cert_key_path=dict(type="path", required=True, no_log=True),
entrust_api_specification_path=dict(
type="path",
default="https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml",
),
)
class SessionConfigurationException(Exception):
""" Raised if we cannot configure a session with the API """
"""Raised if we cannot configure a session with the API"""
pass
class RestOperationException(Exception):
""" Encapsulate a REST API error """
"""Encapsulate a REST API error"""
def __init__(self, error):
self.status = to_native(error.get("status", None))
@@ -106,7 +109,12 @@ class RestOperation(object):
self.parameters = {}
else:
self.parameters = parameters
self.url = "{scheme}://{host}{base_path}{uri}".format(scheme="https", host=session._spec.get("host"), base_path=session._spec.get("basePath"), uri=uri)
self.url = "{scheme}://{host}{base_path}{uri}".format(
scheme="https",
host=session._spec.get("host"),
base_path=session._spec.get("basePath"),
uri=uri,
)
def restmethod(self, *args, **kwargs):
"""Do the hard work of making the request here"""
@@ -145,7 +153,9 @@ class RestOperation(object):
try:
if body_parameters:
body_parameters_json = json.dumps(body_parameters)
response = self.session.request.open(method=self.method, url=url, data=body_parameters_json)
response = self.session.request.open(
method=self.method, url=url, data=body_parameters_json
)
else:
response = self.session.request.open(method=self.method, url=url)
except HTTPError as e:
@@ -167,11 +177,13 @@ class RestOperation(object):
raise RestOperationException(result)
# Raise a generic RestOperationException if this fails
raise RestOperationException({"status": result_code, "errors": [{"message": "REST Operation Failed"}]})
raise RestOperationException(
{"status": result_code, "errors": [{"message": "REST Operation Failed"}]}
)
class Resource(object):
""" Implement basic CRUD operations against a path. """
"""Implement basic CRUD operations against a path."""
def __init__(self, session):
self.session = session
@@ -196,13 +208,20 @@ class Resource(object):
elif method.lower() == "patch":
operation_name = "Patch"
else:
raise SessionConfigurationException(to_native("Invalid REST method type {0}".format(method)))
raise SessionConfigurationException(
to_native("Invalid REST method type {0}".format(method))
)
# Get the non-parameter parts of the URL and append to the operation name
# e.g /application/version -> GetApplicationVersion
# e.g. /application/{id} -> GetApplication
# This may lead to duplicates, which we must prevent.
operation_name += re.sub(r"{(.*)}", "", url).replace("/", " ").title().replace(" ", "")
operation_name += (
re.sub(r"{(.*)}", "", url)
.replace("/", " ")
.title()
.replace(" ", "")
)
operation_spec["operationId"] = operation_name
op = RestOperation(session, url, method, parameters)
@@ -244,7 +263,9 @@ class ECSSession(object):
self.request.url_username = entrust_api_user
self.request.url_password = entrust_api_key
else:
raise SessionConfigurationException(to_native("User and key must be provided."))
raise SessionConfigurationException(
to_native("User and key must be provided.")
)
# set up client certificate if passed (support all-in one or cert + key)
entrust_api_cert = self.get_config("entrust_api_cert")
@@ -254,45 +275,78 @@ class ECSSession(object):
if entrust_api_cert_key:
self.request.client_key = entrust_api_cert_key
else:
raise SessionConfigurationException(to_native("Client certificate for authentication to the API must be provided."))
raise SessionConfigurationException(
to_native(
"Client certificate for authentication to the API must be provided."
)
)
# set up the spec
entrust_api_specification_path = self.get_config("entrust_api_specification_path")
entrust_api_specification_path = self.get_config(
"entrust_api_specification_path"
)
if not entrust_api_specification_path.startswith("http") and not os.path.isfile(entrust_api_specification_path):
raise SessionConfigurationException(to_native("OpenAPI specification was not found at location {0}.".format(entrust_api_specification_path)))
if not entrust_api_specification_path.startswith("http") and not os.path.isfile(
entrust_api_specification_path
):
raise SessionConfigurationException(
to_native(
"OpenAPI specification was not found at location {0}.".format(
entrust_api_specification_path
)
)
)
if not valid_file_format.match(entrust_api_specification_path):
raise SessionConfigurationException(to_native("OpenAPI specification filename must end in .json, .yml or .yaml"))
raise SessionConfigurationException(
to_native(
"OpenAPI specification filename must end in .json, .yml or .yaml"
)
)
self.verify = True
if entrust_api_specification_path.startswith("http"):
try:
http_response = Request().open(method="GET", url=entrust_api_specification_path)
http_response = Request().open(
method="GET", url=entrust_api_specification_path
)
http_response_contents = http_response.read()
if entrust_api_specification_path.endswith(".json"):
self._spec = json.load(http_response_contents)
elif entrust_api_specification_path.endswith(".yml") or entrust_api_specification_path.endswith(".yaml"):
elif entrust_api_specification_path.endswith(
".yml"
) or entrust_api_specification_path.endswith(".yaml"):
self._spec = yaml.safe_load(http_response_contents)
except HTTPError as e:
raise SessionConfigurationException(to_native("Error downloading specification from address '{0}', received error code '{1}'".format(
entrust_api_specification_path, e.getcode())))
raise SessionConfigurationException(
to_native(
"Error downloading specification from address '{0}', received error code '{1}'".format(
entrust_api_specification_path, e.getcode()
)
)
)
else:
with open(entrust_api_specification_path) as f:
if ".json" in entrust_api_specification_path:
self._spec = json.load(f)
elif ".yml" in entrust_api_specification_path or ".yaml" in entrust_api_specification_path:
elif (
".yml" in entrust_api_specification_path
or ".yaml" in entrust_api_specification_path
):
self._spec = yaml.safe_load(f)
def get_config(self, item):
return self._config.get(item, None)
def _read_config_vars(self, name, **kwargs):
""" Read configuration from variables passed to the module. """
"""Read configuration from variables passed to the module."""
config = {}
entrust_api_specification_path = kwargs.get("entrust_api_specification_path")
if not entrust_api_specification_path or (not entrust_api_specification_path.startswith("http") and not os.path.isfile(entrust_api_specification_path)):
if not entrust_api_specification_path or (
not entrust_api_specification_path.startswith("http")
and not os.path.isfile(entrust_api_specification_path)
):
raise SessionConfigurationException(
to_native(
"Parameter provided for entrust_api_specification_path of value '{0}' was not a valid file path or HTTPS address.".format(
@@ -305,30 +359,50 @@ class ECSSession(object):
file_path = kwargs.get(required_file)
if not file_path or not os.path.isfile(file_path):
raise SessionConfigurationException(
to_native("Parameter provided for {0} of value '{1}' was not a valid file path.".format(required_file, file_path))
to_native(
"Parameter provided for {0} of value '{1}' was not a valid file path.".format(
required_file, file_path
)
)
)
for required_var in ["entrust_api_user", "entrust_api_key"]:
if not kwargs.get(required_var):
raise SessionConfigurationException(to_native("Parameter provided for {0} was missing.".format(required_var)))
raise SessionConfigurationException(
to_native(
"Parameter provided for {0} was missing.".format(required_var)
)
)
config["entrust_api_cert"] = kwargs.get("entrust_api_cert")
config["entrust_api_cert_key"] = kwargs.get("entrust_api_cert_key")
config["entrust_api_specification_path"] = kwargs.get("entrust_api_specification_path")
config["entrust_api_specification_path"] = kwargs.get(
"entrust_api_specification_path"
)
config["entrust_api_user"] = kwargs.get("entrust_api_user")
config["entrust_api_key"] = kwargs.get("entrust_api_key")
return config
def ECSClient(entrust_api_user=None, entrust_api_key=None, entrust_api_cert=None, entrust_api_cert_key=None, entrust_api_specification_path=None):
def ECSClient(
entrust_api_user=None,
entrust_api_key=None,
entrust_api_cert=None,
entrust_api_cert_key=None,
entrust_api_specification_path=None,
):
"""Create an ECS client"""
if not YAML_FOUND:
raise SessionConfigurationException(missing_required_lib("PyYAML"), exception=YAML_IMP_ERR)
raise SessionConfigurationException(
missing_required_lib("PyYAML"), exception=YAML_IMP_ERR
)
if entrust_api_specification_path is None:
entrust_api_specification_path = "https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml"
entrust_api_specification_path = (
"https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml"
)
# Not functionally necessary with current uses of this module_util, but better to be explicit for future use cases
entrust_api_user = to_text(entrust_api_user)

View File

@@ -39,19 +39,32 @@ class GPGRunner(object):
def get_fingerprint_from_stdout(stdout):
lines = stdout.splitlines(False)
for line in lines:
if line.startswith('fpr:'):
parts = line.split(':')
if line.startswith("fpr:"):
parts = line.split(":")
if len(parts) <= 9 or not parts[9]:
raise GPGError('Result line "{line}" does not have fingerprint as 10th component'.format(line=line))
raise GPGError(
'Result line "{line}" does not have fingerprint as 10th component'.format(
line=line
)
)
return parts[9]
raise GPGError('Cannot extract fingerprint from stdout "{stdout}"'.format(stdout=stdout))
raise GPGError(
'Cannot extract fingerprint from stdout "{stdout}"'.format(stdout=stdout)
)
def get_fingerprint_from_file(gpg_runner, path):
if not os.path.exists(path):
raise GPGError('{path} does not exist'.format(path=path))
raise GPGError("{path} does not exist".format(path=path))
stdout = gpg_runner.run_command(
['--no-keyring', '--with-colons', '--import-options', 'show-only', '--import', path],
[
"--no-keyring",
"--with-colons",
"--import-options",
"show-only",
"--import",
path,
],
check_rc=True,
)[1]
return get_fingerprint_from_stdout(stdout)
@@ -59,7 +72,14 @@ def get_fingerprint_from_file(gpg_runner, path):
def get_fingerprint_from_bytes(gpg_runner, content):
stdout = gpg_runner.run_command(
['--no-keyring', '--with-colons', '--import-options', 'show-only', '--import', '/dev/stdin'],
[
"--no-keyring",
"--with-colons",
"--import-options",
"show-only",
"--import",
"/dev/stdin",
],
data=content,
check_rc=True,
)[1]

View File

@@ -16,28 +16,28 @@ import tempfile
def load_file(path, module=None):
'''
"""
Load the file as a bytes string.
'''
"""
try:
with open(path, 'rb') as f:
with open(path, "rb") as f:
return f.read()
except Exception as exc:
if module is None:
raise
module.fail_json('Error while loading {0} - {1}'.format(path, str(exc)))
module.fail_json("Error while loading {0} - {1}".format(path, str(exc)))
def load_file_if_exists(path, module=None, ignore_errors=False):
'''
"""
Load the file as a bytes string. If the file does not exist, ``None`` is returned.
If ``ignore_errors`` is ``True``, will ignore errors. Otherwise, errors are
raised as exceptions if ``module`` is not specified, and result in ``module.fail_json``
being called when ``module`` is specified.
'''
"""
try:
with open(path, 'rb') as f:
with open(path, "rb") as f:
return f.read()
except EnvironmentError as exc:
if exc.errno == errno.ENOENT:
@@ -46,20 +46,20 @@ def load_file_if_exists(path, module=None, ignore_errors=False):
return None
if module is None:
raise
module.fail_json('Error while loading {0} - {1}'.format(path, str(exc)))
module.fail_json("Error while loading {0} - {1}".format(path, str(exc)))
except Exception as exc:
if ignore_errors:
return None
if module is None:
raise
module.fail_json('Error while loading {0} - {1}'.format(path, str(exc)))
module.fail_json("Error while loading {0} - {1}".format(path, str(exc)))
def write_file(module, content, default_mode=None, path=None):
'''
"""
Writes content into destination file as securely as possible.
Uses file arguments from module.
'''
"""
# Find out parameters for file
try:
file_args = module.load_file_common_arguments(module.params, path=path)
@@ -68,11 +68,11 @@ def write_file(module, content, default_mode=None, path=None):
# pre-2.10 behavior of module_utils/crypto.py for older Ansible versions.
file_args = module.load_file_common_arguments(module.params)
if path is not None:
file_args['path'] = path
if file_args['mode'] is None:
file_args['mode'] = default_mode
file_args["path"] = path
if file_args["mode"] is None:
file_args["mode"] = default_mode
# Create tempfile name
tmp_fd, tmp_name = tempfile.mkstemp(prefix=b'.ansible_tmp')
tmp_fd, tmp_name = tempfile.mkstemp(prefix=b".ansible_tmp")
try:
os.close(tmp_fd)
except Exception:
@@ -89,18 +89,22 @@ def write_file(module, content, default_mode=None, path=None):
os.remove(tmp_name)
except Exception:
pass
module.fail_json(msg='Error while writing result into temporary file: {0}'.format(e))
module.fail_json(
msg="Error while writing result into temporary file: {0}".format(e)
)
# Update destination to wanted permissions
if os.path.exists(file_args['path']):
if os.path.exists(file_args["path"]):
module.set_fs_attributes_if_different(file_args, False)
# Move tempfile to final destination
module.atomic_move(os.path.abspath(tmp_name), os.path.abspath(file_args['path']))
module.atomic_move(
os.path.abspath(tmp_name), os.path.abspath(file_args["path"])
)
# Try to update permissions again
if not module.check_file_absent_if_check_mode(file_args['path']):
if not module.check_file_absent_if_check_mode(file_args["path"]):
module.set_fs_attributes_if_different(file_args, False)
except Exception as e:
try:
os.remove(tmp_name)
except Exception:
pass
module.fail_json(msg='Error while writing result: {0}'.format(e))
module.fail_json(msg="Error while writing result: {0}".format(e))

View File

@@ -44,17 +44,24 @@ def safe_atomic_move(module, path, destination):
def _restore_all_on_failure(f):
def backup_and_restore(self, sources_and_destinations, *args, **kwargs):
backups = [(d, self.module.backup_local(d)) for s, d in sources_and_destinations if os.path.exists(d)]
backups = [
(d, self.module.backup_local(d))
for s, d in sources_and_destinations
if os.path.exists(d)
]
try:
f(self, sources_and_destinations, *args, **kwargs)
except Exception:
for destination, backup in backups:
self.module.atomic_move(os.path.abspath(backup), os.path.abspath(destination))
self.module.atomic_move(
os.path.abspath(backup), os.path.abspath(destination)
)
raise
else:
for destination, backup in backups:
self.module.add_cleanup_file(backup)
return backup_and_restore
@@ -85,10 +92,10 @@ class OpensshModule(object):
def result(self):
result = self._result
result['changed'] = self.changed
result["changed"] = self.changed
if self.module._diff:
result['diff'] = self.diff
result["diff"] = self.diff
return result
@@ -107,6 +114,7 @@ class OpensshModule(object):
def wrapper(self, *args, **kwargs):
if not self.check_mode:
f(self, *args, **kwargs)
return wrapper
@staticmethod
@@ -114,72 +122,92 @@ class OpensshModule(object):
def wrapper(self, *args, **kwargs):
f(self, *args, **kwargs)
self.changed = True
return wrapper
def _check_if_base_dir(self, path):
base_dir = os.path.dirname(path) or '.'
base_dir = os.path.dirname(path) or "."
if not os.path.isdir(base_dir):
self.module.fail_json(
name=base_dir,
msg='The directory %s does not exist or the file is not a directory' % base_dir
msg="The directory %s does not exist or the file is not a directory"
% base_dir,
)
def _get_ssh_version(self):
ssh_bin = self.module.get_bin_path('ssh')
ssh_bin = self.module.get_bin_path("ssh")
if not ssh_bin:
return ""
return parse_openssh_version(self.module.run_command([ssh_bin, '-V', '-q'], check_rc=True)[2].strip())
return parse_openssh_version(
self.module.run_command([ssh_bin, "-V", "-q"], check_rc=True)[2].strip()
)
@_restore_all_on_failure
def _safe_secure_move(self, sources_and_destinations):
"""Moves a list of files from 'source' to 'destination' and restores 'destination' from backup upon failure.
If 'destination' does not already exist, then 'source' permissions are preserved to prevent
exposing protected data ('atomic_move' uses the 'destination' base directory mask for
permissions if 'destination' does not already exists).
If 'destination' does not already exist, then 'source' permissions are preserved to prevent
exposing protected data ('atomic_move' uses the 'destination' base directory mask for
permissions if 'destination' does not already exists).
"""
for source, destination in sources_and_destinations:
if os.path.exists(destination):
self.module.atomic_move(os.path.abspath(source), os.path.abspath(destination))
self.module.atomic_move(
os.path.abspath(source), os.path.abspath(destination)
)
else:
self.module.preserved_copy(source, destination)
def _update_permissions(self, path):
file_args = self.module.load_file_common_arguments(self.module.params)
file_args['path'] = path
file_args["path"] = path
if not self.module.check_file_absent_if_check_mode(path):
self.changed = self.module.set_fs_attributes_if_different(file_args, self.changed)
self.changed = self.module.set_fs_attributes_if_different(
file_args, self.changed
)
else:
self.changed = True
class KeygenCommand(object):
def __init__(self, module):
self._bin_path = module.get_bin_path('ssh-keygen', True)
self._bin_path = module.get_bin_path("ssh-keygen", True)
self._run_command = module.run_command
def generate_certificate(self, certificate_path, identifier, options, pkcs11_provider, principals,
serial_number, signature_algorithm, signing_key_path, type,
time_parameters, use_agent, **kwargs):
args = [self._bin_path, '-s', signing_key_path, '-P', '', '-I', identifier]
def generate_certificate(
self,
certificate_path,
identifier,
options,
pkcs11_provider,
principals,
serial_number,
signature_algorithm,
signing_key_path,
type,
time_parameters,
use_agent,
**kwargs
):
args = [self._bin_path, "-s", signing_key_path, "-P", "", "-I", identifier]
if options:
for option in options:
args.extend(['-O', option])
args.extend(["-O", option])
if pkcs11_provider:
args.extend(['-D', pkcs11_provider])
args.extend(["-D", pkcs11_provider])
if principals:
args.extend(['-n', ','.join(principals)])
args.extend(["-n", ",".join(principals)])
if serial_number is not None:
args.extend(['-z', str(serial_number)])
if type == 'host':
args.extend(['-h'])
args.extend(["-z", str(serial_number)])
if type == "host":
args.extend(["-h"])
if use_agent:
args.extend(['-U'])
args.extend(["-U"])
if time_parameters.validity_string:
args.extend(['-V', time_parameters.validity_string])
args.extend(["-V", time_parameters.validity_string])
if signature_algorithm:
args.extend(['-t', signature_algorithm])
args.extend(["-t", signature_algorithm])
args.append(certificate_path)
return self._run_command(args, **kwargs)
@@ -187,44 +215,62 @@ class KeygenCommand(object):
def generate_keypair(self, private_key_path, size, type, comment, **kwargs):
args = [
self._bin_path,
'-q',
'-N', '',
'-b', str(size),
'-t', type,
'-f', private_key_path,
'-C', comment or ''
"-q",
"-N",
"",
"-b",
str(size),
"-t",
type,
"-f",
private_key_path,
"-C",
comment or "",
]
# "y" must be entered in response to the "overwrite" prompt
data = 'y' if os.path.exists(private_key_path) else None
data = "y" if os.path.exists(private_key_path) else None
return self._run_command(args, data=data, **kwargs)
def get_certificate_info(self, certificate_path, **kwargs):
return self._run_command([self._bin_path, '-L', '-f', certificate_path], **kwargs)
return self._run_command(
[self._bin_path, "-L", "-f", certificate_path], **kwargs
)
def get_matching_public_key(self, private_key_path, **kwargs):
return self._run_command([self._bin_path, '-P', '', '-y', '-f', private_key_path], **kwargs)
return self._run_command(
[self._bin_path, "-P", "", "-y", "-f", private_key_path], **kwargs
)
def get_private_key(self, private_key_path, **kwargs):
return self._run_command([self._bin_path, '-l', '-f', private_key_path], **kwargs)
return self._run_command(
[self._bin_path, "-l", "-f", private_key_path], **kwargs
)
def update_comment(self, private_key_path, comment, force_new_format=True, **kwargs):
if os.path.exists(private_key_path) and not os.access(private_key_path, os.W_OK):
def update_comment(
self, private_key_path, comment, force_new_format=True, **kwargs
):
if os.path.exists(private_key_path) and not os.access(
private_key_path, os.W_OK
):
try:
os.chmod(private_key_path, stat.S_IWUSR + stat.S_IRUSR)
except (IOError, OSError) as e:
raise e("The private key at %s is not writeable preventing a comment update" % private_key_path)
raise e(
"The private key at %s is not writeable preventing a comment update"
% private_key_path
)
command = [self._bin_path, '-q']
command = [self._bin_path, "-q"]
if force_new_format:
command.append('-o')
command.extend(['-c', '-C', comment, '-f', private_key_path])
command.append("-o")
command.extend(["-c", "-C", comment, "-f", private_key_path])
return self._run_command(command, **kwargs)
class PrivateKey(object):
def __init__(self, size, key_type, fingerprint, format=''):
def __init__(self, size, key_type, fingerprint, format=""):
self._size = size
self._type = key_type
self._fingerprint = fingerprint
@@ -258,10 +304,10 @@ class PrivateKey(object):
def to_dict(self):
return {
'size': self._size,
'type': self._type,
'fingerprint': self._fingerprint,
'format': self._format,
"size": self._size,
"type": self._type,
"fingerprint": self._fingerprint,
"format": self._format,
}
@@ -275,11 +321,17 @@ class PublicKey(object):
if not isinstance(other, type(self)):
return NotImplemented
return all([
self._type_string == other._type_string,
self._data == other._data,
(self._comment == other._comment) if self._comment is not None and other._comment is not None else True
])
return all(
[
self._type_string == other._type_string,
self._data == other._data,
(
(self._comment == other._comment)
if self._comment is not None and other._comment is not None
else True
),
]
)
def __ne__(self, other):
return not self == other
@@ -305,19 +357,19 @@ class PublicKey(object):
@classmethod
def from_string(cls, string):
properties = string.strip('\n').split(' ', 2)
properties = string.strip("\n").split(" ", 2)
return cls(
type_string=properties[0],
data=properties[1],
comment=properties[2] if len(properties) > 2 else ""
comment=properties[2] if len(properties) > 2 else "",
)
@classmethod
def load(cls, path):
try:
with open(path, 'r') as f:
properties = f.read().strip(' \n').split(' ', 2)
with open(path, "r") as f:
properties = f.read().strip(" \n").split(" ", 2)
except (IOError, OSError):
raise
@@ -327,25 +379,25 @@ class PublicKey(object):
return cls(
type_string=properties[0],
data=properties[1],
comment='' if len(properties) <= 2 else properties[2],
comment="" if len(properties) <= 2 else properties[2],
)
def to_dict(self):
return {
'comment': self._comment,
'public_key': self._data,
"comment": self._comment,
"public_key": self._data,
}
def parse_private_key_format(path):
with open(path, 'r') as file:
with open(path, "r") as file:
header = file.readline().strip()
if header == '-----BEGIN OPENSSH PRIVATE KEY-----':
return 'SSH'
elif header == '-----BEGIN PRIVATE KEY-----':
return 'PKCS8'
elif header == '-----BEGIN RSA PRIVATE KEY-----':
return 'PKCS1'
if header == "-----BEGIN OPENSSH PRIVATE KEY-----":
return "SSH"
elif header == "-----BEGIN PRIVATE KEY-----":
return "PKCS8"
elif header == "-----BEGIN RSA PRIVATE KEY-----":
return "PKCS1"
return ''
return ""

View File

@@ -48,14 +48,18 @@ class KeypairBackend(OpensshModule):
def __init__(self, module):
super(KeypairBackend, self).__init__(module)
self.comment = self.module.params['comment']
self.private_key_path = self.module.params['path']
self.public_key_path = self.private_key_path + '.pub'
self.regenerate = self.module.params['regenerate'] if not self.module.params['force'] else 'always'
self.state = self.module.params['state']
self.type = self.module.params['type']
self.comment = self.module.params["comment"]
self.private_key_path = self.module.params["path"]
self.public_key_path = self.private_key_path + ".pub"
self.regenerate = (
self.module.params["regenerate"]
if not self.module.params["force"]
else "always"
)
self.state = self.module.params["state"]
self.type = self.module.params["type"]
self.size = self._get_size(self.module.params['size'])
self.size = self._get_size(self.module.params["size"])
self._validate_path()
self.original_private_key = None
@@ -64,31 +68,35 @@ class KeypairBackend(OpensshModule):
self.public_key = None
def _get_size(self, size):
if self.type in ('rsa', 'rsa1'):
if self.type in ("rsa", "rsa1"):
result = 4096 if size is None else size
if result < 1024:
return self.module.fail_json(
msg="For RSA keys, the minimum size is 1024 bits and the default is 4096 bits. " +
"Attempting to use bit lengths under 1024 will cause the module to fail."
msg="For RSA keys, the minimum size is 1024 bits and the default is 4096 bits. "
+ "Attempting to use bit lengths under 1024 will cause the module to fail."
)
elif self.type == 'dsa':
elif self.type == "dsa":
result = 1024 if size is None else size
if result != 1024:
return self.module.fail_json(msg="DSA keys must be exactly 1024 bits as specified by FIPS 186-2.")
elif self.type == 'ecdsa':
return self.module.fail_json(
msg="DSA keys must be exactly 1024 bits as specified by FIPS 186-2."
)
elif self.type == "ecdsa":
result = 256 if size is None else size
if result not in (256, 384, 521):
return self.module.fail_json(
msg="For ECDSA keys, size determines the key length by selecting from one of " +
"three elliptic curve sizes: 256, 384 or 521 bits. " +
"Attempting to use bit lengths other than these three values for ECDSA keys will " +
"cause this module to fail."
msg="For ECDSA keys, size determines the key length by selecting from one of "
+ "three elliptic curve sizes: 256, 384 or 521 bits. "
+ "Attempting to use bit lengths other than these three values for ECDSA keys will "
+ "cause this module to fail."
)
elif self.type == 'ed25519':
elif self.type == "ed25519":
# User input is ignored for `key size` when `key type` is ed25519
result = 256
else:
return self.module.fail_json(msg="%s is not a valid value for key type" % self.type)
return self.module.fail_json(
msg="%s is not a valid value for key type" % self.type
)
return result
@@ -96,13 +104,16 @@ class KeypairBackend(OpensshModule):
self._check_if_base_dir(self.private_key_path)
if os.path.isdir(self.private_key_path):
self.module.fail_json(msg='%s is a directory. Please specify a path to a file.' % self.private_key_path)
self.module.fail_json(
msg="%s is a directory. Please specify a path to a file."
% self.private_key_path
)
def _execute(self):
self.original_private_key = self._load_private_key()
self.original_public_key = self._load_public_key()
if self.state == 'present':
if self.state == "present":
self._validate_key_load()
if self._should_generate():
@@ -149,13 +160,15 @@ class KeypairBackend(OpensshModule):
return os.path.exists(self.public_key_path)
def _validate_key_load(self):
if (self._private_key_exists()
and self.regenerate in ('never', 'fail', 'partial_idempotence')
and (self.original_private_key is None or not self._private_key_readable())):
if (
self._private_key_exists()
and self.regenerate in ("never", "fail", "partial_idempotence")
and (self.original_private_key is None or not self._private_key_readable())
):
self.module.fail_json(
msg="Unable to read the key. The key is protected with a passphrase or broken. " +
"Will not proceed. To force regeneration, call the module with `generate` " +
"set to `full_idempotence` or `always`, or with `force=true`."
msg="Unable to read the key. The key is protected with a passphrase or broken. "
+ "Will not proceed. To force regeneration, call the module with `generate` "
+ "set to `full_idempotence` or `always`, or with `force=true`."
)
@abc.abstractmethod
@@ -165,17 +178,17 @@ class KeypairBackend(OpensshModule):
def _should_generate(self):
if self.original_private_key is None:
return True
elif self.regenerate == 'never':
elif self.regenerate == "never":
return False
elif self.regenerate == 'fail':
elif self.regenerate == "fail":
if not self._private_key_valid():
self.module.fail_json(
msg="Key has wrong type and/or size. Will not proceed. " +
"To force regeneration, call the module with `generate` set to " +
"`partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
msg="Key has wrong type and/or size. Will not proceed. "
+ "To force regeneration, call the module with `generate` set to "
+ "`partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
)
return False
elif self.regenerate in ('partial_idempotence', 'full_idempotence'):
elif self.regenerate in ("partial_idempotence", "full_idempotence"):
return not self._private_key_valid()
else:
return True
@@ -184,11 +197,13 @@ class KeypairBackend(OpensshModule):
if self.original_private_key is None:
return False
return all([
self.size == self.original_private_key.size,
self.type == self.original_private_key.type,
self._private_key_valid_backend(),
])
return all(
[
self.size == self.original_private_key.size,
self.type == self.original_private_key.type,
self._private_key_valid_backend(),
]
)
@abc.abstractmethod
def _private_key_valid_backend(self):
@@ -200,13 +215,20 @@ class KeypairBackend(OpensshModule):
temp_private_key, temp_public_key = self._generate_temp_keypair()
try:
self._safe_secure_move([(temp_private_key, self.private_key_path), (temp_public_key, self.public_key_path)])
self._safe_secure_move(
[
(temp_private_key, self.private_key_path),
(temp_public_key, self.public_key_path),
]
)
except OSError as e:
self.module.fail_json(msg=to_native(e))
def _generate_temp_keypair(self):
temp_private_key = os.path.join(self.module.tmpdir, os.path.basename(self.private_key_path))
temp_public_key = temp_private_key + '.pub'
temp_private_key = os.path.join(
self.module.tmpdir, os.path.basename(self.private_key_path)
)
temp_public_key = temp_private_key + ".pub"
try:
self._generate_keypair(temp_private_key)
@@ -239,27 +261,33 @@ class KeypairBackend(OpensshModule):
@OpensshModule.skip_if_check_mode
def _restore_public_key(self):
try:
temp_public_key = self._create_temp_public_key(str(self._get_public_key()) + '\n')
self._safe_secure_move([
(temp_public_key, self.public_key_path)
])
temp_public_key = self._create_temp_public_key(
str(self._get_public_key()) + "\n"
)
self._safe_secure_move([(temp_public_key, self.public_key_path)])
except (IOError, OSError):
self.module.fail_json(
msg="The public key is missing or does not match the private key. " +
"Unable to regenerate the public key."
msg="The public key is missing or does not match the private key. "
+ "Unable to regenerate the public key."
)
if self.comment:
self._update_comment()
def _create_temp_public_key(self, content):
temp_public_key = os.path.join(self.module.tmpdir, os.path.basename(self.public_key_path))
temp_public_key = os.path.join(
self.module.tmpdir, os.path.basename(self.public_key_path)
)
default_permissions = 0o644
existing_permissions = file_mode(self.public_key_path)
try:
secure_write(temp_public_key, existing_permissions or default_permissions, to_bytes(content))
secure_write(
temp_public_key,
existing_permissions or default_permissions,
to_bytes(content),
)
except (IOError, OSError) as e:
self.module.fail_json(msg=to_native(e))
self.module.add_cleanup_file(temp_public_key)
@@ -290,25 +318,29 @@ class KeypairBackend(OpensshModule):
public_key = self.public_key or self.original_public_key
return {
'size': self.size,
'type': self.type,
'filename': self.private_key_path,
'fingerprint': private_key.fingerprint if private_key else '',
'public_key': str(public_key) if public_key else '',
'comment': public_key.comment if public_key else '',
"size": self.size,
"type": self.type,
"filename": self.private_key_path,
"fingerprint": private_key.fingerprint if private_key else "",
"public_key": str(public_key) if public_key else "",
"comment": public_key.comment if public_key else "",
}
@property
def diff(self):
before = self.original_private_key.to_dict() if self.original_private_key else {}
before.update(self.original_public_key.to_dict() if self.original_public_key else {})
before = (
self.original_private_key.to_dict() if self.original_private_key else {}
)
before.update(
self.original_public_key.to_dict() if self.original_public_key else {}
)
after = self.private_key.to_dict() if self.private_key else {}
after.update(self.public_key.to_dict() if self.public_key else {})
return {
'before': before,
'after': after,
"before": before,
"after": after,
}
@@ -316,36 +348,59 @@ class KeypairBackendOpensshBin(KeypairBackend):
def __init__(self, module):
super(KeypairBackendOpensshBin, self).__init__(module)
if self.module.params['private_key_format'] != 'auto':
if self.module.params["private_key_format"] != "auto":
self.module.fail_json(
msg="'auto' is the only valid option for " +
"'private_key_format' when 'backend' is not 'cryptography'"
msg="'auto' is the only valid option for "
+ "'private_key_format' when 'backend' is not 'cryptography'"
)
self.ssh_keygen = KeygenCommand(self.module)
def _generate_keypair(self, private_key_path):
self.ssh_keygen.generate_keypair(private_key_path, self.size, self.type, self.comment, check_rc=True)
self.ssh_keygen.generate_keypair(
private_key_path, self.size, self.type, self.comment, check_rc=True
)
def _get_private_key(self):
rc, private_key_content, err = self.ssh_keygen.get_private_key(self.private_key_path, check_rc=False)
rc, private_key_content, err = self.ssh_keygen.get_private_key(
self.private_key_path, check_rc=False
)
if rc != 0:
raise ValueError(err)
return PrivateKey.from_string(private_key_content)
def _get_public_key(self):
public_key_content = self.ssh_keygen.get_matching_public_key(self.private_key_path, check_rc=True)[1]
public_key_content = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=True
)[1]
return PublicKey.from_string(public_key_content)
def _private_key_readable(self):
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(self.private_key_path, check_rc=False)
return not (rc == 255 or any_in(stderr, 'is not a public key file', 'incorrect passphrase', 'load failed'))
rc, stdout, stderr = self.ssh_keygen.get_matching_public_key(
self.private_key_path, check_rc=False
)
return not (
rc == 255
or any_in(
stderr,
"is not a public key file",
"incorrect passphrase",
"load failed",
)
)
def _update_comment(self):
try:
ssh_version = self._get_ssh_version() or "7.8"
force_new_format = LooseVersion('6.5') <= LooseVersion(ssh_version) < LooseVersion('7.8')
self.ssh_keygen.update_comment(self.private_key_path, self.comment, force_new_format=force_new_format, check_rc=True)
force_new_format = (
LooseVersion("6.5") <= LooseVersion(ssh_version) < LooseVersion("7.8")
)
self.ssh_keygen.update_comment(
self.private_key_path,
self.comment,
force_new_format=force_new_format,
check_rc=True,
)
except (IOError, OSError) as e:
self.module.fail_json(msg=to_native(e))
@@ -357,30 +412,41 @@ class KeypairBackendCryptography(KeypairBackend):
def __init__(self, module):
super(KeypairBackendCryptography, self).__init__(module)
if self.type == 'rsa1':
self.module.fail_json(msg="RSA1 keys are not supported by the cryptography backend")
if self.type == "rsa1":
self.module.fail_json(
msg="RSA1 keys are not supported by the cryptography backend"
)
self.passphrase = to_bytes(module.params['passphrase']) if module.params['passphrase'] else None
self.private_key_format = self._get_key_format(module.params['private_key_format'])
self.passphrase = (
to_bytes(module.params["passphrase"])
if module.params["passphrase"]
else None
)
self.private_key_format = self._get_key_format(
module.params["private_key_format"]
)
def _get_key_format(self, key_format):
result = 'SSH'
result = "SSH"
if key_format == 'auto':
if key_format == "auto":
# Default to OpenSSH 7.8 compatibility when OpenSSH is not installed
ssh_version = self._get_ssh_version() or "7.8"
if LooseVersion(ssh_version) < LooseVersion("7.8") and self.type != 'ed25519':
if (
LooseVersion(ssh_version) < LooseVersion("7.8")
and self.type != "ed25519"
):
# OpenSSH made SSH formatted private keys available in version 6.5,
# but still defaulted to PKCS1 format with the exception of ed25519 keys
result = 'PKCS1'
result = "PKCS1"
if result == 'SSH' and not HAS_OPENSSH_PRIVATE_FORMAT:
if result == "SSH" and not HAS_OPENSSH_PRIVATE_FORMAT:
self.module.fail_json(
msg=missing_required_lib(
'cryptography >= 3.0',
reason="to load/dump private keys in the default OpenSSH format for OpenSSH >= 7.8 " +
"or for ed25519 keys"
"cryptography >= 3.0",
reason="to load/dump private keys in the default OpenSSH format for OpenSSH >= 7.8 "
+ "or for ed25519 keys",
)
)
else:
@@ -393,7 +459,7 @@ class KeypairBackendCryptography(KeypairBackend):
keytype=self.type,
size=self.size,
passphrase=self.passphrase,
comment=self.comment or '',
comment=self.comment or "",
)
encoded_private_key = OpensshKeypair.encode_openssh_privatekey(
@@ -401,22 +467,28 @@ class KeypairBackendCryptography(KeypairBackend):
)
secure_write(private_key_path, 0o600, encoded_private_key)
public_key_path = private_key_path + '.pub'
public_key_path = private_key_path + ".pub"
secure_write(public_key_path, 0o644, keypair.public_key)
def _get_private_key(self):
keypair = OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
keypair = OpensshKeypair.load(
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
)
return PrivateKey(
size=keypair.size,
key_type=keypair.key_type,
fingerprint=keypair.fingerprint,
format=parse_private_key_format(self.private_key_path)
format=parse_private_key_format(self.private_key_path),
)
def _get_public_key(self):
try:
keypair = OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
keypair = OpensshKeypair.load(
path=self.private_key_path,
passphrase=self.passphrase,
no_public_key=True,
)
except OpenSSHError:
# Simulates the null output of ssh-keygen
return ""
@@ -425,7 +497,11 @@ class KeypairBackendCryptography(KeypairBackend):
def _private_key_readable(self):
try:
OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
OpensshKeypair.load(
path=self.private_key_path,
passphrase=self.passphrase,
no_public_key=True,
)
except (InvalidPrivateKeyFileError, InvalidPassphraseError):
return False
@@ -433,7 +509,9 @@ class KeypairBackendCryptography(KeypairBackend):
# when loading an unencrypted key
if self.passphrase:
try:
OpensshKeypair.load(path=self.private_key_path, passphrase=None, no_public_key=True)
OpensshKeypair.load(
path=self.private_key_path, passphrase=None, no_public_key=True
)
except (InvalidPrivateKeyFileError, InvalidPassphraseError):
return True
else:
@@ -442,14 +520,16 @@ class KeypairBackendCryptography(KeypairBackend):
return True
def _update_comment(self):
keypair = OpensshKeypair.load(path=self.private_key_path, passphrase=self.passphrase, no_public_key=True)
keypair = OpensshKeypair.load(
path=self.private_key_path, passphrase=self.passphrase, no_public_key=True
)
try:
keypair.comment = self.comment
except InvalidCommentError as e:
self.module.fail_json(msg=to_native(e))
try:
temp_public_key = self._create_temp_public_key(keypair.public_key + b'\n')
temp_public_key = self._create_temp_public_key(keypair.public_key + b"\n")
self._safe_secure_move([(temp_public_key, self.public_key_path)])
except (IOError, OSError) as e:
self.module.fail_json(msg=to_native(e))
@@ -457,7 +537,7 @@ class KeypairBackendCryptography(KeypairBackend):
def _private_key_valid_backend(self):
# avoids breaking behavior and prevents
# automatic conversions with OpenSSH upgrades
if self.module.params['private_key_format'] == 'auto':
if self.module.params["private_key_format"] == "auto":
return True
return self.private_key_format == self.original_private_key.format
@@ -465,24 +545,26 @@ class KeypairBackendCryptography(KeypairBackend):
def select_backend(module, backend):
can_use_cryptography = HAS_OPENSSH_SUPPORT
can_use_opensshbin = bool(module.get_bin_path('ssh-keygen'))
can_use_opensshbin = bool(module.get_bin_path("ssh-keygen"))
if backend == 'auto':
if can_use_opensshbin and not module.params['passphrase']:
backend = 'opensshbin'
if backend == "auto":
if can_use_opensshbin and not module.params["passphrase"]:
backend = "opensshbin"
elif can_use_cryptography:
backend = 'cryptography'
backend = "cryptography"
else:
module.fail_json(msg="Cannot find either the OpenSSH binary in the PATH " +
"or cryptography >= 2.6 installed on this system")
module.fail_json(
msg="Cannot find either the OpenSSH binary in the PATH "
+ "or cryptography >= 2.6 installed on this system"
)
if backend == 'opensshbin':
if backend == "opensshbin":
if not can_use_opensshbin:
module.fail_json(msg="Cannot find the OpenSSH binary in the PATH")
return backend, KeypairBackendOpensshBin(module)
elif backend == 'cryptography':
elif backend == "cryptography":
if not can_use_cryptography:
module.fail_json(msg=missing_required_lib("cryptography >= 2.6"))
return backend, KeypairBackendCryptography(module)
else:
raise ValueError('Unsupported value for backend: {0}'.format(backend))
raise ValueError("Unsupported value for backend: {0}".format(backend))

View File

@@ -51,54 +51,56 @@ _USER_TYPE = 1
_HOST_TYPE = 2
_SSH_TYPE_STRINGS = {
'rsa': b"ssh-rsa",
'dsa': b"ssh-dss",
'ecdsa-nistp256': b"ecdsa-sha2-nistp256",
'ecdsa-nistp384': b"ecdsa-sha2-nistp384",
'ecdsa-nistp521': b"ecdsa-sha2-nistp521",
'ed25519': b"ssh-ed25519",
"rsa": b"ssh-rsa",
"dsa": b"ssh-dss",
"ecdsa-nistp256": b"ecdsa-sha2-nistp256",
"ecdsa-nistp384": b"ecdsa-sha2-nistp384",
"ecdsa-nistp521": b"ecdsa-sha2-nistp521",
"ed25519": b"ssh-ed25519",
}
_CERT_SUFFIX_V01 = b"-cert-v01@openssh.com"
# See https://datatracker.ietf.org/doc/html/rfc5656#section-6.1
_ECDSA_CURVE_IDENTIFIERS = {
'ecdsa-nistp256': b'nistp256',
'ecdsa-nistp384': b'nistp384',
'ecdsa-nistp521': b'nistp521',
"ecdsa-nistp256": b"nistp256",
"ecdsa-nistp384": b"nistp384",
"ecdsa-nistp521": b"nistp521",
}
_ECDSA_CURVE_IDENTIFIERS_LOOKUP = {
b'nistp256': 'ecdsa-nistp256',
b'nistp384': 'ecdsa-nistp384',
b'nistp521': 'ecdsa-nistp521',
b"nistp256": "ecdsa-nistp256",
b"nistp384": "ecdsa-nistp384",
b"nistp521": "ecdsa-nistp521",
}
_USE_TIMEZONE = sys.version_info >= (3, 6)
_ALWAYS = _add_or_remove_timezone(datetime(1970, 1, 1), with_timezone=_USE_TIMEZONE)
_FOREVER = datetime(9999, 12, 31, 23, 59, 59, 999999, _UTC) if _USE_TIMEZONE else datetime.max
_FOREVER = (
datetime(9999, 12, 31, 23, 59, 59, 999999, _UTC) if _USE_TIMEZONE else datetime.max
)
_CRITICAL_OPTIONS = (
'force-command',
'source-address',
'verify-required',
"force-command",
"source-address",
"verify-required",
)
_DIRECTIVES = (
'clear',
'no-x11-forwarding',
'no-agent-forwarding',
'no-port-forwarding',
'no-pty',
'no-user-rc',
"clear",
"no-x11-forwarding",
"no-agent-forwarding",
"no-port-forwarding",
"no-pty",
"no-user-rc",
)
_EXTENSIONS = (
'permit-x11-forwarding',
'permit-agent-forwarding',
'permit-port-forwarding',
'permit-pty',
'permit-user-rc'
"permit-x11-forwarding",
"permit-agent-forwarding",
"permit-port-forwarding",
"permit-pty",
"permit-user-rc",
)
if six.PY3:
@@ -111,13 +113,19 @@ class OpensshCertificateTimeParameters(object):
self._valid_to = self.to_datetime(valid_to)
if self._valid_from > self._valid_to:
raise ValueError("Valid from: %s must not be greater than Valid to: %s" % (valid_from, valid_to))
raise ValueError(
"Valid from: %s must not be greater than Valid to: %s"
% (valid_from, valid_to)
)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
else:
return self._valid_from == other._valid_from and self._valid_to == other._valid_to
return (
self._valid_from == other._valid_from
and self._valid_to == other._valid_to
)
def __ne__(self, other):
return not self == other
@@ -126,7 +134,8 @@ class OpensshCertificateTimeParameters(object):
def validity_string(self):
if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER):
return "%s:%s" % (
self.valid_from(date_format='openssh'), self.valid_to(date_format='openssh')
self.valid_from(date_format="openssh"),
self.valid_to(date_format="openssh"),
)
return ""
@@ -144,16 +153,22 @@ class OpensshCertificateTimeParameters(object):
@staticmethod
def format_datetime(dt, date_format):
if date_format in ('human_readable', 'openssh'):
if date_format in ("human_readable", "openssh"):
if dt == _ALWAYS:
result = 'always'
result = "always"
elif dt == _FOREVER:
result = 'forever'
result = "forever"
else:
result = dt.isoformat().replace('+00:00', '') if date_format == 'human_readable' else dt.strftime("%Y%m%d%H%M%S")
elif date_format == 'timestamp':
result = (
dt.isoformat().replace("+00:00", "")
if date_format == "human_readable"
else dt.strftime("%Y%m%d%H%M%S")
)
elif date_format == "timestamp":
td = dt - _ALWAYS
result = int((td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / 10 ** 6)
result = int(
(td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6
)
else:
raise ValueError("%s is not a valid format" % date_format)
return result
@@ -162,12 +177,17 @@ class OpensshCertificateTimeParameters(object):
def to_datetime(time_string_or_timestamp):
try:
if isinstance(time_string_or_timestamp, six.string_types):
result = OpensshCertificateTimeParameters._time_string_to_datetime(time_string_or_timestamp.strip())
result = OpensshCertificateTimeParameters._time_string_to_datetime(
time_string_or_timestamp.strip()
)
elif isinstance(time_string_or_timestamp, (long, int)):
result = OpensshCertificateTimeParameters._timestamp_to_datetime(time_string_or_timestamp)
result = OpensshCertificateTimeParameters._timestamp_to_datetime(
time_string_or_timestamp
)
else:
raise ValueError(
"Value must be of type (str, unicode, int, long) not %s" % type(time_string_or_timestamp)
"Value must be of type (str, unicode, int, long) not %s"
% type(time_string_or_timestamp)
)
except ValueError:
raise
@@ -182,7 +202,9 @@ class OpensshCertificateTimeParameters(object):
else:
try:
if _USE_TIMEZONE:
result = datetime.fromtimestamp(timestamp, tz=_datetime.timezone.utc)
result = datetime.fromtimestamp(
timestamp, tz=_datetime.timezone.utc
)
else:
result = datetime.utcfromtimestamp(timestamp)
except OverflowError:
@@ -192,16 +214,21 @@ class OpensshCertificateTimeParameters(object):
@staticmethod
def _time_string_to_datetime(time_string):
result = None
if time_string == 'always':
if time_string == "always":
result = _ALWAYS
elif time_string == 'forever':
elif time_string == "forever":
result = _FOREVER
elif is_relative_time_string(time_string):
result = convert_relative_to_datetime(time_string, with_timezone=_USE_TIMEZONE)
result = convert_relative_to_datetime(
time_string, with_timezone=_USE_TIMEZONE
)
else:
for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
try:
result = _add_or_remove_timezone(datetime.strptime(time_string, time_format), with_timezone=_USE_TIMEZONE)
result = _add_or_remove_timezone(
datetime.strptime(time_string, time_format),
with_timezone=_USE_TIMEZONE,
)
except ValueError:
pass
if result is None:
@@ -211,7 +238,7 @@ class OpensshCertificateTimeParameters(object):
class OpensshCertificateOption(object):
def __init__(self, option_type, name, data):
if option_type not in ('critical', 'extension'):
if option_type not in ("critical", "extension"):
raise ValueError("type must be either 'critical' or 'extension'")
if not isinstance(name, six.string_types):
@@ -228,11 +255,13 @@ class OpensshCertificateOption(object):
if not isinstance(other, type(self)):
return NotImplemented
return all([
self._option_type == other._option_type,
self._name == other._name,
self._data == other._data,
])
return all(
[
self._option_type == other._option_type,
self._name == other._name,
self._data == other._data,
]
)
def __hash__(self):
return hash((self._option_type, self._name, self._data))
@@ -260,42 +289,47 @@ class OpensshCertificateOption(object):
@classmethod
def from_string(cls, option_string):
if not isinstance(option_string, six.string_types):
raise ValueError("option_string must be a string not %s" % type(option_string))
raise ValueError(
"option_string must be a string not %s" % type(option_string)
)
option_type = None
if ':' in option_string:
option_type, value = option_string.strip().split(':', 1)
if '=' in value:
name, data = value.split('=', 1)
if ":" in option_string:
option_type, value = option_string.strip().split(":", 1)
if "=" in value:
name, data = value.split("=", 1)
else:
name, data = value, ''
elif '=' in option_string:
name, data = option_string.strip().split('=', 1)
name, data = value, ""
elif "=" in option_string:
name, data = option_string.strip().split("=", 1)
else:
name, data = option_string.strip(), ''
name, data = option_string.strip(), ""
return cls(
option_type=option_type or get_option_type(name.lower()),
name=name,
data=data
data=data,
)
@six.add_metaclass(abc.ABCMeta)
class OpensshCertificateInfo:
"""Encapsulates all certificate information which is signed by a CA key"""
def __init__(self,
nonce=None,
serial=None,
cert_type=None,
key_id=None,
principals=None,
valid_after=None,
valid_before=None,
critical_options=None,
extensions=None,
reserved=None,
signing_key=None):
def __init__(
self,
nonce=None,
serial=None,
cert_type=None,
key_id=None,
principals=None,
valid_after=None,
valid_before=None,
critical_options=None,
extensions=None,
reserved=None,
signing_key=None,
):
self.nonce = nonce
self.serial = serial
self._cert_type = cert_type
@@ -313,17 +347,17 @@ class OpensshCertificateInfo:
@property
def cert_type(self):
if self._cert_type == _USER_TYPE:
return 'user'
return "user"
elif self._cert_type == _HOST_TYPE:
return 'host'
return "host"
else:
return ''
return ""
@cert_type.setter
def cert_type(self, cert_type):
if cert_type == 'user' or cert_type == _USER_TYPE:
if cert_type == "user" or cert_type == _USER_TYPE:
self._cert_type = _USER_TYPE
elif cert_type == 'host' or cert_type == _HOST_TYPE:
elif cert_type == "host" or cert_type == _HOST_TYPE:
self._cert_type = _HOST_TYPE
else:
raise ValueError("%s is not a valid certificate type" % cert_type)
@@ -343,17 +377,17 @@ class OpensshCertificateInfo:
class OpensshRSACertificateInfo(OpensshCertificateInfo):
def __init__(self, e=None, n=None, **kwargs):
super(OpensshRSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS['rsa'] + _CERT_SUFFIX_V01
self.type_string = _SSH_TYPE_STRINGS["rsa"] + _CERT_SUFFIX_V01
self.e = e
self.n = n
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.e is None, self.n is None]):
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS['rsa'])
writer.string(_SSH_TYPE_STRINGS["rsa"])
writer.mpint(self.e)
writer.mpint(self.n)
@@ -367,7 +401,7 @@ class OpensshRSACertificateInfo(OpensshCertificateInfo):
class OpensshDSACertificateInfo(OpensshCertificateInfo):
def __init__(self, p=None, q=None, g=None, y=None, **kwargs):
super(OpensshDSACertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS['dsa'] + _CERT_SUFFIX_V01
self.type_string = _SSH_TYPE_STRINGS["dsa"] + _CERT_SUFFIX_V01
self.p = p
self.q = q
self.g = g
@@ -376,10 +410,10 @@ class OpensshDSACertificateInfo(OpensshCertificateInfo):
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.p is None, self.q is None, self.g is None, self.y is None]):
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS['dsa'])
writer.string(_SSH_TYPE_STRINGS["dsa"])
writer.mpint(self.p)
writer.mpint(self.q)
writer.mpint(self.g)
@@ -411,16 +445,20 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
def curve(self, curve):
if curve in _ECDSA_CURVE_IDENTIFIERS.values():
self._curve = curve
self.type_string = _SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[curve]] + _CERT_SUFFIX_V01
self.type_string = (
_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[curve]]
+ _CERT_SUFFIX_V01
)
else:
raise ValueError(
"Curve must be one of %s" % (b','.join(list(_ECDSA_CURVE_IDENTIFIERS.values()))).decode('UTF-8')
"Curve must be one of %s"
% (b",".join(list(_ECDSA_CURVE_IDENTIFIERS.values()))).decode("UTF-8")
)
# See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
def public_key_fingerprint(self):
if any([self.curve is None, self.public_key is None]):
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[self.curve]])
@@ -437,15 +475,15 @@ class OpensshECDSACertificateInfo(OpensshCertificateInfo):
class OpensshED25519CertificateInfo(OpensshCertificateInfo):
def __init__(self, pk=None, **kwargs):
super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
self.type_string = _SSH_TYPE_STRINGS['ed25519'] + _CERT_SUFFIX_V01
self.type_string = _SSH_TYPE_STRINGS["ed25519"] + _CERT_SUFFIX_V01
self.pk = pk
def public_key_fingerprint(self):
if self.pk is None:
return b''
return b""
writer = _OpensshWriter()
writer.string(_SSH_TYPE_STRINGS['ed25519'])
writer.string(_SSH_TYPE_STRINGS["ed25519"])
writer.string(self.pk)
return fingerprint(writer.bytes())
@@ -457,6 +495,7 @@ class OpensshED25519CertificateInfo(OpensshCertificateInfo):
# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
class OpensshCertificate(object):
"""Encapsulates a formatted OpenSSH certificate including signature and signing key"""
def __init__(self, cert_info, signature):
self._cert_info = cert_info
@@ -468,13 +507,13 @@ class OpensshCertificate(object):
raise ValueError("%s is not a valid path." % path)
try:
with open(path, 'rb') as cert_file:
with open(path, "rb") as cert_file:
data = cert_file.read()
except (IOError, OSError) as e:
raise ValueError("%s cannot be opened for reading: %s" % (path, e))
try:
format_identifier, b64_cert = data.split(b' ')[:2]
format_identifier, b64_cert = data.split(b" ")[:2]
cert = binascii.a2b_base64(b64_cert)
except (binascii.Error, ValueError):
raise ValueError("Certificate not in OpenSSH format")
@@ -484,7 +523,9 @@ class OpensshCertificate(object):
pub_key_type = key_type
break
else:
raise ValueError("Invalid certificate format identifier: %s" % format_identifier)
raise ValueError(
"Invalid certificate format identifier: %s" % format_identifier
)
parser = OpensshParser(cert)
@@ -499,7 +540,8 @@ class OpensshCertificate(object):
if parser.remaining_bytes():
raise ValueError(
"%s bytes of additional data was not parsed while loading %s" % (parser.remaining_bytes(), path)
"%s bytes of additional data was not parsed while loading %s"
% (parser.remaining_bytes(), path)
)
return cls(
@@ -546,12 +588,16 @@ class OpensshCertificate(object):
@property
def critical_options(self):
return [
OpensshCertificateOption('critical', to_text(n), to_text(d)) for n, d in self._cert_info.critical_options
OpensshCertificateOption("critical", to_text(n), to_text(d))
for n, d in self._cert_info.critical_options
]
@property
def extensions(self):
return [OpensshCertificateOption('extension', to_text(n), to_text(d)) for n, d in self._cert_info.extensions]
return [
OpensshCertificateOption("extension", to_text(n), to_text(d))
for n, d in self._cert_info.extensions
]
@property
def reserved(self):
@@ -564,7 +610,7 @@ class OpensshCertificate(object):
@property
def signature_type(self):
signature_data = OpensshParser.signature_data(self.signature)
return to_text(signature_data['signature_type'])
return to_text(signature_data["signature_type"])
@staticmethod
def _parse_cert_info(pub_key_type, parser):
@@ -586,23 +632,24 @@ class OpensshCertificate(object):
def to_dict(self):
time_parameters = OpensshCertificateTimeParameters(
valid_from=self.valid_after,
valid_to=self.valid_before
valid_from=self.valid_after, valid_to=self.valid_before
)
return {
'type_string': self.type_string,
'nonce': self.nonce,
'serial': self.serial,
'cert_type': self.type,
'identifier': self.key_id,
'principals': self.principals,
'valid_after': time_parameters.valid_from(date_format='human_readable'),
'valid_before': time_parameters.valid_to(date_format='human_readable'),
'critical_options': [str(critical_option) for critical_option in self.critical_options],
'extensions': [str(extension) for extension in self.extensions],
'reserved': self.reserved,
'public_key': self.public_key,
'signing_key': self.signing_key,
"type_string": self.type_string,
"nonce": self.nonce,
"serial": self.serial,
"cert_type": self.type,
"identifier": self.key_id,
"principals": self.principals,
"valid_after": time_parameters.valid_from(date_format="human_readable"),
"valid_before": time_parameters.valid_to(date_format="human_readable"),
"critical_options": [
str(critical_option) for critical_option in self.critical_options
],
"extensions": [str(extension) for extension in self.extensions],
"reserved": self.reserved,
"public_key": self.public_key,
"signing_key": self.signing_key,
}
@@ -611,38 +658,46 @@ def apply_directives(directives):
raise ValueError("directives must be one of %s" % ", ".join(_DIRECTIVES))
directive_to_option = {
'no-x11-forwarding': OpensshCertificateOption('extension', 'permit-x11-forwarding', ''),
'no-agent-forwarding': OpensshCertificateOption('extension', 'permit-agent-forwarding', ''),
'no-port-forwarding': OpensshCertificateOption('extension', 'permit-port-forwarding', ''),
'no-pty': OpensshCertificateOption('extension', 'permit-pty', ''),
'no-user-rc': OpensshCertificateOption('extension', 'permit-user-rc', ''),
"no-x11-forwarding": OpensshCertificateOption(
"extension", "permit-x11-forwarding", ""
),
"no-agent-forwarding": OpensshCertificateOption(
"extension", "permit-agent-forwarding", ""
),
"no-port-forwarding": OpensshCertificateOption(
"extension", "permit-port-forwarding", ""
),
"no-pty": OpensshCertificateOption("extension", "permit-pty", ""),
"no-user-rc": OpensshCertificateOption("extension", "permit-user-rc", ""),
}
if 'clear' in directives:
if "clear" in directives:
return []
else:
return list(set(default_options()) - set(directive_to_option[d] for d in directives))
return list(
set(default_options()) - set(directive_to_option[d] for d in directives)
)
def default_options():
return [OpensshCertificateOption('extension', name, '') for name in _EXTENSIONS]
return [OpensshCertificateOption("extension", name, "") for name in _EXTENSIONS]
def fingerprint(public_key):
"""Generates a SHA256 hash and formats output to resemble ``ssh-keygen``"""
h = sha256()
h.update(public_key)
return b'SHA256:' + b64encode(h.digest()).rstrip(b'=')
return b"SHA256:" + b64encode(h.digest()).rstrip(b"=")
def get_cert_info_object(key_type):
if key_type == 'rsa':
if key_type == "rsa":
cert_info = OpensshRSACertificateInfo()
elif key_type == 'dsa':
elif key_type == "dsa":
cert_info = OpensshDSACertificateInfo()
elif key_type in ('ecdsa-nistp256', 'ecdsa-nistp384', 'ecdsa-nistp521'):
elif key_type in ("ecdsa-nistp256", "ecdsa-nistp384", "ecdsa-nistp521"):
cert_info = OpensshECDSACertificateInfo()
elif key_type == 'ed25519':
elif key_type == "ed25519":
cert_info = OpensshED25519CertificateInfo()
else:
raise ValueError("%s is not a valid key type" % key_type)
@@ -652,12 +707,14 @@ def get_cert_info_object(key_type):
def get_option_type(name):
if name in _CRITICAL_OPTIONS:
result = 'critical'
result = "critical"
elif name in _EXTENSIONS:
result = 'extension'
result = "extension"
else:
raise ValueError("%s is not a valid option. " % name +
"Custom options must start with 'critical:' or 'extension:' to indicate type")
raise ValueError(
"%s is not a valid option. " % name
+ "Custom options must start with 'critical:' or 'extension:' to indicate type"
)
return result
@@ -675,7 +732,7 @@ def parse_option_list(option_list):
directives.append(option.lower())
else:
option_object = OpensshCertificateOption.from_string(option)
if option_object.type == 'critical':
if option_object.type == "critical":
critical_options.append(option_object)
else:
extensions.append(option_object)

View File

@@ -38,41 +38,41 @@ try:
HAS_OPENSSH_SUPPORT = True
_ALGORITHM_PARAMETERS = {
'rsa': {
'default_size': 2048,
'valid_sizes': range(1024, 16384),
'signer_params': {
'padding': padding.PSS(
"rsa": {
"default_size": 2048,
"valid_sizes": range(1024, 16384),
"signer_params": {
"padding": padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
'algorithm': hashes.SHA256(),
"algorithm": hashes.SHA256(),
},
},
'dsa': {
'default_size': 1024,
'valid_sizes': [1024],
'signer_params': {
'algorithm': hashes.SHA256(),
"dsa": {
"default_size": 1024,
"valid_sizes": [1024],
"signer_params": {
"algorithm": hashes.SHA256(),
},
},
'ed25519': {
'default_size': 256,
'valid_sizes': [256],
'signer_params': {},
"ed25519": {
"default_size": 256,
"valid_sizes": [256],
"signer_params": {},
},
'ecdsa': {
'default_size': 256,
'valid_sizes': [256, 384, 521],
'signer_params': {
'signature_algorithm': ec.ECDSA(hashes.SHA256()),
"ecdsa": {
"default_size": 256,
"valid_sizes": [256, 384, 521],
"signer_params": {
"signature_algorithm": ec.ECDSA(hashes.SHA256()),
},
'curves': {
"curves": {
256: ec.SECP256R1(),
384: ec.SECP384R1(),
521: ec.SECP521R1(),
}
}
},
},
}
except ImportError:
HAS_OPENSSH_PRIVATE_FORMAT = False
@@ -80,7 +80,7 @@ except ImportError:
CRYPTOGRAPHY_VERSION = "0.0"
_ALGORITHM_PARAMETERS = {}
_TEXT_ENCODING = 'UTF-8'
_TEXT_ENCODING = "UTF-8"
class OpenSSHError(Exception):
@@ -131,26 +131,25 @@ class AsymmetricKeypair(object):
"""Container for newly generated asymmetric key pairs or those loaded from existing files"""
@classmethod
def generate(cls, keytype='rsa', size=None, passphrase=None):
def generate(cls, keytype="rsa", size=None, passphrase=None):
"""Returns an Asymmetric_Keypair object generated with the supplied parameters
or defaults to an unencrypted RSA-2048 key
or defaults to an unencrypted RSA-2048 key
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for newly generated keys
:passphrase: Secret of type Bytes used to encrypt the private key being generated
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for newly generated keys
:passphrase: Secret of type Bytes used to encrypt the private key being generated
"""
if keytype not in _ALGORITHM_PARAMETERS.keys():
raise InvalidKeyTypeError(
"%s is not a valid keytype. Valid keytypes are %s" % (
keytype, ", ".join(_ALGORITHM_PARAMETERS.keys())
)
"%s is not a valid keytype. Valid keytypes are %s"
% (keytype, ", ".join(_ALGORITHM_PARAMETERS.keys()))
)
if not size:
size = _ALGORITHM_PARAMETERS[keytype]['default_size']
size = _ALGORITHM_PARAMETERS[keytype]["default_size"]
else:
if size not in _ALGORITHM_PARAMETERS[keytype]['valid_sizes']:
if size not in _ALGORITHM_PARAMETERS[keytype]["valid_sizes"]:
raise InvalidKeySizeError(
"%s is not a valid key size for %s keys" % (size, keytype)
)
@@ -160,7 +159,7 @@ class AsymmetricKeypair(object):
else:
encryption_algorithm = serialization.NoEncryption()
if keytype == 'rsa':
if keytype == "rsa":
privatekey = rsa.generate_private_key(
# Public exponent should always be 65537 to prevent issues
# if improper padding is used during signing
@@ -168,16 +167,16 @@ class AsymmetricKeypair(object):
key_size=size,
backend=backend,
)
elif keytype == 'dsa':
elif keytype == "dsa":
privatekey = dsa.generate_private_key(
key_size=size,
backend=backend,
)
elif keytype == 'ed25519':
elif keytype == "ed25519":
privatekey = Ed25519PrivateKey.generate()
elif keytype == 'ecdsa':
elif keytype == "ecdsa":
privatekey = ec.generate_private_key(
_ALGORITHM_PARAMETERS['ecdsa']['curves'][size],
_ALGORITHM_PARAMETERS["ecdsa"]["curves"][size],
backend=backend,
)
@@ -188,18 +187,25 @@ class AsymmetricKeypair(object):
size=size,
privatekey=privatekey,
publickey=publickey,
encryption_algorithm=encryption_algorithm
encryption_algorithm=encryption_algorithm,
)
@classmethod
def load(cls, path, passphrase=None, private_key_format='PEM', public_key_format='PEM', no_public_key=False):
def load(
cls,
path,
passphrase=None,
private_key_format="PEM",
public_key_format="PEM",
no_public_key=False,
):
"""Returns an Asymmetric_Keypair object loaded from the supplied file path
:path: A path to an existing private key to be loaded
:passphrase: Secret of type bytes used to decrypt the private key being loaded
:private_key_format: Format of private key to be loaded
:public_key_format: Format of public key to be loaded
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
:path: A path to an existing private key to be loaded
:passphrase: Secret of type bytes used to decrypt the private key being loaded
:private_key_format: Format of private key to be loaded
:public_key_format: Format of public key to be loaded
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
"""
if passphrase:
@@ -211,40 +217,42 @@ class AsymmetricKeypair(object):
if no_public_key:
publickey = privatekey.public_key()
else:
publickey = load_publickey(path + '.pub', public_key_format)
publickey = load_publickey(path + ".pub", public_key_format)
# Ed25519 keys are always of size 256 and do not have a key_size attribute
if isinstance(privatekey, Ed25519PrivateKey):
size = _ALGORITHM_PARAMETERS['ed25519']['default_size']
size = _ALGORITHM_PARAMETERS["ed25519"]["default_size"]
else:
size = privatekey.key_size
if isinstance(privatekey, rsa.RSAPrivateKey):
keytype = 'rsa'
keytype = "rsa"
elif isinstance(privatekey, dsa.DSAPrivateKey):
keytype = 'dsa'
keytype = "dsa"
elif isinstance(privatekey, ec.EllipticCurvePrivateKey):
keytype = 'ecdsa'
keytype = "ecdsa"
elif isinstance(privatekey, Ed25519PrivateKey):
keytype = 'ed25519'
keytype = "ed25519"
else:
raise InvalidKeyTypeError("Key type '%s' is not supported" % type(privatekey))
raise InvalidKeyTypeError(
"Key type '%s' is not supported" % type(privatekey)
)
return cls(
keytype=keytype,
size=size,
privatekey=privatekey,
publickey=publickey,
encryption_algorithm=encryption_algorithm
encryption_algorithm=encryption_algorithm,
)
def __init__(self, keytype, size, privatekey, publickey, encryption_algorithm):
"""
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for the private key of this key pair
:privatekey: Private key object of this key pair
:publickey: Public key object of this key pair
:encryption_algorithm: Hashed secret used to encrypt the private key of this key pair
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for the private key of this key pair
:privatekey: Private key object of this key pair
:publickey: Public key object of this key pair
:encryption_algorithm: Hashed secret used to encrypt the private key of this key pair
"""
self.__size = size
@@ -254,7 +262,7 @@ class AsymmetricKeypair(object):
self.__encryption_algorithm = encryption_algorithm
try:
self.verify(self.sign(b'message'), b'message')
self.verify(self.sign(b"message"), b"message")
except InvalidSignatureError:
raise InvalidPublicKeyFileError(
"The private key and public key of this keypair do not match"
@@ -264,8 +272,11 @@ class AsymmetricKeypair(object):
if not isinstance(other, AsymmetricKeypair):
return NotImplemented
return (compare_publickeys(self.public_key, other.public_key) and
compare_encryption_algorithms(self.encryption_algorithm, other.encryption_algorithm))
return compare_publickeys(
self.public_key, other.public_key
) and compare_encryption_algorithms(
self.encryption_algorithm, other.encryption_algorithm
)
def __ne__(self, other):
return not self == other
@@ -303,13 +314,12 @@ class AsymmetricKeypair(object):
def sign(self, data):
"""Returns signature of data signed with the private key of this key pair
:data: byteslike data to sign
:data: byteslike data to sign
"""
try:
signature = self.__privatekey.sign(
data,
**_ALGORITHM_PARAMETERS[self.__keytype]['signer_params']
data, **_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"]
)
except TypeError as e:
raise InvalidDataError(e)
@@ -318,16 +328,16 @@ class AsymmetricKeypair(object):
def verify(self, signature, data):
"""Verifies that the signature associated with the provided data was signed
by the private key of this key pair.
by the private key of this key pair.
:signature: signature to verify
:data: byteslike data signed by the provided signature
:signature: signature to verify
:data: byteslike data signed by the provided signature
"""
try:
return self.__publickey.verify(
signature,
data,
**_ALGORITHM_PARAMETERS[self.__keytype]['signer_params']
**_ALGORITHM_PARAMETERS[self.__keytype]["signer_params"]
)
except InvalidSignature:
raise InvalidSignatureError
@@ -335,7 +345,7 @@ class AsymmetricKeypair(object):
def update_passphrase(self, passphrase=None):
"""Updates the encryption algorithm of this key pair
:passphrase: Byte secret used to encrypt this key pair
:passphrase: Byte secret used to encrypt this key pair
"""
if passphrase:
@@ -348,20 +358,20 @@ class OpensshKeypair(object):
"""Container for OpenSSH encoded asymmetric key pairs"""
@classmethod
def generate(cls, keytype='rsa', size=None, passphrase=None, comment=None):
def generate(cls, keytype="rsa", size=None, passphrase=None, comment=None):
"""Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for newly generated keys
:passphrase: Secret of type Bytes used to encrypt the newly generated private key
:comment: Comment for a newly generated OpenSSH public key
:keytype: One of rsa, dsa, ecdsa, ed25519
:size: The key length for newly generated keys
:passphrase: Secret of type Bytes used to encrypt the newly generated private key
:comment: Comment for a newly generated OpenSSH public key
"""
if comment is None:
comment = "%s@%s" % (getuser(), gethostname())
asym_keypair = AsymmetricKeypair.generate(keytype, size, passphrase)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, 'SSH')
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
fingerprint = calculate_fingerprint(openssh_publickey)
@@ -377,18 +387,20 @@ class OpensshKeypair(object):
def load(cls, path, passphrase=None, no_public_key=False):
"""Returns an Openssh_Keypair object loaded from the supplied file path
:path: A path to an existing private key to be loaded
:passphrase: Secret used to decrypt the private key being loaded
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
:path: A path to an existing private key to be loaded
:passphrase: Secret used to decrypt the private key being loaded
:no_public_key: Set 'True' to only load a private key and automatically populate the matching public key
"""
if no_public_key:
comment = ""
else:
comment = extract_comment(path + '.pub')
comment = extract_comment(path + ".pub")
asym_keypair = AsymmetricKeypair.load(path, passphrase, 'SSH', 'SSH', no_public_key)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, 'SSH')
asym_keypair = AsymmetricKeypair.load(
path, passphrase, "SSH", "SSH", no_public_key
)
openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair, "SSH")
openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment)
fingerprint = calculate_fingerprint(openssh_publickey)
@@ -404,29 +416,33 @@ class OpensshKeypair(object):
def encode_openssh_privatekey(asym_keypair, key_format):
"""Returns an OpenSSH encoded private key for a given keypair
:asym_keypair: Asymmetric_Keypair from the private key is extracted
:key_format: Format of the encoded private key.
:asym_keypair: Asymmetric_Keypair from the private key is extracted
:key_format: Format of the encoded private key.
"""
if key_format == 'SSH':
if key_format == "SSH":
# Default to PEM format if SSH not available
if not HAS_OPENSSH_PRIVATE_FORMAT:
privatekey_format = serialization.PrivateFormat.PKCS8
else:
privatekey_format = serialization.PrivateFormat.OpenSSH
elif key_format == 'PKCS8':
elif key_format == "PKCS8":
privatekey_format = serialization.PrivateFormat.PKCS8
elif key_format == 'PKCS1':
if asym_keypair.key_type == 'ed25519':
raise InvalidKeyFormatError("ed25519 keys cannot be represented in PKCS1 format")
elif key_format == "PKCS1":
if asym_keypair.key_type == "ed25519":
raise InvalidKeyFormatError(
"ed25519 keys cannot be represented in PKCS1 format"
)
privatekey_format = serialization.PrivateFormat.TraditionalOpenSSL
else:
raise InvalidKeyFormatError("The accepted private key formats are SSH, PKCS8, and PKCS1")
raise InvalidKeyFormatError(
"The accepted private key formats are SSH, PKCS8, and PKCS1"
)
encoded_privatekey = asym_keypair.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=privatekey_format,
encryption_algorithm=asym_keypair.encryption_algorithm
encryption_algorithm=asym_keypair.encryption_algorithm,
)
return encoded_privatekey
@@ -435,8 +451,8 @@ class OpensshKeypair(object):
def encode_openssh_publickey(asym_keypair, comment):
"""Returns an OpenSSH encoded public key for a given keypair
:asym_keypair: Asymmetric_Keypair from the public key is extracted
:comment: Comment to apply to the end of the returned OpenSSH encoded public key
:asym_keypair: Asymmetric_Keypair from the public key is extracted
:comment: Comment to apply to the end of the returned OpenSSH encoded public key
"""
encoded_publickey = asym_keypair.public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
@@ -445,17 +461,21 @@ class OpensshKeypair(object):
validate_comment(comment)
encoded_publickey += (" %s" % comment).encode(encoding=_TEXT_ENCODING) if comment else b''
encoded_publickey += (
(" %s" % comment).encode(encoding=_TEXT_ENCODING) if comment else b""
)
return encoded_publickey
def __init__(self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment):
def __init__(
self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment
):
"""
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
:openssh_privatekey: An OpenSSH encoded private key
:openssh_privatekey: An OpenSSH encoded public key
:fingerprint: The fingerprint of the OpenSSH encoded public key of this keypair
:comment: Comment applied to the OpenSSH public key of this keypair
:asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived
:openssh_privatekey: An OpenSSH encoded private key
:openssh_privatekey: An OpenSSH encoded public key
:fingerprint: The fingerprint of the OpenSSH encoded public key of this keypair
:comment: Comment applied to the OpenSSH public key of this keypair
"""
self.__asym_keypair = asym_keypair
@@ -468,7 +488,10 @@ class OpensshKeypair(object):
if not isinstance(other, OpensshKeypair):
return NotImplemented
return self.asymmetric_keypair == other.asymmetric_keypair and self.comment == other.comment
return (
self.asymmetric_keypair == other.asymmetric_keypair
and self.comment == other.comment
)
@property
def asymmetric_keypair(self):
@@ -516,53 +539,59 @@ class OpensshKeypair(object):
def comment(self, comment):
"""Updates the comment applied to the OpenSSH formatted public key of this key pair
:comment: Text to update the OpenSSH public key comment
:comment: Text to update the OpenSSH public key comment
"""
validate_comment(comment)
self.__comment = comment
encoded_comment = (" %s" % self.__comment).encode(encoding=_TEXT_ENCODING) if self.__comment else b''
self.__openssh_publickey = b' '.join(self.__openssh_publickey.split(b' ', 2)[:2]) + encoded_comment
encoded_comment = (
(" %s" % self.__comment).encode(encoding=_TEXT_ENCODING)
if self.__comment
else b""
)
self.__openssh_publickey = (
b" ".join(self.__openssh_publickey.split(b" ", 2)[:2]) + encoded_comment
)
return self.__openssh_publickey
def update_passphrase(self, passphrase):
"""Updates the passphrase used to encrypt the private key of this keypair
:passphrase: Text secret used for encryption
:passphrase: Text secret used for encryption
"""
self.__asym_keypair.update_passphrase(passphrase)
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(self.__asym_keypair, 'SSH')
self.__openssh_privatekey = OpensshKeypair.encode_openssh_privatekey(
self.__asym_keypair, "SSH"
)
def load_privatekey(path, passphrase, key_format):
privatekey_loaders = {
'PEM': serialization.load_pem_private_key,
'DER': serialization.load_der_private_key,
"PEM": serialization.load_pem_private_key,
"DER": serialization.load_der_private_key,
}
# OpenSSH formatted private keys are not available in Cryptography <3.0
if hasattr(serialization, 'load_ssh_private_key'):
privatekey_loaders['SSH'] = serialization.load_ssh_private_key
if hasattr(serialization, "load_ssh_private_key"):
privatekey_loaders["SSH"] = serialization.load_ssh_private_key
else:
privatekey_loaders['SSH'] = serialization.load_pem_private_key
privatekey_loaders["SSH"] = serialization.load_pem_private_key
try:
privatekey_loader = privatekey_loaders[key_format]
except KeyError:
raise InvalidKeyFormatError(
"%s is not a valid key format (%s)" % (
key_format,
','.join(privatekey_loaders.keys())
)
"%s is not a valid key format (%s)"
% (key_format, ",".join(privatekey_loaders.keys()))
)
if not os.path.exists(path):
raise InvalidPrivateKeyFileError("No file was found at %s" % path)
try:
with open(path, 'rb') as f:
with open(path, "rb") as f:
content = f.read()
privatekey = privatekey_loader(
@@ -573,9 +602,9 @@ def load_privatekey(path, passphrase, key_format):
except ValueError as e:
# Revert to PEM if key could not be loaded in SSH format
if key_format == 'SSH':
if key_format == "SSH":
try:
privatekey = privatekey_loaders['PEM'](
privatekey = privatekey_loaders["PEM"](
data=content,
password=passphrase,
backend=backend,
@@ -598,26 +627,24 @@ def load_privatekey(path, passphrase, key_format):
def load_publickey(path, key_format):
publickey_loaders = {
'PEM': serialization.load_pem_public_key,
'DER': serialization.load_der_public_key,
'SSH': serialization.load_ssh_public_key,
"PEM": serialization.load_pem_public_key,
"DER": serialization.load_der_public_key,
"SSH": serialization.load_ssh_public_key,
}
try:
publickey_loader = publickey_loaders[key_format]
except KeyError:
raise InvalidKeyFormatError(
"%s is not a valid key format (%s)" % (
key_format,
','.join(publickey_loaders.keys())
)
"%s is not a valid key format (%s)"
% (key_format, ",".join(publickey_loaders.keys()))
)
if not os.path.exists(path):
raise InvalidPublicKeyFileError("No file was found at %s" % path)
try:
with open(path, 'rb') as f:
with open(path, "rb") as f:
content = f.read()
publickey = publickey_loader(
@@ -646,10 +673,13 @@ def compare_publickeys(pk1, pk2):
def compare_encryption_algorithms(ea1, ea2):
if isinstance(ea1, serialization.NoEncryption) and isinstance(ea2, serialization.NoEncryption):
if isinstance(ea1, serialization.NoEncryption) and isinstance(
ea2, serialization.NoEncryption
):
return True
elif (isinstance(ea1, serialization.BestAvailableEncryption) and
isinstance(ea2, serialization.BestAvailableEncryption)):
elif isinstance(ea1, serialization.BestAvailableEncryption) and isinstance(
ea2, serialization.BestAvailableEncryption
):
return ea1.password == ea2.password
else:
return False
@@ -663,7 +693,7 @@ def get_encryption_algorithm(passphrase):
def validate_comment(comment):
if not hasattr(comment, 'encode'):
if not hasattr(comment, "encode"):
raise InvalidCommentError("%s cannot be encoded to text" % comment)
@@ -673,8 +703,8 @@ def extract_comment(path):
raise InvalidPublicKeyFileError("No file was found at %s" % path)
try:
with open(path, 'rb') as f:
fields = f.read().split(b' ', 2)
with open(path, "rb") as f:
fields = f.read().split(b" ", 2)
if len(fields) == 3:
comment = fields[2].decode(_TEXT_ENCODING)
else:
@@ -687,7 +717,9 @@ def extract_comment(path):
def calculate_fingerprint(openssh_publickey):
digest = hashes.Hash(hashes.SHA256(), backend=backend)
decoded_pubkey = b64decode(openssh_publickey.split(b' ')[1])
decoded_pubkey = b64decode(openssh_publickey.split(b" ")[1])
digest.update(decoded_pubkey)
return 'SHA256:%s' % b64encode(digest.finalize()).decode(encoding=_TEXT_ENCODING).rstrip('=')
return "SHA256:%s" % b64encode(digest.finalize()).decode(
encoding=_TEXT_ENCODING
).rstrip("=")

View File

@@ -34,17 +34,17 @@ if PY3:
long = int
# 0 (False) or 1 (True) encoded as a single byte
_BOOLEAN = Struct(b'?')
_BOOLEAN = Struct(b"?")
# Unsigned 8-bit integer in network-byte-order
_UBYTE = Struct(b'!B')
_UBYTE = Struct(b"!B")
_UBYTE_MAX = 0xFF
# Unsigned 32-bit integer in network-byte-order
_UINT32 = Struct(b'!I')
_UINT32 = Struct(b"!I")
# Unsigned 32-bit little endian integer
_UINT32_LE = Struct(b'<I')
_UINT32_LE = Struct(b"<I")
_UINT32_MAX = 0xFFFFFFFF
# Unsigned 64-bit integer in network-byte-order
_UINT64 = Struct(b'!Q')
_UINT64 = Struct(b"!Q")
_UINT64_MAX = 0xFFFFFFFFFFFFFFFF
@@ -89,6 +89,7 @@ def secure_write(path, mode, content):
# See https://datatracker.ietf.org/doc/html/rfc4251#section-5 for SSH data types
class OpensshParser(object):
"""Parser for OpenSSH encoded objects"""
BOOLEAN_OFFSET = 1
UINT32_OFFSET = 4
UINT64_OFFSET = 8
@@ -103,21 +104,21 @@ class OpensshParser(object):
def boolean(self):
next_pos = self._check_position(self.BOOLEAN_OFFSET)
value = _BOOLEAN.unpack(self._data[self._pos:next_pos])[0]
value = _BOOLEAN.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def uint32(self):
next_pos = self._check_position(self.UINT32_OFFSET)
value = _UINT32.unpack(self._data[self._pos:next_pos])[0]
value = _UINT32.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
def uint64(self):
next_pos = self._check_position(self.UINT64_OFFSET)
value = _UINT64.unpack(self._data[self._pos:next_pos])[0]
value = _UINT64.unpack(self._data[self._pos : next_pos])[0]
self._pos = next_pos
return value
@@ -126,7 +127,7 @@ class OpensshParser(object):
next_pos = self._check_position(length)
value = self._data[self._pos:next_pos]
value = self._data[self._pos : next_pos]
self._pos = next_pos
# Cast to bytes is required as a memoryview slice is itself a memoryview
return value if not PY3 else bytes(value)
@@ -136,7 +137,7 @@ class OpensshParser(object):
def name_list(self):
raw_string = self.string()
return raw_string.decode('ASCII').split(',')
return raw_string.decode("ASCII").split(",")
# Convenience function, but not an official data type from SSH
def string_list(self):
@@ -193,33 +194,39 @@ class OpensshParser(object):
signature_blob = parser.string()
blob_parser = cls(signature_blob)
if signature_type in (b'ssh-rsa', b'rsa-sha2-256', b'rsa-sha2-512'):
if signature_type in (b"ssh-rsa", b"rsa-sha2-256", b"rsa-sha2-512"):
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
# https://datatracker.ietf.org/doc/html/rfc8332#section-3
signature_data['s'] = cls._big_int(signature_blob, "big")
elif signature_type == b'ssh-dss':
signature_data["s"] = cls._big_int(signature_blob, "big")
elif signature_type == b"ssh-dss":
# https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
signature_data['r'] = cls._big_int(signature_blob[:20], "big")
signature_data['s'] = cls._big_int(signature_blob[20:], "big")
elif signature_type in (b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521'):
signature_data["r"] = cls._big_int(signature_blob[:20], "big")
signature_data["s"] = cls._big_int(signature_blob[20:], "big")
elif signature_type in (
b"ecdsa-sha2-nistp256",
b"ecdsa-sha2-nistp384",
b"ecdsa-sha2-nistp521",
):
# https://datatracker.ietf.org/doc/html/rfc5656#section-3.1.2
signature_data['r'] = blob_parser.mpint()
signature_data['s'] = blob_parser.mpint()
elif signature_type == b'ssh-ed25519':
signature_data["r"] = blob_parser.mpint()
signature_data["s"] = blob_parser.mpint()
elif signature_type == b"ssh-ed25519":
# https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.2
signature_data['R'] = cls._big_int(signature_blob[:32], "little")
signature_data['S'] = cls._big_int(signature_blob[32:], "little")
signature_data["R"] = cls._big_int(signature_blob[:32], "little")
signature_data["S"] = cls._big_int(signature_blob[32:], "little")
else:
raise ValueError("%s is not a valid signature type" % signature_type)
signature_data['signature_type'] = signature_type
signature_data["signature_type"] = signature_type
return signature_data
@classmethod
def _big_int(cls, raw_string, byte_order, signed=False):
if byte_order not in ("big", "little"):
raise ValueError("Byte_order must be one of (big, little) not %s" % byte_order)
raise ValueError(
"Byte_order must be one of (big, little) not %s" % byte_order
)
if PY3:
return int.from_bytes(raw_string, byte_order, signed=signed)
@@ -232,21 +239,31 @@ class OpensshParser(object):
msb = raw_string[0] if byte_order == "big" else raw_string[-1]
negative = bool(ord(msb) & 0x80)
# Match pad value for two's complement
pad = b'\xFF' if signed and negative else b'\x00'
pad = b"\xff" if signed and negative else b"\x00"
# The definition of ``mpint`` enforces that unnecessary bytes are not encoded so they are added back
pad_length = (4 - byte_length % 4)
pad_length = 4 - byte_length % 4
if pad_length < 4:
raw_string = pad * pad_length + raw_string if byte_order == "big" else raw_string + pad * pad_length
raw_string = (
pad * pad_length + raw_string
if byte_order == "big"
else raw_string + pad * pad_length
)
byte_length += pad_length
# Accumulate arbitrary precision integer bytes in the appropriate order
if byte_order == "big":
for i in range(0, byte_length, cls.UINT32_OFFSET):
left_shift = result << cls.UINT32_OFFSET * 8
result = left_shift + _UINT32.unpack(raw_string[i:i + cls.UINT32_OFFSET])[0]
result = (
left_shift
+ _UINT32.unpack(raw_string[i : i + cls.UINT32_OFFSET])[0]
)
else:
for i in range(byte_length, 0, -cls.UINT32_OFFSET):
left_shift = result << cls.UINT32_OFFSET * 8
result = left_shift + _UINT32_LE.unpack(raw_string[i - cls.UINT32_OFFSET:i])[0]
result = (
left_shift
+ _UINT32_LE.unpack(raw_string[i - cls.UINT32_OFFSET : i])[0]
)
# Adjust for two's complement
if signed and negative:
result -= 1 << (8 * byte_length)
@@ -262,10 +279,13 @@ class _OpensshWriter(object):
It is not to be used to construct Openssh objects, but rather as a utility to assist
in validating parsed material.
"""
def __init__(self, buffer=None):
if buffer is not None:
if not isinstance(buffer, (bytes, bytearray)):
raise TypeError("Buffer must be a bytes-like object not %s" % type(buffer))
raise TypeError(
"Buffer must be a bytes-like object not %s" % type(buffer)
)
else:
buffer = bytearray()
@@ -283,7 +303,9 @@ class _OpensshWriter(object):
if not isinstance(value, int):
raise TypeError("Value must be of type int not %s" % type(value))
if value < 0 or value > _UINT32_MAX:
raise ValueError("Value must be a positive integer less than %s" % _UINT32_MAX)
raise ValueError(
"Value must be a positive integer less than %s" % _UINT32_MAX
)
self._buff.extend(_UINT32.pack(value))
@@ -293,7 +315,9 @@ class _OpensshWriter(object):
if not isinstance(value, (long, int)):
raise TypeError("Value must be of type (long, int) not %s" % type(value))
if value < 0 or value > _UINT64_MAX:
raise ValueError("Value must be a positive integer less than %s" % _UINT64_MAX)
raise ValueError(
"Value must be a positive integer less than %s" % _UINT64_MAX
)
self._buff.extend(_UINT64.pack(value))
@@ -320,7 +344,7 @@ class _OpensshWriter(object):
raise TypeError("Value must be a list of byte strings not %s" % type(value))
try:
self.string(','.join(value).encode('ASCII'))
self.string(",".join(value).encode("ASCII"))
except UnicodeEncodeError as e:
raise ValueError("Name-list's must consist of US-ASCII characters: %s" % e)
@@ -365,9 +389,9 @@ class _OpensshWriter(object):
result = bytes()
# 0 and -1 are treated as special cases since they are used as sentinels for all other values
if num == 0:
result += b'\x00'
result += b"\x00"
elif num == -1:
result += b'\xFF'
result += b"\xff"
elif num > 0:
while num >> 32:
result = _UINT32.pack(num & _UINT32_MAX) + result
@@ -378,7 +402,7 @@ class _OpensshWriter(object):
num = num >> 8
# Zero pad final byte if most-significant bit is 1 as per mpint definition
if ord(result[0]) & 0x80:
result = b'\x00' + result
result = b"\x00" + result
else:
while (num >> 32) < -1:
result = _UINT32.pack(num & _UINT32_MAX) + result
@@ -387,7 +411,7 @@ class _OpensshWriter(object):
result = _UBYTE.pack(num & _UBYTE_MAX) + result
num = num >> 8
if not ord(result[0]) & 0x80:
result = b'\xFF' + result
result = b"\xff" + result
return result

View File

@@ -21,12 +21,12 @@ def th(number):
mod_100 = abs_number % 100
if mod_100 not in (11, 12, 13):
if mod_10 == 1:
return 'st'
return "st"
if mod_10 == 2:
return 'nd'
return "nd"
if mod_10 == 3:
return 'rd'
return 'th'
return "rd"
return "th"
def parse_serial(value):
@@ -35,14 +35,17 @@ def parse_serial(value):
"""
value = to_native(value)
result = 0
for i, part in enumerate(value.split(':')):
for i, part in enumerate(value.split(":")):
try:
part_value = int(part, 16)
if part_value < 0 or part_value > 255:
raise ValueError('the value is not in range [0, 255]')
raise ValueError("the value is not in range [0, 255]")
except ValueError as exc:
raise ValueError("The {idx}{th} part {part!r} is not a hexadecimal number in range [0, 255]: {exc}".format(
idx=i + 1, th=th(i + 1), part=part, exc=exc))
raise ValueError(
"The {idx}{th} part {part!r} is not a hexadecimal number in range [0, 255]: {exc}".format(
idx=i + 1, th=th(i + 1), part=part, exc=exc
)
)
result = (result << 8) | part_value
return result
@@ -53,5 +56,5 @@ def to_serial(value):
"""
value = convert_int_to_hex(value).upper()
if len(value) % 2 != 0:
value = '0' + value
return ':'.join(value[i:i + 2] for i in range(0, len(value), 2))
value = "0" + value
return ":".join(value[i : i + 2] for i in range(0, len(value), 2))

View File

@@ -33,13 +33,13 @@ except AttributeError:
return _DURATION_ZERO
def tzname(self, dt):
return 'UTC'
return "UTC"
def fromutc(self, dt):
return dt
def __repr__(self):
return 'UTC'
return "UTC"
UTC = _UTCClass()
@@ -69,20 +69,29 @@ def remove_timezone(timestamp):
def add_or_remove_timezone(timestamp, with_timezone):
return ensure_utc_timezone(timestamp) if with_timezone else remove_timezone(timestamp)
return (
ensure_utc_timezone(timestamp) if with_timezone else remove_timezone(timestamp)
)
if sys.version_info < (3, 3):
def get_epoch_seconds(timestamp):
epoch = datetime.datetime(1970, 1, 1, tzinfo=UTC if timestamp.tzinfo is not None else None)
epoch = datetime.datetime(
1970, 1, 1, tzinfo=UTC if timestamp.tzinfo is not None else None
)
delta = timestamp - epoch
try:
return delta.total_seconds()
except AttributeError:
# Python 2.6 and earlier: total_seconds() does not yet exist, so we use the formula from
# https://docs.python.org/2/library/datetime.html#datetime.timedelta.total_seconds
return (delta.microseconds + (delta.seconds + delta.days * 24 * 3600) * 10**6) / 10**6
return (
delta.microseconds + (delta.seconds + delta.days * 24 * 3600) * 10**6
) / 10**6
else:
def get_epoch_seconds(timestamp):
if timestamp.tzinfo is None:
# timestamp.timestamp() is offset by the local timezone if timestamp has no timezone
@@ -101,7 +110,8 @@ def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=
parsed_result = re.match(
r"^(?P<prefix>[+-])((?P<weeks>\d+)[wW])?((?P<days>\d+)[dD])?((?P<hours>\d+)[hH])?((?P<minutes>\d+)[mM])?((?P<seconds>\d+)[sS]?)?$",
relative_time_string)
relative_time_string,
)
if parsed_result is None or len(relative_time_string) == 1:
# not matched or only a single "+" or "-"
@@ -115,11 +125,9 @@ def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=
if parsed_result.group("hours") is not None:
offset += datetime.timedelta(hours=int(parsed_result.group("hours")))
if parsed_result.group("minutes") is not None:
offset += datetime.timedelta(
minutes=int(parsed_result.group("minutes")))
offset += datetime.timedelta(minutes=int(parsed_result.group("minutes")))
if parsed_result.group("seconds") is not None:
offset += datetime.timedelta(
seconds=int(parsed_result.group("seconds")))
offset += datetime.timedelta(seconds=int(parsed_result.group("seconds")))
if now is None:
now = get_now_datetime(with_timezone=with_timezone)
@@ -132,33 +140,43 @@ def convert_relative_to_datetime(relative_time_string, with_timezone=False, now=
return now - offset
def get_relative_time_option(input_string, input_name, backend='cryptography', with_timezone=False, now=None):
def get_relative_time_option(
input_string, input_name, backend="cryptography", with_timezone=False, now=None
):
"""Return an absolute timespec if a relative timespec or an ASN1 formatted
string is provided.
string is provided.
The return value will be a datetime object for the cryptography backend,
and a ASN1 formatted string for the pyopenssl backend."""
The return value will be a datetime object for the cryptography backend,
and a ASN1 formatted string for the pyopenssl backend."""
result = to_native(input_string)
if result is None:
raise OpenSSLObjectError(
'The timespec "%s" for %s is not valid' %
input_string, input_name)
'The timespec "%s" for %s is not valid' % input_string, input_name
)
# Relative time
if result.startswith("+") or result.startswith("-"):
result_datetime = convert_relative_to_datetime(result, with_timezone=with_timezone, now=now)
if backend == 'pyopenssl':
result_datetime = convert_relative_to_datetime(
result, with_timezone=with_timezone, now=now
)
if backend == "pyopenssl":
return result_datetime.strftime("%Y%m%d%H%M%SZ")
elif backend == 'cryptography':
elif backend == "cryptography":
return result_datetime
# Absolute time
if backend == 'pyopenssl':
if backend == "pyopenssl":
return input_string
elif backend == 'cryptography':
elif backend == "cryptography":
for date_fmt, length in [
('%Y%m%d%H%M%SZ', 15), # this also parses '202401020304Z', but as datetime(2024, 1, 2, 3, 0, 4)
('%Y%m%d%H%MZ', 13),
('%Y%m%d%H%M%S%z', 14 + 5), # this also parses '202401020304+0000', but as datetime(2024, 1, 2, 3, 0, 4, tzinfo=...)
('%Y%m%d%H%M%z', 12 + 5),
(
"%Y%m%d%H%M%SZ",
15,
), # this also parses '202401020304Z', but as datetime(2024, 1, 2, 3, 0, 4)
("%Y%m%d%H%MZ", 13),
(
"%Y%m%d%H%M%S%z",
14 + 5,
), # this also parses '202401020304+0000', but as datetime(2024, 1, 2, 3, 0, 4, tzinfo=...)
("%Y%m%d%H%M%z", 12 + 5),
]:
if len(result) != length:
continue
@@ -170,6 +188,5 @@ def get_relative_time_option(input_string, input_name, backend='cryptography', w
return add_or_remove_timezone(res, with_timezone=with_timezone)
raise OpenSSLObjectError(
'The time spec "%s" for %s is invalid' %
(input_string, input_name)
'The time spec "%s" for %s is invalid' % (input_string, input_name)
)