Make all module_utils and plugin_utils private (#887)

* Add leading underscore. Remove deprecated module utils.

* Document module and plugin utils as private. Add changelog fragment.

* Convert relative to absolute imports.

* Remove unnecessary imports.
This commit is contained in:
Felix Fontein
2025-05-11 19:17:58 +02:00
committed by GitHub
parent f758d94fba
commit a5a4e022ba
146 changed files with 678 additions and 465 deletions

View File

@@ -0,0 +1,161 @@
# Copyright (c) 2020, Jordan Borean <jborean93@gmail.com>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import enum
import re
from ansible.module_utils.common.text.converters import to_bytes
"""
An ASN.1 serialized as a string in the OpenSSL format:
[modifier,]type[:value]
modifier:
The modifier can be 'IMPLICIT:<tag_number><tag_class>,' or 'EXPLICIT:<tag_number><tag_class>' where IMPLICIT
changes the tag of the universal value to encode and EXPLICIT prefixes its tag to the existing universal value.
The tag_number must be set while the tag_class can be 'U', 'A', 'P', or 'C" for 'Universal', 'Application',
'Private', or 'Context Specific' with C being the default.
type:
The underlying ASN.1 type of the value specified. Currently only the following have been implemented:
UTF8: The value must be a UTF-8 encoded string.
value:
The value to encode, the format of this value depends on the <type> specified.
"""
ASN1_STRING_REGEX = re.compile(
r"^((?P<tag_type>IMPLICIT|EXPLICIT):(?P<tag_number>\d+)(?P<tag_class>U|A|P|C)?,)?"
r"(?P<value_type>[\w\d]+):(?P<value>.*)"
)
class TagClass(enum.Enum):
universal = 0
application = 1
context_specific = 2
private = 3
# Universal tag numbers that can be encoded.
class TagNumber(enum.Enum):
utf8_string = 12
def _pack_octet_integer(value: int) -> bytes:
"""Packs an integer value into 1 or multiple octets."""
# NOTE: This is *NOT* the same as packing an ASN.1 INTEGER like value.
octets = bytearray()
# Continue to shift the number by 7 bits and pack into an octet until the
# value is fully packed.
while value:
octet_value = value & 0b01111111
# First round (last octet) must have the MSB set.
if len(octets):
octet_value |= 0b10000000
octets.append(octet_value)
value >>= 7
# Reverse to ensure the higher order octets are first.
octets.reverse()
return bytes(octets)
def serialize_asn1_string_as_der(value: str) -> bytes:
"""Deserializes an ASN.1 string to a DER encoded byte string."""
asn1_match = ASN1_STRING_REGEX.match(value)
if not asn1_match:
raise ValueError(
"The ASN.1 serialized string must be in the format [modifier,]type[:value]"
)
tag_type = asn1_match.group("tag_type")
tag_number = asn1_match.group("tag_number")
tag_class = asn1_match.group("tag_class") or "C"
value_type = asn1_match.group("value_type")
asn1_value = asn1_match.group("value")
if value_type != "UTF8":
raise ValueError(
f'The ASN.1 serialized string is not a known type "{value_type}", only UTF8 types are supported'
)
b_value = to_bytes(asn1_value, encoding="utf-8", errors="surrogate_or_strict")
# We should only do a universal type tag if not IMPLICITLY tagged or the tag class is not universal.
if not tag_type or (tag_type == "EXPLICIT" and tag_class != "U"):
b_value = pack_asn1(TagClass.universal, False, TagNumber.utf8_string, b_value)
if tag_type:
tag_class_enum = {
"U": TagClass.universal,
"A": TagClass.application,
"P": TagClass.private,
"C": TagClass.context_specific,
}[tag_class]
# When adding support for more types this should be looked into further. For now it works with UTF8Strings.
constructed = tag_type == "EXPLICIT" and tag_class_enum != TagClass.universal
b_value = pack_asn1(tag_class_enum, constructed, int(tag_number), b_value)
return b_value
def pack_asn1(
tag_class: TagClass, constructed: bool, tag_number: TagNumber | int, b_data: bytes
) -> bytes:
"""Pack the value into an ASN.1 data structure.
The structure for an ASN.1 element is
| Identifier Octet(s) | Length Octet(s) | Data Octet(s) |
"""
b_asn1_data = bytearray()
# Bit 8 and 7 denotes the class.
identifier_octets = tag_class.value << 6
# Bit 6 denotes whether the value is primitive or constructed.
identifier_octets |= (1 if constructed else 0) << 5
# Bits 5-1 contain the tag number, if it cannot be encoded in these 5 bits
# then they are set and another octet(s) is used to denote the tag number.
if isinstance(tag_number, TagNumber):
tag_number = tag_number.value
if tag_number < 31:
identifier_octets |= tag_number
b_asn1_data.append(identifier_octets)
else:
identifier_octets |= 31
b_asn1_data.append(identifier_octets)
b_asn1_data.extend(_pack_octet_integer(tag_number))
length = len(b_data)
# If the length can be encoded in 7 bits only 1 octet is required.
if length < 128:
b_asn1_data.append(length)
else:
# Otherwise the length must be encoded across multiple octets
length_octets = bytearray()
while length:
length_octets.append(length & 0b11111111)
length >>= 8
length_octets.reverse() # Reverse to make the higher octets first.
# The first length octet must have the MSB set alongside the number of
# octets the length was encoded in.
b_asn1_data.append(len(length_octets) | 0b10000000)
b_asn1_data.extend(length_octets)
return bytes(b_asn1_data) + b_data

View File

@@ -0,0 +1,60 @@
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is licensed under the
# Apache 2.0 License. Modules you write using this snippet, which is embedded
# dynamically by Ansible, still belong to the author of the module, and may assign
# their own license to the complete work.
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
# This excerpt is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file at
# https://github.com/pyca/cryptography/blob/master/LICENSE for complete details.
#
# The Apache 2.0 license has been included as LICENSES/Apache-2.0.txt in this collection.
# The BSD License license has been included as LICENSES/BSD-3-Clause.txt in this collection.
# SPDX-License-Identifier: Apache-2.0 OR BSD-3-Clause
#
# Adapted from cryptography's hazmat/backends/openssl/decode_asn1.py
#
# Copyright (c) 2015, 2016 Paul Kehrer (@reaperhulk)
# Copyright (c) 2017 Fraser Tweedale (@frasertweedale)
# Relevant commits from cryptography project (https://github.com/pyca/cryptography):
# pyca/cryptography@719d536dd691e84e208534798f2eb4f82aaa2e07
# pyca/cryptography@5ab6d6a5c05572bd1c75f05baf264a2d0001894a
# pyca/cryptography@2e776e20eb60378e0af9b7439000d0e80da7c7e3
# pyca/cryptography@fb309ed24647d1be9e319b61b1f2aa8ebb87b90b
# pyca/cryptography@2917e460993c475c72d7146c50dc3bbc2414280d
# pyca/cryptography@3057f91ea9a05fb593825006d87a391286a4d828
# pyca/cryptography@d607dd7e5bc5c08854ec0c9baff70ba4a35be36f
from __future__ import annotations
# WARNING: this function no longer works with cryptography 35.0.0 and newer!
# It must **ONLY** be used in compatibility code for older
# cryptography versions!
def obj2txt(openssl_lib, openssl_ffi, obj) -> str:
# Set to 80 on the recommendation of
# https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values
#
# But OIDs longer than this occur in real life (e.g. Active
# Directory makes some very long OIDs). So we need to detect
# and properly handle the case where the default buffer is not
# big enough.
#
buf_len = 80
buf = openssl_ffi.new("char[]", buf_len)
# 'res' is the number of bytes that *would* be written if the
# buffer is large enough. If 'res' > buf_len - 1, we need to
# alloc a big-enough buffer and go again.
res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1)
if res > buf_len - 1: # account for terminating null byte
buf_len = res + 1
buf = openssl_ffi.new("char[]", buf_len)
res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1)
return openssl_ffi.buffer(buf, res)[:].decode()

View File

@@ -0,0 +1,35 @@
# Copyright (c) 2019, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
from ansible_collections.community.crypto.plugins.module_utils._crypto._objects_data import (
OID_MAP,
)
OID_LOOKUP: dict[str, str] = dict()
NORMALIZE_NAMES: dict[str, str] = dict()
NORMALIZE_NAMES_SHORT: dict[str, str] = dict()
for dotted, names in OID_MAP.items():
for name in names:
if name in NORMALIZE_NAMES and OID_LOOKUP[name] != dotted:
raise AssertionError(
f'Name collision during setup: "{name}" for OIDs {dotted} and {OID_LOOKUP[name]}'
)
NORMALIZE_NAMES[name] = names[0]
NORMALIZE_NAMES_SHORT[name] = names[-1]
OID_LOOKUP[name] = dotted
for alias, original in [("userID", "userId")]:
if alias in NORMALIZE_NAMES:
raise AssertionError(
f'Name collision during adding aliases: "{alias}" (alias for "{original}") is already mapped to OID {OID_LOOKUP[alias]}'
)
NORMALIZE_NAMES[alias] = original
NORMALIZE_NAMES_SHORT[alias] = NORMALIZE_NAMES_SHORT[original]
OID_LOOKUP[alias] = OID_LOOKUP[original]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,26 @@
# Copyright (c) 2016, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2020, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
try:
import cryptography # noqa: F401, pylint: disable=unused-import
HAS_CRYPTOGRAPHY = True
except ImportError:
# Error handled in the calling module.
HAS_CRYPTOGRAPHY = False
class OpenSSLObjectError(Exception):
pass
class OpenSSLBadPassphraseError(OpenSSLObjectError):
pass

View File

@@ -0,0 +1,191 @@
# Copyright (c) 2019, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import typing as t
from ansible_collections.community.crypto.plugins.module_utils._version import (
LooseVersion as _LooseVersion,
)
try:
import cryptography
from cryptography import x509
except ImportError:
# Error handled in the calling module.
pass
from ansible_collections.community.crypto.plugins.module_utils._crypto._obj2txt import (
obj2txt,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.basic import (
HAS_CRYPTOGRAPHY,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
CRYPTOGRAPHY_TIMEZONE,
cryptography_decode_name,
)
if t.TYPE_CHECKING:
import datetime
# TODO: once cryptography has a _utc variant of InvalidityDate.invalidity_date, set this
# to True and adjust get_invalidity_date() accordingly.
# (https://github.com/pyca/cryptography/issues/10818)
CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE = False
if HAS_CRYPTOGRAPHY:
CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE = _LooseVersion(
cryptography.__version__
) >= _LooseVersion("43.0.0")
TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
if HAS_CRYPTOGRAPHY:
REVOCATION_REASON_MAP = {
"unspecified": x509.ReasonFlags.unspecified,
"key_compromise": x509.ReasonFlags.key_compromise,
"ca_compromise": x509.ReasonFlags.ca_compromise,
"affiliation_changed": x509.ReasonFlags.affiliation_changed,
"superseded": x509.ReasonFlags.superseded,
"cessation_of_operation": x509.ReasonFlags.cessation_of_operation,
"certificate_hold": x509.ReasonFlags.certificate_hold,
"privilege_withdrawn": x509.ReasonFlags.privilege_withdrawn,
"aa_compromise": x509.ReasonFlags.aa_compromise,
"remove_from_crl": x509.ReasonFlags.remove_from_crl,
}
REVOCATION_REASON_MAP_INVERSE = dict()
for k, v in REVOCATION_REASON_MAP.items():
REVOCATION_REASON_MAP_INVERSE[v] = k
else:
REVOCATION_REASON_MAP = dict()
REVOCATION_REASON_MAP_INVERSE = dict()
def cryptography_decode_revoked_certificate(
cert: x509.RevokedCertificate,
) -> dict[str, t.Any]:
result = {
"serial_number": cert.serial_number,
"revocation_date": get_revocation_date(cert),
"issuer": None,
"issuer_critical": False,
"reason": None,
"reason_critical": False,
"invalidity_date": None,
"invalidity_date_critical": False,
}
try:
ext_ci = cert.extensions.get_extension_for_class(x509.CertificateIssuer)
result["issuer"] = list(ext_ci.value)
result["issuer_critical"] = ext_ci.critical
except x509.ExtensionNotFound:
pass
try:
ext_cr = cert.extensions.get_extension_for_class(x509.CRLReason)
result["reason"] = ext_cr.value.reason
result["reason_critical"] = ext_cr.critical
except x509.ExtensionNotFound:
pass
try:
ext_id = cert.extensions.get_extension_for_class(x509.InvalidityDate)
result["invalidity_date"] = get_invalidity_date(ext_id.value)
result["invalidity_date_critical"] = ext_id.critical
except x509.ExtensionNotFound:
pass
return result
def cryptography_dump_revoked(
entry: dict[str, t.Any],
idn_rewrite: t.Literal["ignore", "idna", "unicode"] = "ignore",
) -> dict[str, t.Any]:
return {
"serial_number": entry["serial_number"],
"revocation_date": entry["revocation_date"].strftime(TIMESTAMP_FORMAT),
"issuer": (
[
cryptography_decode_name(issuer, idn_rewrite=idn_rewrite)
for issuer in entry["issuer"]
]
if entry["issuer"] is not None
else None
),
"issuer_critical": entry["issuer_critical"],
"reason": (
REVOCATION_REASON_MAP_INVERSE.get(entry["reason"])
if entry["reason"] is not None
else None
),
"reason_critical": entry["reason_critical"],
"invalidity_date": (
entry["invalidity_date"].strftime(TIMESTAMP_FORMAT)
if entry["invalidity_date"] is not None
else None
),
"invalidity_date_critical": entry["invalidity_date_critical"],
}
def cryptography_get_signature_algorithm_oid_from_crl(
crl: x509.CertificateRevocationList,
) -> x509.oid.ObjectIdentifier:
try:
return crl.signature_algorithm_oid
except AttributeError:
# Older cryptography versions do not have signature_algorithm_oid yet
dotted = obj2txt(
crl._backend._lib, crl._backend._ffi, crl._x509_crl.sig_alg.algorithm # type: ignore
)
return x509.oid.ObjectIdentifier(dotted)
def get_next_update(obj: x509.CertificateRevocationList) -> datetime.datetime | None:
if CRYPTOGRAPHY_TIMEZONE:
return obj.next_update_utc
return obj.next_update
def get_last_update(obj: x509.CertificateRevocationList) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE:
return obj.last_update_utc
return obj.last_update
def get_revocation_date(obj: x509.RevokedCertificate) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE:
return obj.revocation_date_utc
return obj.revocation_date
def get_invalidity_date(obj: x509.InvalidityDate) -> datetime.datetime:
if CRYPTOGRAPHY_TIMEZONE_INVALIDITY_DATE:
return obj.invalidity_date_utc
return obj.invalidity_date
def set_next_update(
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
) -> x509.CertificateRevocationListBuilder:
return builder.next_update(value)
def set_last_update(
builder: x509.CertificateRevocationListBuilder, value: datetime.datetime
) -> x509.CertificateRevocationListBuilder:
return builder.last_update(value)
def set_revocation_date(
builder: x509.RevokedCertificateBuilder, value: datetime.datetime
) -> x509.RevokedCertificateBuilder:
return builder.revocation_date(value)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,158 @@
# Copyright (c) 2019, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
def binary_exp_mod(f: int, e: int, m: int) -> int:
"""Computes f^e mod m in O(log e) multiplications modulo m."""
# Compute len_e = floor(log_2(e))
len_e = -1
x = e
while x > 0:
x >>= 1
len_e += 1
# Compute f**e mod m
result = 1
for k in range(len_e, -1, -1):
result = (result * result) % m
if ((e >> k) & 1) != 0:
result = (result * f) % m
return result
def simple_gcd(a: int, b: int) -> int:
"""Compute GCD of its two inputs."""
while b != 0:
a, b = b, a % b
return a
def quick_is_not_prime(n: int) -> bool:
"""Does some quick checks to see if we can poke a hole into the primality of n.
A result of `False` does **not** mean that the number is prime; it just means
that we could not detect quickly whether it is not prime.
"""
if n <= 2:
return n < 2
# The constant in the next line is the product of all primes < 200
prime_product = 7799922041683461553249199106329813876687996789903550945093032474868511536164700810
gcd = simple_gcd(n, prime_product)
if gcd > 1:
if n < 200 and gcd == n:
# Explicitly check for all primes < 200
return n not in (
2,
3,
5,
7,
11,
13,
17,
19,
23,
29,
31,
37,
41,
43,
47,
53,
59,
61,
67,
71,
73,
79,
83,
89,
97,
101,
103,
107,
109,
113,
127,
131,
137,
139,
149,
151,
157,
163,
167,
173,
179,
181,
191,
193,
197,
199,
)
return True
# TODO: maybe do some iterations of Miller-Rabin to increase confidence
# (https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test)
return False
def count_bytes(no: int) -> int:
"""
Given an integer, compute the number of bytes necessary to store its absolute value.
"""
no = abs(no)
if no == 0:
return 0
return (no.bit_length() + 7) // 8
def count_bits(no: int) -> int:
"""
Given an integer, compute the number of bits necessary to store its absolute value.
"""
no = abs(no)
if no == 0:
return 0
return no.bit_length()
def convert_int_to_bytes(no: int, count: int | None = None) -> bytes:
"""
Convert the absolute value of an integer to a byte string in network byte order.
If ``count`` is provided, it must be sufficiently large so that the integer's
absolute value can be represented with these number of bytes. The resulting byte
string will have length exactly ``count``.
The value zero will be converted to an empty byte string if ``count`` is provided.
"""
no = abs(no)
if count is None:
count = count_bytes(no)
return no.to_bytes(count, byteorder="big")
def convert_int_to_hex(no: int, digits: int | None = None) -> str:
"""
Convert the absolute value of an integer to a string of hexadecimal digits.
If ``digits`` is provided, the string will be padded on the left with ``0``s so
that the returned value has length ``digits``. If ``digits`` is not sufficient,
the string will be longer.
"""
no = abs(no)
value = f"{no:x}"
if digits is not None and len(value) < digits:
value = "0" * (digits - len(value)) + value
return value
def convert_bytes_to_int(data: bytes) -> int:
"""
Convert a byte string to an unsigned integer in network byte order.
"""
return int.from_bytes(data, byteorder="big", signed=False)

View File

@@ -0,0 +1,417 @@
# Copyright (c) 2016-2017, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2017, Markus Teufelberger <mteufelberger+ansible@mgit.at>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import typing as t
from ansible_collections.community.crypto.plugins.module_utils._argspec import (
ArgumentSpec,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.basic import (
OpenSSLBadPassphraseError,
OpenSSLObjectError,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
cryptography_compare_public_keys,
get_not_valid_after,
get_not_valid_before,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.module_backends.certificate_info import (
get_certificate_info,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.support import (
load_certificate,
load_certificate_privatekey,
load_certificate_request,
)
from ansible_collections.community.crypto.plugins.module_utils._cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
assert_required_cryptography_version,
)
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
CertificatePrivateKeyTypes,
)
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
)
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
import cryptography
from cryptography import x509
except ImportError:
pass
class CertificateError(OpenSSLObjectError):
pass
class CertificateBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.force: bool = module.params["force"]
self.ignore_timestamps: bool = module.params["ignore_timestamps"]
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
if privatekey_content is not None:
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.csr_path: str | None = module.params["csr_path"]
csr_content = module.params["csr_content"]
if csr_content is not None:
self.csr_content: bytes | None = csr_content.encode("utf-8")
else:
self.csr_content = None
# The following are default values which make sure check() works as
# before if providers do not explicitly change these properties.
self.create_subject_key_identifier: str = "never_create"
self.create_authority_key_identifier: bool = False
self.privatekey: CertificatePrivateKeyTypes | None = None
self.csr: x509.CertificateSigningRequest | None = None
self.cert: x509.Certificate | None = None
self.existing_certificate: x509.Certificate | None = None
self.existing_certificate_bytes: bytes | None = None
self.check_csr_subject: bool = True
self.check_csr_extensions: bool = True
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
try:
result = get_certificate_info(
self.module, data, prefer_one_fingerprint=True
)
result["can_parse_certificate"] = True
return result
except Exception:
return dict(can_parse_certificate=False)
@abc.abstractmethod
def generate_certificate(self) -> None:
"""(Re-)Generate certificate."""
pass
@abc.abstractmethod
def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert."""
pass
def set_existing(self, certificate_bytes: bytes | None) -> None:
"""Set existing certificate bytes. None indicates that the key does not exist."""
self.existing_certificate_bytes = certificate_bytes
self.diff_after = self.diff_before = self._get_info(
self.existing_certificate_bytes
)
def has_existing(self) -> bool:
"""Query whether an existing certificate is/has been there."""
return self.existing_certificate_bytes is not None
def _ensure_private_key_loaded(self) -> None:
"""Load the provided private key into self.privatekey."""
if self.privatekey is not None:
return
if self.privatekey_path is None and self.privatekey_content is None:
return
try:
self.privatekey = load_certificate_privatekey(
path=self.privatekey_path,
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
)
except OpenSSLBadPassphraseError as exc:
raise CertificateError(exc)
def _ensure_csr_loaded(self) -> None:
"""Load the CSR into self.csr."""
if self.csr is not None:
return
if self.csr_path is None and self.csr_content is None:
return
self.csr = load_certificate_request(
path=self.csr_path,
content=self.csr_content,
)
def _ensure_existing_certificate_loaded(self) -> None:
"""Load the existing certificate into self.existing_certificate."""
if self.existing_certificate is not None:
return
if self.existing_certificate_bytes is None:
return
self.existing_certificate = load_certificate(
path=None,
content=self.existing_certificate_bytes,
)
def _check_privatekey(self) -> bool:
"""Check whether provided parameters match, assuming self.existing_certificate and self.privatekey have been populated."""
if self.existing_certificate is None:
raise AssertionError(
"Contract violation: existing_certificate has not been populated"
)
if self.privatekey is None:
raise AssertionError(
"Contract violation: privatekey has not been populated"
)
return cryptography_compare_public_keys(
self.existing_certificate.public_key(), self.privatekey.public_key()
)
def _check_csr(self) -> bool:
"""Check whether provided parameters match, assuming self.existing_certificate and self.csr have been populated."""
if self.existing_certificate is None:
raise AssertionError(
"Contract violation: existing_certificate has not been populated"
)
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
# Verify that CSR is signed by certificate's private key
if not self.csr.is_signature_valid:
return False
if not cryptography_compare_public_keys(
self.csr.public_key(), self.existing_certificate.public_key()
):
return False
# Check subject
if (
self.check_csr_subject
and self.csr.subject != self.existing_certificate.subject
):
return False
# Check extensions
if not self.check_csr_extensions:
return True
cert_exts = list(self.existing_certificate.extensions)
csr_exts = list(self.csr.extensions)
if self.create_subject_key_identifier != "never_create":
# Filter out SubjectKeyIdentifier extension before comparison
cert_exts = list(
filter(
lambda x: not isinstance(x.value, x509.SubjectKeyIdentifier),
cert_exts,
)
)
csr_exts = list(
filter(
lambda x: not isinstance(x.value, x509.SubjectKeyIdentifier),
csr_exts,
)
)
if self.create_authority_key_identifier:
# Filter out AuthorityKeyIdentifier extension before comparison
cert_exts = list(
filter(
lambda x: not isinstance(x.value, x509.AuthorityKeyIdentifier),
cert_exts,
)
)
csr_exts = list(
filter(
lambda x: not isinstance(x.value, x509.AuthorityKeyIdentifier),
csr_exts,
)
)
if len(cert_exts) != len(csr_exts):
return False
for cert_ext in cert_exts:
try:
csr_ext = self.csr.extensions.get_extension_for_oid(cert_ext.oid)
if cert_ext != csr_ext:
return False
except cryptography.x509.ExtensionNotFound:
return False
return True
def _check_subject_key_identifier(self) -> bool:
"""Check whether Subject Key Identifier matches, assuming self.existing_certificate and self.csr have been populated."""
if self.existing_certificate is None:
raise AssertionError(
"Contract violation: existing_certificate has not been populated"
)
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
# Get hold of certificate's SKI
try:
ext = self.existing_certificate.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
except cryptography.x509.ExtensionNotFound:
return False
# Get hold of CSR's SKI for 'create_if_not_provided'
csr_ext = None
if self.create_subject_key_identifier == "create_if_not_provided":
try:
csr_ext = self.csr.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
except cryptography.x509.ExtensionNotFound:
pass
if csr_ext is None:
# If CSR had no SKI, or we chose to ignore it ('always_create'), compare with created SKI
if (
ext.value.digest
!= x509.SubjectKeyIdentifier.from_public_key(
self.existing_certificate.public_key()
).digest
):
return False
else:
# If CSR had SKI and we did not ignore it ('create_if_not_provided'), compare SKIs
if ext.value.digest != csr_ext.value.digest:
return False
return True
def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
"""Check whether a regeneration is necessary."""
if self.force or self.existing_certificate_bytes is None:
return True
try:
self._ensure_existing_certificate_loaded()
except Exception:
return True
assert self.existing_certificate is not None
# Check whether private key matches
self._ensure_private_key_loaded()
if self.privatekey is not None and not self._check_privatekey():
return True
# Check whether CSR matches
self._ensure_csr_loaded()
if self.csr is not None and not self._check_csr():
return True
# Check SubjectKeyIdentifier
if (
self.create_subject_key_identifier != "never_create"
and not self._check_subject_key_identifier()
):
return True
# Check not before
if not_before is not None and not self.ignore_timestamps:
if get_not_valid_before(self.existing_certificate) != not_before:
return True
# Check not after
if not_after is not None and not self.ignore_timestamps:
if get_not_valid_after(self.existing_certificate) != not_after:
return True
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
"csr": self.csr_path,
}
# Get hold of certificate bytes
certificate_bytes = self.existing_certificate_bytes
if self.cert is not None:
certificate_bytes = self.get_certificate_data()
self.diff_after = self._get_info(certificate_bytes)
if include_certificate:
# Store result
result["certificate"] = (
certificate_bytes.decode("utf-8") if certificate_bytes else None
)
result["diff"] = {
"before": self.diff_before,
"after": self.diff_after,
}
return result
class CertificateProvider(metaclass=abc.ABCMeta):
@abc.abstractmethod
def validate_module_args(self, module: AnsibleModule) -> None:
"""Check module arguments"""
@abc.abstractmethod
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
"""Whether the provider needs to create a version 2 certificate."""
@abc.abstractmethod
def create_backend(self, module: AnsibleModule) -> CertificateBackend:
"""Create an implementation for a backend.
Return value must be instance of CertificateBackend.
"""
def select_backend(
module: AnsibleModule, provider: CertificateProvider
) -> CertificateBackend:
provider.validate_module_args(module)
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
if provider.needs_version_two_certs(module):
# TODO: remove
module.fail_json(
msg="The cryptography backend does not support v2 certificates"
)
return provider.create_backend(module)
def get_certificate_argument_spec() -> ArgumentSpec:
return ArgumentSpec(
argument_spec=dict(
provider=dict(
type="str", choices=[]
), # choices will be filled by add_XXX_provider_to_argument_spec() in certificate_xxx.py
force=dict(
type="bool",
default=False,
),
csr_path=dict(type="path"),
csr_content=dict(type="str"),
ignore_timestamps=dict(type="bool", default=True),
select_crypto_backend=dict(
type="str", default="auto", choices=["auto", "cryptography"]
),
# General properties of a certificate
privatekey_path=dict(type="path"),
privatekey_content=dict(type="str", no_log=True),
privatekey_passphrase=dict(type="str", no_log=True),
),
mutually_exclusive=[
["csr_path", "csr_content"],
["privatekey_path", "privatekey_content"],
],
)

View File

@@ -0,0 +1,140 @@
# Copyright (c) 2016-2017, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2017, Markus Teufelberger <mteufelberger+ansible@mgit.at>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import os
import tempfile
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils._crypto.module_backends.certificate import (
CertificateBackend,
CertificateError,
CertificateProvider,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils._argspec import (
ArgumentSpec,
)
class AcmeCertificateBackend(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(AcmeCertificateBackend, self).__init__(module)
self.accountkey_path: str = module.params["acme_accountkey_path"]
self.challenge_path: str = module.params["acme_challenge_path"]
self.use_chain: bool = module.params["acme_chain"]
self.acme_directory: str = module.params["acme_directory"]
self.cert_bytes: bytes | None = None
if self.csr_content is None:
if self.csr_path is None:
raise CertificateError(
"csr_path or csr_content is required for ownca provider"
)
if not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
if not os.path.exists(self.accountkey_path):
raise CertificateError(
f"The account key {self.accountkey_path} does not exist"
)
if not os.path.exists(self.challenge_path):
raise CertificateError(
f"The challenge path {self.challenge_path} does not exist"
)
self.acme_tiny_path = self.module.get_bin_path("acme-tiny", required=True)
def generate_certificate(self) -> None:
"""(Re-)Generate certificate."""
command = [self.acme_tiny_path]
if self.use_chain:
command.append("--chain")
command.extend(["--account-key", self.accountkey_path])
if self.csr_content is not None:
# We need to temporarily write the CSR to disk
fd, tmpsrc = tempfile.mkstemp()
self.module.add_cleanup_file(tmpsrc) # Ansible will delete the file on exit
f = os.fdopen(fd, "wb")
try:
f.write(self.csr_content)
except Exception as err:
try:
f.close()
except Exception:
pass
self.module.fail_json(
msg=f"failed to create temporary CSR file: {err}",
exception=traceback.format_exc(),
)
f.close()
command.extend(["--csr", tmpsrc])
else:
command.extend(["--csr", self.csr_path])
command.extend(["--acme-dir", self.challenge_path])
command.extend(["--directory-url", self.acme_directory])
try:
self.cert_bytes = to_bytes(
self.module.run_command(command, check_rc=True)[1]
)
except OSError as exc:
raise CertificateError(exc)
def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert."""
if self.cert_bytes is None:
raise AssertionError("Contract violation: cert_bytes is None")
return self.cert_bytes
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(AcmeCertificateBackend, self).dump(include_certificate)
result["accountkey"] = self.accountkey_path
return result
class AcmeCertificateProvider(CertificateProvider):
def validate_module_args(self, module: AnsibleModule) -> None:
if module.params["acme_accountkey_path"] is None:
module.fail_json(
msg="The acme_accountkey_path option must be specified for the acme provider."
)
if module.params["acme_challenge_path"] is None:
module.fail_json(
msg="The acme_challenge_path option must be specified for the acme provider."
)
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
return False
def create_backend(self, module: AnsibleModule) -> AcmeCertificateBackend:
return AcmeCertificateBackend(module)
def add_acme_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("acme")
argument_spec.argument_spec.update(
dict(
acme_accountkey_path=dict(type="path"),
acme_challenge_path=dict(type="path"),
acme_chain=dict(type="bool", default=False),
acme_directory=dict(
type="str", default="https://acme-v02.api.letsencrypt.org/directory"
),
)
)

View File

@@ -0,0 +1,283 @@
# Copyright (c) 2016-2017, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2017, Markus Teufelberger <mteufelberger+ansible@mgit.at>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import datetime
import os
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
CRYPTOGRAPHY_TIMEZONE,
get_not_valid_after,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.module_backends.certificate import (
CertificateBackend,
CertificateError,
CertificateProvider,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.support import (
load_certificate,
)
from ansible_collections.community.crypto.plugins.module_utils._ecs.api import (
ECSClient,
RestOperationException,
SessionConfigurationException,
)
from ansible_collections.community.crypto.plugins.module_utils._time import (
get_now_datetime,
get_relative_time_option,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils._argspec import (
ArgumentSpec,
)
try:
from cryptography.x509.oid import NameOID
except ImportError:
pass
class EntrustCertificateBackend(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(EntrustCertificateBackend, self).__init__(module)
self.trackingId = None
self.notAfter = get_relative_time_option(
module.params["entrust_not_after"],
"entrust_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
if self.csr_content is None:
if self.csr_path is None:
raise CertificateError(
"csr_path or csr_content is required for entrust provider"
)
if not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
self._ensure_csr_loaded()
if self.csr is None:
raise CertificateError("CSR not provided")
# ECS API defaults to using the validated organization tied to the account.
# We want to always force behavior of trying to use the organization provided in the CSR.
# To that end we need to parse out the organization from the CSR.
self.csr_org = None
csr_subject_orgs = self.csr.subject.get_attributes_for_oid(
NameOID.ORGANIZATION_NAME
)
if len(csr_subject_orgs) == 1:
self.csr_org = csr_subject_orgs[0].value
elif len(csr_subject_orgs) > 1:
self.module.fail_json(
msg=(
"Entrust provider does not currently support multiple validated organizations. Multiple organizations found in "
f"Subject DN: '{self.csr.subject}'. "
)
)
# If no organization in the CSR, explicitly tell ECS that it should be blank in issued cert, not defaulted to
# organization tied to the account.
if self.csr_org is None:
self.csr_org = ""
try:
self.ecs_client = ECSClient(
entrust_api_user=self.module.params["entrust_api_user"],
entrust_api_key=self.module.params["entrust_api_key"],
entrust_api_cert=self.module.params["entrust_api_client_cert_path"],
entrust_api_cert_key=self.module.params[
"entrust_api_client_cert_key_path"
],
entrust_api_specification_path=self.module.params[
"entrust_api_specification_path"
],
)
except SessionConfigurationException as e:
module.fail_json(msg=f"Failed to initialize Entrust Provider: {e}")
def generate_certificate(self) -> None:
"""(Re-)Generate certificate."""
body = {}
# Read the CSR that was generated for us
if self.csr_content is not None:
# csr_content contains bytes
body["csr"] = to_native(self.csr_content)
else:
assert self.csr_path is not None
with open(self.csr_path, "r") as csr_file:
body["csr"] = csr_file.read()
body["certType"] = self.module.params["entrust_cert_type"]
# Handle expiration (30 days if not specified)
expiry = self.notAfter
if not expiry:
gmt_now = get_now_datetime(with_timezone=CRYPTOGRAPHY_TIMEZONE)
expiry = gmt_now + datetime.timedelta(days=365)
expiry_iso3339 = expiry.strftime("%Y-%m-%dT%H:%M:%S.00Z")
body["certExpiryDate"] = expiry_iso3339
body["org"] = self.csr_org
body["tracking"] = {
"requesterName": self.module.params["entrust_requester_name"],
"requesterEmail": self.module.params["entrust_requester_email"],
"requesterPhone": self.module.params["entrust_requester_phone"],
}
try:
result = self.ecs_client.NewCertRequest(Body=body)
self.trackingId = result.get("trackingId")
except RestOperationException as e:
self.module.fail_json(
msg=f"Failed to request new certificate from Entrust Certificate Services (ECS): {e.message}"
)
self.cert_bytes = to_bytes(result.get("endEntityCert"))
self.cert = load_certificate(
path=None,
content=self.cert_bytes,
)
def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert."""
return self.cert_bytes
def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
parent_check = super(EntrustCertificateBackend, self).needs_regeneration()
try:
cert_details = self._get_cert_details()
except RestOperationException as e:
self.module.fail_json(
msg=f"Failed to get status of existing certificate from Entrust Certificate Services (ECS): {e.message}."
)
# Always issue a new certificate if the certificate is expired, suspended or revoked
status = cert_details.get("status", False)
if status == "EXPIRED" or status == "SUSPENDED" or status == "REVOKED":
return True
# If the requested cert type was specified and it is for a different certificate type than the initial certificate, a new one is needed
if (
self.module.params["entrust_cert_type"]
and cert_details.get("certType")
and self.module.params["entrust_cert_type"] != cert_details.get("certType")
):
return True
return parent_check
def _get_cert_details(self) -> dict[str, t.Any]:
cert_details: dict[str, t.Any] = {}
try:
self._ensure_existing_certificate_loaded()
except Exception:
return cert_details
if self.existing_certificate:
serial_number = f"{self.existing_certificate.serial_number:X}"
expiry = get_not_valid_after(self.existing_certificate)
# get some information about the expiry of this certificate
expiry_iso3339 = expiry.strftime("%Y-%m-%dT%H:%M:%S.00Z")
cert_details["expiresAfter"] = expiry_iso3339
# If a trackingId is not already defined (from the result of a generate)
# use the serial number to identify the tracking Id
if self.trackingId is None and serial_number is not None:
cert_results = self.ecs_client.GetCertificates(
serialNumber=serial_number
).get("certificates", {})
# Finding 0 or more than 1 result is a very unlikely use case, it simply means we cannot perform additional checks
# on the 'state' as returned by Entrust Certificate Services (ECS). The general certificate validity is
# still checked as it is in the rest of the module.
if len(cert_results) == 1:
self.trackingId = cert_results[0].get("trackingId")
if self.trackingId is not None:
cert_details.update(
self.ecs_client.GetCertificate(trackingId=self.trackingId)
)
return cert_details
class EntrustCertificateProvider(CertificateProvider):
def validate_module_args(self, module: AnsibleModule) -> None:
pass
def needs_version_two_certs(self, module: AnsibleModule) -> t.Literal[False]:
return False
def create_backend(self, module: AnsibleModule) -> EntrustCertificateBackend:
return EntrustCertificateBackend(module)
def add_entrust_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("entrust")
argument_spec.argument_spec.update(
dict(
entrust_cert_type=dict(
type="str",
default="STANDARD_SSL",
choices=[
"STANDARD_SSL",
"ADVANTAGE_SSL",
"UC_SSL",
"EV_SSL",
"WILDCARD_SSL",
"PRIVATE_SSL",
"PD_SSL",
"CDS_ENT_LITE",
"CDS_ENT_PRO",
"SMIME_ENT",
],
),
entrust_requester_email=dict(type="str"),
entrust_requester_name=dict(type="str"),
entrust_requester_phone=dict(type="str"),
entrust_api_user=dict(type="str"),
entrust_api_key=dict(type="str", no_log=True),
entrust_api_client_cert_path=dict(type="path"),
entrust_api_client_cert_key_path=dict(type="path", no_log=True),
entrust_api_specification_path=dict(
type="path",
default="https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml",
),
entrust_not_after=dict(type="str", default="+365d"),
)
)
argument_spec.required_if.append(
(
"provider",
"entrust",
[
"entrust_requester_email",
"entrust_requester_name",
"entrust_requester_phone",
"entrust_api_user",
"entrust_api_key",
"entrust_api_client_cert_path",
"entrust_api_client_cert_key_path",
],
)
)

View File

@@ -0,0 +1,480 @@
# Copyright (c) 2016-2017, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2017, Markus Teufelberger <mteufelberger+ansible@mgit.at>
# Copyright (c) 2020, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import binascii
import typing as t
from ansible.module_utils.common.text.converters import to_native
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
CRYPTOGRAPHY_TIMEZONE,
cryptography_decode_name,
cryptography_get_extensions_from_cert,
cryptography_oid_to_name,
get_not_valid_after,
get_not_valid_before,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.module_backends.publickey_info import (
get_publickey_info,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.support import (
get_fingerprint_of_bytes,
load_certificate,
)
from ansible_collections.community.crypto.plugins.module_utils._cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
assert_required_cryptography_version,
)
from ansible_collections.community.crypto.plugins.module_utils._time import (
get_now_datetime,
)
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils._argspec import (
ArgumentSpec,
)
from ansible_collections.community.crypto.plugins.plugin_utils._action_module import (
AnsibleActionModule,
)
from ansible_collections.community.crypto.plugins.plugin_utils._filter_module import (
FilterModuleMock,
)
from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
import cryptography
from cryptography import x509
from cryptography.hazmat.primitives import serialization
except ImportError:
pass
TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CertificateInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
# content must be a bytes string
self.module = module
self.content = content
@abc.abstractmethod
def _get_der_bytes(self) -> bytes:
pass
@abc.abstractmethod
def _get_signature_algorithm(self) -> str:
pass
@abc.abstractmethod
def _get_subject_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_issuer_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_version(self) -> int | str:
pass
@abc.abstractmethod
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
pass
@abc.abstractmethod
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def get_not_before(self) -> datetime.datetime:
pass
@abc.abstractmethod
def get_not_after(self) -> datetime.datetime:
pass
@abc.abstractmethod
def _get_public_key_pem(self) -> bytes:
pass
@abc.abstractmethod
def _get_public_key_object(self) -> PublicKeyTypes:
pass
@abc.abstractmethod
def _get_subject_key_identifier(self) -> bytes | None:
pass
@abc.abstractmethod
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
pass
@abc.abstractmethod
def _get_serial_number(self) -> int:
pass
@abc.abstractmethod
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
pass
@abc.abstractmethod
def _get_ocsp_uri(self) -> str | None:
pass
@abc.abstractmethod
def _get_issuer_uri(self) -> str | None:
pass
def get_info(
self, prefer_one_fingerprint: bool = False, der_support_enabled: bool = False
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.cert = load_certificate(
None,
content=self.content,
der_support_enabled=der_support_enabled,
)
result["signature_algorithm"] = self._get_signature_algorithm()
subject = self._get_subject_ordered()
issuer = self._get_issuer_ordered()
result["subject"] = dict()
for k, v in subject:
result["subject"][k] = v
result["subject_ordered"] = subject
result["issuer"] = dict()
for k, v in issuer:
result["issuer"][k] = v
result["issuer_ordered"] = issuer
result["version"] = self._get_version()
result["key_usage"], result["key_usage_critical"] = self._get_key_usage()
result["extended_key_usage"], result["extended_key_usage_critical"] = (
self._get_extended_key_usage()
)
result["basic_constraints"], result["basic_constraints_critical"] = (
self._get_basic_constraints()
)
result["ocsp_must_staple"], result["ocsp_must_staple_critical"] = (
self._get_ocsp_must_staple()
)
result["subject_alt_name"], result["subject_alt_name_critical"] = (
self._get_subject_alt_name()
)
not_before = self.get_not_before()
not_after = self.get_not_after()
result["not_before"] = not_before.strftime(TIMESTAMP_FORMAT)
result["not_after"] = not_after.strftime(TIMESTAMP_FORMAT)
result["expired"] = not_after < get_now_datetime(
with_timezone=CRYPTOGRAPHY_TIMEZONE
)
result["public_key"] = to_native(self._get_public_key_pem())
public_key_info = get_publickey_info(
self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
result.update(
{
"public_key_type": public_key_info["type"],
"public_key_data": public_key_info["public_data"],
"public_key_fingerprints": public_key_info["fingerprints"],
}
)
result["fingerprints"] = get_fingerprint_of_bytes(
self._get_der_bytes(), prefer_one=prefer_one_fingerprint
)
ski_bytes = self._get_subject_key_identifier()
if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
else:
ski = None
result["subject_key_identifier"] = ski
aki_bytes, aci, acsn = self._get_authority_key_identifier()
if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
else:
aki = None
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
result["serial_number"] = self._get_serial_number()
result["extensions_by_oid"] = self._get_all_extensions()
result["ocsp_uri"] = self._get_ocsp_uri()
result["issuer_uri"] = self._get_issuer_uri()
return result
class CertificateInfoRetrievalCryptography(CertificateInfoRetrieval):
"""Validate the supplied cert, using the cryptography backend"""
def __init__(self, module: GeneralAnsibleModule, content: bytes) -> None:
super(CertificateInfoRetrievalCryptography, self).__init__(module, content)
self.name_encoding = module.params.get("name_encoding", "ignore")
def _get_der_bytes(self) -> bytes:
return self.cert.public_bytes(serialization.Encoding.DER)
def _get_signature_algorithm(self) -> str:
return cryptography_oid_to_name(self.cert.signature_algorithm_oid)
def _get_subject_ordered(self) -> list[list[str]]:
result: list[list[str]] = []
for attribute in self.cert.subject:
result.append(
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
)
return result
def _get_issuer_ordered(self) -> list[list[str]]:
result = []
for attribute in self.cert.issuer:
result.append(
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
)
return result
def _get_version(self) -> int | str:
if self.cert.version == x509.Version.v1:
return 1
if self.cert.version == x509.Version.v3:
return 3
return "unknown"
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
try:
current_key_ext = self.cert.extensions.get_extension_for_class(
x509.KeyUsage
)
current_key_usage = current_key_ext.value
key_usage = dict(
digital_signature=current_key_usage.digital_signature,
content_commitment=current_key_usage.content_commitment,
key_encipherment=current_key_usage.key_encipherment,
data_encipherment=current_key_usage.data_encipherment,
key_agreement=current_key_usage.key_agreement,
key_cert_sign=current_key_usage.key_cert_sign,
crl_sign=current_key_usage.crl_sign,
encipher_only=False,
decipher_only=False,
)
if key_usage["key_agreement"]:
key_usage.update(
dict(
encipher_only=current_key_usage.encipher_only,
decipher_only=current_key_usage.decipher_only,
)
)
key_usage_names = dict(
digital_signature="Digital Signature",
content_commitment="Non Repudiation",
key_encipherment="Key Encipherment",
data_encipherment="Data Encipherment",
key_agreement="Key Agreement",
key_cert_sign="Certificate Sign",
crl_sign="CRL Sign",
encipher_only="Encipher Only",
decipher_only="Decipher Only",
)
return (
sorted(
[
key_usage_names[name]
for name, value in key_usage.items()
if value
]
),
current_key_ext.critical,
)
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
try:
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
x509.ExtendedKeyUsage
)
return (
sorted(
[cryptography_oid_to_name(eku) for eku in ext_keyusage_ext.value]
),
ext_keyusage_ext.critical,
)
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
try:
ext_keyusage_ext = self.cert.extensions.get_extension_for_class(
x509.BasicConstraints
)
result = []
result.append(f"CA:{'TRUE' if ext_keyusage_ext.value.ca else 'FALSE'}")
if ext_keyusage_ext.value.path_length is not None:
result.append(f"pathlen:{ext_keyusage_ext.value.path_length}")
return sorted(result), ext_keyusage_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
try:
tlsfeature_ext = self.cert.extensions.get_extension_for_class(
x509.TLSFeature
)
value = (
cryptography.x509.TLSFeatureType.status_request in tlsfeature_ext.value
)
return value, tlsfeature_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
try:
san_ext = self.cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
result = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in san_ext.value
]
return result, san_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
def get_not_before(self) -> datetime.datetime:
return get_not_valid_before(self.cert)
def get_not_after(self) -> datetime.datetime:
return get_not_valid_after(self.cert)
def _get_public_key_pem(self) -> bytes:
return self.cert.public_key().public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_public_key_object(self) -> PublicKeyTypes:
return self.cert.public_key()
def _get_subject_key_identifier(self) -> bytes | None:
try:
ext = self.cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
return ext.value.digest
except cryptography.x509.ExtensionNotFound:
return None
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
try:
ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier
)
issuer = None
if ext.value.authority_cert_issuer is not None:
issuer = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in ext.value.authority_cert_issuer
]
return (
ext.value.key_identifier,
issuer,
ext.value.authority_cert_serial_number,
)
except cryptography.x509.ExtensionNotFound:
return None, None, None
def _get_serial_number(self) -> int:
return self.cert.serial_number
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
return cryptography_get_extensions_from_cert(self.cert)
def _get_ocsp_uri(self) -> str | None:
try:
ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityInformationAccess
)
for desc in ext.value:
if desc.access_method == x509.oid.AuthorityInformationAccessOID.OCSP:
if isinstance(desc.access_location, x509.UniformResourceIdentifier):
return desc.access_location.value
except x509.ExtensionNotFound:
pass
return None
def _get_issuer_uri(self) -> str | None:
try:
ext = self.cert.extensions.get_extension_for_class(
x509.AuthorityInformationAccess
)
for desc in ext.value:
if (
desc.access_method
== x509.oid.AuthorityInformationAccessOID.CA_ISSUERS
):
if isinstance(desc.access_location, x509.UniformResourceIdentifier):
return desc.access_location.value
except x509.ExtensionNotFound:
pass
return None
def get_certificate_info(
module: GeneralAnsibleModule, content: bytes, prefer_one_fingerprint: bool = False
) -> dict[str, t.Any]:
info = CertificateInfoRetrievalCryptography(module, content)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
module: GeneralAnsibleModule, content: bytes
) -> CertificateInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateInfoRetrievalCryptography(module, content)

View File

@@ -0,0 +1,371 @@
# Copyright (c) 2016-2017, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2017, Markus Teufelberger <mteufelberger+ansible@mgit.at>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import os
import typing as t
from random import randrange
from ansible_collections.community.crypto.plugins.module_utils._crypto.basic import (
OpenSSLBadPassphraseError,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
CRYPTOGRAPHY_TIMEZONE,
cryptography_compare_public_keys,
cryptography_key_needs_digest_for_signing,
cryptography_verify_certificate_signature,
get_not_valid_after,
get_not_valid_before,
is_potential_certificate_issuer_public_key,
set_not_valid_after,
set_not_valid_before,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.module_backends.certificate import (
CertificateBackend,
CertificateError,
CertificateProvider,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.support import (
load_certificate,
load_certificate_issuer_privatekey,
select_message_digest,
)
from ansible_collections.community.crypto.plugins.module_utils._time import (
get_relative_time_option,
)
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils._argspec import (
ArgumentSpec,
)
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
)
try:
import cryptography
from cryptography import x509
from cryptography.hazmat.primitives.serialization import Encoding
except ImportError:
pass
class OwnCACertificateBackendCryptography(CertificateBackend):
def __init__(self, module: AnsibleModule) -> None:
super(OwnCACertificateBackendCryptography, self).__init__(module)
self.create_subject_key_identifier: t.Literal[
"create_if_not_provided", "always_create", "never_create"
] = module.params["ownca_create_subject_key_identifier"]
self.create_authority_key_identifier: bool = module.params[
"ownca_create_authority_key_identifier"
]
self.notBefore = get_relative_time_option(
module.params["ownca_not_before"],
"ownca_not_before",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.notAfter = get_relative_time_option(
module.params["ownca_not_after"],
"ownca_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["ownca_digest"])
self.version: int = module.params["ownca_version"]
self.serial_number = x509.random_serial_number()
self.ca_cert_path: str | None = module.params["ownca_path"]
ca_cert_content: str | None = module.params["ownca_content"]
if ca_cert_content is not None:
self.ca_cert_content: bytes | None = ca_cert_content.encode("utf-8")
else:
self.ca_cert_content = None
self.ca_privatekey_path: str | None = module.params["ownca_privatekey_path"]
ca_privatekey_content: str | None = module.params["ownca_privatekey_content"]
if ca_privatekey_content is not None:
self.ca_privatekey_content: bytes | None = ca_privatekey_content.encode(
"utf-8"
)
else:
self.ca_privatekey_content = None
self.ca_privatekey_passphrase: str | None = module.params[
"ownca_privatekey_passphrase"
]
if self.csr_content is None:
if self.csr_path is None:
raise CertificateError(
"csr_path or csr_content is required for ownca provider"
)
if not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
if self.ca_cert_path is not None and not os.path.exists(self.ca_cert_path):
raise CertificateError(
f"The CA certificate file {self.ca_cert_path} does not exist"
)
if self.ca_privatekey_path is not None and not os.path.exists(
self.ca_privatekey_path
):
raise CertificateError(
f"The CA private key file {self.ca_privatekey_path} does not exist"
)
self._ensure_csr_loaded()
self.ca_cert = load_certificate(
path=self.ca_cert_path,
content=self.ca_cert_content,
)
if not is_potential_certificate_issuer_public_key(self.ca_cert.public_key()):
raise CertificateError(
"CA certificate's public key cannot be used to sign certificates"
)
try:
self.ca_private_key = load_certificate_issuer_privatekey(
path=self.ca_privatekey_path,
content=self.ca_privatekey_content,
passphrase=self.ca_privatekey_passphrase,
)
except OpenSSLBadPassphraseError as exc:
module.fail_json(msg=str(exc))
if not cryptography_compare_public_keys(
self.ca_cert.public_key(), self.ca_private_key.public_key()
):
raise CertificateError(
"The CA private key does not belong to the CA certificate"
)
if cryptography_key_needs_digest_for_signing(self.ca_private_key):
if self.digest is None:
raise CertificateError(
f"The digest {module.params['ownca_digest']} is not supported with the cryptography backend"
)
else:
self.digest = None
def generate_certificate(self) -> None:
"""(Re-)Generate certificate."""
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
cert_builder = x509.CertificateBuilder()
cert_builder = cert_builder.subject_name(self.csr.subject)
cert_builder = cert_builder.issuer_name(self.ca_cert.subject)
cert_builder = cert_builder.serial_number(self.serial_number)
cert_builder = set_not_valid_before(cert_builder, self.notBefore)
cert_builder = set_not_valid_after(cert_builder, self.notAfter)
cert_builder = cert_builder.public_key(self.csr.public_key())
has_ski = False
for extension in self.csr.extensions:
if isinstance(extension.value, x509.SubjectKeyIdentifier):
if self.create_subject_key_identifier == "always_create":
continue
has_ski = True
if self.create_authority_key_identifier and isinstance(
extension.value, x509.AuthorityKeyIdentifier
):
continue
cert_builder = cert_builder.add_extension(
extension.value, critical=extension.critical
)
if not has_ski and self.create_subject_key_identifier != "never_create":
cert_builder = cert_builder.add_extension(
x509.SubjectKeyIdentifier.from_public_key(self.csr.public_key()),
critical=False,
)
if self.create_authority_key_identifier:
try:
ext = self.ca_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
cert_builder = cert_builder.add_extension(
(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ext.value
)
),
critical=False,
)
except cryptography.x509.ExtensionNotFound:
public_key = self.ca_cert.public_key()
assert is_potential_certificate_issuer_public_key(public_key)
cert_builder = cert_builder.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_public_key(public_key),
critical=False,
)
certificate = cert_builder.sign(
private_key=self.ca_private_key,
algorithm=self.digest,
)
self.cert = certificate
def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert."""
if self.cert is None:
raise AssertionError("Contract violation: cert has not been populated")
return self.cert.public_bytes(Encoding.PEM)
def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
if super(OwnCACertificateBackendCryptography, self).needs_regeneration(
not_before=self.notBefore, not_after=self.notAfter
):
return True
self._ensure_existing_certificate_loaded()
assert self.existing_certificate is not None
# Check whether certificate is signed by CA certificate
if not cryptography_verify_certificate_signature(
self.existing_certificate, self.ca_cert.public_key()
):
return True
# Check subject
if self.ca_cert.subject != self.existing_certificate.issuer:
return True
# Check AuthorityKeyIdentifier
if self.create_authority_key_identifier:
try:
ext_ski = self.ca_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
)
expected_ext = (
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ext_ski.value
)
)
except cryptography.x509.ExtensionNotFound:
public_key = self.ca_cert.public_key()
assert is_potential_certificate_issuer_public_key(public_key)
expected_ext = x509.AuthorityKeyIdentifier.from_issuer_public_key(
public_key
)
try:
ext_aki = self.existing_certificate.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier
)
if ext_aki.value != expected_ext:
return True
except cryptography.x509.ExtensionNotFound:
return True
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(OwnCACertificateBackendCryptography, self).dump(
include_certificate
)
result.update(
{
"ca_cert": self.ca_cert_path,
"ca_privatekey": self.ca_privatekey_path,
}
)
if self.module.check_mode:
result.update(
{
"notBefore": self.notBefore.strftime("%Y%m%d%H%M%SZ"),
"notAfter": self.notAfter.strftime("%Y%m%d%H%M%SZ"),
"serial_number": self.serial_number,
}
)
else:
if self.cert is None:
self.cert = self.existing_certificate
assert self.cert is not None
result.update(
{
"notBefore": get_not_valid_before(self.cert).strftime(
"%Y%m%d%H%M%SZ"
),
"notAfter": get_not_valid_after(self.cert).strftime(
"%Y%m%d%H%M%SZ"
),
"serial_number": self.cert.serial_number,
}
)
return result
def generate_serial_number() -> int:
"""Generate a serial number for a certificate"""
while True:
result = randrange(0, 1 << 160)
if result >= 1000:
return result
class OwnCACertificateProvider(CertificateProvider):
def validate_module_args(self, module: AnsibleModule) -> None:
if (
module.params["ownca_path"] is None
and module.params["ownca_content"] is None
):
module.fail_json(
msg="One of ownca_path and ownca_content must be specified for the ownca provider."
)
if (
module.params["ownca_privatekey_path"] is None
and module.params["ownca_privatekey_content"] is None
):
module.fail_json(
msg="One of ownca_privatekey_path and ownca_privatekey_content must be specified for the ownca provider."
)
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
return module.params["ownca_version"] == 2
def create_backend(
self, module: AnsibleModule
) -> OwnCACertificateBackendCryptography:
return OwnCACertificateBackendCryptography(module)
def add_ownca_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("ownca")
argument_spec.argument_spec.update(
dict(
ownca_path=dict(type="path"),
ownca_content=dict(type="str"),
ownca_privatekey_path=dict(type="path"),
ownca_privatekey_content=dict(type="str", no_log=True),
ownca_privatekey_passphrase=dict(type="str", no_log=True),
ownca_digest=dict(type="str", default="sha256"),
ownca_version=dict(type="int", default=3),
ownca_not_before=dict(type="str", default="+0s"),
ownca_not_after=dict(type="str", default="+3650d"),
ownca_create_subject_key_identifier=dict(
type="str",
default="create_if_not_provided",
choices=["create_if_not_provided", "always_create", "never_create"],
),
ownca_create_authority_key_identifier=dict(type="bool", default=True),
)
)
argument_spec.mutually_exclusive.extend(
[
["ownca_path", "ownca_content"],
["ownca_privatekey_path", "ownca_privatekey_content"],
]
)

View File

@@ -0,0 +1,263 @@
# Copyright (c) 2016-2017, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2017, Markus Teufelberger <mteufelberger+ansible@mgit.at>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import os
import typing as t
from random import randrange
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
CRYPTOGRAPHY_TIMEZONE,
cryptography_key_needs_digest_for_signing,
cryptography_verify_certificate_signature,
get_not_valid_after,
get_not_valid_before,
is_potential_certificate_issuer_private_key,
set_not_valid_after,
set_not_valid_before,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.module_backends.certificate import (
CertificateBackend,
CertificateError,
CertificateProvider,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.support import (
select_message_digest,
)
from ansible_collections.community.crypto.plugins.module_utils._time import (
get_relative_time_option,
)
if t.TYPE_CHECKING:
import datetime
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils._argspec import (
ArgumentSpec,
)
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
)
try:
import cryptography
from cryptography import x509
from cryptography.hazmat.primitives.serialization import Encoding
except ImportError:
pass
class SelfSignedCertificateBackendCryptography(CertificateBackend):
privatekey: CertificateIssuerPrivateKeyTypes
def __init__(self, module: AnsibleModule) -> None:
super(SelfSignedCertificateBackendCryptography, self).__init__(module)
self.create_subject_key_identifier: t.Literal[
"create_if_not_provided", "always_create", "never_create"
] = module.params["selfsigned_create_subject_key_identifier"]
self.notBefore = get_relative_time_option(
module.params["selfsigned_not_before"],
"selfsigned_not_before",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.notAfter = get_relative_time_option(
module.params["selfsigned_not_after"],
"selfsigned_not_after",
with_timezone=CRYPTOGRAPHY_TIMEZONE,
)
self.digest = select_message_digest(module.params["selfsigned_digest"])
self.version: int = module.params["selfsigned_version"]
self.serial_number = x509.random_serial_number()
if self.csr_path is not None and not os.path.exists(self.csr_path):
raise CertificateError(
f"The certificate signing request file {self.csr_path} does not exist"
)
if self.privatekey_path is not None and not os.path.exists(
self.privatekey_path
):
raise CertificateError(
f"The private key file {self.privatekey_path} does not exist"
)
self._module = module
self._ensure_private_key_loaded()
if self.privatekey is None:
raise CertificateError("Private key has not been provided")
if not is_potential_certificate_issuer_private_key(self.privatekey):
raise CertificateError("Private key cannot be used to sign certificates")
if cryptography_key_needs_digest_for_signing(self.privatekey):
if self.digest is None:
raise CertificateError(
f"The digest {module.params['selfsigned_digest']} is not supported with the cryptography backend"
)
else:
self.digest = None
self._ensure_csr_loaded()
if self.csr is None:
# Create empty CSR on the fly
csr = cryptography.x509.CertificateSigningRequestBuilder()
csr = csr.subject_name(cryptography.x509.Name([]))
self.csr = csr.sign(self.privatekey, self.digest)
def generate_certificate(self) -> None:
"""(Re-)Generate certificate."""
if self.csr is None:
raise AssertionError("Contract violation: csr has not been populated")
if self.privatekey is None:
raise AssertionError(
"Contract violation: privatekey has not been populated"
)
try:
cert_builder = x509.CertificateBuilder()
cert_builder = cert_builder.subject_name(self.csr.subject)
cert_builder = cert_builder.issuer_name(self.csr.subject)
cert_builder = cert_builder.serial_number(self.serial_number)
cert_builder = set_not_valid_before(cert_builder, self.notBefore)
cert_builder = set_not_valid_after(cert_builder, self.notAfter)
cert_builder = cert_builder.public_key(self.privatekey.public_key())
has_ski = False
for extension in self.csr.extensions:
if isinstance(extension.value, x509.SubjectKeyIdentifier):
if self.create_subject_key_identifier == "always_create":
continue
has_ski = True
cert_builder = cert_builder.add_extension(
extension.value, critical=extension.critical
)
if not has_ski and self.create_subject_key_identifier != "never_create":
cert_builder = cert_builder.add_extension(
x509.SubjectKeyIdentifier.from_public_key(
self.privatekey.public_key()
),
critical=False,
)
except ValueError as e:
raise CertificateError(str(e))
certificate = cert_builder.sign(
private_key=self.privatekey,
algorithm=self.digest,
)
self.cert = certificate
def get_certificate_data(self) -> bytes:
"""Return bytes for self.cert."""
if self.cert is None:
raise AssertionError("Contract violation: cert has not been populated")
return self.cert.public_bytes(Encoding.PEM)
def needs_regeneration(
self,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> bool:
assert self.privatekey is not None
if super(SelfSignedCertificateBackendCryptography, self).needs_regeneration(
not_before=self.notBefore, not_after=self.notAfter
):
return True
self._ensure_existing_certificate_loaded()
assert self.existing_certificate is not None
# Check whether certificate is signed by private key
if not cryptography_verify_certificate_signature(
self.existing_certificate, self.privatekey.public_key()
):
return True
return False
def dump(self, include_certificate: bool) -> dict[str, t.Any]:
result = super(SelfSignedCertificateBackendCryptography, self).dump(
include_certificate
)
if self.module.check_mode:
result.update(
{
"notBefore": self.notBefore.strftime("%Y%m%d%H%M%SZ"),
"notAfter": self.notAfter.strftime("%Y%m%d%H%M%SZ"),
"serial_number": self.serial_number,
}
)
else:
if self.cert is None:
self.cert = self.existing_certificate
assert self.cert is not None
result.update(
{
"notBefore": get_not_valid_before(self.cert).strftime(
"%Y%m%d%H%M%SZ"
),
"notAfter": get_not_valid_after(self.cert).strftime(
"%Y%m%d%H%M%SZ"
),
"serial_number": self.cert.serial_number,
}
)
return result
def generate_serial_number() -> int:
"""Generate a serial number for a certificate"""
while True:
result = randrange(0, 1 << 160)
if result >= 1000:
return result
class SelfSignedCertificateProvider(CertificateProvider):
def validate_module_args(self, module: AnsibleModule) -> None:
if (
module.params["privatekey_path"] is None
and module.params["privatekey_content"] is None
):
module.fail_json(
msg="One of privatekey_path and privatekey_content must be specified for the selfsigned provider."
)
def needs_version_two_certs(self, module: AnsibleModule) -> bool:
return module.params["selfsigned_version"] == 2
def create_backend(
self, module: AnsibleModule
) -> SelfSignedCertificateBackendCryptography:
return SelfSignedCertificateBackendCryptography(module)
def add_selfsigned_provider_to_argument_spec(argument_spec: ArgumentSpec) -> None:
argument_spec.argument_spec["provider"]["choices"].append("selfsigned")
argument_spec.argument_spec.update(
dict(
selfsigned_version=dict(type="int", default=3),
selfsigned_digest=dict(type="str", default="sha256"),
selfsigned_not_before=dict(
type="str", default="+0s", aliases=["selfsigned_notBefore"]
),
selfsigned_not_after=dict(
type="str", default="+3650d", aliases=["selfsigned_notAfter"]
),
selfsigned_create_subject_key_identifier=dict(
type="str",
default="create_if_not_provided",
choices=["create_if_not_provided", "always_create", "never_create"],
),
)
)

View File

@@ -0,0 +1,124 @@
# Copyright (c) 2020, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import typing as t
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_crl import (
TIMESTAMP_FORMAT,
cryptography_decode_revoked_certificate,
cryptography_dump_revoked,
cryptography_get_signature_algorithm_oid_from_crl,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
cryptography_oid_to_name,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.pem import (
identify_pem_format,
)
from ansible_collections.community.crypto.plugins.module_utils._cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
assert_required_cryptography_version,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.plugin_utils._action_module import (
AnsibleActionModule,
)
from ansible_collections.community.crypto.plugins.plugin_utils._filter_module import (
FilterModuleMock,
)
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
# crypto_utils
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
from cryptography import x509
except ImportError:
pass
class CRLInfoRetrieval:
def __init__(
self,
module: GeneralAnsibleModule,
content: bytes,
list_revoked_certificates: bool = True,
) -> None:
# content must be a bytes string
self.module = module
self.content = content
self.list_revoked_certificates = list_revoked_certificates
self.name_encoding = module.params.get("name_encoding", "ignore")
def get_info(self) -> dict[str, t.Any]:
self.crl_pem = identify_pem_format(self.content)
try:
if self.crl_pem:
self.crl = x509.load_pem_x509_crl(self.content)
else:
self.crl = x509.load_der_x509_crl(self.content)
except ValueError as e:
self.module.fail_json(msg=f"Error while decoding CRL: {e}")
result: dict[str, t.Any] = {
"changed": False,
"format": "pem" if self.crl_pem else "der",
"last_update": None,
"next_update": None,
"digest": None,
"issuer_ordered": None,
"issuer": None,
}
result["last_update"] = self.crl.last_update.strftime(TIMESTAMP_FORMAT)
result["next_update"] = (
self.crl.next_update.strftime(TIMESTAMP_FORMAT)
if self.crl.next_update
else None
)
result["digest"] = cryptography_oid_to_name(
cryptography_get_signature_algorithm_oid_from_crl(self.crl)
)
issuer = []
for attribute in self.crl.issuer:
issuer.append([cryptography_oid_to_name(attribute.oid), attribute.value])
result["issuer_ordered"] = issuer
result["issuer"] = {}
for k, v in issuer:
result["issuer"][k] = v
if self.list_revoked_certificates:
result["revoked_certificates"] = []
for cert in self.crl:
entry = cryptography_decode_revoked_certificate(cert)
result["revoked_certificates"].append(
cryptography_dump_revoked(entry, idn_rewrite=self.name_encoding)
)
return result
def get_crl_info(
module: GeneralAnsibleModule, content: bytes, list_revoked_certificates: bool = True
) -> dict[str, t.Any]:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
info = CRLInfoRetrieval(
module, content, list_revoked_certificates=list_revoked_certificates
)
return info.get_info()

View File

@@ -0,0 +1,905 @@
# Copyright (c) 2016, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2020, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import binascii
import typing as t
from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.crypto.plugins.module_utils._argspec import (
ArgumentSpec,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.basic import (
OpenSSLBadPassphraseError,
OpenSSLObjectError,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_crl import (
REVOCATION_REASON_MAP,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
cryptography_get_basic_constraints,
cryptography_get_name,
cryptography_key_needs_digest_for_signing,
cryptography_name_to_oid,
cryptography_parse_key_usage_params,
cryptography_parse_relative_distinguished_name,
is_potential_certificate_issuer_public_key,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.module_backends.csr_info import (
get_csr_info,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.support import (
load_certificate_issuer_privatekey,
load_certificate_request,
parse_name_field,
parse_ordered_name_field,
select_message_digest,
)
from ansible_collections.community.crypto.plugins.module_utils._cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
assert_required_cryptography_version,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
CertificatePrivateKeyTypes,
)
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
)
_ET = t.TypeVar("_ET", bound="cryptography.x509.ExtensionType")
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
import cryptography
import cryptography.exceptions
import cryptography.hazmat.backends
import cryptography.hazmat.primitives.hashes
import cryptography.hazmat.primitives.serialization
import cryptography.x509
import cryptography.x509.oid
except ImportError:
pass
class CertificateSigningRequestError(OpenSSLObjectError):
pass
# From the object called `module`, only the following properties are used:
#
# - module.params[]
# - module.warn(msg: str)
# - module.fail_json(msg: str, **kwargs)
class CertificateSigningRequestBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.digest: str = module.params["digest"]
self.privatekey_path: str | None = module.params["privatekey_path"]
privatekey_content: str | None = module.params["privatekey_content"]
if privatekey_content is not None:
self.privatekey_content: bytes | None = privatekey_content.encode("utf-8")
else:
self.privatekey_content = None
self.privatekey_passphrase: str | None = module.params["privatekey_passphrase"]
self.version: t.Literal[1] = module.params["version"]
self.subjectAltName: list[str] | None = module.params["subject_alt_name"]
self.subjectAltName_critical: bool = module.params["subject_alt_name_critical"]
self.keyUsage: list[str] | None = module.params["key_usage"]
self.keyUsage_critical: bool = module.params["key_usage_critical"]
self.extendedKeyUsage: list[str] | None = module.params["extended_key_usage"]
self.extendedKeyUsage_critical: bool = module.params[
"extended_key_usage_critical"
]
self.basicConstraints: list[str] | None = module.params["basic_constraints"]
self.basicConstraints_critical: bool = module.params[
"basic_constraints_critical"
]
self.ocspMustStaple: bool = module.params["ocsp_must_staple"]
self.ocspMustStaple_critical: bool = module.params["ocsp_must_staple_critical"]
self.name_constraints_permitted: list[str] = (
module.params["name_constraints_permitted"] or []
)
self.name_constraints_excluded: list[str] = (
module.params["name_constraints_excluded"] or []
)
self.name_constraints_critical: bool = module.params[
"name_constraints_critical"
]
self.create_subject_key_identifier: bool = module.params[
"create_subject_key_identifier"
]
subject_key_identifier: str | None = module.params["subject_key_identifier"]
authority_key_identifier: str | None = module.params["authority_key_identifier"]
self.authority_cert_issuer: list[str] | None = module.params[
"authority_cert_issuer"
]
self.authority_cert_serial_number: int = module.params[
"authority_cert_serial_number"
]
self.crl_distribution_points: (
list[cryptography.x509.DistributionPoint] | None
) = None
self.csr: cryptography.x509.CertificateSigningRequest | None = None
self.privatekey: CertificateIssuerPrivateKeyTypes | None = None
if self.create_subject_key_identifier and subject_key_identifier is not None:
module.fail_json(
msg="subject_key_identifier cannot be specified if create_subject_key_identifier is true"
)
self.ordered_subject = False
self.subject = [
("C", module.params["country_name"]),
("ST", module.params["state_or_province_name"]),
("L", module.params["locality_name"]),
("O", module.params["organization_name"]),
("OU", module.params["organizational_unit_name"]),
("CN", module.params["common_name"]),
("emailAddress", module.params["email_address"]),
]
self.subject = [(entry[0], entry[1]) for entry in self.subject if entry[1]]
try:
if module.params["subject"]:
self.subject = self.subject + parse_name_field(
module.params["subject"], "subject"
)
if module.params["subject_ordered"]:
if self.subject:
raise CertificateSigningRequestError(
"subject_ordered cannot be combined with any other subject field"
)
self.subject = parse_ordered_name_field(
module.params["subject_ordered"], "subject_ordered"
)
self.ordered_subject = True
except ValueError as exc:
raise CertificateSigningRequestError(str(exc))
self.using_common_name_for_san = False
if not self.subjectAltName and module.params["use_common_name_for_san"]:
for sub in self.subject:
if sub[0] in ("commonName", "CN"):
self.subjectAltName = [f"DNS:{sub[1]}"]
self.using_common_name_for_san = True
break
self.subject_key_identifier: bytes | None = None
if subject_key_identifier is not None:
try:
self.subject_key_identifier = binascii.unhexlify(
subject_key_identifier.replace(":", "")
)
except Exception as e:
raise CertificateSigningRequestError(
f"Cannot parse subject_key_identifier: {e}"
)
self.authority_key_identifier: bytes | None = None
if authority_key_identifier is not None:
try:
self.authority_key_identifier = binascii.unhexlify(
authority_key_identifier.replace(":", "")
)
except Exception as e:
raise CertificateSigningRequestError(
f"Cannot parse authority_key_identifier: {e}"
)
self.existing_csr: cryptography.x509.CertificateSigningRequest | None = None
self.existing_csr_bytes: bytes | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
try:
result = get_csr_info(
self.module,
data,
validate_signature=False,
prefer_one_fingerprint=True,
)
result["can_parse_csr"] = True
return result
except Exception:
return dict(can_parse_csr=False)
@abc.abstractmethod
def generate_csr(self) -> None:
"""(Re-)Generate CSR."""
@abc.abstractmethod
def get_csr_data(self) -> bytes:
"""Return bytes for self.csr."""
def set_existing(self, csr_bytes: bytes | None) -> None:
"""Set existing CSR bytes. None indicates that the CSR does not exist."""
self.existing_csr_bytes = csr_bytes
self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes)
def has_existing(self) -> bool:
"""Query whether an existing CSR is/has been there."""
return self.existing_csr_bytes is not None
def _ensure_private_key_loaded(self) -> None:
"""Load the provided private key into self.privatekey."""
if self.privatekey is not None:
return
try:
self.privatekey = load_certificate_issuer_privatekey(
path=self.privatekey_path,
content=self.privatekey_content,
passphrase=self.privatekey_passphrase,
)
except OpenSSLBadPassphraseError as exc:
raise CertificateSigningRequestError(exc)
@abc.abstractmethod
def _check_csr(self) -> bool:
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
def needs_regeneration(self) -> bool:
"""Check whether a regeneration is necessary."""
if self.existing_csr_bytes is None:
return True
try:
self.existing_csr = load_certificate_request(
None,
content=self.existing_csr_bytes,
)
except Exception:
return True
self._ensure_private_key_loaded()
return not self._check_csr()
def dump(self, include_csr: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
result: dict[str, t.Any] = {
"privatekey": self.privatekey_path,
"subject": self.subject,
"subjectAltName": self.subjectAltName,
"keyUsage": self.keyUsage,
"extendedKeyUsage": self.extendedKeyUsage,
"basicConstraints": self.basicConstraints,
"ocspMustStaple": self.ocspMustStaple,
"name_constraints_permitted": self.name_constraints_permitted,
"name_constraints_excluded": self.name_constraints_excluded,
}
# Get hold of CSR bytes
csr_bytes = self.existing_csr_bytes
if self.csr is not None:
csr_bytes = self.get_csr_data()
self.diff_after = self._get_info(csr_bytes)
if include_csr:
# Store result
result["csr"] = csr_bytes.decode("utf-8") if csr_bytes else None
result["diff"] = dict(
before=self.diff_before,
after=self.diff_after,
)
return result
def parse_crl_distribution_points(
module: AnsibleModule, crl_distribution_points: list[dict[str, t.Any]]
) -> list[cryptography.x509.DistributionPoint]:
result = []
for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
try:
full_name = None
relative_name = None
crl_issuer = None
reasons = None
if parse_crl_distribution_point["full_name"] is not None:
if not parse_crl_distribution_point["full_name"]:
raise OpenSSLObjectError("full_name must not be empty")
full_name = [
cryptography_get_name(name, "full name")
for name in parse_crl_distribution_point["full_name"]
]
if parse_crl_distribution_point["relative_name"] is not None:
if not parse_crl_distribution_point["relative_name"]:
raise OpenSSLObjectError("relative_name must not be empty")
relative_name = cryptography_parse_relative_distinguished_name(
parse_crl_distribution_point["relative_name"]
)
if parse_crl_distribution_point["crl_issuer"] is not None:
if not parse_crl_distribution_point["crl_issuer"]:
raise OpenSSLObjectError("crl_issuer must not be empty")
crl_issuer = [
cryptography_get_name(name, "CRL issuer")
for name in parse_crl_distribution_point["crl_issuer"]
]
if parse_crl_distribution_point["reasons"] is not None:
reasons_list = []
for reason in parse_crl_distribution_point["reasons"]:
reasons_list.append(REVOCATION_REASON_MAP[reason])
reasons = frozenset(reasons_list)
result.append(
cryptography.x509.DistributionPoint(
full_name=full_name,
relative_name=relative_name,
crl_issuer=crl_issuer,
reasons=reasons,
)
)
except (OpenSSLObjectError, ValueError) as e:
raise OpenSSLObjectError(
f"Error while parsing CRL distribution point #{index}: {e}"
)
return result
# Implementation with using cryptography
class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
def __init__(self, module: AnsibleModule) -> None:
super(CertificateSigningRequestCryptographyBackend, self).__init__(module)
if self.version != 1:
module.warn(
"The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)"
)
crl_distribution_points: list[dict[str, t.Any]] | None = module.params[
"crl_distribution_points"
]
if crl_distribution_points:
self.crl_distribution_points = parse_crl_distribution_points(
module, crl_distribution_points
)
def generate_csr(self) -> None:
"""(Re-)Generate CSR."""
self._ensure_private_key_loaded()
assert self.privatekey is not None
csr = cryptography.x509.CertificateSigningRequestBuilder()
try:
csr = csr.subject_name(
cryptography.x509.Name(
[
cryptography.x509.NameAttribute(
cryptography_name_to_oid(entry[0]), to_text(entry[1])
)
for entry in self.subject
]
)
)
except ValueError as e:
raise CertificateSigningRequestError(e)
if self.subjectAltName:
csr = csr.add_extension(
cryptography.x509.SubjectAlternativeName(
[cryptography_get_name(name) for name in self.subjectAltName]
),
critical=self.subjectAltName_critical,
)
if self.keyUsage:
params = cryptography_parse_key_usage_params(self.keyUsage)
csr = csr.add_extension(
cryptography.x509.KeyUsage(**params), critical=self.keyUsage_critical
)
if self.extendedKeyUsage:
usages = [
cryptography_name_to_oid(usage) for usage in self.extendedKeyUsage
]
csr = csr.add_extension(
cryptography.x509.ExtendedKeyUsage(usages),
critical=self.extendedKeyUsage_critical,
)
if self.basicConstraints:
params = {}
ca, path_length = cryptography_get_basic_constraints(self.basicConstraints)
csr = csr.add_extension(
cryptography.x509.BasicConstraints(ca, path_length),
critical=self.basicConstraints_critical,
)
if self.ocspMustStaple:
csr = csr.add_extension(
cryptography.x509.TLSFeature(
[cryptography.x509.TLSFeatureType.status_request]
),
critical=self.ocspMustStaple_critical,
)
if self.name_constraints_permitted or self.name_constraints_excluded:
try:
csr = csr.add_extension(
cryptography.x509.NameConstraints(
[
cryptography_get_name(name, "name constraints permitted")
for name in self.name_constraints_permitted
]
or None,
[
cryptography_get_name(name, "name constraints excluded")
for name in self.name_constraints_excluded
]
or None,
),
critical=self.name_constraints_critical,
)
except TypeError as e:
raise OpenSSLObjectError(f"Error while parsing name constraint: {e}")
if self.create_subject_key_identifier:
if not is_potential_certificate_issuer_public_key(
self.privatekey.public_key()
):
raise OpenSSLObjectError(
"Private key can not be used to create subject key identifier"
)
csr = csr.add_extension(
cryptography.x509.SubjectKeyIdentifier.from_public_key(
self.privatekey.public_key()
),
critical=False,
)
elif self.subject_key_identifier is not None:
csr = csr.add_extension(
cryptography.x509.SubjectKeyIdentifier(self.subject_key_identifier),
critical=False,
)
if (
self.authority_key_identifier is not None
or self.authority_cert_issuer is not None
or self.authority_cert_serial_number is not None
):
issuers = None
if self.authority_cert_issuer is not None:
issuers = [
cryptography_get_name(n, "authority cert issuer")
for n in self.authority_cert_issuer
]
csr = csr.add_extension(
cryptography.x509.AuthorityKeyIdentifier(
self.authority_key_identifier,
issuers,
self.authority_cert_serial_number,
),
critical=False,
)
if self.crl_distribution_points:
csr = csr.add_extension(
cryptography.x509.CRLDistributionPoints(self.crl_distribution_points),
critical=False,
)
# csr.sign() does not accept some digests we theoretically could have in digest.
# For that reason we use type t.Any here. csr.sign() will complain if
# the digest is not acceptable.
digest: t.Any | None = None
if cryptography_key_needs_digest_for_signing(self.privatekey):
digest = select_message_digest(self.digest)
if digest is None:
raise CertificateSigningRequestError(
f'Unsupported digest "{self.digest}"'
)
try:
self.csr = csr.sign(self.privatekey, digest)
except UnicodeError as e:
# This catches IDNAErrors, which happens when a bad name is passed as a SAN
# (https://github.com/ansible-collections/community.crypto/issues/105).
# For older cryptography versions, this is handled by idna, which raises
# an idna.core.IDNAError. Later versions of cryptography deprecated and stopped
# requiring idna, whence we cannot easily handle this error. Fortunately, in
# most versions of idna, IDNAError extends UnicodeError. There is only version
# 2.3 where it extends Exception instead (see
# https://github.com/kjd/idna/commit/ebefacd3134d0f5da4745878620a6a1cba86d130
# and then
# https://github.com/kjd/idna/commit/ea03c7b5db7d2a99af082e0239da2b68aeea702a).
msg = f"Error while creating CSR: {e}\n"
if self.using_common_name_for_san:
self.module.fail_json(
msg=msg
+ "This is probably caused because the Common Name is used as a SAN."
" Specifying use_common_name_for_san=false might fix this."
)
self.module.fail_json(
msg=msg
+ "This is probably caused by an invalid Subject Alternative DNS Name."
)
def get_csr_data(self) -> bytes:
"""Return bytes for self.csr."""
if self.csr is None:
raise AssertionError("Violated contract: csr is not populated")
return self.csr.public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM
)
def _check_csr(self) -> bool:
"""Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
if self.existing_csr is None:
raise AssertionError("Violated contract: existing_csr is not populated")
if self.privatekey is None:
raise AssertionError("Violated contract: privatekey is not populated")
def _check_subject(csr: cryptography.x509.CertificateSigningRequest) -> bool:
subject = [
(cryptography_name_to_oid(entry[0]), to_text(entry[1]))
for entry in self.subject
]
current_subject = [(sub.oid, sub.value) for sub in csr.subject]
if self.ordered_subject:
return subject == current_subject
else:
return set(subject) == set(current_subject)
def _find_extension(
extensions: cryptography.x509.Extensions, exttype: type[_ET]
) -> cryptography.x509.Extension[_ET] | None:
return next(
(ext for ext in extensions if isinstance(ext.value, exttype)), None
)
def _check_subjectAltName(extensions: cryptography.x509.Extensions) -> bool:
current_altnames_ext = _find_extension(
extensions, cryptography.x509.SubjectAlternativeName
)
current_altnames = (
[to_text(altname) for altname in current_altnames_ext.value]
if current_altnames_ext
else []
)
altnames = (
[
to_text(cryptography_get_name(altname))
for altname in self.subjectAltName
]
if self.subjectAltName
else []
)
if set(altnames) != set(current_altnames):
return False
if altnames and current_altnames_ext:
if current_altnames_ext.critical != self.subjectAltName_critical:
return False
return True
def _check_keyUsage(extensions: cryptography.x509.Extensions) -> bool:
current_keyusage_ext = _find_extension(
extensions, cryptography.x509.KeyUsage
)
if not self.keyUsage:
return current_keyusage_ext is None
elif current_keyusage_ext is None:
return False
params = cryptography_parse_key_usage_params(self.keyUsage)
for param in params:
if getattr(current_keyusage_ext.value, "_" + param) != params[param]:
return False
if current_keyusage_ext.critical != self.keyUsage_critical:
return False
return True
def _check_extenededKeyUsage(extensions: cryptography.x509.Extensions) -> bool:
current_usages_ext = _find_extension(
extensions, cryptography.x509.ExtendedKeyUsage
)
current_usages = (
[str(usage) for usage in current_usages_ext.value]
if current_usages_ext
else []
)
usages = (
[
str(cryptography_name_to_oid(usage))
for usage in self.extendedKeyUsage
]
if self.extendedKeyUsage
else []
)
if set(current_usages) != set(usages):
return False
if usages and current_usages_ext:
if current_usages_ext.critical != self.extendedKeyUsage_critical:
return False
return True
def _check_basicConstraints(extensions: cryptography.x509.Extensions) -> bool:
bc_ext = _find_extension(extensions, cryptography.x509.BasicConstraints)
current_ca = bc_ext.value.ca if bc_ext else False
current_path_length = bc_ext.value.path_length if bc_ext else None
ca, path_length = cryptography_get_basic_constraints(self.basicConstraints)
# Check CA flag
if ca != current_ca:
return False
# Check path length
if path_length != current_path_length:
return False
# Check criticality
if self.basicConstraints:
return (
bc_ext is not None
and bc_ext.critical == self.basicConstraints_critical
)
else:
return bc_ext is None
def _check_ocspMustStaple(extensions: cryptography.x509.Extensions) -> bool:
tlsfeature_ext = _find_extension(extensions, cryptography.x509.TLSFeature)
if self.ocspMustStaple:
if (
not tlsfeature_ext
or tlsfeature_ext.critical != self.ocspMustStaple_critical
):
return False
return (
cryptography.x509.TLSFeatureType.status_request
in tlsfeature_ext.value
)
else:
return tlsfeature_ext is None
def _check_nameConstraints(extensions: cryptography.x509.Extensions) -> bool:
current_nc_ext = _find_extension(
extensions, cryptography.x509.NameConstraints
)
current_nc_perm = (
[
to_text(altname)
for altname in current_nc_ext.value.permitted_subtrees or []
]
if current_nc_ext
else []
)
current_nc_excl = (
[
to_text(altname)
for altname in current_nc_ext.value.excluded_subtrees or []
]
if current_nc_ext
else []
)
nc_perm = [
to_text(cryptography_get_name(altname, "name constraints permitted"))
for altname in self.name_constraints_permitted
]
nc_excl = [
to_text(cryptography_get_name(altname, "name constraints excluded"))
for altname in self.name_constraints_excluded
]
if set(nc_perm) != set(current_nc_perm) or set(nc_excl) != set(
current_nc_excl
):
return False
if (nc_perm or nc_excl) and current_nc_ext:
if current_nc_ext.critical != self.name_constraints_critical:
return False
return True
def _check_subject_key_identifier(
extensions: cryptography.x509.Extensions,
) -> bool:
ext = _find_extension(extensions, cryptography.x509.SubjectKeyIdentifier)
if (
self.create_subject_key_identifier
or self.subject_key_identifier is not None
):
if not ext or ext.critical:
return False
if self.create_subject_key_identifier:
assert self.privatekey is not None
digest = cryptography.x509.SubjectKeyIdentifier.from_public_key(
self.privatekey.public_key()
).digest
return ext.value.digest == digest
else:
return ext.value.digest == self.subject_key_identifier
else:
return ext is None
def _check_authority_key_identifier(
extensions: cryptography.x509.Extensions,
) -> bool:
ext = _find_extension(extensions, cryptography.x509.AuthorityKeyIdentifier)
if (
self.authority_key_identifier is not None
or self.authority_cert_issuer is not None
or self.authority_cert_serial_number is not None
):
if not ext or ext.critical:
return False
aci = None
csr_aci = None
if self.authority_cert_issuer is not None:
aci = [
to_text(cryptography_get_name(n, "authority cert issuer"))
for n in self.authority_cert_issuer
]
if ext.value.authority_cert_issuer is not None:
csr_aci = [to_text(n) for n in ext.value.authority_cert_issuer]
return (
ext.value.key_identifier == self.authority_key_identifier
and csr_aci == aci
and ext.value.authority_cert_serial_number
== self.authority_cert_serial_number
)
else:
return ext is None
def _check_crl_distribution_points(
extensions: cryptography.x509.Extensions,
) -> bool:
ext = _find_extension(extensions, cryptography.x509.CRLDistributionPoints)
if self.crl_distribution_points is None:
return ext is None
if not ext:
return False
return list(ext.value) == self.crl_distribution_points
def _check_extensions(csr: cryptography.x509.CertificateSigningRequest) -> bool:
extensions = csr.extensions
return (
_check_subjectAltName(extensions)
and _check_keyUsage(extensions)
and _check_extenededKeyUsage(extensions)
and _check_basicConstraints(extensions)
and _check_ocspMustStaple(extensions)
and _check_subject_key_identifier(extensions)
and _check_authority_key_identifier(extensions)
and _check_nameConstraints(extensions)
and _check_crl_distribution_points(extensions)
)
def _check_signature(csr: cryptography.x509.CertificateSigningRequest) -> bool:
if not csr.is_signature_valid:
return False
# To check whether public key of CSR belongs to private key,
# encode both public keys and compare PEMs.
key_a = csr.public_key().public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
)
assert self.privatekey is not None
key_b = self.privatekey.public_key().public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo,
)
return key_a == key_b
return (
_check_subject(self.existing_csr)
and _check_extensions(self.existing_csr)
and _check_signature(self.existing_csr)
)
def select_backend(
module: AnsibleModule,
) -> CertificateSigningRequestCryptographyBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CertificateSigningRequestCryptographyBackend(module)
def get_csr_argument_spec() -> ArgumentSpec:
return ArgumentSpec(
argument_spec=dict(
digest=dict(type="str", default="sha256"),
privatekey_path=dict(type="path"),
privatekey_content=dict(type="str", no_log=True),
privatekey_passphrase=dict(type="str", no_log=True),
version=dict(type="int", default=1, choices=[1]),
subject=dict(type="dict"),
subject_ordered=dict(type="list", elements="dict"),
country_name=dict(type="str", aliases=["C", "countryName"]),
state_or_province_name=dict(
type="str", aliases=["ST", "stateOrProvinceName"]
),
locality_name=dict(type="str", aliases=["L", "localityName"]),
organization_name=dict(type="str", aliases=["O", "organizationName"]),
organizational_unit_name=dict(
type="str", aliases=["OU", "organizationalUnitName"]
),
common_name=dict(type="str", aliases=["CN", "commonName"]),
email_address=dict(type="str", aliases=["E", "emailAddress"]),
subject_alt_name=dict(
type="list", elements="str", aliases=["subjectAltName"]
),
subject_alt_name_critical=dict(
type="bool", default=False, aliases=["subjectAltName_critical"]
),
use_common_name_for_san=dict(
type="bool", default=True, aliases=["useCommonNameForSAN"]
),
key_usage=dict(type="list", elements="str", aliases=["keyUsage"]),
key_usage_critical=dict(
type="bool", default=False, aliases=["keyUsage_critical"]
),
extended_key_usage=dict(
type="list", elements="str", aliases=["extKeyUsage", "extendedKeyUsage"]
),
extended_key_usage_critical=dict(
type="bool",
default=False,
aliases=["extKeyUsage_critical", "extendedKeyUsage_critical"],
),
basic_constraints=dict(
type="list", elements="str", aliases=["basicConstraints"]
),
basic_constraints_critical=dict(
type="bool", default=False, aliases=["basicConstraints_critical"]
),
ocsp_must_staple=dict(
type="bool", default=False, aliases=["ocspMustStaple"]
),
ocsp_must_staple_critical=dict(
type="bool", default=False, aliases=["ocspMustStaple_critical"]
),
name_constraints_permitted=dict(type="list", elements="str"),
name_constraints_excluded=dict(type="list", elements="str"),
name_constraints_critical=dict(type="bool", default=False),
create_subject_key_identifier=dict(type="bool", default=False),
subject_key_identifier=dict(type="str"),
authority_key_identifier=dict(type="str"),
authority_cert_issuer=dict(type="list", elements="str"),
authority_cert_serial_number=dict(type="int"),
crl_distribution_points=dict(
type="list",
elements="dict",
options=dict(
full_name=dict(type="list", elements="str"),
relative_name=dict(type="list", elements="str"),
crl_issuer=dict(type="list", elements="str"),
reasons=dict(
type="list",
elements="str",
choices=[
"key_compromise",
"ca_compromise",
"affiliation_changed",
"superseded",
"cessation_of_operation",
"certificate_hold",
"privilege_withdrawn",
"aa_compromise",
],
),
),
mutually_exclusive=[("full_name", "relative_name")],
required_one_of=[("full_name", "relative_name", "crl_issuer")],
),
select_crypto_backend=dict(
type="str", default="auto", choices=["auto", "cryptography"]
),
),
required_together=[
["authority_cert_issuer", "authority_cert_serial_number"],
],
mutually_exclusive=[
["privatekey_path", "privatekey_content"],
["subject", "subject_ordered"],
],
required_one_of=[
["privatekey_path", "privatekey_content"],
],
)

View File

@@ -0,0 +1,394 @@
# Copyright (c) 2016-2017, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2017, Markus Teufelberger <mteufelberger+ansible@mgit.at>
# Copyright (c) 2020, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import binascii
import typing as t
from ansible.module_utils.common.text.converters import to_native
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
cryptography_decode_name,
cryptography_get_extensions_from_csr,
cryptography_oid_to_name,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.module_backends.publickey_info import (
get_publickey_info,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.support import (
load_certificate_request,
)
from ansible_collections.community.crypto.plugins.module_utils._cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
assert_required_cryptography_version,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.plugin_utils._action_module import (
AnsibleActionModule,
)
from ansible_collections.community.crypto.plugins.plugin_utils._filter_module import (
FilterModuleMock,
)
from cryptography.hazmat.primitives.asymmetric.types import (
CertificatePublicKeyTypes,
PrivateKeyTypes,
)
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
import cryptography
from cryptography import x509
from cryptography.hazmat.primitives import serialization
except ImportError:
pass
TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ"
class CSRInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
self.module = module
self.content = content
self.validate_signature = validate_signature
@abc.abstractmethod
def _get_subject_ordered(self) -> list[list[str]]:
pass
@abc.abstractmethod
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
pass
@abc.abstractmethod
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
pass
@abc.abstractmethod
def _get_public_key_pem(self) -> bytes:
pass
@abc.abstractmethod
def _get_public_key_object(self) -> CertificatePublicKeyTypes:
pass
@abc.abstractmethod
def _get_subject_key_identifier(self) -> bytes | None:
pass
@abc.abstractmethod
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
pass
@abc.abstractmethod
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
pass
@abc.abstractmethod
def _is_signature_valid(self) -> bool:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
self.csr = load_certificate_request(
None,
content=self.content,
)
subject = self._get_subject_ordered()
result["subject"] = dict()
for k, v in subject:
result["subject"][k] = v
result["subject_ordered"] = subject
result["key_usage"], result["key_usage_critical"] = self._get_key_usage()
result["extended_key_usage"], result["extended_key_usage_critical"] = (
self._get_extended_key_usage()
)
result["basic_constraints"], result["basic_constraints_critical"] = (
self._get_basic_constraints()
)
result["ocsp_must_staple"], result["ocsp_must_staple_critical"] = (
self._get_ocsp_must_staple()
)
result["subject_alt_name"], result["subject_alt_name_critical"] = (
self._get_subject_alt_name()
)
(
result["name_constraints_permitted"],
result["name_constraints_excluded"],
result["name_constraints_critical"],
) = self._get_name_constraints()
result["public_key"] = to_native(self._get_public_key_pem())
public_key_info = get_publickey_info(
self.module,
key=self._get_public_key_object(),
prefer_one_fingerprint=prefer_one_fingerprint,
)
result.update(
{
"public_key_type": public_key_info["type"],
"public_key_data": public_key_info["public_data"],
"public_key_fingerprints": public_key_info["fingerprints"],
}
)
ski_bytes = self._get_subject_key_identifier()
ski = None
if ski_bytes is not None:
ski = binascii.hexlify(ski_bytes).decode("ascii")
ski = ":".join([ski[i : i + 2] for i in range(0, len(ski), 2)])
result["subject_key_identifier"] = ski
aki_bytes, aci, acsn = self._get_authority_key_identifier()
aki = None
if aki_bytes is not None:
aki = binascii.hexlify(aki_bytes).decode("ascii")
aki = ":".join([aki[i : i + 2] for i in range(0, len(aki), 2)])
result["authority_key_identifier"] = aki
result["authority_cert_issuer"] = aci
result["authority_cert_serial_number"] = acsn
result["extensions_by_oid"] = self._get_all_extensions()
result["signature_valid"] = self._is_signature_valid()
if self.validate_signature and not result["signature_valid"]:
self.module.fail_json(msg="CSR signature is invalid!", **result)
return result
class CSRInfoRetrievalCryptography(CSRInfoRetrieval):
"""Validate the supplied CSR, using the cryptography backend"""
def __init__(
self, module: GeneralAnsibleModule, content: bytes, validate_signature: bool
) -> None:
super(CSRInfoRetrievalCryptography, self).__init__(
module, content, validate_signature
)
self.name_encoding: t.Literal["ignore", "idna", "unicode"] = module.params.get(
"name_encoding", "ignore"
)
def _get_subject_ordered(self) -> list[list[str]]:
result: list[list[str]] = []
for attribute in self.csr.subject:
result.append(
[cryptography_oid_to_name(attribute.oid), to_native(attribute.value)]
)
return result
def _get_key_usage(self) -> tuple[list[str] | None, bool]:
try:
current_key_ext = self.csr.extensions.get_extension_for_class(x509.KeyUsage)
current_key_usage = current_key_ext.value
key_usage = dict(
digital_signature=current_key_usage.digital_signature,
content_commitment=current_key_usage.content_commitment,
key_encipherment=current_key_usage.key_encipherment,
data_encipherment=current_key_usage.data_encipherment,
key_agreement=current_key_usage.key_agreement,
key_cert_sign=current_key_usage.key_cert_sign,
crl_sign=current_key_usage.crl_sign,
encipher_only=False,
decipher_only=False,
)
if key_usage["key_agreement"]:
key_usage.update(
dict(
encipher_only=current_key_usage.encipher_only,
decipher_only=current_key_usage.decipher_only,
)
)
key_usage_names = dict(
digital_signature="Digital Signature",
content_commitment="Non Repudiation",
key_encipherment="Key Encipherment",
data_encipherment="Data Encipherment",
key_agreement="Key Agreement",
key_cert_sign="Certificate Sign",
crl_sign="CRL Sign",
encipher_only="Encipher Only",
decipher_only="Decipher Only",
)
return (
sorted(
[
key_usage_names[name]
for name, value in key_usage.items()
if value
]
),
current_key_ext.critical,
)
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_extended_key_usage(self) -> tuple[list[str] | None, bool]:
try:
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
x509.ExtendedKeyUsage
)
return (
sorted(
[cryptography_oid_to_name(eku) for eku in ext_keyusage_ext.value]
),
ext_keyusage_ext.critical,
)
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_basic_constraints(self) -> tuple[list[str] | None, bool]:
try:
ext_keyusage_ext = self.csr.extensions.get_extension_for_class(
x509.BasicConstraints
)
result = [f"CA:{'TRUE' if ext_keyusage_ext.value.ca else 'FALSE'}"]
if ext_keyusage_ext.value.path_length is not None:
result.append(f"pathlen:{ext_keyusage_ext.value.path_length}")
return sorted(result), ext_keyusage_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_ocsp_must_staple(self) -> tuple[bool | None, bool]:
try:
# This only works with cryptography >= 2.1
tlsfeature_ext = self.csr.extensions.get_extension_for_class(
x509.TLSFeature
)
value = (
cryptography.x509.TLSFeatureType.status_request in tlsfeature_ext.value
)
return value, tlsfeature_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_subject_alt_name(self) -> tuple[list[str] | None, bool]:
try:
san_ext = self.csr.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
result = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in san_ext.value
]
return result, san_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, False
def _get_name_constraints(self) -> tuple[list[str] | None, list[str] | None, bool]:
try:
nc_ext = self.csr.extensions.get_extension_for_class(x509.NameConstraints)
permitted = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in nc_ext.value.permitted_subtrees or []
]
excluded = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in nc_ext.value.excluded_subtrees or []
]
return permitted, excluded, nc_ext.critical
except cryptography.x509.ExtensionNotFound:
return None, None, False
def _get_public_key_pem(self) -> bytes:
return self.csr.public_key().public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_public_key_object(self) -> CertificatePublicKeyTypes:
return self.csr.public_key()
def _get_subject_key_identifier(self) -> bytes | None:
try:
ext = self.csr.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
return ext.value.digest
except cryptography.x509.ExtensionNotFound:
return None
def _get_authority_key_identifier(
self,
) -> tuple[bytes | None, list[str] | None, int | None]:
try:
ext = self.csr.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier
)
issuer = None
if ext.value.authority_cert_issuer is not None:
issuer = [
cryptography_decode_name(san, idn_rewrite=self.name_encoding)
for san in ext.value.authority_cert_issuer
]
return (
ext.value.key_identifier,
issuer,
ext.value.authority_cert_serial_number,
)
except cryptography.x509.ExtensionNotFound:
return None, None, None
def _get_all_extensions(self) -> dict[str, dict[str, bool | str]]:
return cryptography_get_extensions_from_csr(self.csr)
def _is_signature_valid(self) -> bool:
return self.csr.is_signature_valid
def get_csr_info(
module: GeneralAnsibleModule,
content: bytes,
validate_signature: bool = True,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
module: GeneralAnsibleModule, content: bytes, validate_signature: bool = True
) -> CSRInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return CSRInfoRetrievalCryptography(
module, content, validate_signature=validate_signature
)

View File

@@ -0,0 +1,663 @@
# Copyright (c) 2016, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2020, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import base64
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils._argspec import (
ArgumentSpec,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.basic import (
OpenSSLObjectError,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.module_backends.privatekey_info import (
PrivateKeyConsistencyError,
PrivateKeyParseError,
get_privatekey_info,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.pem import (
identify_private_key_format,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.support import (
get_fingerprint_of_privatekey,
)
from ansible_collections.community.crypto.plugins.module_utils._cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
assert_required_cryptography_version,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.plugin_utils._action_module import (
AnsibleActionModule,
)
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
import cryptography
import cryptography.exceptions
import cryptography.hazmat.backends
import cryptography.hazmat.primitives.asymmetric.dsa
import cryptography.hazmat.primitives.asymmetric.ec
import cryptography.hazmat.primitives.asymmetric.ed448
import cryptography.hazmat.primitives.asymmetric.ed25519
import cryptography.hazmat.primitives.asymmetric.rsa
import cryptography.hazmat.primitives.asymmetric.utils
import cryptography.hazmat.primitives.asymmetric.x448
import cryptography.hazmat.primitives.asymmetric.x25519
import cryptography.hazmat.primitives.serialization
except ImportError:
pass
class PrivateKeyError(OpenSSLObjectError):
pass
# From the object called `module`, only the following properties are used:
#
# - module.params[]
# - module.warn(msg: str)
# - module.fail_json(msg: str, **kwargs)
class PrivateKeyBackend(metaclass=abc.ABCMeta):
def __init__(self, module: GeneralAnsibleModule) -> None:
self.module = module
self.type: t.Literal[
"DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"
] = module.params["type"]
self.size: int = module.params["size"]
self.curve: str | None = module.params["curve"]
self.passphrase: str | None = module.params["passphrase"]
self.cipher: str = module.params["cipher"]
self.format: t.Literal["pkcs1", "pkcs8", "raw", "auto", "auto_ignore"] = (
module.params["format"]
)
self.format_mismatch: t.Literal["regenerate", "convert"] = module.params.get(
"format_mismatch", "regenerate"
)
self.regenerate: t.Literal[
"never", "fail", "partial_idempotence", "full_idempotence", "always"
] = module.params.get("regenerate", "full_idempotence")
self.private_key: PrivateKeyTypes | None = None
self.existing_private_key: PrivateKeyTypes | None = None
self.existing_private_key_bytes: bytes | None = None
self.diff_before = self._get_info(None)
self.diff_after = self._get_info(None)
def _get_info(self, data: bytes | None) -> dict[str, t.Any]:
if data is None:
return {}
result: dict[str, t.Any] = {"can_parse_key": False}
try:
result.update(
get_privatekey_info(
self.module,
data,
passphrase=self.passphrase,
return_private_key_data=False,
prefer_one_fingerprint=True,
)
)
except PrivateKeyConsistencyError as exc:
result.update(exc.result)
except PrivateKeyParseError as exc:
result.update(exc.result)
except Exception:
pass
return result
@abc.abstractmethod
def generate_private_key(self) -> None:
"""(Re-)Generate private key."""
pass
def convert_private_key(self) -> None:
"""Convert existing private key (self.existing_private_key) to new private key (self.private_key).
This is effectively a copy without active conversion. The conversion is done
during load and store; get_private_key_data() uses the destination format to
serialize the key.
"""
self._ensure_existing_private_key_loaded()
self.private_key = self.existing_private_key
@abc.abstractmethod
def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key."""
def set_existing(self, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.existing_private_key_bytes = privatekey_bytes
self.diff_after = self.diff_before = self._get_info(
self.existing_private_key_bytes
)
def has_existing(self) -> bool:
"""Query whether an existing private key is/has been there."""
return self.existing_private_key_bytes is not None
@abc.abstractmethod
def _check_passphrase(self) -> bool:
"""Check whether provided passphrase matches, assuming self.existing_private_key_bytes has been populated."""
@abc.abstractmethod
def _ensure_existing_private_key_loaded(self) -> None:
"""Make sure that self.existing_private_key is populated from self.existing_private_key_bytes."""
@abc.abstractmethod
def _check_size_and_type(self) -> bool:
"""Check whether provided size and type matches, assuming self.existing_private_key has been populated."""
@abc.abstractmethod
def _check_format(self) -> bool:
"""Check whether the key file format, assuming self.existing_private_key and self.existing_private_key_bytes has been populated."""
def needs_regeneration(self) -> bool:
"""Check whether a regeneration is necessary."""
if self.regenerate == "always":
return True
if not self.has_existing():
# key does not exist
return True
if not self._check_passphrase():
if self.regenerate == "full_idempotence":
return True
self.module.fail_json(
msg="Unable to read the key. The key is protected with a another passphrase / no passphrase or broken."
" Will not proceed. To force regeneration, call the module with `generate`"
" set to `full_idempotence` or `always`, or with `force=true`."
)
self._ensure_existing_private_key_loaded()
if self.regenerate != "never":
if not self._check_size_and_type():
if self.regenerate in ("partial_idempotence", "full_idempotence"):
return True
self.module.fail_json(
msg="Key has wrong type and/or size."
" Will not proceed. To force regeneration, call the module with `generate`"
" set to `partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
)
# During generation step, regenerate if format does not match and format_mismatch == 'regenerate'
if self.format_mismatch == "regenerate" and self.regenerate != "never":
if not self._check_format():
if self.regenerate in ("partial_idempotence", "full_idempotence"):
return True
self.module.fail_json(
msg="Key has wrong format."
" Will not proceed. To force regeneration, call the module with `generate`"
" set to `partial_idempotence`, `full_idempotence` or `always`, or with `force=true`."
" To convert the key, set `format_mismatch` to `convert`."
)
return False
def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
# During conversion step, convert if format does not match and format_mismatch == 'convert'
self._ensure_existing_private_key_loaded()
return (
self.has_existing()
and self.format_mismatch == "convert"
and not self._check_format()
)
def _get_fingerprint(self) -> dict[str, str] | None:
if self.private_key:
return get_fingerprint_of_privatekey(self.private_key)
try:
self._ensure_existing_private_key_loaded()
except Exception:
# Ignore errors
pass
if self.existing_private_key:
return get_fingerprint_of_privatekey(self.existing_private_key)
return None
def dump(self, include_key: bool) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
if not self.private_key:
try:
self._ensure_existing_private_key_loaded()
except Exception:
# Ignore errors
pass
result: dict[str, t.Any] = {
"type": self.type,
"size": self.size,
"fingerprint": self._get_fingerprint(),
}
if self.type == "ECC":
result["curve"] = self.curve
# Get hold of private key bytes
pk_bytes = self.existing_private_key_bytes
if self.private_key is not None:
pk_bytes = self.get_private_key_data()
self.diff_after = self._get_info(pk_bytes)
if include_key:
# Store result
if pk_bytes:
if identify_private_key_format(pk_bytes) == "raw":
result["privatekey"] = base64.b64encode(pk_bytes)
else:
result["privatekey"] = pk_bytes.decode("utf-8")
else:
result["privatekey"] = None
result["diff"] = dict(
before=self.diff_before,
after=self.diff_after,
)
return result
class _Curve:
def __init__(
self,
name: str,
ectype: str,
deprecated: bool,
) -> None:
self.name = name
self.ectype = ectype
self.deprecated = deprecated
def _get_ec_class(
self, module: GeneralAnsibleModule
) -> type[cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve]:
ecclass = cryptography.hazmat.primitives.asymmetric.ec.__dict__.get(self.ectype) # type: ignore
if ecclass is None:
module.fail_json(
msg=f"Your cryptography version does not support {self.ectype}"
)
return ecclass
def create(
self, size: int, module: GeneralAnsibleModule
) -> cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve:
ecclass = self._get_ec_class(module)
return ecclass()
def verify(
self,
privatekey: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
module: GeneralAnsibleModule,
) -> bool:
ecclass = self._get_ec_class(module)
return isinstance(privatekey.private_numbers().public_numbers.curve, ecclass)
# Implementation with using cryptography
class PrivateKeyCryptographyBackend(PrivateKeyBackend):
def _add_curve(
self,
name: str,
ectype: str,
deprecated: bool = False,
) -> None:
self.curves[name] = _Curve(name=name, ectype=ectype, deprecated=deprecated)
def __init__(self, module: GeneralAnsibleModule) -> None:
super(PrivateKeyCryptographyBackend, self).__init__(module=module)
self.curves: dict[str, _Curve] = {}
self._add_curve("secp224r1", "SECP224R1")
self._add_curve("secp256k1", "SECP256K1")
self._add_curve("secp256r1", "SECP256R1")
self._add_curve("secp384r1", "SECP384R1")
self._add_curve("secp521r1", "SECP521R1")
self._add_curve("secp192r1", "SECP192R1", deprecated=True)
self._add_curve("sect163k1", "SECT163K1", deprecated=True)
self._add_curve("sect163r2", "SECT163R2", deprecated=True)
self._add_curve("sect233k1", "SECT233K1", deprecated=True)
self._add_curve("sect233r1", "SECT233R1", deprecated=True)
self._add_curve("sect283k1", "SECT283K1", deprecated=True)
self._add_curve("sect283r1", "SECT283R1", deprecated=True)
self._add_curve("sect409k1", "SECT409K1", deprecated=True)
self._add_curve("sect409r1", "SECT409R1", deprecated=True)
self._add_curve("sect571k1", "SECT571K1", deprecated=True)
self._add_curve("sect571r1", "SECT571R1", deprecated=True)
self._add_curve("brainpoolP256r1", "BrainpoolP256R1", deprecated=True)
self._add_curve("brainpoolP384r1", "BrainpoolP384R1", deprecated=True)
self._add_curve("brainpoolP512r1", "BrainpoolP512R1", deprecated=True)
def _get_wanted_format(self) -> t.Literal["pkcs1", "pkcs8", "raw"]:
if self.format not in ("auto", "auto_ignore"):
return self.format # type: ignore
if self.type in ("X25519", "X448", "Ed25519", "Ed448"):
return "pkcs8"
else:
return "pkcs1"
def generate_private_key(self) -> None:
"""(Re-)Generate private key."""
try:
if self.type == "RSA":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
public_exponent=65537, # OpenSSL always uses this
key_size=self.size,
)
)
if self.type == "DSA":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.dsa.generate_private_key(
key_size=self.size
)
)
if self.type == "X25519":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.generate()
)
if self.type == "X448":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.generate()
)
if self.type == "Ed25519":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.generate()
)
if self.type == "Ed448":
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.generate()
)
if self.type == "ECC" and self.curve in self.curves:
if self.curves[self.curve].deprecated:
self.module.warn(
f"Elliptic curves of type {self.curve} should not be used for new keys!"
)
self.private_key = (
cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
curve=self.curves[self.curve].create(
size=self.size, module=self.module
),
)
)
except cryptography.exceptions.UnsupportedAlgorithm:
self.module.fail_json(
msg=f"Cryptography backend does not support the algorithm required for {self.type}"
)
def get_private_key_data(self) -> bytes:
"""Return bytes for self.private_key"""
if self.private_key is None:
raise AssertionError("private_key not set")
# Select export format and encoding
try:
export_format_txt = self._get_wanted_format()
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
if export_format_txt == "pkcs1":
# "TraditionalOpenSSL" format is PKCS1
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
)
elif export_format_txt == "pkcs8":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
)
elif export_format_txt == "raw":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
)
export_encoding = (
cryptography.hazmat.primitives.serialization.Encoding.Raw
)
except AttributeError:
self.module.fail_json(
msg=f'Cryptography backend does not support the selected output format "{self.format}"'
)
# Select key encryption
encryption_algorithm: (
cryptography.hazmat.primitives.serialization.KeySerializationEncryption
) = cryptography.hazmat.primitives.serialization.NoEncryption()
if self.cipher and self.passphrase:
if self.cipher == "auto":
encryption_algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
to_bytes(self.passphrase)
)
else:
self.module.fail_json(
msg='Cryptography backend can only use "auto" for cipher option.'
)
# Serialize key
try:
return self.private_key.private_bytes(
encoding=export_encoding,
format=export_format,
encryption_algorithm=encryption_algorithm,
)
except ValueError:
self.module.fail_json(
msg=f'Cryptography backend cannot serialize the private key in the required format "{self.format}"'
)
except Exception:
self.module.fail_json(
msg=f'Error while serializing the private key in the required format "{self.format}"',
exception=traceback.format_exc(),
)
def _load_privatekey(self) -> PrivateKeyTypes:
data = self.existing_private_key_bytes
if data is None:
raise AssertionError("existing_private_key_bytes not set")
try:
# Interpret bytes depending on format.
format = identify_private_key_format(data)
if format == "raw":
if len(data) == 56:
return cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.from_private_bytes(
data
)
if len(data) == 57:
return cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.from_private_bytes(
data
)
if len(data) == 32:
if self.type == "X25519":
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
)
if self.type == "Ed25519":
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
)
try:
return cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
)
except Exception:
return cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
)
raise PrivateKeyError("Cannot load raw key")
else:
return (
cryptography.hazmat.primitives.serialization.load_pem_private_key(
data,
None if self.passphrase is None else to_bytes(self.passphrase),
)
)
except Exception as e:
raise PrivateKeyError(e)
def _ensure_existing_private_key_loaded(self) -> None:
if self.existing_private_key is None and self.has_existing():
self.existing_private_key = self._load_privatekey()
def _check_passphrase(self) -> bool:
if self.existing_private_key_bytes is None:
raise AssertionError("existing_private_key_bytes not set")
try:
format = identify_private_key_format(self.existing_private_key_bytes)
if format == "raw":
# Raw keys cannot be encrypted. To avoid incompatibilities, we try to
# actually load the key (and return False when this fails).
self._load_privatekey()
# Loading the key succeeded. Only return True when no passphrase was
# provided.
return self.passphrase is None
else:
return bool(
cryptography.hazmat.primitives.serialization.load_pem_private_key(
self.existing_private_key_bytes,
None if self.passphrase is None else to_bytes(self.passphrase),
)
)
except Exception:
return False
def _check_size_and_type(self) -> bool:
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey,
):
return (
self.type == "RSA" and self.size == self.existing_private_key.key_size
)
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey,
):
return (
self.type == "DSA" and self.size == self.existing_private_key.key_size
)
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
):
return self.type == "X25519"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey,
):
return self.type == "X448"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey,
):
return self.type == "Ed25519"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey,
):
return self.type == "Ed448"
if isinstance(
self.existing_private_key,
cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey,
):
if self.type != "ECC":
return False
if self.curve not in self.curves:
return False
return self.curves[self.curve].verify(
self.existing_private_key, module=self.module
)
return False
def _check_format(self) -> bool:
if self.existing_private_key_bytes is None:
raise AssertionError("existing_private_key_bytes not set")
if self.format == "auto_ignore":
return True
try:
format = identify_private_key_format(self.existing_private_key_bytes)
return format == self._get_wanted_format()
except Exception:
return False
def select_backend(module: GeneralAnsibleModule) -> PrivateKeyBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyCryptographyBackend(module)
def get_privatekey_argument_spec() -> ArgumentSpec:
return ArgumentSpec(
argument_spec=dict(
size=dict(type="int", default=4096),
type=dict(
type="str",
default="RSA",
choices=["DSA", "ECC", "Ed25519", "Ed448", "RSA", "X25519", "X448"],
),
curve=dict(
type="str",
choices=[
"secp224r1",
"secp256k1",
"secp256r1",
"secp384r1",
"secp521r1",
"secp192r1",
"brainpoolP256r1",
"brainpoolP384r1",
"brainpoolP512r1",
"sect163k1",
"sect163r2",
"sect233k1",
"sect233r1",
"sect283k1",
"sect283r1",
"sect409k1",
"sect409r1",
"sect571k1",
"sect571r1",
],
),
passphrase=dict(type="str", no_log=True),
cipher=dict(type="str", default="auto"),
format=dict(
type="str",
default="auto_ignore",
choices=["pkcs1", "pkcs8", "raw", "auto", "auto_ignore"],
),
format_mismatch=dict(
type="str", default="regenerate", choices=["regenerate", "convert"]
),
select_crypto_backend=dict(
type="str", choices=["auto", "cryptography"], default="auto"
),
regenerate=dict(
type="str",
default="full_idempotence",
choices=[
"never",
"fail",
"partial_idempotence",
"full_idempotence",
"always",
],
),
),
required_if=[
("type", "ECC", ["curve"]),
],
)

View File

@@ -0,0 +1,297 @@
# Copyright (c) 2022, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils._argspec import (
ArgumentSpec,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.basic import (
OpenSSLObjectError,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
cryptography_compare_private_keys,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.pem import (
identify_private_key_format,
)
from ansible_collections.community.crypto.plugins.module_utils._cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
assert_required_cryptography_version,
)
from ansible_collections.community.crypto.plugins.module_utils._io import load_file
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
import cryptography
import cryptography.exceptions
import cryptography.hazmat.backends
import cryptography.hazmat.primitives.asymmetric.dsa
import cryptography.hazmat.primitives.asymmetric.ec
import cryptography.hazmat.primitives.asymmetric.ed448
import cryptography.hazmat.primitives.asymmetric.ed25519
import cryptography.hazmat.primitives.asymmetric.rsa
import cryptography.hazmat.primitives.asymmetric.utils
import cryptography.hazmat.primitives.asymmetric.x448
import cryptography.hazmat.primitives.asymmetric.x25519
import cryptography.hazmat.primitives.serialization
except ImportError:
pass
class PrivateKeyError(OpenSSLObjectError):
pass
# From the object called `module`, only the following properties are used:
#
# - module.params[]
# - module.warn(msg: str)
# - module.fail_json(msg: str, **kwargs)
class PrivateKeyConvertBackend(metaclass=abc.ABCMeta):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.src_path: str | None = module.params["src_path"]
self.src_content: str | None = module.params["src_content"]
self.src_passphrase: str | None = module.params["src_passphrase"]
self.format: t.Literal["pkcs1", "pkcs8", "raw"] = module.params["format"]
self.dest_passphrase: str | None = module.params["dest_passphrase"]
self.src_private_key: PrivateKeyTypes | None = None
if self.src_path is not None:
self.src_private_key_bytes = load_file(self.src_path, module)
else:
if self.src_content is None:
raise AssertionError("src_content is None")
self.src_private_key_bytes = self.src_content.encode("utf-8")
self.dest_private_key: PrivateKeyTypes | None = None
self.dest_private_key_bytes: bytes | None = None
@abc.abstractmethod
def get_private_key_data(self) -> bytes:
"""Return bytes for self.src_private_key in output format."""
pass
def set_existing_destination(self, privatekey_bytes: bytes | None) -> None:
"""Set existing private key bytes. None indicates that the key does not exist."""
self.dest_private_key_bytes = privatekey_bytes
def has_existing_destination(self) -> bool:
"""Query whether an existing private key is/has been there."""
return self.dest_private_key_bytes is not None
@abc.abstractmethod
def _load_private_key(
self,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
) -> tuple[str, PrivateKeyTypes]:
"""Check whether data can be loaded as a private key with the provided passphrase. Return tuple (type, private_key)."""
def needs_conversion(self) -> bool:
"""Check whether a conversion is necessary. Must only be called if needs_regeneration() returned False."""
dummy, self.src_private_key = self._load_private_key(
self.src_private_key_bytes, self.src_passphrase
)
if not self.has_existing_destination():
return True
assert self.dest_private_key_bytes is not None
try:
format, self.dest_private_key = self._load_private_key(
self.dest_private_key_bytes,
self.dest_passphrase,
current_hint=self.src_private_key,
)
except Exception:
return True
return format != self.format or not cryptography_compare_private_keys(
self.dest_private_key, self.src_private_key
)
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
return {}
# Implementation with using cryptography
class PrivateKeyConvertCryptographyBackend(PrivateKeyConvertBackend):
def __init__(self, module: AnsibleModule) -> None:
super(PrivateKeyConvertCryptographyBackend, self).__init__(module=module)
def get_private_key_data(self) -> bytes:
"""Return bytes for self.src_private_key in output format"""
if self.src_private_key is None:
raise AssertionError("src_private_key not set")
# Select export format and encoding
try:
export_encoding = cryptography.hazmat.primitives.serialization.Encoding.PEM
if self.format == "pkcs1":
# "TraditionalOpenSSL" format is PKCS1
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL
)
elif self.format == "pkcs8":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8
)
elif self.format == "raw":
export_format = (
cryptography.hazmat.primitives.serialization.PrivateFormat.Raw
)
export_encoding = (
cryptography.hazmat.primitives.serialization.Encoding.Raw
)
except AttributeError:
self.module.fail_json(
msg=f'Cryptography backend does not support the selected output format "{self.format}"'
)
# Select key encryption
encryption_algorithm: (
cryptography.hazmat.primitives.serialization.KeySerializationEncryption
) = cryptography.hazmat.primitives.serialization.NoEncryption()
if self.dest_passphrase:
encryption_algorithm = (
cryptography.hazmat.primitives.serialization.BestAvailableEncryption(
to_bytes(self.dest_passphrase)
)
)
# Serialize key
try:
return self.src_private_key.private_bytes(
encoding=export_encoding,
format=export_format,
encryption_algorithm=encryption_algorithm,
)
except ValueError:
self.module.fail_json(
msg=f'Cryptography backend cannot serialize the private key in the required format "{self.format}"'
)
except Exception:
self.module.fail_json(
msg=f'Error while serializing the private key in the required format "{self.format}"',
exception=traceback.format_exc(),
)
def _load_private_key(
self,
data: bytes,
passphrase: str | None,
current_hint: PrivateKeyTypes | None = None,
) -> tuple[str, PrivateKeyTypes]:
try:
# Interpret bytes depending on format.
format = identify_private_key_format(data)
if format == "raw":
if passphrase is not None:
raise PrivateKeyError("Cannot load raw key with passphrase")
if len(data) == 56:
return (
format,
cryptography.hazmat.primitives.asymmetric.x448.X448PrivateKey.from_private_bytes(
data
),
)
if len(data) == 57:
return (
format,
cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey.from_private_bytes(
data
),
)
if len(data) == 32:
if isinstance(
current_hint,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey,
):
try:
return (
format,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
),
)
except Exception:
return (
format,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
),
)
else:
try:
return (
format,
cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey.from_private_bytes(
data
),
)
except Exception:
return (
format,
cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.from_private_bytes(
data
),
)
raise PrivateKeyError("Cannot load raw key")
else:
return (
format,
cryptography.hazmat.primitives.serialization.load_pem_private_key(
data,
None if passphrase is None else to_bytes(passphrase),
),
)
except Exception as e:
raise PrivateKeyError(e)
def select_backend(module: AnsibleModule) -> PrivateKeyConvertBackend:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyConvertCryptographyBackend(module)
def get_privatekey_argument_spec() -> ArgumentSpec:
return ArgumentSpec(
argument_spec=dict(
src_path=dict(type="path"),
src_content=dict(type="str"),
src_passphrase=dict(type="str", no_log=True),
dest_passphrase=dict(type="str", no_log=True),
format=dict(type="str", required=True, choices=["pkcs1", "pkcs8", "raw"]),
),
mutually_exclusive=[
["src_path", "src_content"],
],
required_one_of=[
["src_path", "src_content"],
],
)

View File

@@ -0,0 +1,349 @@
# Copyright (c) 2016-2017, Yanis Guenane <yanis+ansible@guenane.org>
# Copyright (c) 2017, Markus Teufelberger <mteufelberger+ansible@mgit.at>
# Copyright (c) 2020, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible_collections.community.crypto.plugins.module_utils._crypto.basic import (
OpenSSLObjectError,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.math import (
binary_exp_mod,
quick_is_not_prime,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.module_backends.publickey_info import (
_get_cryptography_public_key_info,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.support import (
get_fingerprint_of_bytes,
load_privatekey,
)
from ansible_collections.community.crypto.plugins.module_utils._cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
assert_required_cryptography_version,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.plugin_utils._action_module import (
AnsibleActionModule,
)
from ansible_collections.community.crypto.plugins.plugin_utils._filter_module import (
FilterModuleMock,
)
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
import cryptography
from cryptography.hazmat.primitives import serialization
except ImportError:
pass
SIGNATURE_TEST_DATA = b"1234"
def _get_cryptography_private_key_info(
key: PrivateKeyTypes, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
key_type, key_public_data = _get_cryptography_public_key_info(key.public_key())
key_private_data: dict[str, t.Any] = {}
if need_private_key_data:
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
rsa_private_numbers = key.private_numbers()
key_private_data["p"] = rsa_private_numbers.p
key_private_data["q"] = rsa_private_numbers.q
key_private_data["exponent"] = rsa_private_numbers.d
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey
):
dsa_private_numbers = key.private_numbers()
key_private_data["x"] = dsa_private_numbers.x
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
):
ecc_private_numbers = key.private_numbers()
key_private_data["multiplier"] = ecc_private_numbers.private_value
return key_type, key_public_data, key_private_data
def _check_dsa_consistency(
key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
# Get parameters
p: int | None = key_public_data.get("p")
if p is None:
return None
q: int | None = key_public_data.get("q")
if q is None:
return None
g: int | None = key_public_data.get("g")
if g is None:
return None
y: int | None = key_public_data.get("y")
if y is None:
return None
x: int | None = key_private_data.get("x")
if x is None:
return None
# Make sure that g is not 0, 1 or -1 in Z/pZ
if g < 2 or g >= p - 1:
return False
# Make sure that x is in range
if x < 1 or x >= q:
return False
# Check whether q divides p-1
if (p - 1) % q != 0:
return False
# Check that g**q mod p == 1
if binary_exp_mod(g, q, p) != 1:
return False
# Check whether g**x mod p == y
if binary_exp_mod(g, x, p) != y:
return False
# Check (quickly) whether p or q are not primes
if quick_is_not_prime(q) or quick_is_not_prime(p):
return False
return True
def _is_cryptography_key_consistent(
key: PrivateKeyTypes,
key_public_data: dict[str, t.Any],
key_private_data: dict[str, t.Any],
warn_func: t.Callable[[str], None] | None = None,
) -> bool | None:
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey):
# key._backend was removed in cryptography 42.0.0
backend = getattr(key, "_backend", None)
if backend is not None:
return bool(backend._lib.RSA_check_key(key._rsa_cdata)) # type: ignore
if isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPrivateKey):
result = _check_dsa_consistency(key_public_data, key_private_data)
if result is not None:
return result
signature = key.sign(
SIGNATURE_TEST_DATA, cryptography.hazmat.primitives.hashes.SHA256()
)
try:
key.public_key().verify(
signature,
SIGNATURE_TEST_DATA,
cryptography.hazmat.primitives.hashes.SHA256(),
)
return True
except cryptography.exceptions.InvalidSignature:
return False
if isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey
):
signature = key.sign(
SIGNATURE_TEST_DATA,
cryptography.hazmat.primitives.asymmetric.ec.ECDSA(
cryptography.hazmat.primitives.hashes.SHA256()
),
)
try:
key.public_key().verify(
signature,
SIGNATURE_TEST_DATA,
cryptography.hazmat.primitives.asymmetric.ec.ECDSA(
cryptography.hazmat.primitives.hashes.SHA256()
),
)
return True
except cryptography.exceptions.InvalidSignature:
return False
has_simple_sign_function = False
if isinstance(
key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey
):
has_simple_sign_function = True
if isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey):
has_simple_sign_function = True
if has_simple_sign_function:
signature = key.sign(SIGNATURE_TEST_DATA) # type: ignore
try:
key.public_key().verify(signature, SIGNATURE_TEST_DATA) # type: ignore
return True
except cryptography.exceptions.InvalidSignature:
return False
# For X25519 and X448, there's no test yet.
if warn_func is not None:
warn_func(f"Cannot determine consistency for key of type {type(key)}")
return None
class PrivateKeyConsistencyError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
super(PrivateKeyConsistencyError, self).__init__(msg)
self.error_message = msg
self.result = result
class PrivateKeyParseError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
super(PrivateKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
class PrivateKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self,
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
return_private_key_data: bool = False,
check_consistency: bool = False,
):
self.module = module
self.content = content
self.passphrase = passphrase
self.return_private_key_data = return_private_key_data
self.check_consistency = check_consistency
@abc.abstractmethod
def _get_public_key(self, binary: bool) -> bytes:
pass
@abc.abstractmethod
def _get_key_info(
self, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
pass
@abc.abstractmethod
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {
"can_parse_key": False,
"key_is_consistent": None,
}
priv_key_detail = self.content
try:
self.key = load_privatekey(
path=None,
content=priv_key_detail,
passphrase=(
to_bytes(self.passphrase)
if self.passphrase is not None
else self.passphrase
),
)
result["can_parse_key"] = True
except OpenSSLObjectError as exc:
raise PrivateKeyParseError(str(exc), result)
result["public_key"] = to_native(self._get_public_key(binary=False))
pk = self._get_public_key(binary=True)
result["public_key_fingerprints"] = (
get_fingerprint_of_bytes(pk, prefer_one=prefer_one_fingerprint)
if pk is not None
else dict()
)
key_type, key_public_data, key_private_data = self._get_key_info(
need_private_key_data=self.return_private_key_data or self.check_consistency
)
result["type"] = key_type
result["public_data"] = key_public_data
if self.return_private_key_data:
result["private_data"] = key_private_data
if self.check_consistency:
result["key_is_consistent"] = self._is_key_consistent(
key_public_data, key_private_data
)
if result["key_is_consistent"] is False:
# Only fail when it is False, to avoid to fail on None (which means "we do not know")
msg = (
"Private key is not consistent! (See "
"https://blog.hboeck.de/archives/888-How-I-tricked-Symantec-with-a-Fake-Private-Key.html)"
)
raise PrivateKeyConsistencyError(msg, result)
return result
class PrivateKeyInfoRetrievalCryptography(PrivateKeyInfoRetrieval):
"""Validate the supplied private key, using the cryptography backend"""
def __init__(self, module: GeneralAnsibleModule, content: bytes, **kwargs) -> None:
super(PrivateKeyInfoRetrievalCryptography, self).__init__(
module, content, **kwargs
)
def _get_public_key(self, binary: bool) -> bytes:
return self.key.public_key().public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(
self, need_private_key_data: bool = False
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return _get_cryptography_private_key_info(
self.key, need_private_key_data=need_private_key_data
)
def _is_key_consistent(
self, key_public_data: dict[str, t.Any], key_private_data: dict[str, t.Any]
) -> bool | None:
return _is_cryptography_key_consistent(
self.key, key_public_data, key_private_data, warn_func=self.module.warn
)
def get_privatekey_info(
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
return_private_key_data: bool = False,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PrivateKeyInfoRetrievalCryptography(
module,
content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
module: GeneralAnsibleModule,
content: bytes,
passphrase: str | None = None,
return_private_key_data: bool = False,
check_consistency: bool = False,
) -> PrivateKeyInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PrivateKeyInfoRetrievalCryptography(
module,
content,
passphrase=passphrase,
return_private_key_data=return_private_key_data,
check_consistency=check_consistency,
)

View File

@@ -0,0 +1,194 @@
# Copyright (c) 2020-2021, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import typing as t
from ansible_collections.community.crypto.plugins.module_utils._crypto.basic import (
OpenSSLObjectError,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.support import (
get_fingerprint_of_bytes,
load_publickey,
)
from ansible_collections.community.crypto.plugins.module_utils._cryptography_dep import (
COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION,
assert_required_cryptography_version,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.plugin_utils._action_module import (
AnsibleActionModule,
)
from ansible_collections.community.crypto.plugins.plugin_utils._filter_module import (
FilterModuleMock,
)
from cryptography.hazmat.primitives.asymmetric.types import (
PublicKeyTypes,
)
GeneralAnsibleModule = t.Union[AnsibleModule, AnsibleActionModule, FilterModuleMock]
MINIMAL_CRYPTOGRAPHY_VERSION = COLLECTION_MINIMUM_CRYPTOGRAPHY_VERSION
try:
import cryptography
import cryptography.hazmat.primitives.asymmetric.ed448
import cryptography.hazmat.primitives.asymmetric.ed25519
import cryptography.hazmat.primitives.asymmetric.x448
import cryptography.hazmat.primitives.asymmetric.x25519
from cryptography.hazmat.primitives import serialization
except ImportError:
pass
def _get_cryptography_public_key_info(
key: PublicKeyTypes,
) -> tuple[str, dict[str, t.Any]]:
key_public_data: dict[str, t.Any] = {}
if isinstance(key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey):
key_type = "RSA"
rsa_public_numbers = key.public_numbers()
key_public_data["size"] = key.key_size
key_public_data["modulus"] = rsa_public_numbers.n
key_public_data["exponent"] = rsa_public_numbers.e
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.dsa.DSAPublicKey):
key_type = "DSA"
dsa_parameter_numbers = key.parameters().parameter_numbers()
dsa_public_numbers = key.public_numbers()
key_public_data["size"] = key.key_size
key_public_data["p"] = dsa_parameter_numbers.p
key_public_data["q"] = dsa_parameter_numbers.q
key_public_data["g"] = dsa_parameter_numbers.g
key_public_data["y"] = dsa_public_numbers.y
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey
):
key_type = "X25519"
elif isinstance(key, cryptography.hazmat.primitives.asymmetric.x448.X448PublicKey):
key_type = "X448"
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey
):
key_type = "Ed25519"
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey
):
key_type = "Ed448"
elif isinstance(
key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey
):
key_type = "ECC"
ecc_public_numbers = key.public_numbers()
key_public_data["curve"] = key.curve.name
key_public_data["x"] = ecc_public_numbers.x
key_public_data["y"] = ecc_public_numbers.y
key_public_data["exponent_size"] = key.curve.key_size
else:
key_type = f"unknown ({type(key)})"
return key_type, key_public_data
class PublicKeyParseError(OpenSSLObjectError):
def __init__(self, msg: str, result: dict[str, t.Any]) -> None:
super(PublicKeyParseError, self).__init__(msg)
self.error_message = msg
self.result = result
class PublicKeyInfoRetrieval(metaclass=abc.ABCMeta):
def __init__(
self,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> None:
# content must be a bytes string
self.module = module
self.content = content
self.key = key
@abc.abstractmethod
def _get_public_key(self, binary: bool) -> bytes:
pass
@abc.abstractmethod
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
pass
def get_info(self, prefer_one_fingerprint: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
if self.key is None:
try:
self.key = load_publickey(content=self.content)
except OpenSSLObjectError as e:
raise PublicKeyParseError(str(e), {})
pk = self._get_public_key(binary=True)
result["fingerprints"] = (
get_fingerprint_of_bytes(pk, prefer_one=prefer_one_fingerprint)
if pk is not None
else dict()
)
key_type, key_public_data = self._get_key_info()
result["type"] = key_type
result["public_data"] = key_public_data
return result
class PublicKeyInfoRetrievalCryptography(PublicKeyInfoRetrieval):
"""Validate the supplied public key, using the cryptography backend"""
def __init__(
self,
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> None:
super(PublicKeyInfoRetrievalCryptography, self).__init__(
module, content=content, key=key
)
def _get_public_key(self, binary: bool) -> bytes:
if self.key is None:
raise AssertionError("key must be set")
return self.key.public_bytes(
serialization.Encoding.DER if binary else serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _get_key_info(self) -> tuple[str, dict[str, t.Any]]:
if self.key is None:
raise AssertionError("key must be set")
return _get_cryptography_public_key_info(self.key)
def get_publickey_info(
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
prefer_one_fingerprint: bool = False,
) -> dict[str, t.Any]:
info = PublicKeyInfoRetrievalCryptography(module, content=content, key=key)
return info.get_info(prefer_one_fingerprint=prefer_one_fingerprint)
def select_backend(
module: GeneralAnsibleModule,
content: bytes | None = None,
key: PublicKeyTypes | None = None,
) -> PublicKeyInfoRetrieval:
assert_required_cryptography_version(
module, minimum_cryptography_version=MINIMAL_CRYPTOGRAPHY_VERSION
)
return PublicKeyInfoRetrievalCryptography(module, content=content, key=key)

View File

@@ -0,0 +1,127 @@
# Copyright (c) 2019, Felix Fontein <felix@fontein.de>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import typing as t
PEM_START = "-----BEGIN "
PEM_END_START = "-----END "
PEM_END = "-----"
PKCS8_PRIVATEKEY_NAMES = ("PRIVATE KEY", "ENCRYPTED PRIVATE KEY")
PKCS1_PRIVATEKEY_SUFFIX = " PRIVATE KEY"
def identify_pem_format(content: bytes, encoding: str = "utf-8") -> bool:
"""Given the contents of a binary file, tests whether this could be a PEM file."""
try:
first_pem = extract_first_pem(content.decode(encoding))
if first_pem is None:
return False
lines = first_pem.splitlines(False)
if (
lines[0].startswith(PEM_START)
and lines[0].endswith(PEM_END)
and len(lines[0]) > len(PEM_START) + len(PEM_END)
):
return True
except UnicodeDecodeError:
pass
return False
def identify_private_key_format(
content: bytes, encoding: str = "utf-8"
) -> t.Literal["raw", "pkcs1", "pkcs8", "unknown-pem"]:
"""Given the contents of a private key file, identifies its format."""
# See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85
# (PEM_read_bio_PrivateKey)
# and https://github.com/openssl/openssl/blob/master/include/openssl/pem.h#L46-L47
# (PEM_STRING_PKCS8, PEM_STRING_PKCS8INF)
try:
first_pem = extract_first_pem(content.decode(encoding))
if first_pem is None:
return "raw"
lines = first_pem.splitlines(False)
if (
lines[0].startswith(PEM_START)
and lines[0].endswith(PEM_END)
and len(lines[0]) > len(PEM_START) + len(PEM_END)
):
name = lines[0][len(PEM_START) : -len(PEM_END)]
if name in PKCS8_PRIVATEKEY_NAMES:
return "pkcs8"
if len(name) > len(PKCS1_PRIVATEKEY_SUFFIX) and name.endswith(
PKCS1_PRIVATEKEY_SUFFIX
):
return "pkcs1"
return "unknown-pem"
except UnicodeDecodeError:
pass
return "raw"
def split_pem_list(text: str, keep_inbetween: bool = False) -> list[str]:
"""
Split concatenated PEM objects into a list of strings, where each is one PEM object.
"""
result = []
current: list[str] | None = [] if keep_inbetween else None
for line in text.splitlines(True):
if line.strip():
if not keep_inbetween and line.startswith("-----BEGIN "):
current = []
if current is not None:
current.append(line)
if line.startswith("-----END "):
result.append("".join(current))
current = [] if keep_inbetween else None
return result
def extract_first_pem(text: str) -> str | None:
"""
Given one PEM or multiple concatenated PEM objects, return only the first one, or None if there is none.
"""
all_pems = split_pem_list(text)
if not all_pems:
return None
return all_pems[0]
def _extract_type(line: str, start: str = PEM_START) -> str | None:
if not line.startswith(start):
return None
if not line.endswith(PEM_END):
return None
return line[len(start) : -len(PEM_END)]
def extract_pem(content: str, strict: bool = False) -> tuple[str, str]:
lines = content.splitlines()
if len(lines) < 3:
raise ValueError(f"PEM must have at least 3 lines, have only {len(lines)}")
header_type = _extract_type(lines[0])
if header_type is None:
raise ValueError(
f"First line is not of format {PEM_START}...{PEM_END}: {lines[0]!r}"
)
footer_type = _extract_type(lines[-1], start=PEM_END_START)
if strict:
if header_type != footer_type:
raise ValueError(
f"Header type ({header_type}) is different from footer type ({footer_type})"
)
for idx, line in enumerate(lines[1:-2]):
if len(line) != 64:
raise ValueError(f"Line {idx} has length {len(line)} instead of 64")
if not (0 < len(lines[-2]) <= 64):
raise ValueError(
f"Last line has length {len(lines[-2])}, should be in (0, 64]"
)
return header_type, "".join(lines[1:-1])

View File

@@ -0,0 +1,422 @@
# Copyright (c) 2016, Yanis Guenane <yanis+ansible@guenane.org>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import abc
import errno
import hashlib
import os
import typing as t
from ansible.module_utils.common.text.converters import to_bytes
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
is_potential_certificate_issuer_private_key,
is_potential_certificate_private_key,
)
from ansible_collections.community.crypto.plugins.module_utils._crypto.pem import (
identify_pem_format,
)
try:
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.serialization import load_pem_private_key
except ImportError:
# Error handled in the calling module.
pass
from ansible_collections.community.crypto.plugins.module_utils._crypto.basic import (
OpenSSLBadPassphraseError,
OpenSSLObjectError,
)
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.crypto.plugins.module_utils._crypto.cryptography_support import (
CertificatePrivateKeyTypes,
)
from cryptography.hazmat.primitives.asymmetric.types import (
CertificateIssuerPrivateKeyTypes,
PrivateKeyTypes,
PublicKeyTypes,
)
# This list of preferred fingerprints is used when prefer_one=True is supplied to the
# fingerprinting methods.
PREFERRED_FINGERPRINTS = (
"sha256",
"sha3_256",
"sha512",
"sha3_512",
"sha384",
"sha3_384",
"sha1",
"md5",
)
def get_fingerprint_of_bytes(source: bytes, prefer_one: bool = False) -> dict[str, str]:
"""Generate the fingerprint of the given bytes."""
fingerprint = {}
algorithms: t.Iterable[str] = hashlib.algorithms_guaranteed
if prefer_one:
# Sort algorithms to have the ones in PREFERRED_FINGERPRINTS at the beginning
prefered_algorithms = [
algorithm for algorithm in PREFERRED_FINGERPRINTS if algorithm in algorithms
]
prefered_algorithms += sorted(
[
algorithm
for algorithm in algorithms
if algorithm not in PREFERRED_FINGERPRINTS
]
)
algorithms = prefered_algorithms
for algo in algorithms:
f = getattr(hashlib, algo)
try:
h = f(source)
except ValueError:
# This can happen for hash algorithms not supported in FIPS mode
# (https://github.com/ansible/ansible/issues/67213)
continue
try:
# Certain hash functions have a hexdigest() which expects a length parameter
pubkey_digest = h.hexdigest()
except TypeError:
pubkey_digest = h.hexdigest(32)
fingerprint[algo] = ":".join(
pubkey_digest[i : i + 2] for i in range(0, len(pubkey_digest), 2)
)
if prefer_one:
break
return fingerprint
def get_fingerprint_of_privatekey(
privatekey: PrivateKeyTypes, prefer_one: bool = False
) -> dict[str, str]:
"""Generate the fingerprint of the public key."""
publickey = privatekey.public_key().public_bytes(
serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo
)
return get_fingerprint_of_bytes(publickey, prefer_one=prefer_one)
def get_fingerprint(
path: os.PathLike | str | None = None,
passphrase: str | bytes | None = None,
content: bytes | None = None,
prefer_one: bool = False,
) -> dict[str, str]:
"""Generate the fingerprint of the public key."""
privatekey = load_privatekey(
path=path,
passphrase=passphrase,
content=content,
check_passphrase=False,
)
return get_fingerprint_of_privatekey(privatekey, prefer_one=prefer_one)
def load_privatekey(
path: os.PathLike | str | None = None,
passphrase: str | bytes | None = None,
check_passphrase: bool = True,
content: bytes | None = None,
) -> PrivateKeyTypes:
"""Load the specified OpenSSL private key.
The content can also be specified via content; in that case,
this function will not load the key from disk.
"""
try:
if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
with open(path, "rb") as b_priv_key_fh:
priv_key_detail = b_priv_key_fh.read()
else:
priv_key_detail = content
except (IOError, OSError) as exc:
raise OpenSSLObjectError(exc)
try:
return load_pem_private_key(
priv_key_detail,
None if passphrase is None else to_bytes(passphrase),
)
except TypeError:
raise OpenSSLBadPassphraseError(
"Wrong or empty passphrase provided for private key"
)
except ValueError:
raise OpenSSLBadPassphraseError("Wrong passphrase provided for private key")
def load_certificate_privatekey(
*,
path: os.PathLike | str | None = None,
content: bytes | None = None,
passphrase: str | bytes | None = None,
check_passphrase: bool = True,
) -> CertificatePrivateKeyTypes:
"""
Load the specified OpenSSL private key that can be used as a private key for certificates.
"""
private_key = load_privatekey(
path=path,
passphrase=passphrase,
check_passphrase=check_passphrase,
content=content,
)
if not is_potential_certificate_private_key(private_key):
raise OpenSSLObjectError(
f"Key of type {type(private_key)} not supported for certificates"
)
return private_key
def load_certificate_issuer_privatekey(
*,
path: os.PathLike | str | None = None,
content: bytes | None = None,
passphrase: str | bytes | None = None,
check_passphrase: bool = True,
) -> CertificateIssuerPrivateKeyTypes:
"""
Load the specified OpenSSL private key that can be used for issuing certificates.
"""
private_key = load_privatekey(
path=path,
passphrase=passphrase,
check_passphrase=check_passphrase,
content=content,
)
if not is_potential_certificate_issuer_private_key(private_key):
raise OpenSSLObjectError(
f"Key of type {type(private_key)} not supported for issuing certificates"
)
return private_key
def load_publickey(
path: os.PathLike | str | None = None, content: bytes | None = None
) -> PublicKeyTypes:
if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
try:
with open(path, "rb") as b_priv_key_fh:
content = b_priv_key_fh.read()
except (IOError, OSError) as exc:
raise OpenSSLObjectError(exc)
try:
return serialization.load_pem_public_key(content)
except Exception as e:
raise OpenSSLObjectError(f"Error while deserializing key: {e}")
def load_certificate(
path: os.PathLike | str | None = None,
content: bytes | None = None,
der_support_enabled: bool = False,
) -> x509.Certificate:
"""Load the specified certificate."""
try:
if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
with open(path, "rb") as cert_fh:
cert_content = cert_fh.read()
else:
cert_content = content
except (IOError, OSError) as exc:
raise OpenSSLObjectError(exc)
if der_support_enabled is False or identify_pem_format(cert_content):
try:
return x509.load_pem_x509_certificate(cert_content)
except ValueError as exc:
raise OpenSSLObjectError(exc)
elif der_support_enabled:
try:
return x509.load_der_x509_certificate(cert_content)
except ValueError as exc:
raise OpenSSLObjectError(f"Cannot parse DER certificate: {exc}")
def load_certificate_request(
path: os.PathLike | str | None = None, content: bytes | None = None
) -> x509.CertificateSigningRequest:
"""Load the specified certificate signing request."""
try:
if content is None:
if path is None:
raise OpenSSLObjectError("Must provide either path or content")
with open(path, "rb") as csr_fh:
csr_content = csr_fh.read()
else:
csr_content = content
except (IOError, OSError) as exc:
raise OpenSSLObjectError(exc)
try:
return x509.load_pem_x509_csr(csr_content)
except ValueError as exc:
raise OpenSSLObjectError(exc)
def parse_name_field(
input_dict: dict[str, list[str | bytes] | str | bytes],
name_field_name: str | None = None,
) -> list[tuple[str, str | bytes]]:
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
def error_str(key: str) -> str:
if name_field_name is None:
return f"{key}"
return f"{key} in {name_field_name}"
result = []
for key, value in input_dict.items():
if isinstance(value, list):
for entry in value:
if not isinstance(entry, (str, bytes)):
raise TypeError(f"Values {error_str(key)} must be strings")
if not entry:
raise ValueError(
f"Values for {error_str(key)} must not be empty strings"
)
result.append((key, entry))
elif isinstance(value, (str, bytes)):
if not value:
raise ValueError(
f"Value for {error_str(key)} must not be an empty string"
)
result.append((key, value))
else:
raise TypeError(
f"Value for {error_str(key)} must be either a string or a list of strings"
)
return result
def parse_ordered_name_field(
input_list: list[dict[str, list[str | bytes] | str | bytes]], name_field_name: str
) -> list[tuple[str, str | bytes]]:
"""Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
result = []
for index, entry in enumerate(input_list):
if len(entry) != 1:
raise ValueError(
f"Entry #{index + 1} in {name_field_name} must be a dictionary with exactly one key-value pair"
)
try:
result.extend(parse_name_field(entry, name_field_name=name_field_name))
except (TypeError, ValueError) as exc:
raise ValueError(
f"Error while processing entry #{index + 1} in {name_field_name}: {exc}"
)
return result
@t.overload
def select_message_digest(
digest_string: t.Literal["sha256", "sha384", "sha512", "sha1", "md5"],
) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5: ...
@t.overload
def select_message_digest(
digest_string: str,
) -> (
hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None
): ...
def select_message_digest(
digest_string: str,
) -> hashes.SHA256 | hashes.SHA384 | hashes.SHA512 | hashes.SHA1 | hashes.MD5 | None:
if digest_string == "sha256":
return hashes.SHA256()
if digest_string == "sha384":
return hashes.SHA384()
if digest_string == "sha512":
return hashes.SHA512()
if digest_string == "sha1":
return hashes.SHA1()
if digest_string == "md5":
return hashes.MD5()
return None
class OpenSSLObject(metaclass=abc.ABCMeta):
def __init__(self, path: str, state: str, force: bool, check_mode: bool) -> None:
self.path = path
self.state = state
self.force = force
self.name = os.path.basename(path)
self.changed = False
self.check_mode = check_mode
def check(self, module: AnsibleModule, perms_required: bool = True) -> bool:
"""Ensure the resource is in its desired state."""
def _check_state() -> bool:
return os.path.exists(self.path)
def _check_perms(module: AnsibleModule) -> bool:
file_args = module.load_file_common_arguments(module.params)
if module.check_file_absent_if_check_mode(file_args["path"]):
return False
return not module.set_fs_attributes_if_different(file_args, False)
if not perms_required:
return _check_state()
return _check_state() and _check_perms(module)
@abc.abstractmethod
def dump(self) -> dict[str, t.Any]:
"""Serialize the object into a dictionary."""
@abc.abstractmethod
def generate(self, module: AnsibleModule) -> None:
"""Generate the resource."""
def remove(self, module: AnsibleModule) -> None:
"""Remove the resource from the filesystem."""
if self.check_mode:
if os.path.exists(self.path):
self.changed = True
return
try:
os.remove(self.path)
self.changed = True
except OSError as exc:
if exc.errno != errno.ENOENT:
raise OpenSSLObjectError(exc)
else:
pass